1use super::{
2 ActorOrEvent, BinOp, Constant, Display, Expr, ExprRef, MonOp, Quantifier, RootedLambdaPool,
3 thiserror,
4};
5use crate::lambda::{
6 LambdaError, LambdaExpr, LambdaExprRef, LambdaLanguageOfThought, LambdaPool,
7 types::{LambdaType, TypeError},
8};
9use ahash::HashMap;
10use chumsky::container::Container;
11use rand::{
12 Rng, RngExt,
13 distr::{Distribution, weighted::WeightedIndex},
14 seq::{IndexedRandom, IteratorRandom},
15};
16use std::{
17 cmp::Reverse,
18 collections::{BinaryHeap, VecDeque},
19 fmt::Debug,
20};
21use thiserror::Error;
22
23mod context;
24mod samplers;
25pub use context::Context;
26pub(crate) use samplers::PossibleExpr;
27pub use samplers::PossibleExpressions;
28
29#[derive(Debug, Error, Clone)]
30pub struct ExprOrTypeError();
31
32impl Display for ExprOrTypeError {
33 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
34 write!(f, "This ExprOrType is not an Expr!")
35 }
36}
37
38#[derive(Debug, Clone, Eq, PartialEq)]
39enum ExprOrType<'src, T: LambdaLanguageOfThought> {
40 Type {
41 lambda_type: LambdaType,
42 parent: Option<usize>,
43 is_app_subformula: bool,
44 },
45 Expr {
46 lambda_expr: LambdaExpr<'src, T>,
47 parent: Option<usize>,
48 },
49}
50
51impl<'src, T: LambdaLanguageOfThought> TryFrom<ExprOrType<'src, T>> for LambdaExpr<'src, T> {
52 type Error = ExprOrTypeError;
53
54 fn try_from(value: ExprOrType<'src, T>) -> Result<Self, Self::Error> {
55 match value {
56 ExprOrType::Type { .. } => Err(ExprOrTypeError()),
57 ExprOrType::Expr { lambda_expr, .. } => Ok(lambda_expr),
58 }
59 }
60}
61
62impl<T: LambdaLanguageOfThought> ExprOrType<'_, T> {
63 fn parent(&self) -> Option<usize> {
64 match self {
65 ExprOrType::Type { parent, .. } | ExprOrType::Expr { parent, .. } => *parent,
66 }
67 }
68
69 fn is_type(&self) -> bool {
70 matches!(self, ExprOrType::Type { .. })
71 }
72}
73
74#[derive(Debug, Clone, Eq, PartialEq)]
75struct UnfinishedLambdaPool<'src, T: LambdaLanguageOfThought> {
76 pool: Vec<ExprOrType<'src, T>>,
77}
78
79impl<T: LambdaLanguageOfThought> Default for UnfinishedLambdaPool<'_, T> {
80 fn default() -> Self {
81 Self { pool: vec![] }
82 }
83}
84
85impl<'src, T: LambdaLanguageOfThought + Clone> UnfinishedLambdaPool<'src, T> {
86 fn add_expr<'a>(&mut self, expr: PossibleExpr<'a, 'src, T>, c: &mut Context, t: &LambdaType) {
87 let (mut lambda_expr, app_details) = expr.into_expr();
88 c.depth += 1;
89 c.open_nodes += lambda_expr.n_children();
90 c.open_nodes -= 1;
91 let parent = self.pool[c.position].parent();
92 match &mut lambda_expr {
93 LambdaExpr::Lambda(body, arg) => {
94 c.add_lambda(arg);
95 *body = LambdaExprRef::new(self.pool.len());
96 self.pool.push(ExprOrType::Type {
97 lambda_type: t.rhs().unwrap().clone(),
98 parent: Some(c.position),
99 is_app_subformula: false,
100 });
101 }
102 LambdaExpr::BoundVariable(b, _) => {
103 c.use_bvar(*b);
104 }
105 LambdaExpr::FreeVariable(..) => (),
106 LambdaExpr::Application {
107 subformula,
108 argument,
109 } => {
110 *subformula = LambdaExprRef::new(self.pool.len());
111 *argument = LambdaExprRef::new(self.pool.len() + 1);
112 let (subformula, argument) = app_details.unwrap();
113 self.pool.push(ExprOrType::Type {
114 lambda_type: subformula,
115 parent: Some(c.position),
116 is_app_subformula: true,
117 });
118 self.pool.push(ExprOrType::Type {
119 lambda_type: argument,
120 parent: Some(c.position),
121 is_app_subformula: false,
122 });
123 }
124 LambdaExpr::LanguageOfThoughtExpr(e) => {
125 let children_start = self.pool.len();
126 if let Some(t) = e.var_type() {
127 c.add_lambda(t);
128 }
129 self.pool
130 .extend(e.get_arguments().map(|lambda_type| ExprOrType::Type {
131 lambda_type,
132 parent: Some(c.position),
133 is_app_subformula: false,
134 }));
135 e.change_children((children_start..self.pool.len()).map(LambdaExprRef::new));
136 }
137 }
138 self.pool[c.position] = ExprOrType::Expr {
139 lambda_expr,
140 parent,
141 };
142 }
143}
144
145#[derive(Debug)]
146pub struct NormalEnumeration(BinaryHeap<Reverse<Context>>, VecDeque<ExprDetails>);
147
148impl EnumerationType for NormalEnumeration {
149 fn pop(&mut self) -> Option<Context> {
150 self.0.pop().map(|x| x.0)
151 }
152
153 fn push(&mut self, context: Context, _: bool) {
154 self.0.push(Reverse(context));
155 }
156
157 fn get_yield(&mut self) -> Option<ExprDetails> {
158 self.1.pop_front()
159 }
160
161 fn push_yield(&mut self, e: ExprDetails) {
162 self.1.push(e);
163 }
164
165 fn include(&mut self, n: usize) -> impl Iterator<Item = bool> + 'static {
166 std::iter::repeat_n(true, n)
167 }
168}
169
170impl ExprDetails {
171 #[allow(clippy::cast_precision_loss)]
172 fn score(&self) -> f64 {
173 (1.0 / (self.size as f64)) + if self.constant_function { 0.0 } else { 10.0 }
174 }
175
176 pub fn has_constant_function(&self) -> bool {
177 self.constant_function
178 }
179}
180
181#[derive(Debug, Clone, Copy, PartialEq)]
182struct KeyedExprDetails {
183 expr_details: ExprDetails,
184 k: f64,
185}
186
187impl Eq for KeyedExprDetails {}
188
189impl PartialOrd for KeyedExprDetails {
190 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
191 Some(self.cmp(other))
192 }
193}
194
195impl Ord for KeyedExprDetails {
196 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
197 other.k.partial_cmp(&self.k).unwrap()
199 }
200}
201impl KeyedExprDetails {
202 fn new(expr_details: ExprDetails, rng: &mut impl Rng) -> Self {
203 let u: f64 = rng.random();
204 KeyedExprDetails {
205 expr_details,
206 k: u.powf(1.0 / expr_details.score()),
207 }
208 }
209}
210
211#[derive(Debug, Clone, PartialEq)]
212struct RandomPQ(Context, f64);
213
214impl Eq for RandomPQ {}
215
216impl RandomPQ {
217 fn new(c: Context, rng: &mut impl Rng) -> Self {
218 RandomPQ(c, rng.random())
219 }
220}
221
222#[derive(Debug)]
223struct ProbabilisticEnumeration<'a, R: Rng, F>
224where
225 F: Fn(&ExprDetails) -> bool,
226{
227 rng: &'a mut R,
228 reservoir_size: usize,
229 reservoir: BinaryHeap<KeyedExprDetails>,
230 backups: Vec<Context>,
231 pq: BinaryHeap<RandomPQ>,
232 filter: F,
233 n_seen: usize,
234 done: bool,
235}
236impl<R: Rng, F> ProbabilisticEnumeration<'_, R, F>
237where
238 F: Fn(&ExprDetails) -> bool,
239{
240 fn threshold(&self) -> Option<f64> {
241 self.reservoir.peek().map(|x| x.k)
242 }
243
244 fn new<'a, 'src, T: LambdaLanguageOfThought, E: Fn(&Context) -> bool>(
245 reservoir_size: usize,
246 t: &LambdaType,
247 possible_expressions: &'a PossibleExpressions<'src, T>,
248 eager_filter: E,
249 filter: F,
250 rng: &'a mut R,
251 ) -> LambdaEnumerator<'a, 'src, T, E, ProbabilisticEnumeration<'a, R, F>> {
252 let context = Context::new(0, vec![]);
253 let mut pq = BinaryHeap::default();
254 pq.push(RandomPQ::new(context, rng));
255 let pools = vec![UnfinishedLambdaPool {
256 pool: vec![ExprOrType::Type {
257 lambda_type: t.clone(),
258 parent: None,
259 is_app_subformula: false,
260 }],
261 }];
262
263 LambdaEnumerator {
264 pools,
265 possible_expressions,
266 eager_filter,
267 pq: ProbabilisticEnumeration {
268 rng,
269 reservoir_size,
270 reservoir: BinaryHeap::default(),
271 backups: vec![],
272 filter,
273 pq,
274 n_seen: 0,
275 done: false,
276 },
277 }
278 }
279}
280
281impl<R: Rng, F> EnumerationType for ProbabilisticEnumeration<'_, R, F>
282where
283 F: Fn(&ExprDetails) -> bool,
284{
285 fn pop(&mut self) -> Option<Context> {
286 self.pq.pop().map(|x| x.0).or_else(|| {
288 (0..self.backups.len()).choose(self.rng).and_then(|index| {
289 let last_item = self.backups.len() - 1;
290 self.backups.swap(index, last_item);
291 self.backups.pop()
292 })
293 })
294 }
295
296 fn push(&mut self, context: Context, included: bool) {
297 if included {
298 self.pq.push(RandomPQ::new(context, &mut self.rng));
299 } else {
300 self.backups.push(context);
301 }
302 }
303
304 fn get_yield(&mut self) -> Option<ExprDetails> {
305 if (self.done || self.pq.is_empty())
306 && let Some(x) = self.reservoir.pop()
307 {
308 Some(x.expr_details)
309 } else {
310 None
311 }
312 }
313
314 fn push_yield(&mut self, e: ExprDetails) {
315 let e = KeyedExprDetails::new(e, &mut self.rng);
316 if (self.filter)(&e.expr_details) {
317 self.n_seen += 1;
318 if self.reservoir_size > self.reservoir.len() {
319 self.reservoir.push(e);
320 } else if let Some(t) = self.threshold()
321 && e.k > t
322 {
323 self.reservoir.pop();
324 self.reservoir.push(e);
325 }
326 if self.n_seen >= self.reservoir_size * 20 {
327 self.pq.clear();
328 self.done = true;
329 }
330 }
331 }
332
333 fn include(&mut self, n: usize) -> impl Iterator<Item = bool> + 'static {
334 let x = (0..n).sample(self.rng, (n / 2).max(1));
335 let mut v = vec![false; n];
336 for i in x {
337 v[i] = true;
338 }
339 v.into_iter()
340 }
341}
342
343#[derive(Debug)]
344pub struct LambdaEnumerator<'a, 'src, T: LambdaLanguageOfThought, F, E = NormalEnumeration> {
346 pools: Vec<UnfinishedLambdaPool<'src, T>>,
347 possible_expressions: &'a PossibleExpressions<'src, T>,
348 eager_filter: F,
349 pq: E,
350}
351
352#[derive(Debug, Clone, Copy, Eq, PartialEq)]
354pub struct ExprDetails {
355 id: usize,
356 constant_function: bool,
357 root: LambdaExprRef,
358 size: usize,
359}
360
361impl ExprDetails {
362 pub fn size(&self) -> usize {
364 self.size
365 }
366}
367
368#[derive(Debug, Clone, Eq, PartialEq)]
369pub struct TypeAgnosticSampler<'src, T: LambdaLanguageOfThought> {
371 type_to_sampler: HashMap<LambdaType, (usize, LambdaSampler<'src, T>)>,
372 max_expr: usize,
373 max_types: usize,
374 possible_expressions: PossibleExpressions<'src, T>,
375}
376
377impl<'src> TypeAgnosticSampler<'src, Expr<'src>> {
378 #[allow(clippy::missing_panics_doc)]
379 pub fn sample(
381 &mut self,
382 lambda_type: LambdaType,
383 rng: &mut impl Rng,
384 ) -> RootedLambdaPool<'src, Expr<'src>> {
385 let (counts, exprs) = self
386 .type_to_sampler
387 .entry(lambda_type)
388 .or_insert_with_key(|t| {
389 (
390 1,
391 RootedLambdaPool::sampler(t, &self.possible_expressions, self.max_expr),
392 )
393 });
394 *counts += 1;
395 let sample = exprs.sample(rng);
396
397 if self.type_to_sampler.len() > self.max_types {
398 let (_, k) = self
399 .type_to_sampler
400 .iter()
401 .map(|(k, (n_visits, _))| (n_visits, k))
402 .min_by_key(|x| x.0)
403 .unwrap();
404
405 let t = k.clone();
406 self.type_to_sampler.remove(&t);
407 }
408
409 sample
410 }
411
412 #[must_use]
414 pub fn possibles(&self) -> &PossibleExpressions<'src, Expr<'src>> {
415 &self.possible_expressions
416 }
417}
418
419impl<'src, T: LambdaLanguageOfThought + Clone> RootedLambdaPool<'src, T> {
420 #[must_use]
425 pub fn typeless_sampler(
426 possible_expressions: PossibleExpressions<'src, T>,
427 max_expr: usize,
428 max_types: usize,
429 ) -> TypeAgnosticSampler<'src, T> {
430 assert!(max_types >= 1);
431 assert!(max_expr >= 1);
432 TypeAgnosticSampler {
433 possible_expressions,
434 max_expr,
435 max_types,
436 type_to_sampler: HashMap::default(),
437 }
438 }
439}
440
441#[derive(Debug, Clone, Eq, PartialEq)]
443pub struct LambdaSampler<'src, T: LambdaLanguageOfThought> {
444 lambdas: Vec<RootedLambdaPool<'src, T>>,
445 expr_details: Vec<ExprDetails>,
446}
447
448impl<'src, T: LambdaLanguageOfThought + Clone> Distribution<RootedLambdaPool<'src, T>>
449 for LambdaSampler<'src, T>
450{
451 fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> RootedLambdaPool<'src, T> {
452 let w = WeightedIndex::new(self.expr_details.iter().map(ExprDetails::score)).unwrap();
453 let i = w.sample(rng);
454 self.lambdas
455 .get(i)
456 .expect("The Lambda Sampler has no lambdas to sample :(")
457 .clone()
458 }
459}
460
461pub trait EnumerationType {
464 fn pop(&mut self) -> Option<Context>;
465 fn push(&mut self, context: Context, included: bool);
466 fn get_yield(&mut self) -> Option<ExprDetails>;
467 fn push_yield(&mut self, e: ExprDetails);
468 fn include(&mut self, n: usize) -> impl Iterator<Item = bool> + 'static;
469}
470
471fn try_yield<'src, T, F, E>(
472 x: &mut LambdaEnumerator<'_, 'src, T, F, E>,
473) -> Option<(RootedLambdaPool<'src, T>, ExprDetails)>
474where
475 T: LambdaLanguageOfThought,
476 E: EnumerationType,
477{
478 if let Some(item) = x.pq.get_yield() {
479 let p = std::mem::take(&mut x.pools[item.id]);
480 return Some((
481 RootedLambdaPool {
482 pool: LambdaPool(
483 p.pool
484 .into_iter()
485 .map(|x| LambdaExpr::try_from(x).unwrap())
486 .collect(),
487 ),
488 root: item.root,
489 },
490 item,
491 ));
492 }
493 None
494}
495
496impl<'a, 'src, T, F, E> LambdaEnumerator<'a, 'src, T, F, E>
497where
498 T: LambdaLanguageOfThought + Clone + Debug,
499 F: Fn(&Context) -> bool,
500 E: EnumerationType,
501{
502 fn push(&mut self, c: Context, included: bool) {
503 if (self.eager_filter)(&c) {
504 self.pq.push(c, included);
505 } else {
506 self.pools[c.pool_index] = UnfinishedLambdaPool::default();
507 }
508 }
509
510 pub fn eager_filter<F2>(self, eager_filter: F2) -> LambdaEnumerator<'a, 'src, T, F2, E> {
512 let LambdaEnumerator {
513 pools,
514 possible_expressions,
515 eager_filter: _,
516 pq,
517 } = self;
518
519 LambdaEnumerator {
520 pools,
521 possible_expressions,
522 eager_filter,
523 pq,
524 }
525 }
526}
527
528impl<'src, F, E> Iterator for LambdaEnumerator<'_, 'src, Expr<'src>, F, E>
529where
530 F: Fn(&Context) -> bool,
531 E: EnumerationType,
532{
533 type Item = (RootedLambdaPool<'src, Expr<'src>>, ExprDetails);
534
535 #[allow(clippy::too_many_lines)]
536 fn next(&mut self) -> Option<Self::Item> {
537 if let Some(x) = try_yield(self) {
538 return Some(x);
539 }
540
541 while let Some(mut c) = self.pq.pop() {
542 if let Some(x) = try_yield(self) {
543 self.push(c, true);
544 return Some(x);
545 }
546 let (possibles, lambda_type) = match &self.pools[c.pool_index].pool[c.position] {
547 ExprOrType::Type {
548 lambda_type,
549 is_app_subformula,
550 parent,
551 } => {
552 let mut possibles = self.possible_expressions.possibilities(
553 lambda_type,
554 *is_app_subformula,
555 &c,
556 );
557
558 if let Some(p) = parent
561 && let ExprOrType::Expr {
562 lambda_expr:
563 LambdaExpr::LanguageOfThoughtExpr(Expr::Quantifier {
564 var_type,
565 restrictor,
566 ..
567 }),
568 ..
569 } = self.pools[c.pool_index].pool[*p]
570 && restrictor.0 == u32::try_from(c.position).unwrap()
571 {
572 possibles.push(PossibleExpr::new_borrowed(match var_type {
573 ActorOrEvent::Actor => &LambdaExpr::LanguageOfThoughtExpr(
574 Expr::Constant(Constant::Everyone),
575 ),
576 ActorOrEvent::Event => &LambdaExpr::LanguageOfThoughtExpr(
577 Expr::Constant(Constant::EveryEvent),
578 ),
579 }));
580 }
581
582 (possibles, lambda_type.clone())
583 }
584 ExprOrType::Expr {
585 lambda_expr,
586 parent,
587 } => {
588 if let Some(child) = lambda_expr
592 .get_children()
593 .map(|x| x.0 as usize)
594 .find(|child| self.pools[c.pool_index].pool[*child].is_type())
595 {
596 c.position = child;
597 self.pq.push(c, true);
598 continue;
599 }
600
601 if lambda_expr.inc_depth() {
602 c.pop_lambda();
603 }
604
605 if let Some(p) = parent {
606 c.position = *p;
607 self.pq.push(c, true);
608 continue;
609 }
610 self.pq.push_yield(ExprDetails {
612 id: c.pool_index,
613 root: LambdaExprRef(u32::try_from(c.position).unwrap()),
614 size: c.depth,
615 constant_function: c.is_constant(),
616 });
617 continue;
618 }
619 };
620
621 let n = possibles.len();
622 let included = self.pq.include(n);
623 if n == 0 {
624 continue;
625 }
626 let n_pools = self.pools.len();
627 if n_pools.is_multiple_of(10_000) {
628 self.pools.shrink_to_fit();
629 }
630 for _ in 0..n.saturating_sub(1) {
631 self.pools.push(self.pools[c.pool_index].clone());
632 }
633
634 let positions =
635 std::iter::once(c.pool_index).chain(n_pools..n_pools + n.saturating_sub(1));
636
637 for (((expr, pool_id), mut c), included) in possibles
638 .into_iter()
639 .zip(positions)
640 .zip(std::iter::repeat_n(c, n))
641 .zip(included)
642 {
643 c.pool_index = pool_id;
644 let pool = self.pools.get_mut(pool_id).unwrap();
645 pool.add_expr(expr, &mut c, &lambda_type);
646 self.push(c, included);
647 }
648
649 if let Some(x) = try_yield(self) {
650 return Some(x);
651 }
652 }
653
654 if let Some(x) = try_yield(self) {
656 return Some(x);
657 }
658 None
659 }
660}
661
662impl<'src> RootedLambdaPool<'src, Expr<'src>> {
663 pub fn resample_from_expr<'a>(
671 &mut self,
672 possible_expressions: &'a PossibleExpressions<'src, Expr<'src>>,
673 helpers: Option<&HashMap<LambdaType, Vec<RootedLambdaPool<'src, Expr<'src>>>>>,
674 rng: &mut impl Rng,
675 ) -> Result<(), LambdaError> {
676 let position = LambdaExprRef(u32::try_from((0..self.len()).choose(rng).unwrap()).unwrap());
677 let t = self.pool.get_type(position)?;
678
679 let pool = if let Some(helpers) = helpers
680 && rng.random_bool(0.2)
681 && let Some(v) = helpers.get(&t)
682 && !v.is_empty()
683 {
684 let pool = v.choose(rng).unwrap();
685 pool.clone()
686 } else {
687 let (pool, _) = self
688 .probabilistic_enumerate_from_expr(
689 position,
690 possible_expressions,
691 |_| true,
692 |_| true,
693 rng,
694 )?
695 .next()
696 .unwrap();
697 pool
698 };
699
700 let offset = u32::try_from(self.len()).unwrap();
701 let new_root = pool.root.0 + offset;
702 self.pool.0.extend(pool.pool.0.into_iter().map(|mut x| {
703 let children: Vec<_> = x
704 .get_children()
705 .map(|x| LambdaExprRef(x.0 + offset))
706 .collect();
707 x.change_children(children.into_iter());
708 x
709 }));
710 self.pool.0.swap(position.0 as usize, new_root as usize);
711 self.cleanup();
712 Ok(())
713 }
714
715 fn probabilistic_enumerate_from_expr<'a, R, E, F>(
716 &self,
717 position: LambdaExprRef,
718 possible_expressions: &'a PossibleExpressions<'src, Expr<'src>>,
719 eager_filter: E,
720 filter: F,
721 rng: &'a mut R,
722 ) -> Result<
723 LambdaEnumerator<'a, 'src, Expr<'src>, E, ProbabilisticEnumeration<'a, R, F>>,
724 TypeError,
725 >
726 where
727 R: Rng,
728 F: Fn(&ExprDetails) -> bool,
729 E: Fn(&Context) -> bool,
730 {
731 let (context, is_app_subformula) = Context::from_pos(self, position);
732 let output = self.pool.get_type(position)?;
733 let mut pq = BinaryHeap::default();
734 pq.push(RandomPQ::new(context, rng));
735 let pools = vec![UnfinishedLambdaPool {
736 pool: vec![ExprOrType::Type {
737 lambda_type: output,
738 parent: None,
739 is_app_subformula,
740 }],
741 }];
742 let enumerator = LambdaEnumerator {
743 pools,
744 possible_expressions,
745 eager_filter,
746 pq: ProbabilisticEnumeration {
747 rng,
748 reservoir_size: 1,
749 reservoir: BinaryHeap::default(),
750 done: false,
751 n_seen: 0,
752 filter,
753 backups: vec![],
754 pq,
755 },
756 };
757
758 Ok(enumerator)
759 }
760
761 pub fn enumerator_filter<'a, F: Fn(&Context) -> bool>(
763 t: &LambdaType,
764 filter: F,
765 possible_expressions: &'a PossibleExpressions<'src, Expr<'src>>,
766 ) -> LambdaEnumerator<'a, 'src, Expr<'src>, F> {
767 let context = Context::new(0, vec![]);
768 let mut pq = BinaryHeap::default();
769 pq.push(Reverse(context));
770 let pools = vec![UnfinishedLambdaPool {
771 pool: vec![ExprOrType::Type {
772 lambda_type: t.clone(),
773 parent: None,
774 is_app_subformula: false,
775 }],
776 }];
777
778 LambdaEnumerator {
779 pools,
780 possible_expressions,
781 eager_filter: filter,
782 pq: NormalEnumeration(pq, VecDeque::default()),
783 }
784 }
785
786 #[must_use]
788 pub fn enumerator<'a>(
789 t: &LambdaType,
790 possible_expressions: &'a PossibleExpressions<'src, Expr<'src>>,
791 ) -> LambdaEnumerator<'a, 'src, Expr<'src>, impl Fn(&'_ Context) -> bool> {
792 let context = Context::new(0, vec![]);
793 let mut pq = BinaryHeap::default();
794 pq.push(Reverse(context));
795 let pools = vec![UnfinishedLambdaPool {
796 pool: vec![ExprOrType::Type {
797 lambda_type: t.clone(),
798 parent: None,
799 is_app_subformula: false,
800 }],
801 }];
802
803 LambdaEnumerator {
804 pools,
805 possible_expressions,
806 eager_filter: |_: &Context| true,
807 pq: NormalEnumeration(pq, VecDeque::default()),
808 }
809 }
810
811 #[must_use]
813 pub fn sampler(
814 t: &LambdaType,
815 possible_expressions: &PossibleExpressions<'src, Expr<'src>>,
816 max_expr: usize,
817 ) -> LambdaSampler<'src, Expr<'src>> {
818 let enumerator = RootedLambdaPool::enumerator(t, possible_expressions);
819 let mut lambdas = Vec::with_capacity(max_expr);
820 let mut expr_details = Vec::with_capacity(max_expr);
821 for (lambda, expr_detail) in enumerator.take(max_expr) {
822 lambdas.push(lambda);
823 expr_details.push(expr_detail);
824 }
825
826 LambdaSampler {
827 lambdas,
828 expr_details,
829 }
830 }
831
832 pub fn random_expr(
837 t: &LambdaType,
838 possible_expressions: &PossibleExpressions<'src, Expr<'src>>,
839 rng: &mut impl Rng,
840 ) -> RootedLambdaPool<'src, Expr<'src>> {
841 ProbabilisticEnumeration::new(
842 1,
843 t,
844 possible_expressions,
845 |_: &Context| true,
846 |_| true,
847 rng,
848 )
849 .next()
850 .unwrap()
851 .0
852 }
853
854 pub fn random_expr_no_constant(
856 t: &LambdaType,
857 possible_expressions: &PossibleExpressions<'src, Expr<'src>>,
858 rng: &mut impl Rng,
859 ) -> Option<RootedLambdaPool<'src, Expr<'src>>> {
860 ProbabilisticEnumeration::new(
861 1,
862 t,
863 possible_expressions,
864 |_: &Context| true,
865 |e| !e.constant_function,
866 rng,
867 )
868 .next()
869 .map(|x| x.0)
870 }
871}
872
873impl<'src> RootedLambdaPool<'src, Expr<'src>> {
874 pub fn prune_quantifiers(&mut self) {
876 let quantifiers = self
877 .pool
878 .bfs_from(self.root)
879 .filter_map(|(i, _)| match self.get(i) {
880 LambdaExpr::LanguageOfThoughtExpr(Expr::Quantifier { subformula, .. }) => {
881 if self
882 .pool
883 .bfs_from(LambdaExprRef(subformula.0))
884 .any(|(x, d)| {
885 if let LambdaExpr::BoundVariable(v, _) = self.get(x) {
886 *v == d
887 } else {
888 false
889 }
890 })
891 {
892 None
893 } else {
894 Some((i, LambdaExprRef(subformula.0)))
895 }
896 }
897 _ => None,
898 })
899 .collect::<Vec<_>>();
900
901 for (quantifier, subformula) in quantifiers.into_iter().rev() {
903 *self.pool.get_mut(quantifier) = self.pool.get(subformula).clone();
904 self.pool.bfs_from_mut(quantifier).for_each(|(x, d, _)| {
905 if let LambdaExpr::BoundVariable(b_d, _) = x
906 && *b_d > d
907 {
908 *b_d -= 1;
909 }
910 });
911 }
912 self.root = self.pool.cleanup(self.root);
913 }
914
915 pub fn swap_expr(
923 &mut self,
924 possible_expressions: &PossibleExpressions<'src, Expr<'src>>,
925 rng: &mut impl Rng,
926 ) -> Result<(), TypeError> {
927 let position = LambdaExprRef(u32::try_from((0..self.len()).choose(rng).unwrap()).unwrap());
928
929 let (context, _) = Context::from_pos(self, position);
930
931 let output = self.pool.get_type(position)?;
932 let expr = self.pool.get(position);
933
934 let arguments: Vec<_> = expr
935 .get_children()
936 .map(|x| self.pool.get_type(x).unwrap())
937 .collect();
938
939 let mut replacements = possible_expressions.possiblities_fixed_children(
940 &output,
941 &arguments,
942 expr.var_type(),
943 &context,
944 );
945
946 if replacements.is_empty()
948 && let LambdaExpr::LanguageOfThoughtExpr(Expr::Quantifier {
949 var_type,
950 restrictor,
951 subformula,
952 ..
953 }) = expr
954 {
955 for quantifier in [Quantifier::Universal, Quantifier::Existential] {
956 replacements.push(std::borrow::Cow::Owned(LambdaExpr::LanguageOfThoughtExpr(
957 Expr::Quantifier {
958 quantifier,
959 var_type: *var_type,
960 restrictor: *restrictor,
961 subformula: *subformula,
962 },
963 )));
964 }
965 }
966 for c in [Constant::EveryEvent, Constant::Everyone] {
967 if replacements.is_empty()
968 && expr == &LambdaExpr::LanguageOfThoughtExpr(Expr::Constant(c))
969 {
970 replacements = possible_expressions.possiblities_fixed_children(
971 LambdaType::t(),
972 &arguments,
973 expr.var_type(),
974 &context,
975 );
976 replacements.push(std::borrow::Cow::Owned(LambdaExpr::LanguageOfThoughtExpr(
977 Expr::Constant(c),
978 )));
979 }
980 }
981
982 let choice = replacements.choose(rng).unwrap_or_else(|| {
983 panic!(
984 "There is no node with output {output} and arguments {}",
985 arguments
986 .into_iter()
987 .map(|x| x.to_string())
988 .collect::<Vec<_>>()
989 .join(", ")
990 )
991 });
992
993 let mut new_expr = choice.clone().into_owned();
994 new_expr.change_children(self.pool.get(position).get_children());
995 self.pool.0[position.0 as usize] = new_expr;
996
997 Ok(())
998 }
999
1000 pub fn swap_subtree(&mut self, rng: &mut impl Rng) -> Result<(), TypeError> {
1009 let position = LambdaExprRef(u32::try_from((0..self.len()).choose(rng).unwrap()).unwrap());
1010 let (context, _) = Context::from_pos(self, position);
1011 let alt = context.find_compatible(self, position)?;
1012 if let Some(new_pos) = alt.choose(rng).copied() {
1013 let offset = self.pool.0.len();
1014 let mut lookup = HashMap::default();
1015 let mut new_pool: Vec<LambdaExpr<'src, Expr<'src>>> = self
1016 .pool
1017 .bfs_from(new_pos)
1018 .map(|(x, _)| {
1019 let n = lookup.len();
1020 lookup
1021 .entry(x)
1022 .or_insert(LambdaExprRef(u32::try_from(n + offset).unwrap()));
1023 self.get(x).clone()
1024 })
1025 .collect();
1026
1027 for x in &mut new_pool {
1028 let child: Vec<_> = x.get_children().map(|x| *lookup.get(&x).unwrap()).collect();
1029 x.change_children(child.into_iter());
1030 }
1031
1032 self.pool.0.extend(new_pool);
1033 self.pool.0.swap(position.0 as usize, offset);
1034 self.cleanup();
1035 }
1036 Ok(())
1037 }
1038}
1039
1040#[cfg(test)]
1041mod test {
1042
1043 use std::collections::HashSet;
1044
1045 use super::*;
1046 use crate::lambda::{LambdaPool, LambdaSummaryStats};
1047 use rand::SeedableRng;
1048 use rand_chacha::ChaCha8Rng;
1049
1050 #[test]
1051 fn prune_quantifier_test() -> anyhow::Result<()> {
1052 let mut pool =
1053 RootedLambdaPool::parse("some_e(x,all_e,AgentOf(a_2,e_1) & PatientOf(a_0,e_0))")?;
1054
1055 pool.prune_quantifiers();
1056 assert_eq!(pool.to_string(), "AgentOf(a_2, e_1) & PatientOf(a_0, e_0)");
1057
1058 let mut pool = RootedLambdaPool::parse(
1059 "some_e(x0, all_e, some(z, all_a, AgentOf(z, e_1) & PatientOf(a_0, e_0)))",
1060 )?;
1061
1062 pool.prune_quantifiers();
1063 assert_eq!(
1064 pool.to_string(),
1065 "some(x, all_a, AgentOf(x, e_1) & PatientOf(a_0, e_0))"
1066 );
1067
1068 let mut pool = RootedLambdaPool::parse("~every_e(z, pe_1, pa_2(a_0))")?;
1069
1070 pool.prune_quantifiers();
1071
1072 assert_eq!(pool.to_string(), "~pa_2(a_0)");
1073 let mut pool = RootedLambdaPool::new(
1074 LambdaPool(vec![
1075 LambdaExpr::Lambda(LambdaExprRef(1), LambdaType::E),
1076 LambdaExpr::Lambda(LambdaExprRef(2), LambdaType::T),
1077 LambdaExpr::LanguageOfThoughtExpr(Expr::Quantifier {
1078 quantifier: Quantifier::Universal,
1079 var_type: ActorOrEvent::Actor,
1080 restrictor: ExprRef(3),
1081 subformula: ExprRef(4),
1082 }),
1083 LambdaExpr::LanguageOfThoughtExpr(Expr::Constant(Constant::Property(
1084 "1",
1085 ActorOrEvent::Actor,
1086 ))),
1087 LambdaExpr::LanguageOfThoughtExpr(Expr::Quantifier {
1088 quantifier: Quantifier::Existential,
1089 var_type: ActorOrEvent::Actor,
1090 restrictor: ExprRef(5),
1091 subformula: ExprRef(6),
1092 }),
1093 LambdaExpr::LanguageOfThoughtExpr(Expr::Constant(Constant::Property(
1094 "0",
1095 ActorOrEvent::Actor,
1096 ))),
1097 LambdaExpr::LanguageOfThoughtExpr(Expr::Unary(
1098 MonOp::Property("3", ActorOrEvent::Event),
1099 ExprRef(7),
1100 )),
1101 LambdaExpr::BoundVariable(3, LambdaType::E),
1102 ]),
1103 LambdaExprRef(0),
1104 );
1105
1106 assert_eq!(
1107 pool.to_string(),
1108 "lambda e x lambda t phi every(y, pa_1, some(z, pa_0, pe_3(x)))"
1109 );
1110
1111 let mut parsed_pool = RootedLambdaPool::parse(
1112 "lambda e x lambda t phi every(y, pa_1, some(z, pa_0, pe_3(x)))",
1113 )?;
1114 parsed_pool.prune_quantifiers();
1115 pool.prune_quantifiers();
1116 assert_eq!(pool.to_string(), "lambda e x lambda t phi pe_3(x)");
1117
1118 Ok(())
1119 }
1120
1121 #[test]
1122 fn random_swap_tree() -> anyhow::Result<()> {
1123 let mut rng = ChaCha8Rng::seed_from_u64(2);
1124 let x = RootedLambdaPool::parse("lambda a x pa_cool(a_John) | pa_cool(x)")?;
1125 let mut h = HashSet::new();
1126 for _ in 0..100 {
1127 let mut z = x.clone();
1128 z.swap_subtree(&mut rng)?;
1129 h.insert(z.to_string());
1130 }
1131 for x in h.iter() {
1132 println!("{x}");
1133 }
1134 assert!(h.contains("lambda a x pa_cool(x)"));
1135 assert!(h.contains("lambda a x pa_cool(a_John)"));
1136 Ok(())
1137 }
1138
1139 #[test]
1140 fn randomn_swap() -> anyhow::Result<()> {
1141 let mut rng = ChaCha8Rng::seed_from_u64(2);
1142 let actors = ["0", "1"];
1143 let available_event_properties = ["2", "3", "4"];
1144 let possible_expressions = PossibleExpressions::new(
1145 &actors,
1146 &available_event_properties,
1147 &available_event_properties,
1148 );
1149 for _ in 0..200 {
1150 let t = LambdaType::random(&mut rng);
1151 println!("{t}");
1152 let mut pool = RootedLambdaPool::random_expr(&t, &possible_expressions, &mut rng);
1153 println!("{t}: {pool}");
1154 assert_eq!(t, pool.get_type()?);
1155 println!("{pool:?}");
1156 pool.swap_expr(&possible_expressions, &mut rng)?;
1157 println!("{pool:?}");
1158 println!("{t}: {pool}");
1159 assert_eq!(t, pool.get_type()?);
1160 }
1161 let p =
1162 RootedLambdaPool::parse("lambda <a,t> P some(x, P(x), some_e(y, pe_2(y), pe_4(y)))")?;
1163 let t = p.get_type()?;
1164 for _ in 0..1000 {
1165 let mut pool = p.clone();
1166 assert_eq!(t, pool.get_type()?);
1167
1168 pool.swap_expr(&possible_expressions, &mut rng)?;
1169 dbg!(&pool);
1170 println!("{t}: {pool}");
1171 assert_eq!(t, pool.get_type()?);
1172 }
1173
1174 Ok(())
1175 }
1176
1177 #[test]
1178 fn enumerate() -> anyhow::Result<()> {
1179 let actors = ["john"];
1180 let actor_properties = ["a"];
1181 let event_properties = ["e"];
1182 let possibles = PossibleExpressions::new(&actors, &actor_properties, &event_properties);
1183
1184 let t = LambdaType::from_string("<<a,t>,t>")?;
1185
1186 let p = RootedLambdaPool::enumerator(&t, &possibles);
1187 let pools: HashSet<_> = p
1188 .filter_map(|(p, x)| {
1189 if !x.has_constant_function() {
1190 Some(p.to_string())
1191 } else {
1192 None
1193 }
1194 })
1195 .take(20)
1196 .collect();
1197 for p in pools.iter() {
1198 println!("{p}");
1199 }
1200 assert!(pools.contains("lambda <a,t> P P(a_john)"));
1201 for p in pools {
1202 println!("{p}");
1203 }
1204
1205 Ok(())
1206 }
1207
1208 #[test]
1236 fn enumerate_weirds() -> anyhow::Result<()> {
1237 let actors = &["1"];
1238 let actor_properties = &["1"];
1239 let event_properties = &["1"];
1240 let poss = PossibleExpressions::new(actors, actor_properties, event_properties);
1241
1242 let t = "<a,<e,e>>";
1243
1244 for (p, d) in RootedLambdaPool::enumerator(&LambdaType::from_string(t)?, &poss)
1245 .filter(|(_, e)| !e.has_constant_function())
1246 .take(5)
1247 {
1248 println!("{p} {d:?} {}", d.has_constant_function());
1249 }
1250 Ok(())
1251 }
1252
1253 #[test]
1254 fn random_expr() -> anyhow::Result<()> {
1255 let actors = ["john"];
1256 let actor_properties = ["a"];
1257 let event_properties = ["e"];
1258 let possibles = PossibleExpressions::new(&actors, &actor_properties, &event_properties);
1259 let mut rng = ChaCha8Rng::seed_from_u64(0);
1260
1261 let map = [(LambdaType::A, vec![RootedLambdaPool::parse("a_john")?])]
1262 .into_iter()
1263 .collect();
1264
1265 for _ in 0..100 {
1266 let t = LambdaType::random(&mut rng);
1267 println!("sampling: {t}");
1268 let mut pool = RootedLambdaPool::random_expr(&t, &possibles, &mut rng);
1269 assert_eq!(t, pool.get_type()?);
1270 let s = pool.to_string();
1271 println!("{s}");
1272 let mut pool2 = RootedLambdaPool::parse(s.as_str())?;
1273 assert_eq!(s, pool2.to_string());
1274 println!("{pool}");
1275 pool2.resample_from_expr(&possibles, None, &mut rng)?;
1276 assert_eq!(pool2.get_type()?, t);
1277 pool2.resample_from_expr(&possibles, Some(&map), &mut rng)?;
1278 assert_eq!(pool2.get_type()?, t);
1279
1280 pool.swap_subtree(&mut rng)?;
1281 assert_eq!(pool.get_type()?, t);
1282 }
1283 let t = LambdaType::from_string("<a,<a,t>>")?;
1284 for _ in 0..100 {
1285 let pool = RootedLambdaPool::random_expr(&t, &possibles, &mut rng);
1286 println!("{pool}");
1287 }
1288 Ok(())
1289 }
1290
1291 #[test]
1292 fn constant_exprs() -> anyhow::Result<()> {
1293 let actors = ["john", "mary", "phil", "sue"];
1294 let actor_properties = ["a"];
1295 let event_properties = ["e"];
1296 let possibles = PossibleExpressions::new(&actors, &actor_properties, &event_properties);
1297 let mut rng = ChaCha8Rng::seed_from_u64(0);
1298
1299 let v: HashSet<_> = RootedLambdaPool::enumerator(LambdaType::a(), &possibles)
1300 .map(|(x, _)| x.to_string())
1301 .take(4)
1302 .collect();
1303
1304 assert_eq!(v, HashSet::from(actors.map(|x| format!("a_{x}"))));
1305 println!("{v:?}");
1306
1307 let mut constants = 0;
1308 for _ in 0..100 {
1309 let t = LambdaType::from_string("<a, <a,t>>")?;
1310 println!("sampling: {t}");
1311 let pool = RootedLambdaPool::random_expr(&t, &possibles, &mut rng);
1312 let x = pool.stats();
1313 match x {
1314 LambdaSummaryStats::WellFormed {
1315 constant_function, ..
1316 } => {
1317 if constant_function {
1318 constants += 1;
1319 }
1320 }
1321 LambdaSummaryStats::Malformed => todo!(),
1322 }
1323 }
1324 println!("{constants}");
1325
1326 Ok(())
1327 }
1328
1329 #[test]
1330 fn random_expr_counts() -> anyhow::Result<()> {
1331 let actors = ["john", "mary", "phil", "sue"];
1332 let actor_properties = ["a"];
1333 let event_properties = ["e"];
1334 let possibles = PossibleExpressions::new(&actors, &actor_properties, &event_properties);
1335 let mut rng = ChaCha8Rng::seed_from_u64(1);
1336
1337 let mut counts: HashMap<_, usize> = HashMap::default();
1338 for _ in 0..1000 {
1339 let t = LambdaType::A;
1340 let pool = RootedLambdaPool::random_expr(&t, &possibles, &mut rng);
1341 assert_eq!(t, pool.get_type()?);
1342 let s = pool.to_string();
1343 *counts.entry(s).or_default() += 1;
1344 }
1345 assert_eq!(counts.len(), 4);
1346 dbg!(&counts);
1347 for (_, v) in counts.iter() {
1348 assert!(200 <= *v && *v <= 300);
1349 }
1350
1351 counts.clear();
1352 for _ in 0..1000 {
1353 let t = LambdaType::at().clone();
1354 let pool = RootedLambdaPool::random_expr(&t, &possibles, &mut rng);
1355 assert_eq!(t, pool.get_type()?);
1356 let s = pool.to_string();
1357 *counts.entry(s).or_default() += 1;
1358 }
1359 dbg!(counts);
1360 Ok(())
1361 }
1362}