Skip to main content

simple_semantics/language/
mutations.rs

1use super::{
2    ActorOrEvent, BinOp, Constant, Display, Expr, ExprRef, MonOp, Quantifier, RootedLambdaPool,
3    thiserror,
4};
5use crate::lambda::{
6    LambdaError, LambdaExpr, LambdaExprRef, LambdaLanguageOfThought, LambdaPool,
7    types::{LambdaType, TypeError},
8};
9use ahash::HashMap;
10use rand::{
11    Rng, RngExt,
12    distr::{Distribution, weighted::WeightedIndex},
13    seq::{IndexedRandom, IteratorRandom},
14};
15use std::{
16    cmp::Reverse,
17    collections::{BinaryHeap, VecDeque},
18    fmt::Debug,
19};
20use thiserror::Error;
21
22mod context;
23mod samplers;
24pub use context::Context;
25pub(crate) use samplers::PossibleExpr;
26pub use samplers::PossibleExpressions;
27
28#[derive(Debug, Error, Clone)]
29pub struct ExprOrTypeError();
30
31impl Display for ExprOrTypeError {
32    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
33        write!(f, "This ExprOrType is not an Expr!")
34    }
35}
36
37#[derive(Debug, Clone, Eq, PartialEq)]
38enum ExprOrType<'src, T: LambdaLanguageOfThought> {
39    Type {
40        lambda_type: LambdaType,
41        parent: Option<usize>,
42        is_app_subformula: bool,
43    },
44    Expr {
45        lambda_expr: LambdaExpr<'src, T>,
46        parent: Option<usize>,
47    },
48}
49
50impl<'src, T: LambdaLanguageOfThought> TryFrom<ExprOrType<'src, T>> for LambdaExpr<'src, T> {
51    type Error = ExprOrTypeError;
52
53    fn try_from(value: ExprOrType<'src, T>) -> Result<Self, Self::Error> {
54        match value {
55            ExprOrType::Type { .. } => Err(ExprOrTypeError()),
56            ExprOrType::Expr { lambda_expr, .. } => Ok(lambda_expr),
57        }
58    }
59}
60
61impl<T: LambdaLanguageOfThought> ExprOrType<'_, T> {
62    fn parent(&self) -> Option<usize> {
63        match self {
64            ExprOrType::Type { parent, .. } | ExprOrType::Expr { parent, .. } => *parent,
65        }
66    }
67
68    fn is_type(&self) -> bool {
69        matches!(self, ExprOrType::Type { .. })
70    }
71}
72
73#[derive(Debug, Clone, Eq, PartialEq)]
74struct UnfinishedLambdaPool<'src, T: LambdaLanguageOfThought> {
75    pool: Vec<ExprOrType<'src, T>>,
76}
77
78impl<T: LambdaLanguageOfThought> Default for UnfinishedLambdaPool<'_, T> {
79    fn default() -> Self {
80        Self { pool: vec![] }
81    }
82}
83
84impl<'src, T: LambdaLanguageOfThought + Clone> UnfinishedLambdaPool<'src, T> {
85    fn add_expr<'a>(&mut self, expr: PossibleExpr<'a, 'src, T>, c: &mut Context, t: &LambdaType) {
86        let (mut lambda_expr, app_details) = expr.into_expr();
87        c.depth += 1;
88        c.open_nodes += lambda_expr.n_children();
89        c.open_nodes -= 1;
90        let parent = self.pool[c.position].parent();
91        match &mut lambda_expr {
92            LambdaExpr::Lambda(body, arg) => {
93                c.add_lambda(arg);
94                *body = LambdaExprRef::new(self.pool.len());
95                self.pool.push(ExprOrType::Type {
96                    lambda_type: t.rhs().unwrap().clone(),
97                    parent: Some(c.position),
98                    is_app_subformula: false,
99                });
100            }
101            LambdaExpr::BoundVariable(b, _) => {
102                c.use_bvar(*b);
103            }
104            LambdaExpr::FreeVariable(..) => (),
105            LambdaExpr::Application {
106                subformula,
107                argument,
108            } => {
109                *subformula = LambdaExprRef::new(self.pool.len());
110                *argument = LambdaExprRef::new(self.pool.len() + 1);
111                let (subformula, argument) = app_details.unwrap();
112                self.pool.push(ExprOrType::Type {
113                    lambda_type: subformula,
114                    parent: Some(c.position),
115                    is_app_subformula: true,
116                });
117                self.pool.push(ExprOrType::Type {
118                    lambda_type: argument,
119                    parent: Some(c.position),
120                    is_app_subformula: false,
121                });
122            }
123            LambdaExpr::LanguageOfThoughtExpr(e) => {
124                let children_start = self.pool.len();
125                if let Some(t) = e.var_type() {
126                    c.add_lambda(t);
127                }
128                self.pool
129                    .extend(e.get_arguments().map(|lambda_type| ExprOrType::Type {
130                        lambda_type,
131                        parent: Some(c.position),
132                        is_app_subformula: false,
133                    }));
134                e.change_children((children_start..self.pool.len()).map(LambdaExprRef::new));
135            }
136        }
137        self.pool[c.position] = ExprOrType::Expr {
138            lambda_expr,
139            parent,
140        };
141    }
142}
143
144#[derive(Debug)]
145pub struct NormalEnumeration(BinaryHeap<Reverse<Context>>, VecDeque<ExprDetails>);
146
147impl EnumerationType for NormalEnumeration {
148    fn pop(&mut self) -> Option<Context> {
149        self.0.pop().map(|x| x.0)
150    }
151
152    fn push(&mut self, context: Context, _: bool) {
153        self.0.push(Reverse(context));
154    }
155
156    fn get_yield(&mut self) -> Option<ExprDetails> {
157        self.1.pop_front()
158    }
159
160    fn push_yield(&mut self, e: ExprDetails) {
161        self.1.push_back(e);
162    }
163
164    fn include(&mut self, n: usize) -> impl Iterator<Item = bool> + 'static {
165        std::iter::repeat_n(true, n)
166    }
167}
168
169impl ExprDetails {
170    #[allow(clippy::cast_precision_loss)]
171    fn score(&self) -> f64 {
172        (1.0 / (self.size as f64)) + if self.constant_function { 0.0 } else { 10.0 }
173    }
174
175    pub fn has_constant_function(&self) -> bool {
176        self.constant_function
177    }
178}
179
180#[derive(Debug, Clone, Copy, PartialEq)]
181struct KeyedExprDetails {
182    expr_details: ExprDetails,
183    k: f64,
184}
185
186impl Eq for KeyedExprDetails {}
187
188impl PartialOrd for KeyedExprDetails {
189    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
190        Some(self.cmp(other))
191    }
192}
193
194impl Ord for KeyedExprDetails {
195    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
196        //reversed since we need a min-heap not a max-heap
197        other.k.partial_cmp(&self.k).unwrap()
198    }
199}
200impl KeyedExprDetails {
201    fn new(expr_details: ExprDetails, rng: &mut impl Rng) -> Self {
202        let u: f64 = rng.random();
203        KeyedExprDetails {
204            expr_details,
205            k: u.powf(1.0 / expr_details.score()),
206        }
207    }
208}
209
210#[derive(Debug, Clone, PartialEq)]
211struct RandomPQ(Context, f64);
212
213impl Eq for RandomPQ {}
214
215impl RandomPQ {
216    fn new(c: Context, rng: &mut impl Rng) -> Self {
217        RandomPQ(c, rng.random())
218    }
219}
220
221#[derive(Debug)]
222struct ProbabilisticEnumeration<'a, R: Rng, F>
223where
224    F: Fn(&ExprDetails) -> bool,
225{
226    rng: &'a mut R,
227    reservoir_size: usize,
228    reservoir: BinaryHeap<KeyedExprDetails>,
229    backups: Vec<Context>,
230    pq: BinaryHeap<RandomPQ>,
231    filter: F,
232    n_seen: usize,
233    done: bool,
234}
235impl<R: Rng, F> ProbabilisticEnumeration<'_, R, F>
236where
237    F: Fn(&ExprDetails) -> bool,
238{
239    fn threshold(&self) -> Option<f64> {
240        self.reservoir.peek().map(|x| x.k)
241    }
242
243    fn new<'a, 'src, T: LambdaLanguageOfThought, E: Fn(&Context) -> bool>(
244        reservoir_size: usize,
245        t: &LambdaType,
246        possible_expressions: &'a PossibleExpressions<'src, T>,
247        eager_filter: E,
248        filter: F,
249        rng: &'a mut R,
250    ) -> LambdaEnumerator<'a, 'src, T, E, ProbabilisticEnumeration<'a, R, F>> {
251        let context = Context::new(0, vec![]);
252        let mut pq = BinaryHeap::default();
253        pq.push(RandomPQ::new(context, rng));
254        let pools = vec![UnfinishedLambdaPool {
255            pool: vec![ExprOrType::Type {
256                lambda_type: t.clone(),
257                parent: None,
258                is_app_subformula: false,
259            }],
260        }];
261
262        LambdaEnumerator {
263            pools,
264            possible_expressions,
265            eager_filter,
266            pq: ProbabilisticEnumeration {
267                rng,
268                reservoir_size,
269                reservoir: BinaryHeap::default(),
270                backups: vec![],
271                filter,
272                pq,
273                n_seen: 0,
274                done: false,
275            },
276        }
277    }
278}
279
280impl<R: Rng, F> EnumerationType for ProbabilisticEnumeration<'_, R, F>
281where
282    F: Fn(&ExprDetails) -> bool,
283{
284    fn pop(&mut self) -> Option<Context> {
285        //Pop from min-heap, or grab a random back up if the min-heap is exhausted
286        self.pq.pop().map(|x| x.0).or_else(|| {
287            (0..self.backups.len()).choose(self.rng).and_then(|index| {
288                let last_item = self.backups.len() - 1;
289                self.backups.swap(index, last_item);
290                self.backups.pop()
291            })
292        })
293    }
294
295    fn push(&mut self, context: Context, included: bool) {
296        if included {
297            self.pq.push(RandomPQ::new(context, &mut self.rng));
298        } else {
299            self.backups.push(context);
300        }
301    }
302
303    fn get_yield(&mut self) -> Option<ExprDetails> {
304        if (self.done || self.pq.is_empty())
305            && let Some(x) = self.reservoir.pop()
306        {
307            Some(x.expr_details)
308        } else {
309            None
310        }
311    }
312
313    fn push_yield(&mut self, e: ExprDetails) {
314        let e = KeyedExprDetails::new(e, &mut self.rng);
315        if (self.filter)(&e.expr_details) {
316            self.n_seen += 1;
317            if self.reservoir_size > self.reservoir.len() {
318                self.reservoir.push(e);
319            } else if let Some(t) = self.threshold()
320                && e.k > t
321            {
322                self.reservoir.pop();
323                self.reservoir.push(e);
324            }
325            if self.n_seen >= self.reservoir_size * 20 {
326                self.pq.clear();
327                self.done = true;
328            }
329        }
330    }
331
332    fn include(&mut self, n: usize) -> impl Iterator<Item = bool> + 'static {
333        let x = (0..n).sample(self.rng, (n / 2).max(1));
334        let mut v = vec![false; n];
335        for i in x {
336            v[i] = true;
337        }
338        v.into_iter()
339    }
340}
341
342#[derive(Debug)]
343///An iterator that enumerates over all possible expressions of a given type.
344pub struct LambdaEnumerator<'a, 'src, T: LambdaLanguageOfThought, F, E = NormalEnumeration> {
345    pools: Vec<UnfinishedLambdaPool<'src, T>>,
346    possible_expressions: &'a PossibleExpressions<'src, T>,
347    eager_filter: F,
348    pq: E,
349}
350
351///Provides detail about a generated lambda expression
352#[derive(Debug, Clone, Copy, Eq, PartialEq)]
353pub struct ExprDetails {
354    id: usize,
355    constant_function: bool,
356    root: LambdaExprRef,
357    size: usize,
358}
359
360impl ExprDetails {
361    ///Get the size of the associated [`RootedLambdaPool`].
362    pub fn size(&self) -> usize {
363        self.size
364    }
365}
366
367#[derive(Debug, Clone, Eq, PartialEq)]
368///A re-usable sampler for sampling expressions of arbitrary types while caching frequent types
369pub struct TypeAgnosticSampler<'src, T: LambdaLanguageOfThought> {
370    type_to_sampler: HashMap<LambdaType, (usize, LambdaSampler<'src, T>)>,
371    max_expr: usize,
372    max_types: usize,
373    possible_expressions: PossibleExpressions<'src, T>,
374}
375
376impl<'src> TypeAgnosticSampler<'src, Expr<'src>> {
377    #[allow(clippy::missing_panics_doc)]
378    ///Samples an expression of a given type
379    pub fn sample(
380        &mut self,
381        lambda_type: LambdaType,
382        rng: &mut impl Rng,
383    ) -> RootedLambdaPool<'src, Expr<'src>> {
384        let (counts, exprs) = self
385            .type_to_sampler
386            .entry(lambda_type)
387            .or_insert_with_key(|t| {
388                (
389                    1,
390                    RootedLambdaPool::sampler(t, &self.possible_expressions, self.max_expr),
391                )
392            });
393        *counts += 1;
394        let sample = exprs.sample(rng);
395
396        if self.type_to_sampler.len() > self.max_types {
397            let (_, k) = self
398                .type_to_sampler
399                .iter()
400                .map(|(k, (n_visits, _))| (n_visits, k))
401                .min_by_key(|x| x.0)
402                .unwrap();
403
404            let t = k.clone();
405            self.type_to_sampler.remove(&t);
406        }
407
408        sample
409    }
410
411    ///Get a reference to the [`PossibleExpressions`] used by the model
412    #[must_use]
413    pub fn possibles(&self) -> &PossibleExpressions<'src, Expr<'src>> {
414        &self.possible_expressions
415    }
416}
417
418impl<'src, T: LambdaLanguageOfThought + Clone> RootedLambdaPool<'src, T> {
419    ///Create a sampler which can sample arbitrary types.
420    ///
421    ///# Panics
422    ///Will panic if `max_types` == 0 or `max_expr` == 0
423    #[must_use]
424    pub fn typeless_sampler(
425        possible_expressions: PossibleExpressions<'src, T>,
426        max_expr: usize,
427        max_types: usize,
428    ) -> TypeAgnosticSampler<'src, T> {
429        assert!(max_types >= 1);
430        assert!(max_expr >= 1);
431        TypeAgnosticSampler {
432            possible_expressions,
433            max_expr,
434            max_types,
435            type_to_sampler: HashMap::default(),
436        }
437    }
438}
439
440///A struct which samples expressions from a distribution.
441#[derive(Debug, Clone, Eq, PartialEq)]
442pub struct LambdaSampler<'src, T: LambdaLanguageOfThought> {
443    lambdas: Vec<RootedLambdaPool<'src, T>>,
444    expr_details: Vec<ExprDetails>,
445}
446
447impl<'src, T: LambdaLanguageOfThought + Clone> Distribution<RootedLambdaPool<'src, T>>
448    for LambdaSampler<'src, T>
449{
450    fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> RootedLambdaPool<'src, T> {
451        let w = WeightedIndex::new(self.expr_details.iter().map(ExprDetails::score)).unwrap();
452        let i = w.sample(rng);
453        self.lambdas
454            .get(i)
455            .expect("The Lambda Sampler has no lambdas to sample :(")
456            .clone()
457    }
458}
459
460///A trait that handles how enumeration is processed (either normal enumeration or by doing
461///reservoir sampling).
462pub trait EnumerationType {
463    fn pop(&mut self) -> Option<Context>;
464    fn push(&mut self, context: Context, included: bool);
465    fn get_yield(&mut self) -> Option<ExprDetails>;
466    fn push_yield(&mut self, e: ExprDetails);
467    fn include(&mut self, n: usize) -> impl Iterator<Item = bool> + 'static;
468}
469
470fn try_yield<'src, T, F, E>(
471    x: &mut LambdaEnumerator<'_, 'src, T, F, E>,
472) -> Option<(RootedLambdaPool<'src, T>, ExprDetails)>
473where
474    T: LambdaLanguageOfThought,
475    E: EnumerationType,
476{
477    if let Some(item) = x.pq.get_yield() {
478        let p = std::mem::take(&mut x.pools[item.id]);
479        return Some((
480            RootedLambdaPool {
481                pool: LambdaPool(
482                    p.pool
483                        .into_iter()
484                        .map(|x| LambdaExpr::try_from(x).unwrap())
485                        .collect(),
486                ),
487                root: item.root,
488            },
489            item,
490        ));
491    }
492    None
493}
494
495impl<'a, 'src, T, F, E> LambdaEnumerator<'a, 'src, T, F, E>
496where
497    T: LambdaLanguageOfThought + Clone + Debug,
498    F: Fn(&Context) -> bool,
499    E: EnumerationType,
500{
501    fn push(&mut self, c: Context, included: bool) {
502        if (self.eager_filter)(&c) {
503            self.pq.push(c, included);
504        } else {
505            self.pools[c.pool_index] = UnfinishedLambdaPool::default();
506        }
507    }
508
509    ///Change the `eager_filter` function for this enumerator
510    pub fn eager_filter<F2>(self, eager_filter: F2) -> LambdaEnumerator<'a, 'src, T, F2, E> {
511        let LambdaEnumerator {
512            pools,
513            possible_expressions,
514            eager_filter: _,
515            pq,
516        } = self;
517
518        LambdaEnumerator {
519            pools,
520            possible_expressions,
521            eager_filter,
522            pq,
523        }
524    }
525}
526
527impl<'src, F, E> Iterator for LambdaEnumerator<'_, 'src, Expr<'src>, F, E>
528where
529    F: Fn(&Context) -> bool,
530    E: EnumerationType,
531{
532    type Item = (RootedLambdaPool<'src, Expr<'src>>, ExprDetails);
533
534    #[allow(clippy::too_many_lines)]
535    fn next(&mut self) -> Option<Self::Item> {
536        if let Some(x) = try_yield(self) {
537            return Some(x);
538        }
539
540        while let Some(mut c) = self.pq.pop() {
541            if let Some(x) = try_yield(self) {
542                self.push(c, true);
543                return Some(x);
544            }
545            let (possibles, lambda_type) = match &self.pools[c.pool_index].pool[c.position] {
546                ExprOrType::Type {
547                    lambda_type,
548                    is_app_subformula,
549                    parent,
550                } => {
551                    let mut possibles = self.possible_expressions.possibilities(
552                        lambda_type,
553                        *is_app_subformula,
554                        &c,
555                    );
556
557                    //Super hacky way to introduce all_e and all_a in quantifiers even though the
558                    //types are messed up.
559                    if let Some(p) = parent
560                        && let ExprOrType::Expr {
561                            lambda_expr:
562                                LambdaExpr::LanguageOfThoughtExpr(Expr::Quantifier {
563                                    var_type,
564                                    restrictor,
565                                    ..
566                                }),
567                            ..
568                        } = self.pools[c.pool_index].pool[*p]
569                        && restrictor.0 == u32::try_from(c.position).unwrap()
570                    {
571                        possibles.push(PossibleExpr::new_borrowed(match var_type {
572                            ActorOrEvent::Actor => &LambdaExpr::LanguageOfThoughtExpr(
573                                Expr::Constant(Constant::Everyone),
574                            ),
575                            ActorOrEvent::Event => &LambdaExpr::LanguageOfThoughtExpr(
576                                Expr::Constant(Constant::EveryEvent),
577                            ),
578                        }));
579                    }
580
581                    (possibles, lambda_type.clone())
582                }
583                ExprOrType::Expr {
584                    lambda_expr,
585                    parent,
586                } => {
587                    //We add the next uninitialized child to the context or go to the parent if
588                    //there are none.
589
590                    if let Some(child) = lambda_expr
591                        .get_children()
592                        .map(|x| x.0 as usize)
593                        .find(|child| self.pools[c.pool_index].pool[*child].is_type())
594                    {
595                        c.position = child;
596                        self.pq.push(c, true);
597                        continue;
598                    }
599
600                    if lambda_expr.inc_depth() {
601                        c.pop_lambda();
602                    }
603
604                    if let Some(p) = parent {
605                        c.position = *p;
606                        self.pq.push(c, true);
607                        continue;
608                    }
609                    //If the parent is None, we're done!
610                    self.pq.push_yield(ExprDetails {
611                        id: c.pool_index,
612                        root: LambdaExprRef(u32::try_from(c.position).unwrap()),
613                        size: c.depth,
614                        constant_function: c.is_constant(),
615                    });
616                    continue;
617                }
618            };
619
620            let n = possibles.len();
621            let included = self.pq.include(n);
622            if n == 0 {
623                continue;
624            }
625            let n_pools = self.pools.len();
626            if n_pools.is_multiple_of(10_000) {
627                self.pools.shrink_to_fit();
628            }
629            for _ in 0..n.saturating_sub(1) {
630                self.pools.push(self.pools[c.pool_index].clone());
631            }
632
633            let positions =
634                std::iter::once(c.pool_index).chain(n_pools..n_pools + n.saturating_sub(1));
635
636            for (((expr, pool_id), mut c), included) in possibles
637                .into_iter()
638                .zip(positions)
639                .zip(std::iter::repeat_n(c, n))
640                .zip(included)
641            {
642                c.pool_index = pool_id;
643                let pool = self.pools.get_mut(pool_id).unwrap();
644                pool.add_expr(expr, &mut c, &lambda_type);
645                self.push(c, included);
646            }
647
648            if let Some(x) = try_yield(self) {
649                return Some(x);
650            }
651        }
652
653        //If we've somehow exhausted the pq, lets yield anything remaining that's done.
654        if let Some(x) = try_yield(self) {
655            return Some(x);
656        }
657        None
658    }
659}
660
661impl<'src> RootedLambdaPool<'src, Expr<'src>> {
662    ///Create a [`LambdaSampler`] of a given type.
663    ///
664    ///# Errors
665    ///Will return a [`LambdaError`] if the tree is malformed
666    ///
667    ///# Panics
668    ///Will panic if the size of the tree is greater than [`u32::MAX`].
669    pub fn resample_from_expr<'a>(
670        &mut self,
671        possible_expressions: &'a PossibleExpressions<'src, Expr<'src>>,
672        helpers: Option<&HashMap<LambdaType, Vec<RootedLambdaPool<'src, Expr<'src>>>>>,
673        rng: &mut impl Rng,
674    ) -> Result<(), LambdaError> {
675        let position = LambdaExprRef(u32::try_from((0..self.len()).choose(rng).unwrap()).unwrap());
676        let t = self.pool.get_type(position)?;
677
678        let pool = if let Some(helpers) = helpers
679            && rng.random_bool(0.2)
680            && let Some(v) = helpers.get(&t)
681            && !v.is_empty()
682        {
683            let pool = v.choose(rng).unwrap();
684            pool.clone()
685        } else {
686            let (pool, _) = self
687                .probabilistic_enumerate_from_expr(
688                    position,
689                    possible_expressions,
690                    |_| true,
691                    |_| true,
692                    rng,
693                )?
694                .next()
695                .unwrap();
696            pool
697        };
698
699        let offset = u32::try_from(self.len()).unwrap();
700        let new_root = pool.root.0 + offset;
701        self.pool.0.extend(pool.pool.0.into_iter().map(|mut x| {
702            let children: Vec<_> = x
703                .get_children()
704                .map(|x| LambdaExprRef(x.0 + offset))
705                .collect();
706            x.change_children(children.into_iter());
707            x
708        }));
709        self.pool.0.swap(position.0 as usize, new_root as usize);
710        self.cleanup();
711        Ok(())
712    }
713
714    fn probabilistic_enumerate_from_expr<'a, R, E, F>(
715        &self,
716        position: LambdaExprRef,
717        possible_expressions: &'a PossibleExpressions<'src, Expr<'src>>,
718        eager_filter: E,
719        filter: F,
720        rng: &'a mut R,
721    ) -> Result<
722        LambdaEnumerator<'a, 'src, Expr<'src>, E, ProbabilisticEnumeration<'a, R, F>>,
723        TypeError,
724    >
725    where
726        R: Rng,
727        F: Fn(&ExprDetails) -> bool,
728        E: Fn(&Context) -> bool,
729    {
730        let (context, is_app_subformula) = Context::from_pos(self, position);
731        let output = self.pool.get_type(position)?;
732        let mut pq = BinaryHeap::default();
733        pq.push(RandomPQ::new(context, rng));
734        let pools = vec![UnfinishedLambdaPool {
735            pool: vec![ExprOrType::Type {
736                lambda_type: output,
737                parent: None,
738                is_app_subformula,
739            }],
740        }];
741        let enumerator = LambdaEnumerator {
742            pools,
743            possible_expressions,
744            eager_filter,
745            pq: ProbabilisticEnumeration {
746                rng,
747                reservoir_size: 1,
748                reservoir: BinaryHeap::default(),
749                done: false,
750                n_seen: 0,
751                filter,
752                backups: vec![],
753                pq,
754            },
755        };
756
757        Ok(enumerator)
758    }
759
760    ///Create a [`LambdaSampler`] of a given type with a filter
761    pub fn enumerator_filter<'a, F: Fn(&Context) -> bool>(
762        t: &LambdaType,
763        filter: F,
764        possible_expressions: &'a PossibleExpressions<'src, Expr<'src>>,
765    ) -> LambdaEnumerator<'a, 'src, Expr<'src>, F> {
766        let context = Context::new(0, vec![]);
767        let mut pq = BinaryHeap::default();
768        pq.push(Reverse(context));
769        let pools = vec![UnfinishedLambdaPool {
770            pool: vec![ExprOrType::Type {
771                lambda_type: t.clone(),
772                parent: None,
773                is_app_subformula: false,
774            }],
775        }];
776
777        LambdaEnumerator {
778            pools,
779            possible_expressions,
780            eager_filter: filter,
781            pq: NormalEnumeration(pq, VecDeque::default()),
782        }
783    }
784
785    ///Create a [`LambdaSampler`] of a given type.
786    #[must_use]
787    pub fn enumerator<'a>(
788        t: &LambdaType,
789        possible_expressions: &'a PossibleExpressions<'src, Expr<'src>>,
790    ) -> LambdaEnumerator<'a, 'src, Expr<'src>, impl Fn(&'_ Context) -> bool> {
791        let context = Context::new(0, vec![]);
792        let mut pq = BinaryHeap::default();
793        pq.push(Reverse(context));
794        let pools = vec![UnfinishedLambdaPool {
795            pool: vec![ExprOrType::Type {
796                lambda_type: t.clone(),
797                parent: None,
798                is_app_subformula: false,
799            }],
800        }];
801
802        LambdaEnumerator {
803            pools,
804            possible_expressions,
805            eager_filter: |_: &Context| true,
806            pq: NormalEnumeration(pq, VecDeque::default()),
807        }
808    }
809
810    ///Creates a reusable random sampler by enumerating over the first `max_expr` expressions
811    #[must_use]
812    pub fn sampler(
813        t: &LambdaType,
814        possible_expressions: &PossibleExpressions<'src, Expr<'src>>,
815        max_expr: usize,
816    ) -> LambdaSampler<'src, Expr<'src>> {
817        let enumerator = RootedLambdaPool::enumerator(t, possible_expressions);
818        let mut lambdas = Vec::with_capacity(max_expr);
819        let mut expr_details = Vec::with_capacity(max_expr);
820        for (lambda, expr_detail) in enumerator.take(max_expr) {
821            lambdas.push(lambda);
822            expr_details.push(expr_detail);
823        }
824
825        LambdaSampler {
826            lambdas,
827            expr_details,
828        }
829    }
830
831    ///Randomly generate a [`RootedLambdaPool`] of type `t`.
832    ///
833    ///# Panics
834    ///Will panic if no such type can be generated.
835    pub fn random_expr(
836        t: &LambdaType,
837        possible_expressions: &PossibleExpressions<'src, Expr<'src>>,
838        rng: &mut impl Rng,
839    ) -> RootedLambdaPool<'src, Expr<'src>> {
840        ProbabilisticEnumeration::new(
841            1,
842            t,
843            possible_expressions,
844            |_: &Context| true,
845            |_| true,
846            rng,
847        )
848        .next()
849        .unwrap()
850        .0
851    }
852
853    ///Randomly generate a [`RootedLambdaPool`] of type `t` without constant functions.
854    pub fn random_expr_no_constant(
855        t: &LambdaType,
856        possible_expressions: &PossibleExpressions<'src, Expr<'src>>,
857        rng: &mut impl Rng,
858    ) -> Option<RootedLambdaPool<'src, Expr<'src>>> {
859        ProbabilisticEnumeration::new(
860            1,
861            t,
862            possible_expressions,
863            |_: &Context| true,
864            |e| !e.constant_function,
865            rng,
866        )
867        .next()
868        .map(|x| x.0)
869    }
870}
871
872impl<'src> RootedLambdaPool<'src, Expr<'src>> {
873    ///Remove quantifiers which do not use their variable in their body.
874    pub fn prune_quantifiers(&mut self) {
875        let quantifiers = self
876            .pool
877            .bfs_from(self.root)
878            .filter_map(|(i, _)| match self.get(i) {
879                LambdaExpr::LanguageOfThoughtExpr(Expr::Quantifier { subformula, .. }) => {
880                    if self
881                        .pool
882                        .bfs_from(LambdaExprRef(subformula.0))
883                        .any(|(x, d)| {
884                            if let LambdaExpr::BoundVariable(v, _) = self.get(x) {
885                                *v == d
886                            } else {
887                                false
888                            }
889                        })
890                    {
891                        None
892                    } else {
893                        Some((i, LambdaExprRef(subformula.0)))
894                    }
895                }
896                _ => None,
897            })
898            .collect::<Vec<_>>();
899
900        //By reversing, we ensure that we fix inner quantifiers before outer ones.
901        for (quantifier, subformula) in quantifiers.into_iter().rev() {
902            *self.pool.get_mut(quantifier) = self.pool.get(subformula).clone();
903            self.pool.bfs_from_mut(quantifier).for_each(|(x, d, _)| {
904                if let LambdaExpr::BoundVariable(b_d, _) = x
905                    && *b_d > d
906                {
907                    *b_d -= 1;
908                }
909            });
910        }
911        self.root = self.pool.cleanup(self.root);
912    }
913
914    ///Replace a random expression with something else of the same type.
915    ///
916    ///# Errors
917    ///Will return a [`TypeError`] if there is no compatible type to replace
918    ///
919    ///# Panics
920    ///Will panic if the size of the tree is greater than [`u32::MAX`].
921    pub fn swap_expr(
922        &mut self,
923        possible_expressions: &PossibleExpressions<'src, Expr<'src>>,
924        rng: &mut impl Rng,
925    ) -> Result<(), TypeError> {
926        let position = LambdaExprRef(u32::try_from((0..self.len()).choose(rng).unwrap()).unwrap());
927
928        let (context, _) = Context::from_pos(self, position);
929
930        let output = self.pool.get_type(position)?;
931        let expr = self.pool.get(position);
932
933        let arguments: Vec<_> = expr
934            .get_children()
935            .map(|x| self.pool.get_type(x).unwrap())
936            .collect();
937
938        let mut replacements = possible_expressions.possiblities_fixed_children(
939            &output,
940            &arguments,
941            expr.var_type(),
942            &context,
943        );
944
945        //hacky fix!!
946        if replacements.is_empty()
947            && let LambdaExpr::LanguageOfThoughtExpr(Expr::Quantifier {
948                var_type,
949                restrictor,
950                subformula,
951                ..
952            }) = expr
953        {
954            for quantifier in [Quantifier::Universal, Quantifier::Existential] {
955                replacements.push(std::borrow::Cow::Owned(LambdaExpr::LanguageOfThoughtExpr(
956                    Expr::Quantifier {
957                        quantifier,
958                        var_type: *var_type,
959                        restrictor: *restrictor,
960                        subformula: *subformula,
961                    },
962                )));
963            }
964        }
965        for c in [Constant::EveryEvent, Constant::Everyone] {
966            if replacements.is_empty()
967                && expr == &LambdaExpr::LanguageOfThoughtExpr(Expr::Constant(c))
968            {
969                replacements = possible_expressions.possiblities_fixed_children(
970                    LambdaType::t(),
971                    &arguments,
972                    expr.var_type(),
973                    &context,
974                );
975                replacements.push(std::borrow::Cow::Owned(LambdaExpr::LanguageOfThoughtExpr(
976                    Expr::Constant(c),
977                )));
978            }
979        }
980
981        let choice = replacements.choose(rng).unwrap_or_else(|| {
982            panic!(
983                "There is no node with output {output} and arguments {}",
984                arguments
985                    .into_iter()
986                    .map(|x| x.to_string())
987                    .collect::<Vec<_>>()
988                    .join(", ")
989            )
990        });
991
992        let mut new_expr = choice.clone().into_owned();
993        new_expr.change_children(self.pool.get(position).get_children());
994        self.pool.0[position.0 as usize] = new_expr;
995
996        Ok(())
997    }
998
999    ///Replace a random expression with something else of the same type from within the same
1000    ///expression.
1001    ///
1002    ///# Errors
1003    ///Will return [`TypeError`] if a compatible subtree can't be found.
1004    ///
1005    ///# Panics
1006    ///Will panic if the size of the tree is greater than [`u32::MAX`].
1007    pub fn swap_subtree(&mut self, rng: &mut impl Rng) -> Result<(), TypeError> {
1008        let position = LambdaExprRef(u32::try_from((0..self.len()).choose(rng).unwrap()).unwrap());
1009        let (context, _) = Context::from_pos(self, position);
1010        let alt = context.find_compatible(self, position)?;
1011        if let Some(new_pos) = alt.choose(rng).copied() {
1012            let offset = self.pool.0.len();
1013            let mut lookup = HashMap::default();
1014            let mut new_pool: Vec<LambdaExpr<'src, Expr<'src>>> = self
1015                .pool
1016                .bfs_from(new_pos)
1017                .map(|(x, _)| {
1018                    let n = lookup.len();
1019                    lookup
1020                        .entry(x)
1021                        .or_insert(LambdaExprRef(u32::try_from(n + offset).unwrap()));
1022                    self.get(x).clone()
1023                })
1024                .collect();
1025
1026            for x in &mut new_pool {
1027                let child: Vec<_> = x.get_children().map(|x| *lookup.get(&x).unwrap()).collect();
1028                x.change_children(child.into_iter());
1029            }
1030
1031            self.pool.0.extend(new_pool);
1032            self.pool.0.swap(position.0 as usize, offset);
1033            self.cleanup();
1034        }
1035        Ok(())
1036    }
1037}
1038
1039#[cfg(test)]
1040mod test {
1041
1042    use std::collections::HashSet;
1043
1044    use super::*;
1045    use crate::lambda::{LambdaPool, LambdaSummaryStats};
1046    use rand::SeedableRng;
1047    use rand_chacha::ChaCha8Rng;
1048
1049    #[test]
1050    fn prune_quantifier_test() -> anyhow::Result<()> {
1051        let mut pool =
1052            RootedLambdaPool::parse("some_e(x,all_e,AgentOf(a_2,e_1) & PatientOf(a_0,e_0))")?;
1053
1054        pool.prune_quantifiers();
1055        assert_eq!(pool.to_string(), "AgentOf(a_2, e_1) & PatientOf(a_0, e_0)");
1056
1057        let mut pool = RootedLambdaPool::parse(
1058            "some_e(x0, all_e, some(z, all_a, AgentOf(z, e_1) & PatientOf(a_0, e_0)))",
1059        )?;
1060
1061        pool.prune_quantifiers();
1062        assert_eq!(
1063            pool.to_string(),
1064            "some(x, all_a, AgentOf(x, e_1) & PatientOf(a_0, e_0))"
1065        );
1066
1067        let mut pool = RootedLambdaPool::parse("~every_e(z, pe_1, pa_2(a_0))")?;
1068
1069        pool.prune_quantifiers();
1070
1071        assert_eq!(pool.to_string(), "~pa_2(a_0)");
1072        let mut pool = RootedLambdaPool::new(
1073            LambdaPool(vec![
1074                LambdaExpr::Lambda(LambdaExprRef(1), LambdaType::E),
1075                LambdaExpr::Lambda(LambdaExprRef(2), LambdaType::T),
1076                LambdaExpr::LanguageOfThoughtExpr(Expr::Quantifier {
1077                    quantifier: Quantifier::Universal,
1078                    var_type: ActorOrEvent::Actor,
1079                    restrictor: ExprRef(3),
1080                    subformula: ExprRef(4),
1081                }),
1082                LambdaExpr::LanguageOfThoughtExpr(Expr::Constant(Constant::Property(
1083                    "1",
1084                    ActorOrEvent::Actor,
1085                ))),
1086                LambdaExpr::LanguageOfThoughtExpr(Expr::Quantifier {
1087                    quantifier: Quantifier::Existential,
1088                    var_type: ActorOrEvent::Actor,
1089                    restrictor: ExprRef(5),
1090                    subformula: ExprRef(6),
1091                }),
1092                LambdaExpr::LanguageOfThoughtExpr(Expr::Constant(Constant::Property(
1093                    "0",
1094                    ActorOrEvent::Actor,
1095                ))),
1096                LambdaExpr::LanguageOfThoughtExpr(Expr::Unary(
1097                    MonOp::Property("3", ActorOrEvent::Event),
1098                    ExprRef(7),
1099                )),
1100                LambdaExpr::BoundVariable(3, LambdaType::E),
1101            ]),
1102            LambdaExprRef(0),
1103        );
1104
1105        assert_eq!(
1106            pool.to_string(),
1107            "lambda e x lambda t phi every(y, pa_1, some(z, pa_0, pe_3(x)))"
1108        );
1109
1110        let mut parsed_pool = RootedLambdaPool::parse(
1111            "lambda e x lambda t phi every(y, pa_1, some(z, pa_0, pe_3(x)))",
1112        )?;
1113        parsed_pool.prune_quantifiers();
1114        pool.prune_quantifiers();
1115        assert_eq!(pool.to_string(), "lambda e x lambda t phi pe_3(x)");
1116
1117        Ok(())
1118    }
1119
1120    #[test]
1121    fn random_swap_tree() -> anyhow::Result<()> {
1122        let mut rng = ChaCha8Rng::seed_from_u64(2);
1123        let x = RootedLambdaPool::parse("lambda a x pa_cool(a_John) | pa_cool(x)")?;
1124        let mut h = HashSet::new();
1125        for _ in 0..100 {
1126            let mut z = x.clone();
1127            z.swap_subtree(&mut rng)?;
1128            h.insert(z.to_string());
1129        }
1130        for x in h.iter() {
1131            println!("{x}");
1132        }
1133        assert!(h.contains("lambda a x pa_cool(x)"));
1134        assert!(h.contains("lambda a x pa_cool(a_John)"));
1135        Ok(())
1136    }
1137
1138    #[test]
1139    fn randomn_swap() -> anyhow::Result<()> {
1140        let mut rng = ChaCha8Rng::seed_from_u64(2);
1141        let actors = ["0", "1"];
1142        let available_event_properties = ["2", "3", "4"];
1143        let possible_expressions = PossibleExpressions::new(
1144            &actors,
1145            &available_event_properties,
1146            &available_event_properties,
1147        );
1148        for _ in 0..200 {
1149            let t = LambdaType::random(&mut rng);
1150            println!("{t}");
1151            let mut pool = RootedLambdaPool::random_expr(&t, &possible_expressions, &mut rng);
1152            println!("{t}: {pool}");
1153            assert_eq!(t, pool.get_type()?);
1154            println!("{pool:?}");
1155            pool.swap_expr(&possible_expressions, &mut rng)?;
1156            println!("{pool:?}");
1157            println!("{t}: {pool}");
1158            assert_eq!(t, pool.get_type()?);
1159        }
1160        let p =
1161            RootedLambdaPool::parse("lambda <a,t> P some(x, P(x), some_e(y, pe_2(y), pe_4(y)))")?;
1162        let t = p.get_type()?;
1163        for _ in 0..1000 {
1164            let mut pool = p.clone();
1165            assert_eq!(t, pool.get_type()?);
1166
1167            pool.swap_expr(&possible_expressions, &mut rng)?;
1168            dbg!(&pool);
1169            println!("{t}: {pool}");
1170            assert_eq!(t, pool.get_type()?);
1171        }
1172
1173        Ok(())
1174    }
1175
1176    #[test]
1177    fn enumerate() -> anyhow::Result<()> {
1178        let actors = ["john"];
1179        let actor_properties = ["a"];
1180        let event_properties = ["e"];
1181        let possibles = PossibleExpressions::new(&actors, &actor_properties, &event_properties);
1182
1183        let t = LambdaType::from_string("<<a,t>,t>")?;
1184
1185        let p = RootedLambdaPool::enumerator(&t, &possibles);
1186        let pools: HashSet<_> = p
1187            .filter_map(|(p, x)| {
1188                if !x.has_constant_function() {
1189                    Some(p.to_string())
1190                } else {
1191                    None
1192                }
1193            })
1194            .take(20)
1195            .collect();
1196        for p in pools.iter() {
1197            println!("{p}");
1198        }
1199        assert!(pools.contains("lambda <a,t> P P(a_john)"));
1200        for p in pools {
1201            println!("{p}");
1202        }
1203
1204        Ok(())
1205    }
1206
1207    /*
1208    #[test]
1209    fn random_expr_no_constant() -> anyhow::Result<()> {
1210        let actors = &["1", "2", "3", "4", "5"];
1211        let actor_properties = &["1", "2", "3", "4", "5"];
1212        let event_properties = &["1", "2", "3", "4", "5"];
1213        let poss = PossibleExpressions::new(actors, actor_properties, event_properties);
1214        let mut rng = ChaCha8Rng::seed_from_u64(1);
1215
1216        let t = LambdaType::from_string("<<a,<e,t>>, <e,t>>")?;
1217        println!("{t}: ");
1218        let pool = RootedLambdaPool::random_expr_no_constant(&t, &poss, &mut rng).unwrap();
1219        println!("{pool}");
1220
1221        for _ in 0..500 {
1222            let t = LambdaType::random(&mut rng);
1223            println!("{t}: ");
1224            if t.size() <= 3 {
1225                let pool = RootedLambdaPool::random_expr_no_constant(&t, &poss, &mut rng).unwrap();
1226                println!("{pool}");
1227            } else {
1228                println!("too big :(");
1229            }
1230        }
1231        Ok(())
1232    }*/
1233
1234    #[test]
1235    fn enumerate_weirds() -> anyhow::Result<()> {
1236        let actors = &["1"];
1237        let actor_properties = &["1"];
1238        let event_properties = &["1"];
1239        let poss = PossibleExpressions::new(actors, actor_properties, event_properties);
1240
1241        let t = "<a,<e,e>>";
1242
1243        for (p, d) in RootedLambdaPool::enumerator(&LambdaType::from_string(t)?, &poss)
1244            .filter(|(_, e)| !e.has_constant_function())
1245            .take(5)
1246        {
1247            println!("{p} {d:?} {}", d.has_constant_function());
1248        }
1249        Ok(())
1250    }
1251
1252    #[test]
1253    fn random_expr() -> anyhow::Result<()> {
1254        let actors = ["john"];
1255        let actor_properties = ["a"];
1256        let event_properties = ["e"];
1257        let possibles = PossibleExpressions::new(&actors, &actor_properties, &event_properties);
1258        let mut rng = ChaCha8Rng::seed_from_u64(0);
1259
1260        let map = [(LambdaType::A, vec![RootedLambdaPool::parse("a_john")?])]
1261            .into_iter()
1262            .collect();
1263
1264        for _ in 0..100 {
1265            let t = LambdaType::random(&mut rng);
1266            println!("sampling: {t}");
1267            let mut pool = RootedLambdaPool::random_expr(&t, &possibles, &mut rng);
1268            assert_eq!(t, pool.get_type()?);
1269            let s = pool.to_string();
1270            println!("{s}");
1271            let mut pool2 = RootedLambdaPool::parse(s.as_str())?;
1272            assert_eq!(s, pool2.to_string());
1273            println!("{pool}");
1274            pool2.resample_from_expr(&possibles, None, &mut rng)?;
1275            assert_eq!(pool2.get_type()?, t);
1276            pool2.resample_from_expr(&possibles, Some(&map), &mut rng)?;
1277            assert_eq!(pool2.get_type()?, t);
1278
1279            pool.swap_subtree(&mut rng)?;
1280            assert_eq!(pool.get_type()?, t);
1281        }
1282        let t = LambdaType::from_string("<a,<a,t>>")?;
1283        for _ in 0..100 {
1284            let pool = RootedLambdaPool::random_expr(&t, &possibles, &mut rng);
1285            println!("{pool}");
1286        }
1287        Ok(())
1288    }
1289
1290    #[test]
1291    fn constant_exprs() -> anyhow::Result<()> {
1292        let actors = ["john", "mary", "phil", "sue"];
1293        let actor_properties = ["a"];
1294        let event_properties = ["e"];
1295        let possibles = PossibleExpressions::new(&actors, &actor_properties, &event_properties);
1296        let mut rng = ChaCha8Rng::seed_from_u64(0);
1297
1298        let v: HashSet<_> = RootedLambdaPool::enumerator(LambdaType::a(), &possibles)
1299            .map(|(x, _)| x.to_string())
1300            .take(4)
1301            .collect();
1302
1303        assert_eq!(v, HashSet::from(actors.map(|x| format!("a_{x}"))));
1304        println!("{v:?}");
1305
1306        let mut constants = 0;
1307        for _ in 0..100 {
1308            let t = LambdaType::from_string("<a, <a,t>>")?;
1309            println!("sampling: {t}");
1310            let pool = RootedLambdaPool::random_expr(&t, &possibles, &mut rng);
1311            let x = pool.stats();
1312            match x {
1313                LambdaSummaryStats::WellFormed {
1314                    constant_function, ..
1315                } => {
1316                    if constant_function {
1317                        constants += 1;
1318                    }
1319                }
1320                LambdaSummaryStats::Malformed => todo!(),
1321            }
1322        }
1323        println!("{constants}");
1324
1325        Ok(())
1326    }
1327
1328    #[test]
1329    fn random_expr_counts() -> anyhow::Result<()> {
1330        let actors = ["john", "mary", "phil", "sue"];
1331        let actor_properties = ["a"];
1332        let event_properties = ["e"];
1333        let possibles = PossibleExpressions::new(&actors, &actor_properties, &event_properties);
1334        let mut rng = ChaCha8Rng::seed_from_u64(1);
1335
1336        let mut counts: HashMap<_, usize> = HashMap::default();
1337        for _ in 0..1000 {
1338            let t = LambdaType::A;
1339            let pool = RootedLambdaPool::random_expr(&t, &possibles, &mut rng);
1340            assert_eq!(t, pool.get_type()?);
1341            let s = pool.to_string();
1342            *counts.entry(s).or_default() += 1;
1343        }
1344        assert_eq!(counts.len(), 4);
1345        dbg!(&counts);
1346        for (_, v) in counts.iter() {
1347            assert!(200 <= *v && *v <= 300);
1348        }
1349
1350        counts.clear();
1351        for _ in 0..1000 {
1352            let t = LambdaType::at().clone();
1353            let pool = RootedLambdaPool::random_expr(&t, &possibles, &mut rng);
1354            assert_eq!(t, pool.get_type()?);
1355            let s = pool.to_string();
1356            *counts.entry(s).or_default() += 1;
1357        }
1358        dbg!(counts);
1359        Ok(())
1360    }
1361}