simple_semantics/language/
mutations.rs

1use super::{
2    ActorOrEvent, BinOp, Constant, Display, Expr, ExprRef, MonOp, Quantifier, RootedLambdaPool,
3    thiserror,
4};
5use crate::lambda::{
6    LambdaError, LambdaExpr, LambdaExprRef, LambdaLanguageOfThought, LambdaPool,
7    types::{LambdaType, TypeError},
8};
9use ahash::HashMap;
10use chumsky::container::Container;
11use rand::{
12    Rng, RngExt,
13    distr::{Distribution, weighted::WeightedIndex},
14    seq::{IndexedRandom, IteratorRandom},
15};
16use std::{
17    cmp::Reverse,
18    collections::{BinaryHeap, VecDeque},
19    fmt::Debug,
20};
21use thiserror::Error;
22
23mod context;
24mod samplers;
25pub use context::Context;
26pub(crate) use samplers::PossibleExpr;
27pub use samplers::PossibleExpressions;
28
29#[derive(Debug, Error, Clone)]
30pub struct ExprOrTypeError();
31
32impl Display for ExprOrTypeError {
33    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
34        write!(f, "This ExprOrType is not an Expr!")
35    }
36}
37
38#[derive(Debug, Clone, Eq, PartialEq)]
39enum ExprOrType<'src, T: LambdaLanguageOfThought> {
40    Type {
41        lambda_type: LambdaType,
42        parent: Option<usize>,
43        is_app_subformula: bool,
44    },
45    Expr {
46        lambda_expr: LambdaExpr<'src, T>,
47        parent: Option<usize>,
48    },
49}
50
51impl<'src, T: LambdaLanguageOfThought> TryFrom<ExprOrType<'src, T>> for LambdaExpr<'src, T> {
52    type Error = ExprOrTypeError;
53
54    fn try_from(value: ExprOrType<'src, T>) -> Result<Self, Self::Error> {
55        match value {
56            ExprOrType::Type { .. } => Err(ExprOrTypeError()),
57            ExprOrType::Expr { lambda_expr, .. } => Ok(lambda_expr),
58        }
59    }
60}
61
62impl<T: LambdaLanguageOfThought> ExprOrType<'_, T> {
63    fn parent(&self) -> Option<usize> {
64        match self {
65            ExprOrType::Type { parent, .. } | ExprOrType::Expr { parent, .. } => *parent,
66        }
67    }
68
69    fn is_type(&self) -> bool {
70        matches!(self, ExprOrType::Type { .. })
71    }
72}
73
74#[derive(Debug, Clone, Eq, PartialEq)]
75struct UnfinishedLambdaPool<'src, T: LambdaLanguageOfThought> {
76    pool: Vec<ExprOrType<'src, T>>,
77}
78
79impl<T: LambdaLanguageOfThought> Default for UnfinishedLambdaPool<'_, T> {
80    fn default() -> Self {
81        Self { pool: vec![] }
82    }
83}
84
85impl<'src, T: LambdaLanguageOfThought + Clone> UnfinishedLambdaPool<'src, T> {
86    fn add_expr<'a>(&mut self, expr: PossibleExpr<'a, 'src, T>, c: &mut Context, t: &LambdaType) {
87        let (mut lambda_expr, app_details) = expr.into_expr();
88        c.depth += 1;
89        c.open_nodes += lambda_expr.n_children();
90        c.open_nodes -= 1;
91        let parent = self.pool[c.position].parent();
92        match &mut lambda_expr {
93            LambdaExpr::Lambda(body, arg) => {
94                c.add_lambda(arg);
95                *body = LambdaExprRef::new(self.pool.len());
96                self.pool.push(ExprOrType::Type {
97                    lambda_type: t.rhs().unwrap().clone(),
98                    parent: Some(c.position),
99                    is_app_subformula: false,
100                });
101            }
102            LambdaExpr::BoundVariable(b, _) => {
103                c.use_bvar(*b);
104            }
105            LambdaExpr::FreeVariable(..) => (),
106            LambdaExpr::Application {
107                subformula,
108                argument,
109            } => {
110                *subformula = LambdaExprRef::new(self.pool.len());
111                *argument = LambdaExprRef::new(self.pool.len() + 1);
112                let (subformula, argument) = app_details.unwrap();
113                self.pool.push(ExprOrType::Type {
114                    lambda_type: subformula,
115                    parent: Some(c.position),
116                    is_app_subformula: true,
117                });
118                self.pool.push(ExprOrType::Type {
119                    lambda_type: argument,
120                    parent: Some(c.position),
121                    is_app_subformula: false,
122                });
123            }
124            LambdaExpr::LanguageOfThoughtExpr(e) => {
125                let children_start = self.pool.len();
126                if let Some(t) = e.var_type() {
127                    c.add_lambda(t);
128                }
129                self.pool
130                    .extend(e.get_arguments().map(|lambda_type| ExprOrType::Type {
131                        lambda_type,
132                        parent: Some(c.position),
133                        is_app_subformula: false,
134                    }));
135                e.change_children((children_start..self.pool.len()).map(LambdaExprRef::new));
136            }
137        }
138        self.pool[c.position] = ExprOrType::Expr {
139            lambda_expr,
140            parent,
141        };
142    }
143}
144
145#[derive(Debug)]
146pub struct NormalEnumeration(BinaryHeap<Reverse<Context>>, VecDeque<ExprDetails>);
147
148impl EnumerationType for NormalEnumeration {
149    fn pop(&mut self) -> Option<Context> {
150        self.0.pop().map(|x| x.0)
151    }
152
153    fn push(&mut self, context: Context, _: bool) {
154        self.0.push(Reverse(context));
155    }
156
157    fn get_yield(&mut self) -> Option<ExprDetails> {
158        self.1.pop_front()
159    }
160
161    fn push_yield(&mut self, e: ExprDetails) {
162        self.1.push(e);
163    }
164
165    fn include(&mut self, n: usize) -> impl Iterator<Item = bool> + 'static {
166        std::iter::repeat_n(true, n)
167    }
168}
169
170impl ExprDetails {
171    #[allow(clippy::cast_precision_loss)]
172    fn score(&self) -> f64 {
173        (1.0 / (self.size as f64)) + if self.constant_function { 0.0 } else { 10.0 }
174    }
175
176    pub fn has_constant_function(&self) -> bool {
177        self.constant_function
178    }
179}
180
181#[derive(Debug, Clone, Copy, PartialEq)]
182struct KeyedExprDetails {
183    expr_details: ExprDetails,
184    k: f64,
185}
186
187impl Eq for KeyedExprDetails {}
188
189impl PartialOrd for KeyedExprDetails {
190    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
191        Some(self.cmp(other))
192    }
193}
194
195impl Ord for KeyedExprDetails {
196    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
197        //reversed since we need a min-heap not a max-heap
198        other.k.partial_cmp(&self.k).unwrap()
199    }
200}
201impl KeyedExprDetails {
202    fn new(expr_details: ExprDetails, rng: &mut impl Rng) -> Self {
203        let u: f64 = rng.random();
204        KeyedExprDetails {
205            expr_details,
206            k: u.powf(1.0 / expr_details.score()),
207        }
208    }
209}
210
211#[derive(Debug, Clone, PartialEq)]
212struct RandomPQ(Context, f64);
213
214impl Eq for RandomPQ {}
215
216impl RandomPQ {
217    fn new(c: Context, rng: &mut impl Rng) -> Self {
218        RandomPQ(c, rng.random())
219    }
220}
221
222#[derive(Debug)]
223struct ProbabilisticEnumeration<'a, R: Rng, F>
224where
225    F: Fn(&ExprDetails) -> bool,
226{
227    rng: &'a mut R,
228    reservoir_size: usize,
229    reservoir: BinaryHeap<KeyedExprDetails>,
230    backups: Vec<Context>,
231    pq: BinaryHeap<RandomPQ>,
232    filter: F,
233    n_seen: usize,
234    done: bool,
235}
236impl<R: Rng, F> ProbabilisticEnumeration<'_, R, F>
237where
238    F: Fn(&ExprDetails) -> bool,
239{
240    fn threshold(&self) -> Option<f64> {
241        self.reservoir.peek().map(|x| x.k)
242    }
243
244    fn new<'a, 'src, T: LambdaLanguageOfThought, E: Fn(&Context) -> bool>(
245        reservoir_size: usize,
246        t: &LambdaType,
247        possible_expressions: &'a PossibleExpressions<'src, T>,
248        eager_filter: E,
249        filter: F,
250        rng: &'a mut R,
251    ) -> LambdaEnumerator<'a, 'src, T, E, ProbabilisticEnumeration<'a, R, F>> {
252        let context = Context::new(0, vec![]);
253        let mut pq = BinaryHeap::default();
254        pq.push(RandomPQ::new(context, rng));
255        let pools = vec![UnfinishedLambdaPool {
256            pool: vec![ExprOrType::Type {
257                lambda_type: t.clone(),
258                parent: None,
259                is_app_subformula: false,
260            }],
261        }];
262
263        LambdaEnumerator {
264            pools,
265            possible_expressions,
266            eager_filter,
267            pq: ProbabilisticEnumeration {
268                rng,
269                reservoir_size,
270                reservoir: BinaryHeap::default(),
271                backups: vec![],
272                filter,
273                pq,
274                n_seen: 0,
275                done: false,
276            },
277        }
278    }
279}
280
281impl<R: Rng, F> EnumerationType for ProbabilisticEnumeration<'_, R, F>
282where
283    F: Fn(&ExprDetails) -> bool,
284{
285    fn pop(&mut self) -> Option<Context> {
286        //Pop from min-heap, or grab a random back up if the min-heap is exhausted
287        self.pq.pop().map(|x| x.0).or_else(|| {
288            (0..self.backups.len()).choose(self.rng).and_then(|index| {
289                let last_item = self.backups.len() - 1;
290                self.backups.swap(index, last_item);
291                self.backups.pop()
292            })
293        })
294    }
295
296    fn push(&mut self, context: Context, included: bool) {
297        if included {
298            self.pq.push(RandomPQ::new(context, &mut self.rng));
299        } else {
300            self.backups.push(context);
301        }
302    }
303
304    fn get_yield(&mut self) -> Option<ExprDetails> {
305        if (self.done || self.pq.is_empty())
306            && let Some(x) = self.reservoir.pop()
307        {
308            Some(x.expr_details)
309        } else {
310            None
311        }
312    }
313
314    fn push_yield(&mut self, e: ExprDetails) {
315        let e = KeyedExprDetails::new(e, &mut self.rng);
316        if (self.filter)(&e.expr_details) {
317            self.n_seen += 1;
318            if self.reservoir_size > self.reservoir.len() {
319                self.reservoir.push(e);
320            } else if let Some(t) = self.threshold()
321                && e.k > t
322            {
323                self.reservoir.pop();
324                self.reservoir.push(e);
325            }
326            if self.n_seen >= self.reservoir_size * 20 {
327                self.pq.clear();
328                self.done = true;
329            }
330        }
331    }
332
333    fn include(&mut self, n: usize) -> impl Iterator<Item = bool> + 'static {
334        let x = (0..n).sample(self.rng, (n / 2).max(1));
335        let mut v = vec![false; n];
336        for i in x {
337            v[i] = true;
338        }
339        v.into_iter()
340    }
341}
342
343#[derive(Debug)]
344///An iterator that enumerates over all possible expressions of a given type.
345pub struct LambdaEnumerator<'a, 'src, T: LambdaLanguageOfThought, F, E = NormalEnumeration> {
346    pools: Vec<UnfinishedLambdaPool<'src, T>>,
347    possible_expressions: &'a PossibleExpressions<'src, T>,
348    eager_filter: F,
349    pq: E,
350}
351
352///Provides detail about a generated lambda expression
353#[derive(Debug, Clone, Copy, Eq, PartialEq)]
354pub struct ExprDetails {
355    id: usize,
356    constant_function: bool,
357    root: LambdaExprRef,
358    size: usize,
359}
360
361impl ExprDetails {
362    ///Get the size of the associated [`RootedLambdaPool`].
363    pub fn size(&self) -> usize {
364        self.size
365    }
366}
367
368#[derive(Debug, Clone, Eq, PartialEq)]
369///A re-usable sampler for sampling expressions of arbitrary types while caching frequent types
370pub struct TypeAgnosticSampler<'src, T: LambdaLanguageOfThought> {
371    type_to_sampler: HashMap<LambdaType, (usize, LambdaSampler<'src, T>)>,
372    max_expr: usize,
373    max_types: usize,
374    possible_expressions: PossibleExpressions<'src, T>,
375}
376
377impl<'src> TypeAgnosticSampler<'src, Expr<'src>> {
378    #[allow(clippy::missing_panics_doc)]
379    ///Samples an expression of a given type
380    pub fn sample(
381        &mut self,
382        lambda_type: LambdaType,
383        rng: &mut impl Rng,
384    ) -> RootedLambdaPool<'src, Expr<'src>> {
385        let (counts, exprs) = self
386            .type_to_sampler
387            .entry(lambda_type)
388            .or_insert_with_key(|t| {
389                (
390                    1,
391                    RootedLambdaPool::sampler(t, &self.possible_expressions, self.max_expr),
392                )
393            });
394        *counts += 1;
395        let sample = exprs.sample(rng);
396
397        if self.type_to_sampler.len() > self.max_types {
398            let (_, k) = self
399                .type_to_sampler
400                .iter()
401                .map(|(k, (n_visits, _))| (n_visits, k))
402                .min_by_key(|x| x.0)
403                .unwrap();
404
405            let t = k.clone();
406            self.type_to_sampler.remove(&t);
407        }
408
409        sample
410    }
411
412    ///Get a reference to the [`PossibleExpressions`] used by the model
413    #[must_use]
414    pub fn possibles(&self) -> &PossibleExpressions<'src, Expr<'src>> {
415        &self.possible_expressions
416    }
417}
418
419impl<'src, T: LambdaLanguageOfThought + Clone> RootedLambdaPool<'src, T> {
420    ///Create a sampler which can sample arbitrary types.
421    ///
422    ///# Panics
423    ///Will panic if `max_types` == 0 or `max_expr` == 0
424    #[must_use]
425    pub fn typeless_sampler(
426        possible_expressions: PossibleExpressions<'src, T>,
427        max_expr: usize,
428        max_types: usize,
429    ) -> TypeAgnosticSampler<'src, T> {
430        assert!(max_types >= 1);
431        assert!(max_expr >= 1);
432        TypeAgnosticSampler {
433            possible_expressions,
434            max_expr,
435            max_types,
436            type_to_sampler: HashMap::default(),
437        }
438    }
439}
440
441///A struct which samples expressions from a distribution.
442#[derive(Debug, Clone, Eq, PartialEq)]
443pub struct LambdaSampler<'src, T: LambdaLanguageOfThought> {
444    lambdas: Vec<RootedLambdaPool<'src, T>>,
445    expr_details: Vec<ExprDetails>,
446}
447
448impl<'src, T: LambdaLanguageOfThought + Clone> Distribution<RootedLambdaPool<'src, T>>
449    for LambdaSampler<'src, T>
450{
451    fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> RootedLambdaPool<'src, T> {
452        let w = WeightedIndex::new(self.expr_details.iter().map(ExprDetails::score)).unwrap();
453        let i = w.sample(rng);
454        self.lambdas
455            .get(i)
456            .expect("The Lambda Sampler has no lambdas to sample :(")
457            .clone()
458    }
459}
460
461///A trait that handles how enumeration is processed (either normal enumeration or by doing
462///reservoir sampling).
463pub trait EnumerationType {
464    fn pop(&mut self) -> Option<Context>;
465    fn push(&mut self, context: Context, included: bool);
466    fn get_yield(&mut self) -> Option<ExprDetails>;
467    fn push_yield(&mut self, e: ExprDetails);
468    fn include(&mut self, n: usize) -> impl Iterator<Item = bool> + 'static;
469}
470
471fn try_yield<'src, T, F, E>(
472    x: &mut LambdaEnumerator<'_, 'src, T, F, E>,
473) -> Option<(RootedLambdaPool<'src, T>, ExprDetails)>
474where
475    T: LambdaLanguageOfThought,
476    E: EnumerationType,
477{
478    if let Some(item) = x.pq.get_yield() {
479        let p = std::mem::take(&mut x.pools[item.id]);
480        return Some((
481            RootedLambdaPool {
482                pool: LambdaPool(
483                    p.pool
484                        .into_iter()
485                        .map(|x| LambdaExpr::try_from(x).unwrap())
486                        .collect(),
487                ),
488                root: item.root,
489            },
490            item,
491        ));
492    }
493    None
494}
495
496impl<'a, 'src, T, F, E> LambdaEnumerator<'a, 'src, T, F, E>
497where
498    T: LambdaLanguageOfThought + Clone + Debug,
499    F: Fn(&Context) -> bool,
500    E: EnumerationType,
501{
502    fn push(&mut self, c: Context, included: bool) {
503        if (self.eager_filter)(&c) {
504            self.pq.push(c, included);
505        } else {
506            self.pools[c.pool_index] = UnfinishedLambdaPool::default();
507        }
508    }
509
510    ///Change the `eager_filter` function for this enumerator
511    pub fn eager_filter<F2>(self, eager_filter: F2) -> LambdaEnumerator<'a, 'src, T, F2, E> {
512        let LambdaEnumerator {
513            pools,
514            possible_expressions,
515            eager_filter: _,
516            pq,
517        } = self;
518
519        LambdaEnumerator {
520            pools,
521            possible_expressions,
522            eager_filter,
523            pq,
524        }
525    }
526}
527
528impl<'src, F, E> Iterator for LambdaEnumerator<'_, 'src, Expr<'src>, F, E>
529where
530    F: Fn(&Context) -> bool,
531    E: EnumerationType,
532{
533    type Item = (RootedLambdaPool<'src, Expr<'src>>, ExprDetails);
534
535    #[allow(clippy::too_many_lines)]
536    fn next(&mut self) -> Option<Self::Item> {
537        if let Some(x) = try_yield(self) {
538            return Some(x);
539        }
540
541        while let Some(mut c) = self.pq.pop() {
542            if let Some(x) = try_yield(self) {
543                self.push(c, true);
544                return Some(x);
545            }
546            let (possibles, lambda_type) = match &self.pools[c.pool_index].pool[c.position] {
547                ExprOrType::Type {
548                    lambda_type,
549                    is_app_subformula,
550                    parent,
551                } => {
552                    let mut possibles = self.possible_expressions.possibilities(
553                        lambda_type,
554                        *is_app_subformula,
555                        &c,
556                    );
557
558                    //Super hacky way to introduce all_e and all_a in quantifiers even though the
559                    //types are messed up.
560                    if let Some(p) = parent
561                        && let ExprOrType::Expr {
562                            lambda_expr:
563                                LambdaExpr::LanguageOfThoughtExpr(Expr::Quantifier {
564                                    var_type,
565                                    restrictor,
566                                    ..
567                                }),
568                            ..
569                        } = self.pools[c.pool_index].pool[*p]
570                        && restrictor.0 == u32::try_from(c.position).unwrap()
571                    {
572                        possibles.push(PossibleExpr::new_borrowed(match var_type {
573                            ActorOrEvent::Actor => &LambdaExpr::LanguageOfThoughtExpr(
574                                Expr::Constant(Constant::Everyone),
575                            ),
576                            ActorOrEvent::Event => &LambdaExpr::LanguageOfThoughtExpr(
577                                Expr::Constant(Constant::EveryEvent),
578                            ),
579                        }));
580                    }
581
582                    (possibles, lambda_type.clone())
583                }
584                ExprOrType::Expr {
585                    lambda_expr,
586                    parent,
587                } => {
588                    //We add the next uninitialized child to the context or go to the parent if
589                    //there are none.
590
591                    if let Some(child) = lambda_expr
592                        .get_children()
593                        .map(|x| x.0 as usize)
594                        .find(|child| self.pools[c.pool_index].pool[*child].is_type())
595                    {
596                        c.position = child;
597                        self.pq.push(c, true);
598                        continue;
599                    }
600
601                    if lambda_expr.inc_depth() {
602                        c.pop_lambda();
603                    }
604
605                    if let Some(p) = parent {
606                        c.position = *p;
607                        self.pq.push(c, true);
608                        continue;
609                    }
610                    //If the parent is None, we're done!
611                    self.pq.push_yield(ExprDetails {
612                        id: c.pool_index,
613                        root: LambdaExprRef(u32::try_from(c.position).unwrap()),
614                        size: c.depth,
615                        constant_function: c.is_constant(),
616                    });
617                    continue;
618                }
619            };
620
621            let n = possibles.len();
622            let included = self.pq.include(n);
623            if n == 0 {
624                continue;
625            }
626            let n_pools = self.pools.len();
627            if n_pools.is_multiple_of(10_000) {
628                self.pools.shrink_to_fit();
629            }
630            for _ in 0..n.saturating_sub(1) {
631                self.pools.push(self.pools[c.pool_index].clone());
632            }
633
634            let positions =
635                std::iter::once(c.pool_index).chain(n_pools..n_pools + n.saturating_sub(1));
636
637            for (((expr, pool_id), mut c), included) in possibles
638                .into_iter()
639                .zip(positions)
640                .zip(std::iter::repeat_n(c, n))
641                .zip(included)
642            {
643                c.pool_index = pool_id;
644                let pool = self.pools.get_mut(pool_id).unwrap();
645                pool.add_expr(expr, &mut c, &lambda_type);
646                self.push(c, included);
647            }
648
649            if let Some(x) = try_yield(self) {
650                return Some(x);
651            }
652        }
653
654        //If we've somehow exhausted the pq, lets yield anything remaining that's done.
655        if let Some(x) = try_yield(self) {
656            return Some(x);
657        }
658        None
659    }
660}
661
662impl<'src> RootedLambdaPool<'src, Expr<'src>> {
663    ///Create a [`LambdaSampler`] of a given type.
664    ///
665    ///# Errors
666    ///Will return a [`LambdaError`] if the tree is malformed
667    ///
668    ///# Panics
669    ///Will panic if the size of the tree is greater than [`u32::MAX`].
670    pub fn resample_from_expr<'a>(
671        &mut self,
672        possible_expressions: &'a PossibleExpressions<'src, Expr<'src>>,
673        helpers: Option<&HashMap<LambdaType, Vec<RootedLambdaPool<'src, Expr<'src>>>>>,
674        rng: &mut impl Rng,
675    ) -> Result<(), LambdaError> {
676        let position = LambdaExprRef(u32::try_from((0..self.len()).choose(rng).unwrap()).unwrap());
677        let t = self.pool.get_type(position)?;
678
679        let pool = if let Some(helpers) = helpers
680            && rng.random_bool(0.2)
681            && let Some(v) = helpers.get(&t)
682            && !v.is_empty()
683        {
684            let pool = v.choose(rng).unwrap();
685            pool.clone()
686        } else {
687            let (pool, _) = self
688                .probabilistic_enumerate_from_expr(
689                    position,
690                    possible_expressions,
691                    |_| true,
692                    |_| true,
693                    rng,
694                )?
695                .next()
696                .unwrap();
697            pool
698        };
699
700        let offset = u32::try_from(self.len()).unwrap();
701        let new_root = pool.root.0 + offset;
702        self.pool.0.extend(pool.pool.0.into_iter().map(|mut x| {
703            let children: Vec<_> = x
704                .get_children()
705                .map(|x| LambdaExprRef(x.0 + offset))
706                .collect();
707            x.change_children(children.into_iter());
708            x
709        }));
710        self.pool.0.swap(position.0 as usize, new_root as usize);
711        self.cleanup();
712        Ok(())
713    }
714
715    fn probabilistic_enumerate_from_expr<'a, R, E, F>(
716        &self,
717        position: LambdaExprRef,
718        possible_expressions: &'a PossibleExpressions<'src, Expr<'src>>,
719        eager_filter: E,
720        filter: F,
721        rng: &'a mut R,
722    ) -> Result<
723        LambdaEnumerator<'a, 'src, Expr<'src>, E, ProbabilisticEnumeration<'a, R, F>>,
724        TypeError,
725    >
726    where
727        R: Rng,
728        F: Fn(&ExprDetails) -> bool,
729        E: Fn(&Context) -> bool,
730    {
731        let (context, is_app_subformula) = Context::from_pos(self, position);
732        let output = self.pool.get_type(position)?;
733        let mut pq = BinaryHeap::default();
734        pq.push(RandomPQ::new(context, rng));
735        let pools = vec![UnfinishedLambdaPool {
736            pool: vec![ExprOrType::Type {
737                lambda_type: output,
738                parent: None,
739                is_app_subformula,
740            }],
741        }];
742        let enumerator = LambdaEnumerator {
743            pools,
744            possible_expressions,
745            eager_filter,
746            pq: ProbabilisticEnumeration {
747                rng,
748                reservoir_size: 1,
749                reservoir: BinaryHeap::default(),
750                done: false,
751                n_seen: 0,
752                filter,
753                backups: vec![],
754                pq,
755            },
756        };
757
758        Ok(enumerator)
759    }
760
761    ///Create a [`LambdaSampler`] of a given type with a filter
762    pub fn enumerator_filter<'a, F: Fn(&Context) -> bool>(
763        t: &LambdaType,
764        filter: F,
765        possible_expressions: &'a PossibleExpressions<'src, Expr<'src>>,
766    ) -> LambdaEnumerator<'a, 'src, Expr<'src>, F> {
767        let context = Context::new(0, vec![]);
768        let mut pq = BinaryHeap::default();
769        pq.push(Reverse(context));
770        let pools = vec![UnfinishedLambdaPool {
771            pool: vec![ExprOrType::Type {
772                lambda_type: t.clone(),
773                parent: None,
774                is_app_subformula: false,
775            }],
776        }];
777
778        LambdaEnumerator {
779            pools,
780            possible_expressions,
781            eager_filter: filter,
782            pq: NormalEnumeration(pq, VecDeque::default()),
783        }
784    }
785
786    ///Create a [`LambdaSampler`] of a given type.
787    #[must_use]
788    pub fn enumerator<'a>(
789        t: &LambdaType,
790        possible_expressions: &'a PossibleExpressions<'src, Expr<'src>>,
791    ) -> LambdaEnumerator<'a, 'src, Expr<'src>, impl Fn(&'_ Context) -> bool> {
792        let context = Context::new(0, vec![]);
793        let mut pq = BinaryHeap::default();
794        pq.push(Reverse(context));
795        let pools = vec![UnfinishedLambdaPool {
796            pool: vec![ExprOrType::Type {
797                lambda_type: t.clone(),
798                parent: None,
799                is_app_subformula: false,
800            }],
801        }];
802
803        LambdaEnumerator {
804            pools,
805            possible_expressions,
806            eager_filter: |_: &Context| true,
807            pq: NormalEnumeration(pq, VecDeque::default()),
808        }
809    }
810
811    ///Creates a reusable random sampler by enumerating over the first `max_expr` expressions
812    #[must_use]
813    pub fn sampler(
814        t: &LambdaType,
815        possible_expressions: &PossibleExpressions<'src, Expr<'src>>,
816        max_expr: usize,
817    ) -> LambdaSampler<'src, Expr<'src>> {
818        let enumerator = RootedLambdaPool::enumerator(t, possible_expressions);
819        let mut lambdas = Vec::with_capacity(max_expr);
820        let mut expr_details = Vec::with_capacity(max_expr);
821        for (lambda, expr_detail) in enumerator.take(max_expr) {
822            lambdas.push(lambda);
823            expr_details.push(expr_detail);
824        }
825
826        LambdaSampler {
827            lambdas,
828            expr_details,
829        }
830    }
831
832    ///Randomly generate a [`RootedLambdaPool`] of type `t`.
833    ///
834    ///# Panics
835    ///Will panic if no such type can be generated.
836    pub fn random_expr(
837        t: &LambdaType,
838        possible_expressions: &PossibleExpressions<'src, Expr<'src>>,
839        rng: &mut impl Rng,
840    ) -> RootedLambdaPool<'src, Expr<'src>> {
841        ProbabilisticEnumeration::new(
842            1,
843            t,
844            possible_expressions,
845            |_: &Context| true,
846            |_| true,
847            rng,
848        )
849        .next()
850        .unwrap()
851        .0
852    }
853
854    ///Randomly generate a [`RootedLambdaPool`] of type `t` without constant functions.
855    pub fn random_expr_no_constant(
856        t: &LambdaType,
857        possible_expressions: &PossibleExpressions<'src, Expr<'src>>,
858        rng: &mut impl Rng,
859    ) -> Option<RootedLambdaPool<'src, Expr<'src>>> {
860        ProbabilisticEnumeration::new(
861            1,
862            t,
863            possible_expressions,
864            |_: &Context| true,
865            |e| !e.constant_function,
866            rng,
867        )
868        .next()
869        .map(|x| x.0)
870    }
871}
872
873impl<'src> RootedLambdaPool<'src, Expr<'src>> {
874    ///Remove quantifiers which do not use their variable in their body.
875    pub fn prune_quantifiers(&mut self) {
876        let quantifiers = self
877            .pool
878            .bfs_from(self.root)
879            .filter_map(|(i, _)| match self.get(i) {
880                LambdaExpr::LanguageOfThoughtExpr(Expr::Quantifier { subformula, .. }) => {
881                    if self
882                        .pool
883                        .bfs_from(LambdaExprRef(subformula.0))
884                        .any(|(x, d)| {
885                            if let LambdaExpr::BoundVariable(v, _) = self.get(x) {
886                                *v == d
887                            } else {
888                                false
889                            }
890                        })
891                    {
892                        None
893                    } else {
894                        Some((i, LambdaExprRef(subformula.0)))
895                    }
896                }
897                _ => None,
898            })
899            .collect::<Vec<_>>();
900
901        //By reversing, we ensure that we fix inner quantifiers before outer ones.
902        for (quantifier, subformula) in quantifiers.into_iter().rev() {
903            *self.pool.get_mut(quantifier) = self.pool.get(subformula).clone();
904            self.pool.bfs_from_mut(quantifier).for_each(|(x, d, _)| {
905                if let LambdaExpr::BoundVariable(b_d, _) = x
906                    && *b_d > d
907                {
908                    *b_d -= 1;
909                }
910            });
911        }
912        self.root = self.pool.cleanup(self.root);
913    }
914
915    ///Replace a random expression with something else of the same type.
916    ///
917    ///# Errors
918    ///Will return a [`TypeError`] if there is no compatible type to replace
919    ///
920    ///# Panics
921    ///Will panic if the size of the tree is greater than [`u32::MAX`].
922    pub fn swap_expr(
923        &mut self,
924        possible_expressions: &PossibleExpressions<'src, Expr<'src>>,
925        rng: &mut impl Rng,
926    ) -> Result<(), TypeError> {
927        let position = LambdaExprRef(u32::try_from((0..self.len()).choose(rng).unwrap()).unwrap());
928
929        let (context, _) = Context::from_pos(self, position);
930
931        let output = self.pool.get_type(position)?;
932        let expr = self.pool.get(position);
933
934        let arguments: Vec<_> = expr
935            .get_children()
936            .map(|x| self.pool.get_type(x).unwrap())
937            .collect();
938
939        let mut replacements = possible_expressions.possiblities_fixed_children(
940            &output,
941            &arguments,
942            expr.var_type(),
943            &context,
944        );
945
946        //hacky fix!!
947        if replacements.is_empty()
948            && let LambdaExpr::LanguageOfThoughtExpr(Expr::Quantifier {
949                var_type,
950                restrictor,
951                subformula,
952                ..
953            }) = expr
954        {
955            for quantifier in [Quantifier::Universal, Quantifier::Existential] {
956                replacements.push(std::borrow::Cow::Owned(LambdaExpr::LanguageOfThoughtExpr(
957                    Expr::Quantifier {
958                        quantifier,
959                        var_type: *var_type,
960                        restrictor: *restrictor,
961                        subformula: *subformula,
962                    },
963                )));
964            }
965        }
966        for c in [Constant::EveryEvent, Constant::Everyone] {
967            if replacements.is_empty()
968                && expr == &LambdaExpr::LanguageOfThoughtExpr(Expr::Constant(c))
969            {
970                replacements = possible_expressions.possiblities_fixed_children(
971                    LambdaType::t(),
972                    &arguments,
973                    expr.var_type(),
974                    &context,
975                );
976                replacements.push(std::borrow::Cow::Owned(LambdaExpr::LanguageOfThoughtExpr(
977                    Expr::Constant(c),
978                )));
979            }
980        }
981
982        let choice = replacements.choose(rng).unwrap_or_else(|| {
983            panic!(
984                "There is no node with output {output} and arguments {}",
985                arguments
986                    .into_iter()
987                    .map(|x| x.to_string())
988                    .collect::<Vec<_>>()
989                    .join(", ")
990            )
991        });
992
993        let mut new_expr = choice.clone().into_owned();
994        new_expr.change_children(self.pool.get(position).get_children());
995        self.pool.0[position.0 as usize] = new_expr;
996
997        Ok(())
998    }
999
1000    ///Replace a random expression with something else of the same type from within the same
1001    ///expression.
1002    ///
1003    ///# Errors
1004    ///Will return [`TypeError`] if a compatible subtree can't be found.
1005    ///
1006    ///# Panics
1007    ///Will panic if the size of the tree is greater than [`u32::MAX`].
1008    pub fn swap_subtree(&mut self, rng: &mut impl Rng) -> Result<(), TypeError> {
1009        let position = LambdaExprRef(u32::try_from((0..self.len()).choose(rng).unwrap()).unwrap());
1010        let (context, _) = Context::from_pos(self, position);
1011        let alt = context.find_compatible(self, position)?;
1012        if let Some(new_pos) = alt.choose(rng).copied() {
1013            let offset = self.pool.0.len();
1014            let mut lookup = HashMap::default();
1015            let mut new_pool: Vec<LambdaExpr<'src, Expr<'src>>> = self
1016                .pool
1017                .bfs_from(new_pos)
1018                .map(|(x, _)| {
1019                    let n = lookup.len();
1020                    lookup
1021                        .entry(x)
1022                        .or_insert(LambdaExprRef(u32::try_from(n + offset).unwrap()));
1023                    self.get(x).clone()
1024                })
1025                .collect();
1026
1027            for x in &mut new_pool {
1028                let child: Vec<_> = x.get_children().map(|x| *lookup.get(&x).unwrap()).collect();
1029                x.change_children(child.into_iter());
1030            }
1031
1032            self.pool.0.extend(new_pool);
1033            self.pool.0.swap(position.0 as usize, offset);
1034            self.cleanup();
1035        }
1036        Ok(())
1037    }
1038}
1039
1040#[cfg(test)]
1041mod test {
1042
1043    use std::collections::HashSet;
1044
1045    use super::*;
1046    use crate::lambda::{LambdaPool, LambdaSummaryStats};
1047    use rand::SeedableRng;
1048    use rand_chacha::ChaCha8Rng;
1049
1050    #[test]
1051    fn prune_quantifier_test() -> anyhow::Result<()> {
1052        let mut pool =
1053            RootedLambdaPool::parse("some_e(x,all_e,AgentOf(a_2,e_1) & PatientOf(a_0,e_0))")?;
1054
1055        pool.prune_quantifiers();
1056        assert_eq!(pool.to_string(), "AgentOf(a_2, e_1) & PatientOf(a_0, e_0)");
1057
1058        let mut pool = RootedLambdaPool::parse(
1059            "some_e(x0, all_e, some(z, all_a, AgentOf(z, e_1) & PatientOf(a_0, e_0)))",
1060        )?;
1061
1062        pool.prune_quantifiers();
1063        assert_eq!(
1064            pool.to_string(),
1065            "some(x, all_a, AgentOf(x, e_1) & PatientOf(a_0, e_0))"
1066        );
1067
1068        let mut pool = RootedLambdaPool::parse("~every_e(z, pe_1, pa_2(a_0))")?;
1069
1070        pool.prune_quantifiers();
1071
1072        assert_eq!(pool.to_string(), "~pa_2(a_0)");
1073        let mut pool = RootedLambdaPool::new(
1074            LambdaPool(vec![
1075                LambdaExpr::Lambda(LambdaExprRef(1), LambdaType::E),
1076                LambdaExpr::Lambda(LambdaExprRef(2), LambdaType::T),
1077                LambdaExpr::LanguageOfThoughtExpr(Expr::Quantifier {
1078                    quantifier: Quantifier::Universal,
1079                    var_type: ActorOrEvent::Actor,
1080                    restrictor: ExprRef(3),
1081                    subformula: ExprRef(4),
1082                }),
1083                LambdaExpr::LanguageOfThoughtExpr(Expr::Constant(Constant::Property(
1084                    "1",
1085                    ActorOrEvent::Actor,
1086                ))),
1087                LambdaExpr::LanguageOfThoughtExpr(Expr::Quantifier {
1088                    quantifier: Quantifier::Existential,
1089                    var_type: ActorOrEvent::Actor,
1090                    restrictor: ExprRef(5),
1091                    subformula: ExprRef(6),
1092                }),
1093                LambdaExpr::LanguageOfThoughtExpr(Expr::Constant(Constant::Property(
1094                    "0",
1095                    ActorOrEvent::Actor,
1096                ))),
1097                LambdaExpr::LanguageOfThoughtExpr(Expr::Unary(
1098                    MonOp::Property("3", ActorOrEvent::Event),
1099                    ExprRef(7),
1100                )),
1101                LambdaExpr::BoundVariable(3, LambdaType::E),
1102            ]),
1103            LambdaExprRef(0),
1104        );
1105
1106        assert_eq!(
1107            pool.to_string(),
1108            "lambda e x lambda t phi every(y, pa_1, some(z, pa_0, pe_3(x)))"
1109        );
1110
1111        let mut parsed_pool = RootedLambdaPool::parse(
1112            "lambda e x lambda t phi every(y, pa_1, some(z, pa_0, pe_3(x)))",
1113        )?;
1114        parsed_pool.prune_quantifiers();
1115        pool.prune_quantifiers();
1116        assert_eq!(pool.to_string(), "lambda e x lambda t phi pe_3(x)");
1117
1118        Ok(())
1119    }
1120
1121    #[test]
1122    fn random_swap_tree() -> anyhow::Result<()> {
1123        let mut rng = ChaCha8Rng::seed_from_u64(2);
1124        let x = RootedLambdaPool::parse("lambda a x pa_cool(a_John) | pa_cool(x)")?;
1125        let mut h = HashSet::new();
1126        for _ in 0..100 {
1127            let mut z = x.clone();
1128            z.swap_subtree(&mut rng)?;
1129            h.insert(z.to_string());
1130        }
1131        for x in h.iter() {
1132            println!("{x}");
1133        }
1134        assert!(h.contains("lambda a x pa_cool(x)"));
1135        assert!(h.contains("lambda a x pa_cool(a_John)"));
1136        Ok(())
1137    }
1138
1139    #[test]
1140    fn randomn_swap() -> anyhow::Result<()> {
1141        let mut rng = ChaCha8Rng::seed_from_u64(2);
1142        let actors = ["0", "1"];
1143        let available_event_properties = ["2", "3", "4"];
1144        let possible_expressions = PossibleExpressions::new(
1145            &actors,
1146            &available_event_properties,
1147            &available_event_properties,
1148        );
1149        for _ in 0..200 {
1150            let t = LambdaType::random(&mut rng);
1151            println!("{t}");
1152            let mut pool = RootedLambdaPool::random_expr(&t, &possible_expressions, &mut rng);
1153            println!("{t}: {pool}");
1154            assert_eq!(t, pool.get_type()?);
1155            println!("{pool:?}");
1156            pool.swap_expr(&possible_expressions, &mut rng)?;
1157            println!("{pool:?}");
1158            println!("{t}: {pool}");
1159            assert_eq!(t, pool.get_type()?);
1160        }
1161        let p =
1162            RootedLambdaPool::parse("lambda <a,t> P some(x, P(x), some_e(y, pe_2(y), pe_4(y)))")?;
1163        let t = p.get_type()?;
1164        for _ in 0..1000 {
1165            let mut pool = p.clone();
1166            assert_eq!(t, pool.get_type()?);
1167
1168            pool.swap_expr(&possible_expressions, &mut rng)?;
1169            dbg!(&pool);
1170            println!("{t}: {pool}");
1171            assert_eq!(t, pool.get_type()?);
1172        }
1173
1174        Ok(())
1175    }
1176
1177    #[test]
1178    fn enumerate() -> anyhow::Result<()> {
1179        let actors = ["john"];
1180        let actor_properties = ["a"];
1181        let event_properties = ["e"];
1182        let possibles = PossibleExpressions::new(&actors, &actor_properties, &event_properties);
1183
1184        let t = LambdaType::from_string("<<a,t>,t>")?;
1185
1186        let p = RootedLambdaPool::enumerator(&t, &possibles);
1187        let pools: HashSet<_> = p
1188            .filter_map(|(p, x)| {
1189                if !x.has_constant_function() {
1190                    Some(p.to_string())
1191                } else {
1192                    None
1193                }
1194            })
1195            .take(20)
1196            .collect();
1197        for p in pools.iter() {
1198            println!("{p}");
1199        }
1200        assert!(pools.contains("lambda <a,t> P P(a_john)"));
1201        for p in pools {
1202            println!("{p}");
1203        }
1204
1205        Ok(())
1206    }
1207
1208    /*
1209    #[test]
1210    fn random_expr_no_constant() -> anyhow::Result<()> {
1211        let actors = &["1", "2", "3", "4", "5"];
1212        let actor_properties = &["1", "2", "3", "4", "5"];
1213        let event_properties = &["1", "2", "3", "4", "5"];
1214        let poss = PossibleExpressions::new(actors, actor_properties, event_properties);
1215        let mut rng = ChaCha8Rng::seed_from_u64(1);
1216
1217        let t = LambdaType::from_string("<<a,<e,t>>, <e,t>>")?;
1218        println!("{t}: ");
1219        let pool = RootedLambdaPool::random_expr_no_constant(&t, &poss, &mut rng).unwrap();
1220        println!("{pool}");
1221
1222        for _ in 0..500 {
1223            let t = LambdaType::random(&mut rng);
1224            println!("{t}: ");
1225            if t.size() <= 3 {
1226                let pool = RootedLambdaPool::random_expr_no_constant(&t, &poss, &mut rng).unwrap();
1227                println!("{pool}");
1228            } else {
1229                println!("too big :(");
1230            }
1231        }
1232        Ok(())
1233    }*/
1234
1235    #[test]
1236    fn enumerate_weirds() -> anyhow::Result<()> {
1237        let actors = &["1"];
1238        let actor_properties = &["1"];
1239        let event_properties = &["1"];
1240        let poss = PossibleExpressions::new(actors, actor_properties, event_properties);
1241
1242        let t = "<a,<e,e>>";
1243
1244        for (p, d) in RootedLambdaPool::enumerator(&LambdaType::from_string(t)?, &poss)
1245            .filter(|(_, e)| !e.has_constant_function())
1246            .take(5)
1247        {
1248            println!("{p} {d:?} {}", d.has_constant_function());
1249        }
1250        Ok(())
1251    }
1252
1253    #[test]
1254    fn random_expr() -> anyhow::Result<()> {
1255        let actors = ["john"];
1256        let actor_properties = ["a"];
1257        let event_properties = ["e"];
1258        let possibles = PossibleExpressions::new(&actors, &actor_properties, &event_properties);
1259        let mut rng = ChaCha8Rng::seed_from_u64(0);
1260
1261        let map = [(LambdaType::A, vec![RootedLambdaPool::parse("a_john")?])]
1262            .into_iter()
1263            .collect();
1264
1265        for _ in 0..100 {
1266            let t = LambdaType::random(&mut rng);
1267            println!("sampling: {t}");
1268            let mut pool = RootedLambdaPool::random_expr(&t, &possibles, &mut rng);
1269            assert_eq!(t, pool.get_type()?);
1270            let s = pool.to_string();
1271            println!("{s}");
1272            let mut pool2 = RootedLambdaPool::parse(s.as_str())?;
1273            assert_eq!(s, pool2.to_string());
1274            println!("{pool}");
1275            pool2.resample_from_expr(&possibles, None, &mut rng)?;
1276            assert_eq!(pool2.get_type()?, t);
1277            pool2.resample_from_expr(&possibles, Some(&map), &mut rng)?;
1278            assert_eq!(pool2.get_type()?, t);
1279
1280            pool.swap_subtree(&mut rng)?;
1281            assert_eq!(pool.get_type()?, t);
1282        }
1283        let t = LambdaType::from_string("<a,<a,t>>")?;
1284        for _ in 0..100 {
1285            let pool = RootedLambdaPool::random_expr(&t, &possibles, &mut rng);
1286            println!("{pool}");
1287        }
1288        Ok(())
1289    }
1290
1291    #[test]
1292    fn constant_exprs() -> anyhow::Result<()> {
1293        let actors = ["john", "mary", "phil", "sue"];
1294        let actor_properties = ["a"];
1295        let event_properties = ["e"];
1296        let possibles = PossibleExpressions::new(&actors, &actor_properties, &event_properties);
1297        let mut rng = ChaCha8Rng::seed_from_u64(0);
1298
1299        let v: HashSet<_> = RootedLambdaPool::enumerator(LambdaType::a(), &possibles)
1300            .map(|(x, _)| x.to_string())
1301            .take(4)
1302            .collect();
1303
1304        assert_eq!(v, HashSet::from(actors.map(|x| format!("a_{x}"))));
1305        println!("{v:?}");
1306
1307        let mut constants = 0;
1308        for _ in 0..100 {
1309            let t = LambdaType::from_string("<a, <a,t>>")?;
1310            println!("sampling: {t}");
1311            let pool = RootedLambdaPool::random_expr(&t, &possibles, &mut rng);
1312            let x = pool.stats();
1313            match x {
1314                LambdaSummaryStats::WellFormed {
1315                    constant_function, ..
1316                } => {
1317                    if constant_function {
1318                        constants += 1;
1319                    }
1320                }
1321                LambdaSummaryStats::Malformed => todo!(),
1322            }
1323        }
1324        println!("{constants}");
1325
1326        Ok(())
1327    }
1328
1329    #[test]
1330    fn random_expr_counts() -> anyhow::Result<()> {
1331        let actors = ["john", "mary", "phil", "sue"];
1332        let actor_properties = ["a"];
1333        let event_properties = ["e"];
1334        let possibles = PossibleExpressions::new(&actors, &actor_properties, &event_properties);
1335        let mut rng = ChaCha8Rng::seed_from_u64(1);
1336
1337        let mut counts: HashMap<_, usize> = HashMap::default();
1338        for _ in 0..1000 {
1339            let t = LambdaType::A;
1340            let pool = RootedLambdaPool::random_expr(&t, &possibles, &mut rng);
1341            assert_eq!(t, pool.get_type()?);
1342            let s = pool.to_string();
1343            *counts.entry(s).or_default() += 1;
1344        }
1345        assert_eq!(counts.len(), 4);
1346        dbg!(&counts);
1347        for (_, v) in counts.iter() {
1348            assert!(200 <= *v && *v <= 300);
1349        }
1350
1351        counts.clear();
1352        for _ in 0..1000 {
1353            let t = LambdaType::at().clone();
1354            let pool = RootedLambdaPool::random_expr(&t, &possibles, &mut rng);
1355            assert_eq!(t, pool.get_type()?);
1356            let s = pool.to_string();
1357            *counts.entry(s).or_default() += 1;
1358        }
1359        dbg!(counts);
1360        Ok(())
1361    }
1362}