simple_semantics/language/
mutations.rs

1use super::{
2    ActorOrEvent, BinOp, Constant, Display, Expr, ExprRef, MonOp, Quantifier, RootedLambdaPool,
3    thiserror,
4};
5use crate::lambda::{
6    LambdaError, LambdaExpr, LambdaExprRef, LambdaLanguageOfThought, LambdaPool,
7    types::{LambdaType, TypeError},
8};
9use ahash::HashMap;
10use chumsky::container::Container;
11use rand::{
12    Rng, RngExt,
13    distr::{Distribution, weighted::WeightedIndex},
14    seq::{IndexedRandom, IteratorRandom},
15};
16use std::{
17    cmp::Reverse,
18    collections::{BinaryHeap, VecDeque},
19    fmt::Debug,
20};
21use thiserror::Error;
22
23mod context;
24mod samplers;
25pub use context::Context;
26pub(crate) use samplers::PossibleExpr;
27pub use samplers::PossibleExpressions;
28
29#[derive(Debug, Error, Clone)]
30pub struct ExprOrTypeError();
31
32impl Display for ExprOrTypeError {
33    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
34        write!(f, "This ExprOrType is not an Expr!")
35    }
36}
37
38#[derive(Debug, Clone, Eq, PartialEq)]
39enum ExprOrType<'src, T: LambdaLanguageOfThought> {
40    Type {
41        lambda_type: LambdaType,
42        parent: Option<usize>,
43        is_app_subformula: bool,
44    },
45    Expr {
46        lambda_expr: LambdaExpr<'src, T>,
47        parent: Option<usize>,
48    },
49}
50
51impl<'src, T: LambdaLanguageOfThought> TryFrom<ExprOrType<'src, T>> for LambdaExpr<'src, T> {
52    type Error = ExprOrTypeError;
53
54    fn try_from(value: ExprOrType<'src, T>) -> Result<Self, Self::Error> {
55        match value {
56            ExprOrType::Type { .. } => Err(ExprOrTypeError()),
57            ExprOrType::Expr { lambda_expr, .. } => Ok(lambda_expr),
58        }
59    }
60}
61
62impl<T: LambdaLanguageOfThought> ExprOrType<'_, T> {
63    fn parent(&self) -> Option<usize> {
64        match self {
65            ExprOrType::Type { parent, .. } | ExprOrType::Expr { parent, .. } => *parent,
66        }
67    }
68
69    fn is_type(&self) -> bool {
70        matches!(self, ExprOrType::Type { .. })
71    }
72}
73
74#[derive(Debug, Clone, Eq, PartialEq)]
75struct UnfinishedLambdaPool<'src, T: LambdaLanguageOfThought> {
76    pool: Vec<ExprOrType<'src, T>>,
77}
78
79impl<T: LambdaLanguageOfThought> Default for UnfinishedLambdaPool<'_, T> {
80    fn default() -> Self {
81        Self { pool: vec![] }
82    }
83}
84
85impl<'src, T: LambdaLanguageOfThought + Clone> UnfinishedLambdaPool<'src, T> {
86    fn add_expr<'a>(&mut self, expr: PossibleExpr<'a, 'src, T>, c: &mut Context, t: &LambdaType) {
87        let (mut lambda_expr, app_details) = expr.into_expr();
88        c.depth += 1;
89        c.open_nodes += lambda_expr.n_children();
90        c.open_nodes -= 1;
91        let parent = self.pool[c.position].parent();
92        match &mut lambda_expr {
93            LambdaExpr::Lambda(body, arg) => {
94                c.add_lambda(arg);
95                *body = LambdaExprRef(self.pool.len() as u32);
96                self.pool.push(ExprOrType::Type {
97                    lambda_type: t.rhs().unwrap().clone(),
98                    parent: Some(c.position),
99                    is_app_subformula: false,
100                });
101            }
102            LambdaExpr::BoundVariable(b, _) => {
103                c.use_bvar(*b);
104            }
105            LambdaExpr::FreeVariable(..) => (),
106            LambdaExpr::Application {
107                subformula,
108                argument,
109            } => {
110                *subformula = LambdaExprRef(self.pool.len() as u32);
111                *argument = LambdaExprRef((self.pool.len() + 1) as u32);
112                let (subformula, argument) = app_details.unwrap();
113                self.pool.push(ExprOrType::Type {
114                    lambda_type: subformula,
115                    parent: Some(c.position),
116                    is_app_subformula: true,
117                });
118                self.pool.push(ExprOrType::Type {
119                    lambda_type: argument,
120                    parent: Some(c.position),
121                    is_app_subformula: false,
122                });
123            }
124            LambdaExpr::LanguageOfThoughtExpr(e) => {
125                let children_start = self.pool.len();
126                if let Some(t) = e.var_type() {
127                    c.add_lambda(t);
128                }
129                self.pool
130                    .extend(e.get_arguments().map(|lambda_type| ExprOrType::Type {
131                        lambda_type,
132                        parent: Some(c.position),
133                        is_app_subformula: false,
134                    }));
135                e.change_children(
136                    (children_start..self.pool.len()).map(|x| LambdaExprRef(x as u32)),
137                );
138            }
139        }
140        self.pool[c.position] = ExprOrType::Expr {
141            lambda_expr,
142            parent,
143        };
144    }
145}
146
147#[derive(Debug)]
148pub struct NormalEnumeration(BinaryHeap<Reverse<Context>>, VecDeque<ExprDetails>);
149
150impl EnumerationType for NormalEnumeration {
151    fn pop(&mut self) -> Option<Context> {
152        self.0.pop().map(|x| x.0)
153    }
154
155    fn push(&mut self, context: Context, _: bool) {
156        self.0.push(Reverse(context));
157    }
158
159    fn get_yield(&mut self) -> Option<ExprDetails> {
160        self.1.pop_front()
161    }
162
163    fn push_yield(&mut self, e: ExprDetails) {
164        self.1.push(e);
165    }
166
167    fn include(&mut self, n: usize) -> impl Iterator<Item = bool> + 'static {
168        std::iter::repeat_n(true, n)
169    }
170}
171
172impl ExprDetails {
173    fn score(&self) -> f64 {
174        (1.0 / (self.size as f64)) + if self.constant_function { 0.0 } else { 10.0 }
175    }
176
177    pub fn has_constant_function(&self) -> bool {
178        self.constant_function
179    }
180}
181
182#[derive(Debug, Clone, Copy, PartialEq)]
183struct KeyedExprDetails {
184    expr_details: ExprDetails,
185    k: f64,
186}
187
188impl Eq for KeyedExprDetails {}
189
190impl PartialOrd for KeyedExprDetails {
191    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
192        Some(self.cmp(other))
193    }
194}
195
196impl Ord for KeyedExprDetails {
197    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
198        //reversed since we need a min-heap not a max-heap
199        other.k.partial_cmp(&self.k).unwrap()
200    }
201}
202impl KeyedExprDetails {
203    fn new(expr_details: ExprDetails, rng: &mut impl Rng) -> Self {
204        let u: f64 = rng.random();
205        KeyedExprDetails {
206            expr_details,
207            k: u.powf(1.0 / expr_details.score()),
208        }
209    }
210}
211
212#[derive(Debug, Clone, PartialEq)]
213struct RandomPQ(Context, f64);
214
215impl Eq for RandomPQ {}
216
217impl RandomPQ {
218    fn new(c: Context, rng: &mut impl Rng) -> Self {
219        RandomPQ(c, rng.random())
220    }
221}
222
223#[derive(Debug)]
224struct ProbabilisticEnumeration<'a, R: Rng, F>
225where
226    F: Fn(&ExprDetails) -> bool,
227{
228    rng: &'a mut R,
229    reservoir_size: usize,
230    reservoir: BinaryHeap<KeyedExprDetails>,
231    backups: Vec<Context>,
232    pq: BinaryHeap<RandomPQ>,
233    filter: F,
234    n_seen: usize,
235    done: bool,
236}
237impl<R: Rng, F> ProbabilisticEnumeration<'_, R, F>
238where
239    F: Fn(&ExprDetails) -> bool,
240{
241    fn threshold(&self) -> Option<f64> {
242        self.reservoir.peek().map(|x| x.k)
243    }
244
245    fn new<'a, 'src, T: LambdaLanguageOfThought, E: Fn(&Context) -> bool>(
246        reservoir_size: usize,
247        t: &LambdaType,
248        possible_expressions: &'a PossibleExpressions<'src, T>,
249        eager_filter: E,
250        filter: F,
251        rng: &'a mut R,
252    ) -> LambdaEnumerator<'a, 'src, T, E, ProbabilisticEnumeration<'a, R, F>> {
253        let context = Context::new(0, vec![]);
254        let mut pq = BinaryHeap::default();
255        pq.push(RandomPQ::new(context, rng));
256        let pools = vec![UnfinishedLambdaPool {
257            pool: vec![ExprOrType::Type {
258                lambda_type: t.clone(),
259                parent: None,
260                is_app_subformula: false,
261            }],
262        }];
263
264        LambdaEnumerator {
265            pools,
266            possible_expressions,
267            eager_filter,
268            pq: ProbabilisticEnumeration {
269                rng,
270                reservoir_size,
271                reservoir: BinaryHeap::default(),
272                backups: vec![],
273                filter,
274                pq,
275                n_seen: 0,
276                done: false,
277            },
278        }
279    }
280}
281
282impl<R: Rng, F> EnumerationType for ProbabilisticEnumeration<'_, R, F>
283where
284    F: Fn(&ExprDetails) -> bool,
285{
286    fn pop(&mut self) -> Option<Context> {
287        //Pop from min-heap, or grab a random back up if the min-heap is exhausted
288        self.pq.pop().map(|x| x.0).or_else(|| {
289            (0..self.backups.len()).choose(self.rng).and_then(|index| {
290                let last_item = self.backups.len() - 1;
291                self.backups.swap(index, last_item);
292                self.backups.pop()
293            })
294        })
295    }
296
297    fn push(&mut self, context: Context, included: bool) {
298        if included {
299            self.pq.push(RandomPQ::new(context, &mut self.rng));
300        } else {
301            self.backups.push(context);
302        }
303    }
304
305    fn get_yield(&mut self) -> Option<ExprDetails> {
306        if (self.done || self.pq.is_empty())
307            && let Some(x) = self.reservoir.pop()
308        {
309            Some(x.expr_details)
310        } else {
311            None
312        }
313    }
314
315    fn push_yield(&mut self, e: ExprDetails) {
316        let e = KeyedExprDetails::new(e, &mut self.rng);
317        if (self.filter)(&e.expr_details) {
318            self.n_seen += 1;
319            if self.reservoir_size > self.reservoir.len() {
320                self.reservoir.push(e);
321            } else if let Some(t) = self.threshold()
322                && e.k > t
323            {
324                self.reservoir.pop();
325                self.reservoir.push(e);
326            }
327            if self.n_seen >= self.reservoir_size * 20 {
328                self.pq.clear();
329                self.done = true;
330            }
331        }
332    }
333
334    fn include(&mut self, n: usize) -> impl Iterator<Item = bool> + 'static {
335        let x = (0..n).sample(self.rng, (n / 2).max(1));
336        let mut v = vec![false; n];
337        for i in x {
338            v[i] = true;
339        }
340        v.into_iter()
341    }
342}
343
344#[derive(Debug)]
345///An iterator that enumerates over all possible expressions of a given type.
346pub struct LambdaEnumerator<'a, 'src, T: LambdaLanguageOfThought, F, E = NormalEnumeration> {
347    pools: Vec<UnfinishedLambdaPool<'src, T>>,
348    possible_expressions: &'a PossibleExpressions<'src, T>,
349    eager_filter: F,
350    pq: E,
351}
352
353///Provides detail about a generated lambda expression
354#[derive(Debug, Clone, Copy, Eq, PartialEq)]
355pub struct ExprDetails {
356    id: usize,
357    constant_function: bool,
358    root: LambdaExprRef,
359    size: usize,
360}
361
362impl ExprDetails {
363    ///Get the size of the associated [`RootedLambdaPool`].
364    pub fn size(&self) -> usize {
365        self.size
366    }
367}
368
369#[derive(Debug, Clone, Eq, PartialEq)]
370///A re-usable sampler for sampling expressions of arbitrary types while caching frequent types
371pub struct TypeAgnosticSampler<'src, T: LambdaLanguageOfThought> {
372    type_to_sampler: HashMap<LambdaType, (usize, LambdaSampler<'src, T>)>,
373    max_expr: usize,
374    max_types: usize,
375    possible_expressions: PossibleExpressions<'src, T>,
376}
377
378impl<'src> TypeAgnosticSampler<'src, Expr<'src>> {
379    ///Samples an expression of a given type
380    pub fn sample(
381        &mut self,
382        lambda_type: LambdaType,
383        rng: &mut impl Rng,
384    ) -> RootedLambdaPool<'src, Expr<'src>> {
385        let (counts, exprs) = self
386            .type_to_sampler
387            .entry(lambda_type)
388            .or_insert_with_key(|t| {
389                (
390                    1,
391                    RootedLambdaPool::sampler(t, &self.possible_expressions, self.max_expr),
392                )
393            });
394        *counts += 1;
395        let sample = exprs.sample(rng);
396
397        if self.type_to_sampler.len() > self.max_types {
398            let (_, k) = self
399                .type_to_sampler
400                .iter()
401                .map(|(k, (n_visits, _))| (n_visits, k))
402                .min_by_key(|x| x.0)
403                .unwrap();
404
405            let t = k.clone();
406            self.type_to_sampler.remove(&t);
407        }
408
409        sample
410    }
411
412    ///Get a reference to the [`PossibleExpressions`] used by the model
413    #[must_use]
414    pub fn possibles(&self) -> &PossibleExpressions<'src, Expr<'src>> {
415        &self.possible_expressions
416    }
417}
418
419impl<'src, T: LambdaLanguageOfThought + Clone> RootedLambdaPool<'src, T> {
420    ///Create a sampler which can sample arbitrary types.
421    #[must_use]
422    pub fn typeless_sampler(
423        possible_expressions: PossibleExpressions<'src, T>,
424        max_expr: usize,
425        max_types: usize,
426    ) -> TypeAgnosticSampler<'src, T> {
427        assert!(max_types >= 1);
428        assert!(max_expr >= 1);
429        TypeAgnosticSampler {
430            possible_expressions,
431            max_expr,
432            max_types,
433            type_to_sampler: HashMap::default(),
434        }
435    }
436}
437
438///A struct which samples expressions from a distribution.
439#[derive(Debug, Clone, Eq, PartialEq)]
440pub struct LambdaSampler<'src, T: LambdaLanguageOfThought> {
441    lambdas: Vec<RootedLambdaPool<'src, T>>,
442    expr_details: Vec<ExprDetails>,
443}
444
445impl<'src, T: LambdaLanguageOfThought + Clone> Distribution<RootedLambdaPool<'src, T>>
446    for LambdaSampler<'src, T>
447{
448    fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> RootedLambdaPool<'src, T> {
449        let w = WeightedIndex::new(self.expr_details.iter().map(ExprDetails::score)).unwrap();
450        let i = w.sample(rng);
451        self.lambdas
452            .get(i)
453            .expect("The Lambda Sampler has no lambdas to sample :(")
454            .clone()
455    }
456}
457
458///A trait that handles how enumeration is processed (either normal enumeration or by doing
459///reservoir sampling).
460pub trait EnumerationType {
461    fn pop(&mut self) -> Option<Context>;
462    fn push(&mut self, context: Context, included: bool);
463    fn get_yield(&mut self) -> Option<ExprDetails>;
464    fn push_yield(&mut self, e: ExprDetails);
465    fn include(&mut self, n: usize) -> impl Iterator<Item = bool> + 'static;
466}
467
468fn try_yield<'src, T, F, E>(
469    x: &mut LambdaEnumerator<'_, 'src, T, F, E>,
470) -> Option<(RootedLambdaPool<'src, T>, ExprDetails)>
471where
472    T: LambdaLanguageOfThought,
473    E: EnumerationType,
474{
475    if let Some(item) = x.pq.get_yield() {
476        let p = std::mem::take(&mut x.pools[item.id]);
477        return Some((
478            RootedLambdaPool {
479                pool: LambdaPool(
480                    p.pool
481                        .into_iter()
482                        .map(|x| LambdaExpr::try_from(x).unwrap())
483                        .collect(),
484                ),
485                root: item.root,
486            },
487            item,
488        ));
489    }
490    None
491}
492
493impl<'a, 'src, T, F, E> LambdaEnumerator<'a, 'src, T, F, E>
494where
495    T: LambdaLanguageOfThought + Clone + Debug,
496    F: Fn(&Context) -> bool,
497    E: EnumerationType,
498{
499    fn push(&mut self, c: Context, included: bool) {
500        if (self.eager_filter)(&c) {
501            self.pq.push(c, included);
502        } else {
503            self.pools[c.pool_index] = UnfinishedLambdaPool::default();
504        }
505    }
506
507    ///Change the `eager_filter` function for this enumerator
508    pub fn eager_filter<F2>(self, eager_filter: F2) -> LambdaEnumerator<'a, 'src, T, F2, E> {
509        let LambdaEnumerator {
510            pools,
511            possible_expressions,
512            eager_filter: _,
513            pq,
514        } = self;
515
516        LambdaEnumerator {
517            pools,
518            possible_expressions,
519            eager_filter,
520            pq,
521        }
522    }
523}
524
525impl<'src, F, E> Iterator for LambdaEnumerator<'_, 'src, Expr<'src>, F, E>
526where
527    F: Fn(&Context) -> bool,
528    E: EnumerationType,
529{
530    type Item = (RootedLambdaPool<'src, Expr<'src>>, ExprDetails);
531
532    fn next(&mut self) -> Option<Self::Item> {
533        if let Some(x) = try_yield(self) {
534            return Some(x);
535        }
536
537        while let Some(mut c) = self.pq.pop() {
538            if let Some(x) = try_yield(self) {
539                self.push(c, true);
540                return Some(x);
541            }
542            let (possibles, lambda_type) = match &self.pools[c.pool_index].pool[c.position] {
543                ExprOrType::Type {
544                    lambda_type,
545                    is_app_subformula,
546                    parent,
547                } => {
548                    let mut possibles = self.possible_expressions.possibilities(
549                        lambda_type,
550                        *is_app_subformula,
551                        &c,
552                    );
553
554                    //Super hacky way to introduce all_e and all_a in quantifiers even though the
555                    //types are messed up.
556                    if let Some(p) = parent
557                        && let ExprOrType::Expr {
558                            lambda_expr:
559                                LambdaExpr::LanguageOfThoughtExpr(Expr::Quantifier {
560                                    var_type,
561                                    restrictor,
562                                    ..
563                                }),
564                            ..
565                        } = self.pools[c.pool_index].pool[*p]
566                        && restrictor.0 == c.position as u32
567                    {
568                        possibles.push(PossibleExpr::new_borrowed(match var_type {
569                            ActorOrEvent::Actor => &LambdaExpr::LanguageOfThoughtExpr(
570                                Expr::Constant(Constant::Everyone),
571                            ),
572                            ActorOrEvent::Event => &LambdaExpr::LanguageOfThoughtExpr(
573                                Expr::Constant(Constant::EveryEvent),
574                            ),
575                        }));
576                    }
577
578                    (possibles, lambda_type.clone())
579                }
580                ExprOrType::Expr {
581                    lambda_expr,
582                    parent,
583                } => {
584                    //We add the next uninitialized child to the context or go to the parent if
585                    //there are none.
586
587                    if let Some(child) = lambda_expr
588                        .get_children()
589                        .map(|x| x.0 as usize)
590                        .find(|child| self.pools[c.pool_index].pool[*child].is_type())
591                    {
592                        c.position = child;
593                        self.pq.push(c, true);
594                        continue;
595                    }
596
597                    if lambda_expr.inc_depth() {
598                        c.pop_lambda();
599                    }
600
601                    if let Some(p) = parent {
602                        c.position = *p;
603                        self.pq.push(c, true);
604                        continue;
605                    }
606                    //If the parent is None, we're done!
607                    self.pq.push_yield(ExprDetails {
608                        id: c.pool_index,
609                        root: LambdaExprRef(c.position as u32),
610                        size: c.depth,
611                        constant_function: c.is_constant(),
612                    });
613                    continue;
614                }
615            };
616
617            let n = possibles.len();
618            let included = self.pq.include(n);
619            if n == 0 {
620                // dbg!(&self.pools[c.pool_index]);
621                //   panic!("There is no possible expression of type {lambda_type}");
622                //This is a dead-end.
623                continue;
624            }
625            let n_pools = self.pools.len();
626            if n_pools.is_multiple_of(10_000) {
627                self.pools.shrink_to_fit();
628            }
629
630            for _ in 0..n.saturating_sub(1) {
631                self.pools.push(self.pools[c.pool_index].clone());
632            }
633
634            let positions =
635                std::iter::once(c.pool_index).chain(n_pools..n_pools + n.saturating_sub(1));
636
637            for (((expr, pool_id), mut c), included) in possibles
638                .into_iter()
639                .zip(positions)
640                .zip(std::iter::repeat_n(c, n))
641                .zip(included)
642            {
643                c.pool_index = pool_id;
644                let pool = self.pools.get_mut(pool_id).unwrap();
645                pool.add_expr(expr, &mut c, &lambda_type);
646                self.push(c, included);
647            }
648
649            if let Some(x) = try_yield(self) {
650                return Some(x);
651            }
652        }
653
654        //If we've somehow exhausted the pq, lets yield anything remaining that's done.
655        if let Some(x) = try_yield(self) {
656            return Some(x);
657        }
658        None
659    }
660}
661
662impl<'src> RootedLambdaPool<'src, Expr<'src>> {
663    ///Create a [`LambdaSampler`] of a given type.
664    pub fn resample_from_expr<'a>(
665        &mut self,
666        possible_expressions: &'a PossibleExpressions<'src, Expr<'src>>,
667        helpers: Option<&HashMap<LambdaType, Vec<RootedLambdaPool<'src, Expr<'src>>>>>,
668        rng: &mut impl Rng,
669    ) -> Result<(), LambdaError> {
670        let position = LambdaExprRef((0..self.len()).choose(rng).unwrap() as u32);
671        let t = self.pool.get_type(position)?;
672
673        let pool = if let Some(helpers) = helpers
674            && rng.random_bool(0.2)
675            && let Some(v) = helpers.get(&t)
676            && !v.is_empty()
677        {
678            let pool = v.choose(rng).unwrap();
679            pool.clone()
680        } else {
681            let (pool, _) = self
682                .probabilistic_enumerate_from_expr(
683                    position,
684                    possible_expressions,
685                    |_| true,
686                    |_| true,
687                    rng,
688                )?
689                .next()
690                .unwrap();
691            pool
692        };
693
694        let offset = self.len() as u32;
695        let new_root = pool.root.0 + offset;
696        self.pool.0.extend(pool.pool.0.into_iter().map(|mut x| {
697            let children: Vec<_> = x
698                .get_children()
699                .map(|x| LambdaExprRef(x.0 + offset))
700                .collect();
701            x.change_children(children.into_iter());
702            x
703        }));
704        self.pool.0.swap(position.0 as usize, new_root as usize);
705        self.cleanup();
706        Ok(())
707    }
708
709    fn probabilistic_enumerate_from_expr<'a, R, E, F>(
710        &self,
711        position: LambdaExprRef,
712        possible_expressions: &'a PossibleExpressions<'src, Expr<'src>>,
713        eager_filter: E,
714        filter: F,
715        rng: &'a mut R,
716    ) -> Result<
717        LambdaEnumerator<'a, 'src, Expr<'src>, E, ProbabilisticEnumeration<'a, R, F>>,
718        TypeError,
719    >
720    where
721        R: Rng,
722        F: Fn(&ExprDetails) -> bool,
723        E: Fn(&Context) -> bool,
724    {
725        let (context, is_app_subformula) = Context::from_pos(self, position);
726        let output = self.pool.get_type(position)?;
727        let mut pq = BinaryHeap::default();
728        pq.push(RandomPQ::new(context, rng));
729        let pools = vec![UnfinishedLambdaPool {
730            pool: vec![ExprOrType::Type {
731                lambda_type: output,
732                parent: None,
733                is_app_subformula,
734            }],
735        }];
736        let enumerator = LambdaEnumerator {
737            pools,
738            possible_expressions,
739            eager_filter,
740            pq: ProbabilisticEnumeration {
741                rng,
742                reservoir_size: 1,
743                reservoir: BinaryHeap::default(),
744                done: false,
745                n_seen: 0,
746                filter,
747                backups: vec![],
748                pq,
749            },
750        };
751
752        Ok(enumerator)
753    }
754
755    ///Create a [`LambdaSampler`] of a given type with a filter
756    pub fn enumerator_filter<'a, F: Fn(&Context) -> bool>(
757        t: &LambdaType,
758        filter: F,
759        possible_expressions: &'a PossibleExpressions<'src, Expr<'src>>,
760    ) -> LambdaEnumerator<'a, 'src, Expr<'src>, F> {
761        let context = Context::new(0, vec![]);
762        let mut pq = BinaryHeap::default();
763        pq.push(Reverse(context));
764        let pools = vec![UnfinishedLambdaPool {
765            pool: vec![ExprOrType::Type {
766                lambda_type: t.clone(),
767                parent: None,
768                is_app_subformula: false,
769            }],
770        }];
771
772        LambdaEnumerator {
773            pools,
774            possible_expressions,
775            eager_filter: filter,
776            pq: NormalEnumeration(pq, VecDeque::default()),
777        }
778    }
779
780    ///Create a [`LambdaSampler`] of a given type.
781    #[must_use]
782    pub fn enumerator<'a>(
783        t: &LambdaType,
784        possible_expressions: &'a PossibleExpressions<'src, Expr<'src>>,
785    ) -> LambdaEnumerator<'a, 'src, Expr<'src>, impl Fn(&'_ Context) -> bool> {
786        let context = Context::new(0, vec![]);
787        let mut pq = BinaryHeap::default();
788        pq.push(Reverse(context));
789        let pools = vec![UnfinishedLambdaPool {
790            pool: vec![ExprOrType::Type {
791                lambda_type: t.clone(),
792                parent: None,
793                is_app_subformula: false,
794            }],
795        }];
796
797        LambdaEnumerator {
798            pools,
799            possible_expressions,
800            eager_filter: |_: &Context| true,
801            pq: NormalEnumeration(pq, VecDeque::default()),
802        }
803    }
804
805    ///Creates a reusable random sampler by enumerating over the first `max_expr` expressions
806    #[must_use]
807    pub fn sampler(
808        t: &LambdaType,
809        possible_expressions: &PossibleExpressions<'src, Expr<'src>>,
810        max_expr: usize,
811    ) -> LambdaSampler<'src, Expr<'src>> {
812        let enumerator = RootedLambdaPool::enumerator(t, possible_expressions);
813        let mut lambdas = Vec::with_capacity(max_expr);
814        let mut expr_details = Vec::with_capacity(max_expr);
815        for (lambda, expr_detail) in enumerator.take(max_expr) {
816            lambdas.push(lambda);
817            expr_details.push(expr_detail);
818        }
819
820        LambdaSampler {
821            lambdas,
822            expr_details,
823        }
824    }
825
826    ///Randomly generate a [`RootedLambdaPool`] of type `t`.
827    pub fn random_expr(
828        t: &LambdaType,
829        possible_expressions: &PossibleExpressions<'src, Expr<'src>>,
830        rng: &mut impl Rng,
831    ) -> RootedLambdaPool<'src, Expr<'src>> {
832        ProbabilisticEnumeration::new(
833            1,
834            t,
835            possible_expressions,
836            |_: &Context| true,
837            |_| true,
838            rng,
839        )
840        .next()
841        .unwrap()
842        .0
843    }
844
845    ///Randomly generate a [`RootedLambdaPool`] of type `t` without constant functions.
846    pub fn random_expr_no_constant(
847        t: &LambdaType,
848        possible_expressions: &PossibleExpressions<'src, Expr<'src>>,
849        rng: &mut impl Rng,
850    ) -> Option<RootedLambdaPool<'src, Expr<'src>>> {
851        ProbabilisticEnumeration::new(
852            1,
853            t,
854            possible_expressions,
855            |_: &Context| true,
856            |e| !e.constant_function,
857            rng,
858        )
859        .next()
860        .map(|x| x.0)
861    }
862}
863
864impl<'src> RootedLambdaPool<'src, Expr<'src>> {
865    ///Remove quantifiers which do not use their variable in their body.
866    pub fn prune_quantifiers(&mut self) {
867        let quantifiers = self
868            .pool
869            .bfs_from(self.root)
870            .filter_map(|(i, _)| match self.get(i) {
871                LambdaExpr::LanguageOfThoughtExpr(Expr::Quantifier { subformula, .. }) => {
872                    if self
873                        .pool
874                        .bfs_from(LambdaExprRef(subformula.0))
875                        .any(|(x, d)| {
876                            if let LambdaExpr::BoundVariable(v, _) = self.get(x) {
877                                *v == d
878                            } else {
879                                false
880                            }
881                        })
882                    {
883                        None
884                    } else {
885                        Some((i, LambdaExprRef(subformula.0)))
886                    }
887                }
888                _ => None,
889            })
890            .collect::<Vec<_>>();
891
892        //By reversing, we ensure that we fix inner quantifiers before outer ones.
893        for (quantifier, subformula) in quantifiers.into_iter().rev() {
894            *self.pool.get_mut(quantifier) = self.pool.get(subformula).clone();
895            self.pool.bfs_from_mut(quantifier).for_each(|(x, d, _)| {
896                if let LambdaExpr::BoundVariable(b_d, _) = x
897                    && *b_d > d
898                {
899                    *b_d -= 1;
900                }
901            });
902        }
903        self.root = self.pool.cleanup(self.root);
904    }
905
906    ///Replace a random expression with something else of the same type.
907    pub fn swap_expr(
908        &mut self,
909        possible_expressions: &PossibleExpressions<'src, Expr<'src>>,
910        rng: &mut impl Rng,
911    ) -> Result<(), TypeError> {
912        let position = LambdaExprRef((0..self.len()).choose(rng).unwrap() as u32);
913
914        let (context, _) = Context::from_pos(self, position);
915
916        let output = self.pool.get_type(position)?;
917        let expr = self.pool.get(position);
918
919        let arguments: Vec<_> = expr
920            .get_children()
921            .map(|x| self.pool.get_type(x).unwrap())
922            .collect();
923
924        let mut replacements = possible_expressions.possiblities_fixed_children(
925            &output,
926            &arguments,
927            expr.var_type(),
928            &context,
929        );
930
931        //hacky fix!!
932        if replacements.is_empty()
933            && let LambdaExpr::LanguageOfThoughtExpr(Expr::Quantifier {
934                var_type,
935                restrictor,
936                subformula,
937                ..
938            }) = expr
939        {
940            for quantifier in [Quantifier::Universal, Quantifier::Existential] {
941                replacements.push(std::borrow::Cow::Owned(LambdaExpr::LanguageOfThoughtExpr(
942                    Expr::Quantifier {
943                        quantifier,
944                        var_type: *var_type,
945                        restrictor: *restrictor,
946                        subformula: *subformula,
947                    },
948                )));
949            }
950        }
951        for c in [Constant::EveryEvent, Constant::Everyone] {
952            if replacements.is_empty()
953                && expr == &LambdaExpr::LanguageOfThoughtExpr(Expr::Constant(c))
954            {
955                replacements = possible_expressions.possiblities_fixed_children(
956                    LambdaType::t(),
957                    &arguments,
958                    expr.var_type(),
959                    &context,
960                );
961                replacements.push(std::borrow::Cow::Owned(LambdaExpr::LanguageOfThoughtExpr(
962                    Expr::Constant(c),
963                )));
964            }
965        }
966
967        let choice = replacements.choose(rng).unwrap_or_else(|| {
968            panic!(
969                "There is no node with output {output} and arguments {}",
970                arguments
971                    .into_iter()
972                    .map(|x| x.to_string())
973                    .collect::<Vec<_>>()
974                    .join(", ")
975            )
976        });
977
978        let mut new_expr = choice.clone().into_owned();
979        new_expr.change_children(self.pool.get(position).get_children());
980        self.pool.0[position.0 as usize] = new_expr;
981
982        Ok(())
983    }
984
985    ///Replace a random expression with something else of the same type from within the same
986    ///expression.
987    pub fn swap_subtree(&mut self, rng: &mut impl Rng) -> Result<(), TypeError> {
988        let position = LambdaExprRef((0..self.len()).choose(rng).unwrap() as u32);
989        let (context, _) = Context::from_pos(self, position);
990        let alt = context.find_compatible(self, position)?;
991        if let Some(new_pos) = alt.choose(rng).copied() {
992            let offset = self.pool.0.len();
993            let mut lookup = HashMap::default();
994            let mut new_pool: Vec<LambdaExpr<'src, Expr<'src>>> = self
995                .pool
996                .bfs_from(new_pos)
997                .map(|(x, _)| {
998                    let n = lookup.len();
999                    lookup
1000                        .entry(x)
1001                        .or_insert(LambdaExprRef((n + offset) as u32));
1002                    self.get(x).clone()
1003                })
1004                .collect();
1005
1006            for x in &mut new_pool {
1007                let child: Vec<_> = x.get_children().map(|x| *lookup.get(&x).unwrap()).collect();
1008                x.change_children(child.into_iter());
1009            }
1010
1011            self.pool.0.extend(new_pool);
1012            self.pool.0.swap(position.0 as usize, offset);
1013            self.cleanup();
1014        }
1015        Ok(())
1016    }
1017}
1018
1019#[cfg(test)]
1020mod test {
1021
1022    use std::collections::HashSet;
1023
1024    use super::*;
1025    use crate::lambda::{LambdaPool, LambdaSummaryStats};
1026    use rand::SeedableRng;
1027    use rand_chacha::ChaCha8Rng;
1028
1029    #[test]
1030    fn prune_quantifier_test() -> anyhow::Result<()> {
1031        let mut pool =
1032            RootedLambdaPool::parse("some_e(x,all_e,AgentOf(a_2,e_1) & PatientOf(a_0,e_0))")?;
1033
1034        pool.prune_quantifiers();
1035        assert_eq!(pool.to_string(), "AgentOf(a_2, e_1) & PatientOf(a_0, e_0)");
1036
1037        let mut pool = RootedLambdaPool::parse(
1038            "some_e(x0, all_e, some(z, all_a, AgentOf(z, e_1) & PatientOf(a_0, e_0)))",
1039        )?;
1040
1041        pool.prune_quantifiers();
1042        assert_eq!(
1043            pool.to_string(),
1044            "some(x, all_a, AgentOf(x, e_1) & PatientOf(a_0, e_0))"
1045        );
1046
1047        let mut pool = RootedLambdaPool::parse("~every_e(z, pe_1, pa_2(a_0))")?;
1048
1049        pool.prune_quantifiers();
1050
1051        assert_eq!(pool.to_string(), "~pa_2(a_0)");
1052        let mut pool = RootedLambdaPool::new(
1053            LambdaPool(vec![
1054                LambdaExpr::Lambda(LambdaExprRef(1), LambdaType::E),
1055                LambdaExpr::Lambda(LambdaExprRef(2), LambdaType::T),
1056                LambdaExpr::LanguageOfThoughtExpr(Expr::Quantifier {
1057                    quantifier: Quantifier::Universal,
1058                    var_type: ActorOrEvent::Actor,
1059                    restrictor: ExprRef(3),
1060                    subformula: ExprRef(4),
1061                }),
1062                LambdaExpr::LanguageOfThoughtExpr(Expr::Constant(Constant::Property(
1063                    "1",
1064                    ActorOrEvent::Actor,
1065                ))),
1066                LambdaExpr::LanguageOfThoughtExpr(Expr::Quantifier {
1067                    quantifier: Quantifier::Existential,
1068                    var_type: ActorOrEvent::Actor,
1069                    restrictor: ExprRef(5),
1070                    subformula: ExprRef(6),
1071                }),
1072                LambdaExpr::LanguageOfThoughtExpr(Expr::Constant(Constant::Property(
1073                    "0",
1074                    ActorOrEvent::Actor,
1075                ))),
1076                LambdaExpr::LanguageOfThoughtExpr(Expr::Unary(
1077                    MonOp::Property("3", ActorOrEvent::Event),
1078                    ExprRef(7),
1079                )),
1080                LambdaExpr::BoundVariable(3, LambdaType::E),
1081            ]),
1082            LambdaExprRef(0),
1083        );
1084
1085        assert_eq!(
1086            pool.to_string(),
1087            "lambda e x lambda t phi every(y, pa_1, some(z, pa_0, pe_3(x)))"
1088        );
1089
1090        let mut parsed_pool = RootedLambdaPool::parse(
1091            "lambda e x lambda t phi every(y, pa_1, some(z, pa_0, pe_3(x)))",
1092        )?;
1093        parsed_pool.prune_quantifiers();
1094        pool.prune_quantifiers();
1095        assert_eq!(pool.to_string(), "lambda e x lambda t phi pe_3(x)");
1096
1097        Ok(())
1098    }
1099
1100    #[test]
1101    fn random_swap_tree() -> anyhow::Result<()> {
1102        let mut rng = ChaCha8Rng::seed_from_u64(2);
1103        let x = RootedLambdaPool::parse("lambda a x pa_cool(a_John) | pa_cool(x)")?;
1104        let mut h = HashSet::new();
1105        for _ in 0..100 {
1106            let mut z = x.clone();
1107            z.swap_subtree(&mut rng)?;
1108            h.insert(z.to_string());
1109        }
1110        for x in h.iter() {
1111            println!("{x}");
1112        }
1113        assert!(h.contains("lambda a x pa_cool(x)"));
1114        assert!(h.contains("lambda a x pa_cool(a_John)"));
1115        Ok(())
1116    }
1117
1118    #[test]
1119    fn randomn_swap() -> anyhow::Result<()> {
1120        let mut rng = ChaCha8Rng::seed_from_u64(2);
1121        let actors = ["0", "1"];
1122        let available_event_properties = ["2", "3", "4"];
1123        let possible_expressions = PossibleExpressions::new(
1124            &actors,
1125            &available_event_properties,
1126            &available_event_properties,
1127        );
1128        for _ in 0..200 {
1129            let t = LambdaType::random(&mut rng);
1130            println!("{t}");
1131            let mut pool = RootedLambdaPool::random_expr(&t, &possible_expressions, &mut rng);
1132            println!("{t}: {pool}");
1133            assert_eq!(t, pool.get_type()?);
1134            println!("{pool:?}");
1135            pool.swap_expr(&possible_expressions, &mut rng)?;
1136            println!("{pool:?}");
1137            println!("{t}: {pool}");
1138            assert_eq!(t, pool.get_type()?);
1139        }
1140        let p =
1141            RootedLambdaPool::parse("lambda <a,t> P some(x, P(x), some_e(y, pe_2(y), pe_4(y)))")?;
1142        let t = p.get_type()?;
1143        for _ in 0..1000 {
1144            let mut pool = p.clone();
1145            assert_eq!(t, pool.get_type()?);
1146
1147            pool.swap_expr(&possible_expressions, &mut rng)?;
1148            dbg!(&pool);
1149            println!("{t}: {pool}");
1150            assert_eq!(t, pool.get_type()?);
1151        }
1152
1153        Ok(())
1154    }
1155
1156    #[test]
1157    fn enumerate() -> anyhow::Result<()> {
1158        let actors = ["john"];
1159        let actor_properties = ["a"];
1160        let event_properties = ["e"];
1161        let possibles = PossibleExpressions::new(&actors, &actor_properties, &event_properties);
1162
1163        let t = LambdaType::from_string("<<a,t>,t>")?;
1164
1165        let p = RootedLambdaPool::enumerator(&t, &possibles);
1166        let pools: HashSet<_> = p
1167            .filter_map(|(p, x)| {
1168                if !x.has_constant_function() {
1169                    Some(p.to_string())
1170                } else {
1171                    None
1172                }
1173            })
1174            .take(20)
1175            .collect();
1176        for p in pools.iter() {
1177            println!("{p}");
1178        }
1179        assert!(pools.contains("lambda <a,t> P P(a_john)"));
1180        for p in pools {
1181            println!("{p}");
1182        }
1183
1184        Ok(())
1185    }
1186
1187    /*
1188    #[test]
1189    fn random_expr_no_constant() -> anyhow::Result<()> {
1190        let actors = &["1", "2", "3", "4", "5"];
1191        let actor_properties = &["1", "2", "3", "4", "5"];
1192        let event_properties = &["1", "2", "3", "4", "5"];
1193        let poss = PossibleExpressions::new(actors, actor_properties, event_properties);
1194        let mut rng = ChaCha8Rng::seed_from_u64(1);
1195
1196        let t = LambdaType::from_string("<<a,<e,t>>, <e,t>>")?;
1197        println!("{t}: ");
1198        let pool = RootedLambdaPool::random_expr_no_constant(&t, &poss, &mut rng).unwrap();
1199        println!("{pool}");
1200
1201        for _ in 0..500 {
1202            let t = LambdaType::random(&mut rng);
1203            println!("{t}: ");
1204            if t.size() <= 3 {
1205                let pool = RootedLambdaPool::random_expr_no_constant(&t, &poss, &mut rng).unwrap();
1206                println!("{pool}");
1207            } else {
1208                println!("too big :(");
1209            }
1210        }
1211        Ok(())
1212    }*/
1213
1214    #[test]
1215    fn enumerate_weirds() -> anyhow::Result<()> {
1216        let actors = &["1"];
1217        let actor_properties = &["1"];
1218        let event_properties = &["1"];
1219        let poss = PossibleExpressions::new(actors, actor_properties, event_properties);
1220
1221        let t = "<a,<e,e>>";
1222
1223        for (p, d) in RootedLambdaPool::enumerator(&LambdaType::from_string(t)?, &poss)
1224            .filter(|(_, e)| !e.has_constant_function())
1225            .take(5)
1226        {
1227            println!("{p} {d:?} {}", d.has_constant_function());
1228        }
1229        Ok(())
1230    }
1231
1232    #[test]
1233    fn random_expr() -> anyhow::Result<()> {
1234        let actors = ["john"];
1235        let actor_properties = ["a"];
1236        let event_properties = ["e"];
1237        let possibles = PossibleExpressions::new(&actors, &actor_properties, &event_properties);
1238        let mut rng = ChaCha8Rng::seed_from_u64(0);
1239
1240        let map = [(LambdaType::A, vec![RootedLambdaPool::parse("a_john")?])]
1241            .into_iter()
1242            .collect();
1243
1244        for _ in 0..100 {
1245            let t = LambdaType::random(&mut rng);
1246            println!("sampling: {t}");
1247            let mut pool = RootedLambdaPool::random_expr(&t, &possibles, &mut rng);
1248            assert_eq!(t, pool.get_type()?);
1249            let s = pool.to_string();
1250            println!("{s}");
1251            let mut pool2 = RootedLambdaPool::parse(s.as_str())?;
1252            assert_eq!(s, pool2.to_string());
1253            println!("{pool}");
1254            pool2.resample_from_expr(&possibles, None, &mut rng)?;
1255            assert_eq!(pool2.get_type()?, t);
1256            pool2.resample_from_expr(&possibles, Some(&map), &mut rng)?;
1257            assert_eq!(pool2.get_type()?, t);
1258
1259            pool.swap_subtree(&mut rng)?;
1260            assert_eq!(pool.get_type()?, t);
1261        }
1262        let t = LambdaType::from_string("<a,<a,t>>")?;
1263        for _ in 0..100 {
1264            let pool = RootedLambdaPool::random_expr(&t, &possibles, &mut rng);
1265            println!("{pool}");
1266        }
1267        Ok(())
1268    }
1269
1270    #[test]
1271    fn constant_exprs() -> anyhow::Result<()> {
1272        let actors = ["john", "mary", "phil", "sue"];
1273        let actor_properties = ["a"];
1274        let event_properties = ["e"];
1275        let possibles = PossibleExpressions::new(&actors, &actor_properties, &event_properties);
1276        let mut rng = ChaCha8Rng::seed_from_u64(0);
1277
1278        let v: HashSet<_> = RootedLambdaPool::enumerator(LambdaType::a(), &possibles)
1279            .map(|(x, _)| x.to_string())
1280            .take(4)
1281            .collect();
1282
1283        assert_eq!(v, HashSet::from(actors.map(|x| format!("a_{x}"))));
1284        println!("{v:?}");
1285
1286        let mut constants = 0;
1287        for _ in 0..100 {
1288            let t = LambdaType::from_string("<a, <a,t>>")?;
1289            println!("sampling: {t}");
1290            let pool = RootedLambdaPool::random_expr(&t, &possibles, &mut rng);
1291            let x = pool.stats();
1292            match x {
1293                LambdaSummaryStats::WellFormed {
1294                    constant_function, ..
1295                } => {
1296                    if constant_function {
1297                        constants += 1;
1298                    }
1299                }
1300                LambdaSummaryStats::Malformed => todo!(),
1301            }
1302        }
1303        println!("{constants}");
1304
1305        Ok(())
1306    }
1307
1308    #[test]
1309    fn random_expr_counts() -> anyhow::Result<()> {
1310        let actors = ["john", "mary", "phil", "sue"];
1311        let actor_properties = ["a"];
1312        let event_properties = ["e"];
1313        let possibles = PossibleExpressions::new(&actors, &actor_properties, &event_properties);
1314        let mut rng = ChaCha8Rng::seed_from_u64(1);
1315
1316        let mut counts: HashMap<_, usize> = HashMap::default();
1317        for _ in 0..1000 {
1318            let t = LambdaType::A;
1319            let pool = RootedLambdaPool::random_expr(&t, &possibles, &mut rng);
1320            assert_eq!(t, pool.get_type()?);
1321            let s = pool.to_string();
1322            *counts.entry(s).or_default() += 1;
1323        }
1324        assert_eq!(counts.len(), 4);
1325        dbg!(&counts);
1326        for (_, v) in counts.iter() {
1327            assert!(200 <= *v && *v <= 300);
1328        }
1329
1330        counts.clear();
1331        for _ in 0..1000 {
1332            let t = LambdaType::at().clone();
1333            let pool = RootedLambdaPool::random_expr(&t, &possibles, &mut rng);
1334            assert_eq!(t, pool.get_type()?);
1335            let s = pool.to_string();
1336            *counts.entry(s).or_default() += 1;
1337        }
1338        dbg!(counts);
1339        Ok(())
1340    }
1341}