simple_semantics/language/
mutations.rs

1use std::{
2    cmp::Reverse,
3    collections::{BinaryHeap, VecDeque},
4    fmt::Debug,
5};
6
7use ahash::HashMap;
8use chumsky::container::Container;
9use rand::{
10    Rng,
11    distr::{Distribution, weighted::WeightedIndex},
12    seq::{IndexedRandom, IteratorRandom},
13};
14use thiserror::Error;
15
16use super::*;
17use crate::lambda::{
18    LambdaError, LambdaExpr, LambdaExprRef, LambdaLanguageOfThought, LambdaPool,
19    types::{LambdaType, TypeError},
20};
21
22mod context;
23mod samplers;
24pub use context::Context;
25use samplers::PossibleExpr;
26pub use samplers::PossibleExpressions;
27
28#[derive(Debug, Error, Clone)]
29pub struct ExprOrTypeError();
30
31impl Display for ExprOrTypeError {
32    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
33        write!(f, "This ExprOrType is not an Expr!")
34    }
35}
36
37#[derive(Debug, Clone, Eq, PartialEq)]
38enum ExprOrType<'src, T: LambdaLanguageOfThought> {
39    Type {
40        lambda_type: LambdaType,
41        parent: Option<usize>,
42        is_app_subformula: bool,
43    },
44    Expr {
45        lambda_expr: LambdaExpr<'src, T>,
46        parent: Option<usize>,
47    },
48}
49
50impl<'src, T: LambdaLanguageOfThought> TryFrom<ExprOrType<'src, T>> for LambdaExpr<'src, T> {
51    type Error = ExprOrTypeError;
52
53    fn try_from(value: ExprOrType<'src, T>) -> Result<Self, Self::Error> {
54        match value {
55            ExprOrType::Type { .. } => Err(ExprOrTypeError()),
56            ExprOrType::Expr { lambda_expr, .. } => Ok(lambda_expr),
57        }
58    }
59}
60
61impl<'src, T: LambdaLanguageOfThought> ExprOrType<'src, T> {
62    fn parent(&self) -> Option<usize> {
63        match self {
64            ExprOrType::Type { parent, .. } | ExprOrType::Expr { parent, .. } => *parent,
65        }
66    }
67
68    fn is_type(&self) -> bool {
69        matches!(self, ExprOrType::Type { .. })
70    }
71}
72
73#[derive(Debug, Clone, Eq, PartialEq)]
74struct UnfinishedLambdaPool<'src, T: LambdaLanguageOfThought> {
75    pool: Vec<ExprOrType<'src, T>>,
76}
77
78impl<'src, T: LambdaLanguageOfThought> Default for UnfinishedLambdaPool<'src, T> {
79    fn default() -> Self {
80        Self { pool: vec![] }
81    }
82}
83
84impl<'src, T: LambdaLanguageOfThought + Clone> UnfinishedLambdaPool<'src, T> {
85    fn add_expr<'a>(&mut self, expr: PossibleExpr<'a, 'src, T>, c: &mut Context, t: &LambdaType) {
86        let (mut lambda_expr, app_details) = expr.into_expr();
87        c.depth += 1;
88        c.open_nodes += lambda_expr.n_children();
89        c.open_nodes -= 1;
90        let parent = self.pool[c.position].parent();
91        match &mut lambda_expr {
92            LambdaExpr::Lambda(body, arg) => {
93                c.add_lambda(arg);
94                *body = LambdaExprRef(self.pool.len() as u32);
95                self.pool.push(ExprOrType::Type {
96                    lambda_type: t.rhs().unwrap().clone(),
97                    parent: Some(c.position),
98                    is_app_subformula: false,
99                })
100            }
101            LambdaExpr::BoundVariable(b, _) => {
102                c.use_bvar(*b);
103            }
104            LambdaExpr::FreeVariable(..) => (),
105            LambdaExpr::Application {
106                subformula,
107                argument,
108            } => {
109                *subformula = LambdaExprRef(self.pool.len() as u32);
110                *argument = LambdaExprRef((self.pool.len() + 1) as u32);
111                let (subformula, argument) = app_details.unwrap();
112                self.pool.push(ExprOrType::Type {
113                    lambda_type: subformula,
114                    parent: Some(c.position),
115                    is_app_subformula: true,
116                });
117                self.pool.push(ExprOrType::Type {
118                    lambda_type: argument,
119                    parent: Some(c.position),
120                    is_app_subformula: false,
121                });
122            }
123            LambdaExpr::LanguageOfThoughtExpr(e) => {
124                let children_start = self.pool.len();
125                if let Some(t) = e.var_type() {
126                    c.add_lambda(t);
127                }
128                self.pool
129                    .extend(e.get_arguments().map(|lambda_type| ExprOrType::Type {
130                        lambda_type,
131                        parent: Some(c.position),
132                        is_app_subformula: false,
133                    }));
134                e.change_children(
135                    (children_start..self.pool.len()).map(|x| LambdaExprRef(x as u32)),
136                );
137            }
138        }
139        self.pool[c.position] = ExprOrType::Expr {
140            lambda_expr,
141            parent,
142        };
143    }
144}
145
146#[derive(Debug)]
147pub struct NormalEnumeration(BinaryHeap<Reverse<Context>>, VecDeque<ExprDetails>);
148
149impl EnumerationType for NormalEnumeration {
150    fn pop(&mut self) -> Option<Context> {
151        self.0.pop().map(|x| x.0)
152    }
153
154    fn push(&mut self, context: Context, _: bool) {
155        self.0.push(Reverse(context))
156    }
157
158    fn get_yield(&mut self) -> Option<ExprDetails> {
159        self.1.pop_front()
160    }
161
162    fn push_yield(&mut self, e: ExprDetails) {
163        self.1.push(e);
164    }
165
166    fn include(&mut self, n: usize) -> impl Iterator<Item = bool> + 'static {
167        std::iter::repeat_n(true, n)
168    }
169}
170
171impl ExprDetails {
172    fn score(&self) -> f64 {
173        (1.0 / (self.size as f64))
174            + match self.constant_function {
175                true => 0.0,
176                false => 10.0,
177            }
178    }
179
180    pub fn has_constant_function(&self) -> bool {
181        self.constant_function
182    }
183}
184
185#[derive(Debug, Clone, Copy, PartialEq)]
186struct KeyedExprDetails {
187    expr_details: ExprDetails,
188    k: f64,
189}
190
191impl Eq for KeyedExprDetails {}
192
193impl PartialOrd for KeyedExprDetails {
194    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
195        Some(self.cmp(other))
196    }
197}
198
199impl Ord for KeyedExprDetails {
200    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
201        //reversed since we need a min-heap not a max-heap
202        other.k.partial_cmp(&self.k).unwrap()
203    }
204}
205impl KeyedExprDetails {
206    fn new(expr_details: ExprDetails, rng: &mut impl Rng) -> Self {
207        let u: f64 = rng.random();
208        KeyedExprDetails {
209            expr_details,
210            k: u.powf(1.0 / expr_details.score()),
211        }
212    }
213}
214
215#[derive(Debug, Clone, PartialEq)]
216struct RandomPQ(Context, f64);
217
218impl Eq for RandomPQ {}
219
220impl RandomPQ {
221    fn new(c: Context, rng: &mut impl Rng) -> Self {
222        RandomPQ(c, rng.random())
223    }
224}
225
226#[derive(Debug)]
227struct ProbabilisticEnumeration<'a, R: Rng, F>
228where
229    F: Fn(&ExprDetails) -> bool,
230{
231    rng: &'a mut R,
232    reservoir_size: usize,
233    reservoir: BinaryHeap<KeyedExprDetails>,
234    backups: Vec<Context>,
235    pq: BinaryHeap<RandomPQ>,
236    filter: F,
237    n_seen: usize,
238    done: bool,
239}
240impl<R: Rng, F> ProbabilisticEnumeration<'_, R, F>
241where
242    F: Fn(&ExprDetails) -> bool,
243{
244    fn threshold(&self) -> Option<f64> {
245        self.reservoir.peek().map(|x| x.k)
246    }
247
248    fn new<'a, 'src, T: LambdaLanguageOfThought, E: Fn(&Context) -> bool>(
249        reservoir_size: usize,
250        t: &LambdaType,
251        possible_expressions: &'a PossibleExpressions<'src, T>,
252        eager_filter: E,
253        filter: F,
254        rng: &'a mut R,
255    ) -> LambdaEnumerator<'a, 'src, T, E, ProbabilisticEnumeration<'a, R, F>> {
256        let context = Context::new(0, vec![]);
257        let mut pq = BinaryHeap::default();
258        pq.push(RandomPQ::new(context, rng));
259        let pools = vec![UnfinishedLambdaPool {
260            pool: vec![ExprOrType::Type {
261                lambda_type: t.clone(),
262                parent: None,
263                is_app_subformula: false,
264            }],
265        }];
266
267        LambdaEnumerator {
268            pools,
269            possible_expressions,
270            eager_filter,
271            pq: ProbabilisticEnumeration {
272                rng,
273                reservoir_size,
274                reservoir: BinaryHeap::default(),
275                backups: vec![],
276                filter,
277                pq,
278                n_seen: 0,
279                done: false,
280            },
281        }
282    }
283}
284
285impl<R: Rng, F> EnumerationType for ProbabilisticEnumeration<'_, R, F>
286where
287    F: Fn(&ExprDetails) -> bool,
288{
289    fn pop(&mut self) -> Option<Context> {
290        //Pop from min-heap, or grab a random back up if the min-heap is exhausted
291        self.pq.pop().map(|x| x.0).or_else(|| {
292            (0..self.backups.len()).choose(self.rng).and_then(|index| {
293                let last_item = self.backups.len() - 1;
294                self.backups.swap(index, last_item);
295                self.backups.pop()
296            })
297        })
298    }
299
300    fn push(&mut self, context: Context, included: bool) {
301        if included {
302            self.pq.push(RandomPQ::new(context, &mut self.rng));
303        } else {
304            self.backups.push(context);
305        }
306    }
307
308    fn get_yield(&mut self) -> Option<ExprDetails> {
309        if (self.done || self.pq.is_empty())
310            && let Some(x) = self.reservoir.pop()
311        {
312            Some(x.expr_details)
313        } else {
314            None
315        }
316    }
317
318    fn push_yield(&mut self, e: ExprDetails) {
319        let e = KeyedExprDetails::new(e, &mut self.rng);
320        if (self.filter)(&e.expr_details) {
321            self.n_seen += 1;
322            if self.reservoir_size > self.reservoir.len() {
323                self.reservoir.push(e)
324            } else if let Some(t) = self.threshold()
325                && e.k > t
326            {
327                self.reservoir.pop();
328                self.reservoir.push(e)
329            }
330            if self.n_seen >= self.reservoir_size * 20 {
331                self.pq.clear();
332                self.done = true;
333            }
334        }
335    }
336
337    fn include(&mut self, n: usize) -> impl Iterator<Item = bool> + 'static {
338        let x = (0..n).choose_multiple(self.rng, (n / 2).max(1));
339        let mut v = vec![false; n];
340        for i in x {
341            v[i] = true;
342        }
343        v.into_iter()
344    }
345}
346
347#[derive(Debug)]
348///An iterator that enumerates over all possible expressions of a given type.
349pub struct LambdaEnumerator<'a, 'src, T: LambdaLanguageOfThought, F, E = NormalEnumeration> {
350    pools: Vec<UnfinishedLambdaPool<'src, T>>,
351    possible_expressions: &'a PossibleExpressions<'src, T>,
352    eager_filter: F,
353    pq: E,
354}
355
356///Provides detail about a generated lambda expression
357#[derive(Debug, Clone, Copy, Eq, PartialEq)]
358pub struct ExprDetails {
359    id: usize,
360    constant_function: bool,
361    root: LambdaExprRef,
362    size: usize,
363}
364
365impl ExprDetails {
366    ///Get the size of the associated [`RootedLambdaPool`].
367    pub fn size(&self) -> usize {
368        self.size
369    }
370}
371
372#[derive(Debug, Clone, Eq, PartialEq)]
373///A re-usable sampler for sampling expressions of arbitrary types while caching frequent types
374pub struct TypeAgnosticSampler<'src, T: LambdaLanguageOfThought> {
375    type_to_sampler: HashMap<LambdaType, (usize, LambdaSampler<'src, T>)>,
376    max_expr: usize,
377    max_types: usize,
378    possible_expressions: PossibleExpressions<'src, T>,
379}
380
381impl<'src> TypeAgnosticSampler<'src, Expr<'src>> {
382    ///Samples an expression of a given type
383    pub fn sample(
384        &mut self,
385        lambda_type: LambdaType,
386        rng: &mut impl Rng,
387    ) -> RootedLambdaPool<'src, Expr<'src>> {
388        let (counts, exprs) = self
389            .type_to_sampler
390            .entry(lambda_type)
391            .or_insert_with_key(|t| {
392                (
393                    1,
394                    RootedLambdaPool::sampler(t, &self.possible_expressions, self.max_expr),
395                )
396            });
397        *counts += 1;
398        let sample = exprs.sample(rng);
399
400        if self.type_to_sampler.len() > self.max_types {
401            let (_, k) = self
402                .type_to_sampler
403                .iter()
404                .map(|(k, (n_visits, _))| (n_visits, k))
405                .min_by_key(|x| x.0)
406                .unwrap();
407
408            let t = k.clone();
409            self.type_to_sampler.remove(&t);
410        }
411
412        sample
413    }
414
415    ///Get a reference to the [`PossibleExpressions`] used by the model
416    pub fn possibles(&self) -> &PossibleExpressions<'src, Expr<'src>> {
417        &self.possible_expressions
418    }
419}
420
421impl<'src, T: LambdaLanguageOfThought + Clone> RootedLambdaPool<'src, T> {
422    ///Create a sampler which can sample arbitrary types.
423    pub fn typeless_sampler(
424        possible_expressions: PossibleExpressions<'src, T>,
425        max_expr: usize,
426        max_types: usize,
427    ) -> TypeAgnosticSampler<'src, T> {
428        assert!(max_types >= 1);
429        assert!(max_expr >= 1);
430        TypeAgnosticSampler {
431            possible_expressions,
432            max_expr,
433            max_types,
434            type_to_sampler: HashMap::default(),
435        }
436    }
437}
438
439///A struct which samples expressions from a distribution.
440#[derive(Debug, Clone, Eq, PartialEq)]
441pub struct LambdaSampler<'src, T: LambdaLanguageOfThought> {
442    lambdas: Vec<RootedLambdaPool<'src, T>>,
443    expr_details: Vec<ExprDetails>,
444}
445
446impl<'src, T: LambdaLanguageOfThought + Clone> Distribution<RootedLambdaPool<'src, T>>
447    for LambdaSampler<'src, T>
448{
449    fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> RootedLambdaPool<'src, T> {
450        let w = WeightedIndex::new(self.expr_details.iter().map(|x| x.score())).unwrap();
451        let i = w.sample(rng);
452        self.lambdas
453            .get(i)
454            .expect("The Lambda Sampler has no lambdas to sample :(")
455            .clone()
456    }
457}
458
459///A trait that handles how enumeration is processed (either normal enumeration or by doing
460///reservoir sampling).
461pub trait EnumerationType {
462    fn pop(&mut self) -> Option<Context>;
463    fn push(&mut self, context: Context, included: bool);
464    fn get_yield(&mut self) -> Option<ExprDetails>;
465    fn push_yield(&mut self, e: ExprDetails);
466    fn include(&mut self, n: usize) -> impl Iterator<Item = bool> + 'static;
467}
468
469fn try_yield<'a, 'src, T, F, E>(
470    x: &mut LambdaEnumerator<'a, 'src, T, F, E>,
471) -> Option<(RootedLambdaPool<'src, T>, ExprDetails)>
472where
473    T: LambdaLanguageOfThought,
474    E: EnumerationType,
475{
476    if let Some(item) = x.pq.get_yield() {
477        let p = std::mem::take(&mut x.pools[item.id]);
478        return Some((
479            RootedLambdaPool {
480                pool: LambdaPool(
481                    p.pool
482                        .into_iter()
483                        .map(|x| LambdaExpr::try_from(x).unwrap())
484                        .collect(),
485                ),
486                root: item.root,
487            },
488            item,
489        ));
490    }
491    None
492}
493
494impl<'a, 'src, T, F, E> LambdaEnumerator<'a, 'src, T, F, E>
495where
496    T: LambdaLanguageOfThought + Clone + Debug,
497    F: Fn(&Context) -> bool,
498    E: EnumerationType,
499{
500    fn push(&mut self, c: Context, included: bool) {
501        if (self.eager_filter)(&c) {
502            self.pq.push(c, included);
503        } else {
504            self.pools[c.pool_index] = UnfinishedLambdaPool::default();
505        }
506    }
507
508    ///Change the eager_filter function for this enumerator
509    pub fn eager_filter<F2>(self, eager_filter: F2) -> LambdaEnumerator<'a, 'src, T, F2, E> {
510        let LambdaEnumerator {
511            pools,
512            possible_expressions,
513            eager_filter: _,
514            pq,
515        } = self;
516
517        LambdaEnumerator {
518            pools,
519            possible_expressions,
520            eager_filter,
521            pq,
522        }
523    }
524}
525
526impl<'a, 'src, F, E> Iterator for LambdaEnumerator<'a, 'src, Expr<'src>, F, E>
527where
528    F: Fn(&Context) -> bool,
529    E: EnumerationType,
530{
531    type Item = (RootedLambdaPool<'src, Expr<'src>>, ExprDetails);
532
533    fn next(&mut self) -> Option<Self::Item> {
534        if let Some(x) = try_yield(self) {
535            return Some(x);
536        }
537
538        while let Some(mut c) = self.pq.pop() {
539            if let Some(x) = try_yield(self) {
540                self.push(c, true);
541                return Some(x);
542            }
543            let (possibles, lambda_type) = match &self.pools[c.pool_index].pool[c.position] {
544                ExprOrType::Type {
545                    lambda_type,
546                    is_app_subformula,
547                    parent,
548                } => {
549                    let mut possibles = self.possible_expressions.possibilities(
550                        lambda_type,
551                        *is_app_subformula,
552                        &c,
553                    );
554
555                    //Super hacky way to introduce all_e and all_a in quantifiers even though the
556                    //types are messed up.
557                    if let Some(p) = parent
558                        && let ExprOrType::Expr {
559                            lambda_expr:
560                                LambdaExpr::LanguageOfThoughtExpr(Expr::Quantifier {
561                                    var_type,
562                                    restrictor,
563                                    ..
564                                }),
565                            ..
566                        } = self.pools[c.pool_index].pool[*p]
567                        && restrictor.0 == c.position as u32
568                    {
569                        possibles.push(PossibleExpr::new_borrowed(match var_type {
570                            ActorOrEvent::Actor => &LambdaExpr::LanguageOfThoughtExpr(
571                                Expr::Constant(Constant::Everyone),
572                            ),
573                            ActorOrEvent::Event => &LambdaExpr::LanguageOfThoughtExpr(
574                                Expr::Constant(Constant::EveryEvent),
575                            ),
576                        }));
577                    };
578
579                    (possibles, lambda_type.clone())
580                }
581                ExprOrType::Expr {
582                    lambda_expr,
583                    parent,
584                } => {
585                    //We add the next uninitialized child to the context or go to the parent if
586                    //there are none.
587
588                    if let Some(child) = lambda_expr
589                        .get_children()
590                        .map(|x| x.0 as usize)
591                        .find(|child| self.pools[c.pool_index].pool[*child].is_type())
592                    {
593                        c.position = child;
594                        self.pq.push(c, true);
595                        continue;
596                    }
597
598                    if lambda_expr.inc_depth() {
599                        c.pop_lambda();
600                    }
601
602                    if let Some(p) = parent {
603                        c.position = *p;
604                        self.pq.push(c, true);
605                        continue;
606                    } else {
607                        //If the parent is None, we're done!
608                        self.pq.push_yield(ExprDetails {
609                            id: c.pool_index,
610                            root: LambdaExprRef(c.position as u32),
611                            size: c.depth,
612                            constant_function: c.is_constant(),
613                        });
614                        continue;
615                    }
616                }
617            };
618
619            let n = possibles.len();
620            let included = self.pq.include(n);
621            if n == 0 {
622                // dbg!(&self.pools[c.pool_index]);
623                //   panic!("There is no possible expression of type {lambda_type}");
624                //This is a dead-end.
625                continue;
626            }
627            let n_pools = self.pools.len();
628            if n_pools.is_multiple_of(10_000) {
629                self.pools.shrink_to_fit();
630            }
631
632            for _ in 0..n.saturating_sub(1) {
633                self.pools.push(self.pools[c.pool_index].clone());
634            }
635
636            let positions =
637                std::iter::once(c.pool_index).chain(n_pools..n_pools + n.saturating_sub(1));
638
639            for (((expr, pool_id), mut c), included) in possibles
640                .into_iter()
641                .zip(positions)
642                .zip(std::iter::repeat_n(c, n))
643                .zip(included)
644            {
645                c.pool_index = pool_id;
646                let pool = self.pools.get_mut(pool_id).unwrap();
647                pool.add_expr(expr, &mut c, &lambda_type);
648                self.push(c, included);
649            }
650
651            if let Some(x) = try_yield(self) {
652                return Some(x);
653            }
654        }
655
656        //If we've somehow exhausted the pq, lets yield anything remaining that's done.
657        if let Some(x) = try_yield(self) {
658            return Some(x);
659        }
660        None
661    }
662}
663
664impl<'src> RootedLambdaPool<'src, Expr<'src>> {
665    ///Create a [`LambdaSampler`] of a given type.
666    pub fn resample_from_expr<'a>(
667        &mut self,
668        possible_expressions: &'a PossibleExpressions<'src, Expr<'src>>,
669        helpers: Option<&HashMap<LambdaType, Vec<RootedLambdaPool<'src, Expr<'src>>>>>,
670        rng: &mut impl Rng,
671    ) -> Result<(), LambdaError> {
672        let position = LambdaExprRef((0..self.len()).choose(rng).unwrap() as u32);
673        let t = self.pool.get_type(position)?;
674
675        let pool = if let Some(helpers) = helpers
676            && rng.random_bool(0.2)
677            && let Some(v) = helpers.get(&t)
678            && !v.is_empty()
679        {
680            let pool = v.choose(rng).unwrap();
681            pool.clone()
682        } else {
683            let (pool, _) = self
684                .probabilistic_enumerate_from_expr(
685                    position,
686                    possible_expressions,
687                    |_| true,
688                    |_| true,
689                    rng,
690                )?
691                .next()
692                .unwrap();
693            pool
694        };
695
696        let offset = self.len() as u32;
697        let new_root = pool.root.0 + offset;
698        self.pool.0.extend(pool.pool.0.into_iter().map(|mut x| {
699            let children: Vec<_> = x
700                .get_children()
701                .map(|x| LambdaExprRef(x.0 + offset))
702                .collect();
703            x.change_children(children.into_iter());
704            x
705        }));
706        self.pool.0.swap(position.0 as usize, new_root as usize);
707        self.cleanup();
708        Ok(())
709    }
710
711    fn probabilistic_enumerate_from_expr<'a, R, E, F>(
712        &self,
713        position: LambdaExprRef,
714        possible_expressions: &'a PossibleExpressions<'src, Expr<'src>>,
715        eager_filter: E,
716        filter: F,
717        rng: &'a mut R,
718    ) -> Result<
719        LambdaEnumerator<'a, 'src, Expr<'src>, E, ProbabilisticEnumeration<'a, R, F>>,
720        TypeError,
721    >
722    where
723        R: Rng,
724        F: Fn(&ExprDetails) -> bool,
725        E: Fn(&Context) -> bool,
726    {
727        let (context, is_app_subformula) = Context::from_pos(self, position);
728        let output = self.pool.get_type(position)?;
729        let mut pq = BinaryHeap::default();
730        pq.push(RandomPQ::new(context, rng));
731        let pools = vec![UnfinishedLambdaPool {
732            pool: vec![ExprOrType::Type {
733                lambda_type: output,
734                parent: None,
735                is_app_subformula,
736            }],
737        }];
738        let enumerator = LambdaEnumerator {
739            pools,
740            possible_expressions,
741            eager_filter,
742            pq: ProbabilisticEnumeration {
743                rng,
744                reservoir_size: 1,
745                reservoir: BinaryHeap::default(),
746                done: false,
747                n_seen: 0,
748                filter,
749                backups: vec![],
750                pq,
751            },
752        };
753
754        Ok(enumerator)
755    }
756
757    ///Create a [`LambdaSampler`] of a given type with a filter
758    pub fn enumerator_filter<'a, F: Fn(&Context) -> bool>(
759        t: &LambdaType,
760        filter: F,
761        possible_expressions: &'a PossibleExpressions<'src, Expr<'src>>,
762    ) -> LambdaEnumerator<'a, 'src, Expr<'src>, F> {
763        let context = Context::new(0, vec![]);
764        let mut pq = BinaryHeap::default();
765        pq.push(Reverse(context));
766        let pools = vec![UnfinishedLambdaPool {
767            pool: vec![ExprOrType::Type {
768                lambda_type: t.clone(),
769                parent: None,
770                is_app_subformula: false,
771            }],
772        }];
773
774        LambdaEnumerator {
775            pools,
776            possible_expressions,
777            eager_filter: filter,
778            pq: NormalEnumeration(pq, VecDeque::default()),
779        }
780    }
781
782    ///Create a [`LambdaSampler`] of a given type.
783    pub fn enumerator<'a>(
784        t: &LambdaType,
785        possible_expressions: &'a PossibleExpressions<'src, Expr<'src>>,
786    ) -> LambdaEnumerator<'a, 'src, Expr<'src>, impl Fn(&'_ Context) -> bool> {
787        let context = Context::new(0, vec![]);
788        let mut pq = BinaryHeap::default();
789        pq.push(Reverse(context));
790        let pools = vec![UnfinishedLambdaPool {
791            pool: vec![ExprOrType::Type {
792                lambda_type: t.clone(),
793                parent: None,
794                is_app_subformula: false,
795            }],
796        }];
797
798        LambdaEnumerator {
799            pools,
800            possible_expressions,
801            eager_filter: |_: &Context| true,
802            pq: NormalEnumeration(pq, VecDeque::default()),
803        }
804    }
805
806    ///Creates a reusable random sampler by enumerating over the first `max_expr` expressions
807    pub fn sampler(
808        t: &LambdaType,
809        possible_expressions: &PossibleExpressions<'src, Expr<'src>>,
810        max_expr: usize,
811    ) -> LambdaSampler<'src, Expr<'src>> {
812        let enumerator = RootedLambdaPool::enumerator(t, possible_expressions);
813        let mut lambdas = Vec::with_capacity(max_expr);
814        let mut expr_details = Vec::with_capacity(max_expr);
815        for (lambda, expr_detail) in enumerator.take(max_expr) {
816            lambdas.push(lambda);
817            expr_details.push(expr_detail);
818        }
819
820        LambdaSampler {
821            lambdas,
822            expr_details,
823        }
824    }
825
826    ///Randomly generate a [`RootedLambdaPool`] of type `t`.
827    pub fn random_expr(
828        t: &LambdaType,
829        possible_expressions: &PossibleExpressions<'src, Expr<'src>>,
830        rng: &mut impl Rng,
831    ) -> RootedLambdaPool<'src, Expr<'src>> {
832        ProbabilisticEnumeration::new(
833            1,
834            t,
835            possible_expressions,
836            |_: &Context| true,
837            |_| true,
838            rng,
839        )
840        .next()
841        .unwrap()
842        .0
843    }
844
845    ///Randomly generate a [`RootedLambdaPool`] of type `t` without constant functions.
846    pub fn random_expr_no_constant(
847        t: &LambdaType,
848        possible_expressions: &PossibleExpressions<'src, Expr<'src>>,
849        rng: &mut impl Rng,
850    ) -> Option<RootedLambdaPool<'src, Expr<'src>>> {
851        ProbabilisticEnumeration::new(
852            1,
853            t,
854            possible_expressions,
855            |_: &Context| true,
856            |e| !e.constant_function,
857            rng,
858        )
859        .next()
860        .map(|x| x.0)
861    }
862}
863
864impl<'src> RootedLambdaPool<'src, Expr<'src>> {
865    ///Remove quantifiers which do not use their variable in their body.
866    pub fn prune_quantifiers(&mut self) {
867        let quantifiers = self
868            .pool
869            .bfs_from(self.root)
870            .filter_map(|(i, _)| match self.get(i) {
871                LambdaExpr::LanguageOfThoughtExpr(Expr::Quantifier { subformula, .. }) => {
872                    if !self
873                        .pool
874                        .bfs_from(LambdaExprRef(subformula.0))
875                        .any(|(x, d)| {
876                            if let LambdaExpr::BoundVariable(v, _) = self.get(x) {
877                                *v == d
878                            } else {
879                                false
880                            }
881                        })
882                    {
883                        Some((i, LambdaExprRef(subformula.0)))
884                    } else {
885                        None
886                    }
887                }
888                _ => None,
889            })
890            .collect::<Vec<_>>();
891
892        //By reversing, we ensure that we fix inner quantifiers before outer ones.
893        for (quantifier, subformula) in quantifiers.into_iter().rev() {
894            *self.pool.get_mut(quantifier) = self.pool.get(subformula).clone();
895            self.pool.bfs_from_mut(quantifier).for_each(|(x, d, _)| {
896                if let LambdaExpr::BoundVariable(b_d, _) = x
897                    && *b_d > d
898                {
899                    *b_d -= 1;
900                }
901            });
902        }
903        self.root = self.pool.cleanup(self.root);
904    }
905
906    ///Replace a random expression with something else of the same type.
907    pub fn swap_expr(
908        &mut self,
909        possible_expressions: &PossibleExpressions<'src, Expr<'src>>,
910        rng: &mut impl Rng,
911    ) -> Result<(), TypeError> {
912        let position = LambdaExprRef((0..self.len()).choose(rng).unwrap() as u32);
913
914        let (context, _) = Context::from_pos(self, position);
915
916        let output = self.pool.get_type(position)?;
917        let expr = self.pool.get(position);
918
919        let arguments: Vec<_> = expr
920            .get_children()
921            .map(|x| self.pool.get_type(x).unwrap())
922            .collect();
923
924        let mut replacements = possible_expressions.possiblities_fixed_children(
925            &output,
926            &arguments,
927            expr.var_type(),
928            &context,
929        );
930
931        //hacky fix!!
932        if replacements.is_empty()
933            && let LambdaExpr::LanguageOfThoughtExpr(Expr::Quantifier {
934                var_type,
935                restrictor,
936                subformula,
937                ..
938            }) = expr
939        {
940            for quantifier in [Quantifier::Universal, Quantifier::Existential] {
941                replacements.push(std::borrow::Cow::Owned(LambdaExpr::LanguageOfThoughtExpr(
942                    Expr::Quantifier {
943                        quantifier,
944                        var_type: *var_type,
945                        restrictor: *restrictor,
946                        subformula: *subformula,
947                    },
948                )));
949            }
950        }
951        for c in [Constant::EveryEvent, Constant::Everyone] {
952            if replacements.is_empty()
953                && expr == &LambdaExpr::LanguageOfThoughtExpr(Expr::Constant(c))
954            {
955                replacements = possible_expressions.possiblities_fixed_children(
956                    LambdaType::t(),
957                    &arguments,
958                    expr.var_type(),
959                    &context,
960                );
961                replacements.push(std::borrow::Cow::Owned(LambdaExpr::LanguageOfThoughtExpr(
962                    Expr::Constant(c),
963                )));
964            }
965        }
966
967        let choice = replacements.choose(rng).unwrap_or_else(|| {
968            panic!(
969                "There is no node with output {output} and arguments {}",
970                arguments
971                    .into_iter()
972                    .map(|x| x.to_string())
973                    .collect::<Vec<_>>()
974                    .join(", ")
975            )
976        });
977
978        let mut new_expr = choice.clone().into_owned();
979        new_expr.change_children(self.pool.get(position).get_children());
980        self.pool.0[position.0 as usize] = new_expr;
981
982        Ok(())
983    }
984
985    ///Replace a random expression with something else of the same type from within the same
986    ///expression.
987    pub fn swap_subtree(&mut self, rng: &mut impl Rng) -> Result<(), TypeError> {
988        let position = LambdaExprRef((0..self.len()).choose(rng).unwrap() as u32);
989        let (context, _) = Context::from_pos(self, position);
990        let alt = context.find_compatible(self, position)?;
991        if let Some(new_pos) = alt.choose(rng).copied() {
992            let offset = self.pool.0.len();
993            let mut lookup = HashMap::default();
994            let mut new_pool: Vec<LambdaExpr<'src, Expr<'src>>> = self
995                .pool
996                .bfs_from(new_pos)
997                .map(|(x, _)| {
998                    let n = lookup.len();
999                    lookup
1000                        .entry(x)
1001                        .or_insert(LambdaExprRef((n + offset) as u32));
1002                    self.get(x).clone()
1003                })
1004                .collect();
1005
1006            new_pool.iter_mut().for_each(|x| {
1007                let child: Vec<_> = x.get_children().map(|x| *lookup.get(&x).unwrap()).collect();
1008                x.change_children(child.into_iter());
1009            });
1010
1011            self.pool.0.extend(new_pool);
1012            self.pool.0.swap(position.0 as usize, offset);
1013            self.cleanup();
1014        }
1015        Ok(())
1016    }
1017}
1018
1019#[cfg(test)]
1020mod test {
1021
1022    use std::collections::HashSet;
1023
1024    use super::*;
1025    use crate::lambda::{LambdaPool, LambdaSummaryStats};
1026    use rand::SeedableRng;
1027    use rand_chacha::ChaCha8Rng;
1028
1029    #[test]
1030    fn prune_quantifier_test() -> anyhow::Result<()> {
1031        let mut pool =
1032            RootedLambdaPool::parse("some_e(x,all_e,AgentOf(a_2,e_1) & PatientOf(a_0,e_0))")?;
1033
1034        pool.prune_quantifiers();
1035        assert_eq!(pool.to_string(), "AgentOf(a_2, e_1) & PatientOf(a_0, e_0)");
1036
1037        let mut pool = RootedLambdaPool::parse(
1038            "some_e(x0, all_e, some(z, all_a, AgentOf(z, e_1) & PatientOf(a_0, e_0)))",
1039        )?;
1040
1041        pool.prune_quantifiers();
1042        assert_eq!(
1043            pool.to_string(),
1044            "some(x, all_a, AgentOf(x, e_1) & PatientOf(a_0, e_0))"
1045        );
1046
1047        let mut pool = RootedLambdaPool::parse("~every_e(z, pe_1, pa_2(a_0))")?;
1048
1049        pool.prune_quantifiers();
1050
1051        assert_eq!(pool.to_string(), "~pa_2(a_0)");
1052        let mut pool = RootedLambdaPool::new(
1053            LambdaPool(vec![
1054                LambdaExpr::Lambda(LambdaExprRef(1), LambdaType::E),
1055                LambdaExpr::Lambda(LambdaExprRef(2), LambdaType::T),
1056                LambdaExpr::LanguageOfThoughtExpr(Expr::Quantifier {
1057                    quantifier: Quantifier::Universal,
1058                    var_type: ActorOrEvent::Actor,
1059                    restrictor: ExprRef(3),
1060                    subformula: ExprRef(4),
1061                }),
1062                LambdaExpr::LanguageOfThoughtExpr(Expr::Constant(Constant::Property(
1063                    "1",
1064                    ActorOrEvent::Actor,
1065                ))),
1066                LambdaExpr::LanguageOfThoughtExpr(Expr::Quantifier {
1067                    quantifier: Quantifier::Existential,
1068                    var_type: ActorOrEvent::Actor,
1069                    restrictor: ExprRef(5),
1070                    subformula: ExprRef(6),
1071                }),
1072                LambdaExpr::LanguageOfThoughtExpr(Expr::Constant(Constant::Property(
1073                    "0",
1074                    ActorOrEvent::Actor,
1075                ))),
1076                LambdaExpr::LanguageOfThoughtExpr(Expr::Unary(
1077                    MonOp::Property("3", ActorOrEvent::Event),
1078                    ExprRef(7),
1079                )),
1080                LambdaExpr::BoundVariable(3, LambdaType::E),
1081            ]),
1082            LambdaExprRef(0),
1083        );
1084
1085        assert_eq!(
1086            pool.to_string(),
1087            "lambda e x lambda t phi every(y, pa_1, some(z, pa_0, pe_3(x)))"
1088        );
1089
1090        let mut parsed_pool = RootedLambdaPool::parse(
1091            "lambda e x lambda t phi every(y, pa_1, some(z, pa_0, pe_3(x)))",
1092        )?;
1093        parsed_pool.prune_quantifiers();
1094        pool.prune_quantifiers();
1095        assert_eq!(pool.to_string(), "lambda e x lambda t phi pe_3(x)");
1096
1097        Ok(())
1098    }
1099
1100    #[test]
1101    fn random_swap_tree() -> anyhow::Result<()> {
1102        let mut rng = ChaCha8Rng::seed_from_u64(2);
1103        let x = RootedLambdaPool::parse("lambda a x pa_cool(a_John) | pa_cool(x)")?;
1104        let mut h = HashSet::new();
1105        for _ in 0..100 {
1106            let mut z = x.clone();
1107            z.swap_subtree(&mut rng)?;
1108            h.insert(z.to_string());
1109        }
1110        for x in h.iter() {
1111            println!("{x}");
1112        }
1113        assert!(h.contains("lambda a x pa_cool(x)"));
1114        assert!(h.contains("lambda a x pa_cool(a_John)"));
1115        Ok(())
1116    }
1117
1118    #[test]
1119    fn randomn_swap() -> anyhow::Result<()> {
1120        let mut rng = ChaCha8Rng::seed_from_u64(2);
1121        let actors = ["0", "1"];
1122        let available_event_properties = ["2", "3", "4"];
1123        let possible_expressions = PossibleExpressions::new(
1124            &actors,
1125            &available_event_properties,
1126            &available_event_properties,
1127        );
1128        for _ in 0..200 {
1129            let t = LambdaType::random(&mut rng);
1130            println!("{t}");
1131            let mut pool = RootedLambdaPool::random_expr(&t, &possible_expressions, &mut rng);
1132            println!("{t}: {pool}");
1133            assert_eq!(t, pool.get_type()?);
1134            println!("{pool:?}");
1135            pool.swap_expr(&possible_expressions, &mut rng)?;
1136            println!("{pool:?}");
1137            println!("{t}: {pool}");
1138            assert_eq!(t, pool.get_type()?);
1139        }
1140        let p =
1141            RootedLambdaPool::parse("lambda <a,t> P some(x, P(x), some_e(y, pe_2(y), pe_4(y)))")?;
1142        let t = p.get_type()?;
1143        for _ in 0..1000 {
1144            let mut pool = p.clone();
1145            assert_eq!(t, pool.get_type()?);
1146
1147            pool.swap_expr(&possible_expressions, &mut rng)?;
1148            dbg!(&pool);
1149            println!("{t}: {pool}");
1150            assert_eq!(t, pool.get_type()?);
1151        }
1152
1153        Ok(())
1154    }
1155
1156    #[test]
1157    fn enumerate() -> anyhow::Result<()> {
1158        let actors = ["john"];
1159        let actor_properties = ["a"];
1160        let event_properties = ["e"];
1161        let possibles = PossibleExpressions::new(&actors, &actor_properties, &event_properties);
1162
1163        let t = LambdaType::from_string("<<a,t>,t>")?;
1164
1165        let p = RootedLambdaPool::enumerator(&t, &possibles);
1166        let pools: HashSet<_> = p
1167            .filter_map(|(p, x)| {
1168                if !x.has_constant_function() {
1169                    Some(p.to_string())
1170                } else {
1171                    None
1172                }
1173            })
1174            .take(20)
1175            .collect();
1176        for p in pools.iter() {
1177            println!("{p}");
1178        }
1179        assert!(pools.contains("lambda <a,t> P P(a_john)"));
1180        for p in pools {
1181            println!("{p}");
1182        }
1183
1184        Ok(())
1185    }
1186
1187    /*
1188    #[test]
1189    fn random_expr_no_constant() -> anyhow::Result<()> {
1190        let actors = &["1", "2", "3", "4", "5"];
1191        let actor_properties = &["1", "2", "3", "4", "5"];
1192        let event_properties = &["1", "2", "3", "4", "5"];
1193        let poss = PossibleExpressions::new(actors, actor_properties, event_properties);
1194        let mut rng = ChaCha8Rng::seed_from_u64(1);
1195
1196        let t = LambdaType::from_string("<<a,<e,t>>, <e,t>>")?;
1197        println!("{t}: ");
1198        let pool = RootedLambdaPool::random_expr_no_constant(&t, &poss, &mut rng).unwrap();
1199        println!("{pool}");
1200
1201        for _ in 0..500 {
1202            let t = LambdaType::random(&mut rng);
1203            println!("{t}: ");
1204            if t.size() <= 3 {
1205                let pool = RootedLambdaPool::random_expr_no_constant(&t, &poss, &mut rng).unwrap();
1206                println!("{pool}");
1207            } else {
1208                println!("too big :(");
1209            }
1210        }
1211        Ok(())
1212    }*/
1213
1214    #[test]
1215    fn enumerate_weirds() -> anyhow::Result<()> {
1216        let actors = &["1"];
1217        let actor_properties = &["1"];
1218        let event_properties = &["1"];
1219        let poss = PossibleExpressions::new(actors, actor_properties, event_properties);
1220
1221        let t = "<a,<e,e>>";
1222
1223        for (p, d) in RootedLambdaPool::enumerator(&LambdaType::from_string(t)?, &poss)
1224            .filter(|(_, e)| !e.has_constant_function())
1225            .take(5)
1226        {
1227            println!("{p} {d:?} {}", d.has_constant_function());
1228        }
1229        Ok(())
1230    }
1231
1232    #[test]
1233    fn random_expr() -> anyhow::Result<()> {
1234        let actors = ["john"];
1235        let actor_properties = ["a"];
1236        let event_properties = ["e"];
1237        let possibles = PossibleExpressions::new(&actors, &actor_properties, &event_properties);
1238        let mut rng = ChaCha8Rng::seed_from_u64(0);
1239
1240        let map = [(LambdaType::A, vec![RootedLambdaPool::parse("a_john")?])]
1241            .into_iter()
1242            .collect();
1243
1244        for _ in 0..100 {
1245            let t = LambdaType::random(&mut rng);
1246            println!("sampling: {t}");
1247            let mut pool = RootedLambdaPool::random_expr(&t, &possibles, &mut rng);
1248            assert_eq!(t, pool.get_type()?);
1249            let s = pool.to_string();
1250            println!("{s}");
1251            let mut pool2 = RootedLambdaPool::parse(s.as_str())?;
1252            assert_eq!(s, pool2.to_string());
1253            println!("{pool}");
1254            pool2.resample_from_expr(&possibles, None, &mut rng)?;
1255            assert_eq!(pool2.get_type()?, t);
1256            pool2.resample_from_expr(&possibles, Some(&map), &mut rng)?;
1257            assert_eq!(pool2.get_type()?, t);
1258
1259            pool.swap_subtree(&mut rng)?;
1260            assert_eq!(pool.get_type()?, t);
1261        }
1262        let t = LambdaType::from_string("<a,<a,t>>")?;
1263        for _ in 0..100 {
1264            let pool = RootedLambdaPool::random_expr(&t, &possibles, &mut rng);
1265            println!("{pool}");
1266        }
1267        Ok(())
1268    }
1269
1270    #[test]
1271    fn constant_exprs() -> anyhow::Result<()> {
1272        let actors = ["john", "mary", "phil", "sue"];
1273        let actor_properties = ["a"];
1274        let event_properties = ["e"];
1275        let possibles = PossibleExpressions::new(&actors, &actor_properties, &event_properties);
1276        let mut rng = ChaCha8Rng::seed_from_u64(0);
1277
1278        let v: HashSet<_> = RootedLambdaPool::enumerator(LambdaType::a(), &possibles)
1279            .map(|(x, _)| x.to_string())
1280            .take(4)
1281            .collect();
1282
1283        assert_eq!(v, HashSet::from(actors.map(|x| format!("a_{x}"))));
1284        println!("{v:?}");
1285
1286        let mut constants = 0;
1287        for _ in 0..100 {
1288            let t = LambdaType::from_string("<a, <a,t>>")?;
1289            println!("sampling: {t}");
1290            let pool = RootedLambdaPool::random_expr(&t, &possibles, &mut rng);
1291            let x = pool.stats();
1292            match x {
1293                LambdaSummaryStats::WellFormed {
1294                    constant_function, ..
1295                } => {
1296                    if constant_function {
1297                        constants += 1;
1298                    }
1299                }
1300                LambdaSummaryStats::Malformed => todo!(),
1301            }
1302        }
1303        println!("{constants}");
1304
1305        Ok(())
1306    }
1307
1308    #[test]
1309    fn random_expr_counts() -> anyhow::Result<()> {
1310        let actors = ["john", "mary", "phil", "sue"];
1311        let actor_properties = ["a"];
1312        let event_properties = ["e"];
1313        let possibles = PossibleExpressions::new(&actors, &actor_properties, &event_properties);
1314        let mut rng = ChaCha8Rng::seed_from_u64(1);
1315
1316        let mut counts: HashMap<_, usize> = HashMap::default();
1317        for _ in 0..1000 {
1318            let t = LambdaType::A;
1319            let pool = RootedLambdaPool::random_expr(&t, &possibles, &mut rng);
1320            assert_eq!(t, pool.get_type()?);
1321            let s = pool.to_string();
1322            *counts.entry(s).or_default() += 1;
1323        }
1324        assert_eq!(counts.len(), 4);
1325        dbg!(&counts);
1326        for (_, v) in counts.iter() {
1327            assert!(200 <= *v && *v <= 300);
1328        }
1329
1330        counts.clear();
1331        for _ in 0..1000 {
1332            let t = LambdaType::at().clone();
1333            let pool = RootedLambdaPool::random_expr(&t, &possibles, &mut rng);
1334            assert_eq!(t, pool.get_type()?);
1335            let s = pool.to_string();
1336            *counts.entry(s).or_default() += 1;
1337        }
1338        dbg!(counts);
1339        Ok(())
1340    }
1341}