1use super::{
2 ActorOrEvent, BinOp, Constant, Display, Expr, ExprRef, MonOp, Quantifier, RootedLambdaPool,
3 thiserror,
4};
5use crate::lambda::{
6 LambdaError, LambdaExpr, LambdaExprRef, LambdaLanguageOfThought, LambdaPool,
7 types::{LambdaType, TypeError},
8};
9use ahash::HashMap;
10use chumsky::container::Container;
11use rand::{
12 Rng, RngExt,
13 distr::{Distribution, weighted::WeightedIndex},
14 seq::{IndexedRandom, IteratorRandom},
15};
16use std::{
17 cmp::Reverse,
18 collections::{BinaryHeap, VecDeque},
19 fmt::Debug,
20};
21use thiserror::Error;
22
23mod context;
24mod samplers;
25pub use context::Context;
26pub(crate) use samplers::PossibleExpr;
27pub use samplers::PossibleExpressions;
28
29#[derive(Debug, Error, Clone)]
30pub struct ExprOrTypeError();
31
32impl Display for ExprOrTypeError {
33 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
34 write!(f, "This ExprOrType is not an Expr!")
35 }
36}
37
38#[derive(Debug, Clone, Eq, PartialEq)]
39enum ExprOrType<'src, T: LambdaLanguageOfThought> {
40 Type {
41 lambda_type: LambdaType,
42 parent: Option<usize>,
43 is_app_subformula: bool,
44 },
45 Expr {
46 lambda_expr: LambdaExpr<'src, T>,
47 parent: Option<usize>,
48 },
49}
50
51impl<'src, T: LambdaLanguageOfThought> TryFrom<ExprOrType<'src, T>> for LambdaExpr<'src, T> {
52 type Error = ExprOrTypeError;
53
54 fn try_from(value: ExprOrType<'src, T>) -> Result<Self, Self::Error> {
55 match value {
56 ExprOrType::Type { .. } => Err(ExprOrTypeError()),
57 ExprOrType::Expr { lambda_expr, .. } => Ok(lambda_expr),
58 }
59 }
60}
61
62impl<T: LambdaLanguageOfThought> ExprOrType<'_, T> {
63 fn parent(&self) -> Option<usize> {
64 match self {
65 ExprOrType::Type { parent, .. } | ExprOrType::Expr { parent, .. } => *parent,
66 }
67 }
68
69 fn is_type(&self) -> bool {
70 matches!(self, ExprOrType::Type { .. })
71 }
72}
73
74#[derive(Debug, Clone, Eq, PartialEq)]
75struct UnfinishedLambdaPool<'src, T: LambdaLanguageOfThought> {
76 pool: Vec<ExprOrType<'src, T>>,
77}
78
79impl<T: LambdaLanguageOfThought> Default for UnfinishedLambdaPool<'_, T> {
80 fn default() -> Self {
81 Self { pool: vec![] }
82 }
83}
84
85impl<'src, T: LambdaLanguageOfThought + Clone> UnfinishedLambdaPool<'src, T> {
86 fn add_expr<'a>(&mut self, expr: PossibleExpr<'a, 'src, T>, c: &mut Context, t: &LambdaType) {
87 let (mut lambda_expr, app_details) = expr.into_expr();
88 c.depth += 1;
89 c.open_nodes += lambda_expr.n_children();
90 c.open_nodes -= 1;
91 let parent = self.pool[c.position].parent();
92 match &mut lambda_expr {
93 LambdaExpr::Lambda(body, arg) => {
94 c.add_lambda(arg);
95 *body = LambdaExprRef(self.pool.len() as u32);
96 self.pool.push(ExprOrType::Type {
97 lambda_type: t.rhs().unwrap().clone(),
98 parent: Some(c.position),
99 is_app_subformula: false,
100 });
101 }
102 LambdaExpr::BoundVariable(b, _) => {
103 c.use_bvar(*b);
104 }
105 LambdaExpr::FreeVariable(..) => (),
106 LambdaExpr::Application {
107 subformula,
108 argument,
109 } => {
110 *subformula = LambdaExprRef(self.pool.len() as u32);
111 *argument = LambdaExprRef((self.pool.len() + 1) as u32);
112 let (subformula, argument) = app_details.unwrap();
113 self.pool.push(ExprOrType::Type {
114 lambda_type: subformula,
115 parent: Some(c.position),
116 is_app_subformula: true,
117 });
118 self.pool.push(ExprOrType::Type {
119 lambda_type: argument,
120 parent: Some(c.position),
121 is_app_subformula: false,
122 });
123 }
124 LambdaExpr::LanguageOfThoughtExpr(e) => {
125 let children_start = self.pool.len();
126 if let Some(t) = e.var_type() {
127 c.add_lambda(t);
128 }
129 self.pool
130 .extend(e.get_arguments().map(|lambda_type| ExprOrType::Type {
131 lambda_type,
132 parent: Some(c.position),
133 is_app_subformula: false,
134 }));
135 e.change_children(
136 (children_start..self.pool.len()).map(|x| LambdaExprRef(x as u32)),
137 );
138 }
139 }
140 self.pool[c.position] = ExprOrType::Expr {
141 lambda_expr,
142 parent,
143 };
144 }
145}
146
147#[derive(Debug)]
148pub struct NormalEnumeration(BinaryHeap<Reverse<Context>>, VecDeque<ExprDetails>);
149
150impl EnumerationType for NormalEnumeration {
151 fn pop(&mut self) -> Option<Context> {
152 self.0.pop().map(|x| x.0)
153 }
154
155 fn push(&mut self, context: Context, _: bool) {
156 self.0.push(Reverse(context));
157 }
158
159 fn get_yield(&mut self) -> Option<ExprDetails> {
160 self.1.pop_front()
161 }
162
163 fn push_yield(&mut self, e: ExprDetails) {
164 self.1.push(e);
165 }
166
167 fn include(&mut self, n: usize) -> impl Iterator<Item = bool> + 'static {
168 std::iter::repeat_n(true, n)
169 }
170}
171
172impl ExprDetails {
173 fn score(&self) -> f64 {
174 (1.0 / (self.size as f64)) + if self.constant_function { 0.0 } else { 10.0 }
175 }
176
177 pub fn has_constant_function(&self) -> bool {
178 self.constant_function
179 }
180}
181
182#[derive(Debug, Clone, Copy, PartialEq)]
183struct KeyedExprDetails {
184 expr_details: ExprDetails,
185 k: f64,
186}
187
188impl Eq for KeyedExprDetails {}
189
190impl PartialOrd for KeyedExprDetails {
191 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
192 Some(self.cmp(other))
193 }
194}
195
196impl Ord for KeyedExprDetails {
197 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
198 other.k.partial_cmp(&self.k).unwrap()
200 }
201}
202impl KeyedExprDetails {
203 fn new(expr_details: ExprDetails, rng: &mut impl Rng) -> Self {
204 let u: f64 = rng.random();
205 KeyedExprDetails {
206 expr_details,
207 k: u.powf(1.0 / expr_details.score()),
208 }
209 }
210}
211
212#[derive(Debug, Clone, PartialEq)]
213struct RandomPQ(Context, f64);
214
215impl Eq for RandomPQ {}
216
217impl RandomPQ {
218 fn new(c: Context, rng: &mut impl Rng) -> Self {
219 RandomPQ(c, rng.random())
220 }
221}
222
223#[derive(Debug)]
224struct ProbabilisticEnumeration<'a, R: Rng, F>
225where
226 F: Fn(&ExprDetails) -> bool,
227{
228 rng: &'a mut R,
229 reservoir_size: usize,
230 reservoir: BinaryHeap<KeyedExprDetails>,
231 backups: Vec<Context>,
232 pq: BinaryHeap<RandomPQ>,
233 filter: F,
234 n_seen: usize,
235 done: bool,
236}
237impl<R: Rng, F> ProbabilisticEnumeration<'_, R, F>
238where
239 F: Fn(&ExprDetails) -> bool,
240{
241 fn threshold(&self) -> Option<f64> {
242 self.reservoir.peek().map(|x| x.k)
243 }
244
245 fn new<'a, 'src, T: LambdaLanguageOfThought, E: Fn(&Context) -> bool>(
246 reservoir_size: usize,
247 t: &LambdaType,
248 possible_expressions: &'a PossibleExpressions<'src, T>,
249 eager_filter: E,
250 filter: F,
251 rng: &'a mut R,
252 ) -> LambdaEnumerator<'a, 'src, T, E, ProbabilisticEnumeration<'a, R, F>> {
253 let context = Context::new(0, vec![]);
254 let mut pq = BinaryHeap::default();
255 pq.push(RandomPQ::new(context, rng));
256 let pools = vec![UnfinishedLambdaPool {
257 pool: vec![ExprOrType::Type {
258 lambda_type: t.clone(),
259 parent: None,
260 is_app_subformula: false,
261 }],
262 }];
263
264 LambdaEnumerator {
265 pools,
266 possible_expressions,
267 eager_filter,
268 pq: ProbabilisticEnumeration {
269 rng,
270 reservoir_size,
271 reservoir: BinaryHeap::default(),
272 backups: vec![],
273 filter,
274 pq,
275 n_seen: 0,
276 done: false,
277 },
278 }
279 }
280}
281
282impl<R: Rng, F> EnumerationType for ProbabilisticEnumeration<'_, R, F>
283where
284 F: Fn(&ExprDetails) -> bool,
285{
286 fn pop(&mut self) -> Option<Context> {
287 self.pq.pop().map(|x| x.0).or_else(|| {
289 (0..self.backups.len()).choose(self.rng).and_then(|index| {
290 let last_item = self.backups.len() - 1;
291 self.backups.swap(index, last_item);
292 self.backups.pop()
293 })
294 })
295 }
296
297 fn push(&mut self, context: Context, included: bool) {
298 if included {
299 self.pq.push(RandomPQ::new(context, &mut self.rng));
300 } else {
301 self.backups.push(context);
302 }
303 }
304
305 fn get_yield(&mut self) -> Option<ExprDetails> {
306 if (self.done || self.pq.is_empty())
307 && let Some(x) = self.reservoir.pop()
308 {
309 Some(x.expr_details)
310 } else {
311 None
312 }
313 }
314
315 fn push_yield(&mut self, e: ExprDetails) {
316 let e = KeyedExprDetails::new(e, &mut self.rng);
317 if (self.filter)(&e.expr_details) {
318 self.n_seen += 1;
319 if self.reservoir_size > self.reservoir.len() {
320 self.reservoir.push(e);
321 } else if let Some(t) = self.threshold()
322 && e.k > t
323 {
324 self.reservoir.pop();
325 self.reservoir.push(e);
326 }
327 if self.n_seen >= self.reservoir_size * 20 {
328 self.pq.clear();
329 self.done = true;
330 }
331 }
332 }
333
334 fn include(&mut self, n: usize) -> impl Iterator<Item = bool> + 'static {
335 let x = (0..n).sample(self.rng, (n / 2).max(1));
336 let mut v = vec![false; n];
337 for i in x {
338 v[i] = true;
339 }
340 v.into_iter()
341 }
342}
343
344#[derive(Debug)]
345pub struct LambdaEnumerator<'a, 'src, T: LambdaLanguageOfThought, F, E = NormalEnumeration> {
347 pools: Vec<UnfinishedLambdaPool<'src, T>>,
348 possible_expressions: &'a PossibleExpressions<'src, T>,
349 eager_filter: F,
350 pq: E,
351}
352
353#[derive(Debug, Clone, Copy, Eq, PartialEq)]
355pub struct ExprDetails {
356 id: usize,
357 constant_function: bool,
358 root: LambdaExprRef,
359 size: usize,
360}
361
362impl ExprDetails {
363 pub fn size(&self) -> usize {
365 self.size
366 }
367}
368
369#[derive(Debug, Clone, Eq, PartialEq)]
370pub struct TypeAgnosticSampler<'src, T: LambdaLanguageOfThought> {
372 type_to_sampler: HashMap<LambdaType, (usize, LambdaSampler<'src, T>)>,
373 max_expr: usize,
374 max_types: usize,
375 possible_expressions: PossibleExpressions<'src, T>,
376}
377
378impl<'src> TypeAgnosticSampler<'src, Expr<'src>> {
379 pub fn sample(
381 &mut self,
382 lambda_type: LambdaType,
383 rng: &mut impl Rng,
384 ) -> RootedLambdaPool<'src, Expr<'src>> {
385 let (counts, exprs) = self
386 .type_to_sampler
387 .entry(lambda_type)
388 .or_insert_with_key(|t| {
389 (
390 1,
391 RootedLambdaPool::sampler(t, &self.possible_expressions, self.max_expr),
392 )
393 });
394 *counts += 1;
395 let sample = exprs.sample(rng);
396
397 if self.type_to_sampler.len() > self.max_types {
398 let (_, k) = self
399 .type_to_sampler
400 .iter()
401 .map(|(k, (n_visits, _))| (n_visits, k))
402 .min_by_key(|x| x.0)
403 .unwrap();
404
405 let t = k.clone();
406 self.type_to_sampler.remove(&t);
407 }
408
409 sample
410 }
411
412 #[must_use]
414 pub fn possibles(&self) -> &PossibleExpressions<'src, Expr<'src>> {
415 &self.possible_expressions
416 }
417}
418
419impl<'src, T: LambdaLanguageOfThought + Clone> RootedLambdaPool<'src, T> {
420 #[must_use]
422 pub fn typeless_sampler(
423 possible_expressions: PossibleExpressions<'src, T>,
424 max_expr: usize,
425 max_types: usize,
426 ) -> TypeAgnosticSampler<'src, T> {
427 assert!(max_types >= 1);
428 assert!(max_expr >= 1);
429 TypeAgnosticSampler {
430 possible_expressions,
431 max_expr,
432 max_types,
433 type_to_sampler: HashMap::default(),
434 }
435 }
436}
437
438#[derive(Debug, Clone, Eq, PartialEq)]
440pub struct LambdaSampler<'src, T: LambdaLanguageOfThought> {
441 lambdas: Vec<RootedLambdaPool<'src, T>>,
442 expr_details: Vec<ExprDetails>,
443}
444
445impl<'src, T: LambdaLanguageOfThought + Clone> Distribution<RootedLambdaPool<'src, T>>
446 for LambdaSampler<'src, T>
447{
448 fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> RootedLambdaPool<'src, T> {
449 let w = WeightedIndex::new(self.expr_details.iter().map(ExprDetails::score)).unwrap();
450 let i = w.sample(rng);
451 self.lambdas
452 .get(i)
453 .expect("The Lambda Sampler has no lambdas to sample :(")
454 .clone()
455 }
456}
457
458pub trait EnumerationType {
461 fn pop(&mut self) -> Option<Context>;
462 fn push(&mut self, context: Context, included: bool);
463 fn get_yield(&mut self) -> Option<ExprDetails>;
464 fn push_yield(&mut self, e: ExprDetails);
465 fn include(&mut self, n: usize) -> impl Iterator<Item = bool> + 'static;
466}
467
468fn try_yield<'src, T, F, E>(
469 x: &mut LambdaEnumerator<'_, 'src, T, F, E>,
470) -> Option<(RootedLambdaPool<'src, T>, ExprDetails)>
471where
472 T: LambdaLanguageOfThought,
473 E: EnumerationType,
474{
475 if let Some(item) = x.pq.get_yield() {
476 let p = std::mem::take(&mut x.pools[item.id]);
477 return Some((
478 RootedLambdaPool {
479 pool: LambdaPool(
480 p.pool
481 .into_iter()
482 .map(|x| LambdaExpr::try_from(x).unwrap())
483 .collect(),
484 ),
485 root: item.root,
486 },
487 item,
488 ));
489 }
490 None
491}
492
493impl<'a, 'src, T, F, E> LambdaEnumerator<'a, 'src, T, F, E>
494where
495 T: LambdaLanguageOfThought + Clone + Debug,
496 F: Fn(&Context) -> bool,
497 E: EnumerationType,
498{
499 fn push(&mut self, c: Context, included: bool) {
500 if (self.eager_filter)(&c) {
501 self.pq.push(c, included);
502 } else {
503 self.pools[c.pool_index] = UnfinishedLambdaPool::default();
504 }
505 }
506
507 pub fn eager_filter<F2>(self, eager_filter: F2) -> LambdaEnumerator<'a, 'src, T, F2, E> {
509 let LambdaEnumerator {
510 pools,
511 possible_expressions,
512 eager_filter: _,
513 pq,
514 } = self;
515
516 LambdaEnumerator {
517 pools,
518 possible_expressions,
519 eager_filter,
520 pq,
521 }
522 }
523}
524
525impl<'src, F, E> Iterator for LambdaEnumerator<'_, 'src, Expr<'src>, F, E>
526where
527 F: Fn(&Context) -> bool,
528 E: EnumerationType,
529{
530 type Item = (RootedLambdaPool<'src, Expr<'src>>, ExprDetails);
531
532 fn next(&mut self) -> Option<Self::Item> {
533 if let Some(x) = try_yield(self) {
534 return Some(x);
535 }
536
537 while let Some(mut c) = self.pq.pop() {
538 if let Some(x) = try_yield(self) {
539 self.push(c, true);
540 return Some(x);
541 }
542 let (possibles, lambda_type) = match &self.pools[c.pool_index].pool[c.position] {
543 ExprOrType::Type {
544 lambda_type,
545 is_app_subformula,
546 parent,
547 } => {
548 let mut possibles = self.possible_expressions.possibilities(
549 lambda_type,
550 *is_app_subformula,
551 &c,
552 );
553
554 if let Some(p) = parent
557 && let ExprOrType::Expr {
558 lambda_expr:
559 LambdaExpr::LanguageOfThoughtExpr(Expr::Quantifier {
560 var_type,
561 restrictor,
562 ..
563 }),
564 ..
565 } = self.pools[c.pool_index].pool[*p]
566 && restrictor.0 == c.position as u32
567 {
568 possibles.push(PossibleExpr::new_borrowed(match var_type {
569 ActorOrEvent::Actor => &LambdaExpr::LanguageOfThoughtExpr(
570 Expr::Constant(Constant::Everyone),
571 ),
572 ActorOrEvent::Event => &LambdaExpr::LanguageOfThoughtExpr(
573 Expr::Constant(Constant::EveryEvent),
574 ),
575 }));
576 }
577
578 (possibles, lambda_type.clone())
579 }
580 ExprOrType::Expr {
581 lambda_expr,
582 parent,
583 } => {
584 if let Some(child) = lambda_expr
588 .get_children()
589 .map(|x| x.0 as usize)
590 .find(|child| self.pools[c.pool_index].pool[*child].is_type())
591 {
592 c.position = child;
593 self.pq.push(c, true);
594 continue;
595 }
596
597 if lambda_expr.inc_depth() {
598 c.pop_lambda();
599 }
600
601 if let Some(p) = parent {
602 c.position = *p;
603 self.pq.push(c, true);
604 continue;
605 }
606 self.pq.push_yield(ExprDetails {
608 id: c.pool_index,
609 root: LambdaExprRef(c.position as u32),
610 size: c.depth,
611 constant_function: c.is_constant(),
612 });
613 continue;
614 }
615 };
616
617 let n = possibles.len();
618 let included = self.pq.include(n);
619 if n == 0 {
620 continue;
624 }
625 let n_pools = self.pools.len();
626 if n_pools.is_multiple_of(10_000) {
627 self.pools.shrink_to_fit();
628 }
629
630 for _ in 0..n.saturating_sub(1) {
631 self.pools.push(self.pools[c.pool_index].clone());
632 }
633
634 let positions =
635 std::iter::once(c.pool_index).chain(n_pools..n_pools + n.saturating_sub(1));
636
637 for (((expr, pool_id), mut c), included) in possibles
638 .into_iter()
639 .zip(positions)
640 .zip(std::iter::repeat_n(c, n))
641 .zip(included)
642 {
643 c.pool_index = pool_id;
644 let pool = self.pools.get_mut(pool_id).unwrap();
645 pool.add_expr(expr, &mut c, &lambda_type);
646 self.push(c, included);
647 }
648
649 if let Some(x) = try_yield(self) {
650 return Some(x);
651 }
652 }
653
654 if let Some(x) = try_yield(self) {
656 return Some(x);
657 }
658 None
659 }
660}
661
662impl<'src> RootedLambdaPool<'src, Expr<'src>> {
663 pub fn resample_from_expr<'a>(
665 &mut self,
666 possible_expressions: &'a PossibleExpressions<'src, Expr<'src>>,
667 helpers: Option<&HashMap<LambdaType, Vec<RootedLambdaPool<'src, Expr<'src>>>>>,
668 rng: &mut impl Rng,
669 ) -> Result<(), LambdaError> {
670 let position = LambdaExprRef((0..self.len()).choose(rng).unwrap() as u32);
671 let t = self.pool.get_type(position)?;
672
673 let pool = if let Some(helpers) = helpers
674 && rng.random_bool(0.2)
675 && let Some(v) = helpers.get(&t)
676 && !v.is_empty()
677 {
678 let pool = v.choose(rng).unwrap();
679 pool.clone()
680 } else {
681 let (pool, _) = self
682 .probabilistic_enumerate_from_expr(
683 position,
684 possible_expressions,
685 |_| true,
686 |_| true,
687 rng,
688 )?
689 .next()
690 .unwrap();
691 pool
692 };
693
694 let offset = self.len() as u32;
695 let new_root = pool.root.0 + offset;
696 self.pool.0.extend(pool.pool.0.into_iter().map(|mut x| {
697 let children: Vec<_> = x
698 .get_children()
699 .map(|x| LambdaExprRef(x.0 + offset))
700 .collect();
701 x.change_children(children.into_iter());
702 x
703 }));
704 self.pool.0.swap(position.0 as usize, new_root as usize);
705 self.cleanup();
706 Ok(())
707 }
708
709 fn probabilistic_enumerate_from_expr<'a, R, E, F>(
710 &self,
711 position: LambdaExprRef,
712 possible_expressions: &'a PossibleExpressions<'src, Expr<'src>>,
713 eager_filter: E,
714 filter: F,
715 rng: &'a mut R,
716 ) -> Result<
717 LambdaEnumerator<'a, 'src, Expr<'src>, E, ProbabilisticEnumeration<'a, R, F>>,
718 TypeError,
719 >
720 where
721 R: Rng,
722 F: Fn(&ExprDetails) -> bool,
723 E: Fn(&Context) -> bool,
724 {
725 let (context, is_app_subformula) = Context::from_pos(self, position);
726 let output = self.pool.get_type(position)?;
727 let mut pq = BinaryHeap::default();
728 pq.push(RandomPQ::new(context, rng));
729 let pools = vec![UnfinishedLambdaPool {
730 pool: vec![ExprOrType::Type {
731 lambda_type: output,
732 parent: None,
733 is_app_subformula,
734 }],
735 }];
736 let enumerator = LambdaEnumerator {
737 pools,
738 possible_expressions,
739 eager_filter,
740 pq: ProbabilisticEnumeration {
741 rng,
742 reservoir_size: 1,
743 reservoir: BinaryHeap::default(),
744 done: false,
745 n_seen: 0,
746 filter,
747 backups: vec![],
748 pq,
749 },
750 };
751
752 Ok(enumerator)
753 }
754
755 pub fn enumerator_filter<'a, F: Fn(&Context) -> bool>(
757 t: &LambdaType,
758 filter: F,
759 possible_expressions: &'a PossibleExpressions<'src, Expr<'src>>,
760 ) -> LambdaEnumerator<'a, 'src, Expr<'src>, F> {
761 let context = Context::new(0, vec![]);
762 let mut pq = BinaryHeap::default();
763 pq.push(Reverse(context));
764 let pools = vec![UnfinishedLambdaPool {
765 pool: vec![ExprOrType::Type {
766 lambda_type: t.clone(),
767 parent: None,
768 is_app_subformula: false,
769 }],
770 }];
771
772 LambdaEnumerator {
773 pools,
774 possible_expressions,
775 eager_filter: filter,
776 pq: NormalEnumeration(pq, VecDeque::default()),
777 }
778 }
779
780 #[must_use]
782 pub fn enumerator<'a>(
783 t: &LambdaType,
784 possible_expressions: &'a PossibleExpressions<'src, Expr<'src>>,
785 ) -> LambdaEnumerator<'a, 'src, Expr<'src>, impl Fn(&'_ Context) -> bool> {
786 let context = Context::new(0, vec![]);
787 let mut pq = BinaryHeap::default();
788 pq.push(Reverse(context));
789 let pools = vec![UnfinishedLambdaPool {
790 pool: vec![ExprOrType::Type {
791 lambda_type: t.clone(),
792 parent: None,
793 is_app_subformula: false,
794 }],
795 }];
796
797 LambdaEnumerator {
798 pools,
799 possible_expressions,
800 eager_filter: |_: &Context| true,
801 pq: NormalEnumeration(pq, VecDeque::default()),
802 }
803 }
804
805 #[must_use]
807 pub fn sampler(
808 t: &LambdaType,
809 possible_expressions: &PossibleExpressions<'src, Expr<'src>>,
810 max_expr: usize,
811 ) -> LambdaSampler<'src, Expr<'src>> {
812 let enumerator = RootedLambdaPool::enumerator(t, possible_expressions);
813 let mut lambdas = Vec::with_capacity(max_expr);
814 let mut expr_details = Vec::with_capacity(max_expr);
815 for (lambda, expr_detail) in enumerator.take(max_expr) {
816 lambdas.push(lambda);
817 expr_details.push(expr_detail);
818 }
819
820 LambdaSampler {
821 lambdas,
822 expr_details,
823 }
824 }
825
826 pub fn random_expr(
828 t: &LambdaType,
829 possible_expressions: &PossibleExpressions<'src, Expr<'src>>,
830 rng: &mut impl Rng,
831 ) -> RootedLambdaPool<'src, Expr<'src>> {
832 ProbabilisticEnumeration::new(
833 1,
834 t,
835 possible_expressions,
836 |_: &Context| true,
837 |_| true,
838 rng,
839 )
840 .next()
841 .unwrap()
842 .0
843 }
844
845 pub fn random_expr_no_constant(
847 t: &LambdaType,
848 possible_expressions: &PossibleExpressions<'src, Expr<'src>>,
849 rng: &mut impl Rng,
850 ) -> Option<RootedLambdaPool<'src, Expr<'src>>> {
851 ProbabilisticEnumeration::new(
852 1,
853 t,
854 possible_expressions,
855 |_: &Context| true,
856 |e| !e.constant_function,
857 rng,
858 )
859 .next()
860 .map(|x| x.0)
861 }
862}
863
864impl<'src> RootedLambdaPool<'src, Expr<'src>> {
865 pub fn prune_quantifiers(&mut self) {
867 let quantifiers = self
868 .pool
869 .bfs_from(self.root)
870 .filter_map(|(i, _)| match self.get(i) {
871 LambdaExpr::LanguageOfThoughtExpr(Expr::Quantifier { subformula, .. }) => {
872 if self
873 .pool
874 .bfs_from(LambdaExprRef(subformula.0))
875 .any(|(x, d)| {
876 if let LambdaExpr::BoundVariable(v, _) = self.get(x) {
877 *v == d
878 } else {
879 false
880 }
881 })
882 {
883 None
884 } else {
885 Some((i, LambdaExprRef(subformula.0)))
886 }
887 }
888 _ => None,
889 })
890 .collect::<Vec<_>>();
891
892 for (quantifier, subformula) in quantifiers.into_iter().rev() {
894 *self.pool.get_mut(quantifier) = self.pool.get(subformula).clone();
895 self.pool.bfs_from_mut(quantifier).for_each(|(x, d, _)| {
896 if let LambdaExpr::BoundVariable(b_d, _) = x
897 && *b_d > d
898 {
899 *b_d -= 1;
900 }
901 });
902 }
903 self.root = self.pool.cleanup(self.root);
904 }
905
906 pub fn swap_expr(
908 &mut self,
909 possible_expressions: &PossibleExpressions<'src, Expr<'src>>,
910 rng: &mut impl Rng,
911 ) -> Result<(), TypeError> {
912 let position = LambdaExprRef((0..self.len()).choose(rng).unwrap() as u32);
913
914 let (context, _) = Context::from_pos(self, position);
915
916 let output = self.pool.get_type(position)?;
917 let expr = self.pool.get(position);
918
919 let arguments: Vec<_> = expr
920 .get_children()
921 .map(|x| self.pool.get_type(x).unwrap())
922 .collect();
923
924 let mut replacements = possible_expressions.possiblities_fixed_children(
925 &output,
926 &arguments,
927 expr.var_type(),
928 &context,
929 );
930
931 if replacements.is_empty()
933 && let LambdaExpr::LanguageOfThoughtExpr(Expr::Quantifier {
934 var_type,
935 restrictor,
936 subformula,
937 ..
938 }) = expr
939 {
940 for quantifier in [Quantifier::Universal, Quantifier::Existential] {
941 replacements.push(std::borrow::Cow::Owned(LambdaExpr::LanguageOfThoughtExpr(
942 Expr::Quantifier {
943 quantifier,
944 var_type: *var_type,
945 restrictor: *restrictor,
946 subformula: *subformula,
947 },
948 )));
949 }
950 }
951 for c in [Constant::EveryEvent, Constant::Everyone] {
952 if replacements.is_empty()
953 && expr == &LambdaExpr::LanguageOfThoughtExpr(Expr::Constant(c))
954 {
955 replacements = possible_expressions.possiblities_fixed_children(
956 LambdaType::t(),
957 &arguments,
958 expr.var_type(),
959 &context,
960 );
961 replacements.push(std::borrow::Cow::Owned(LambdaExpr::LanguageOfThoughtExpr(
962 Expr::Constant(c),
963 )));
964 }
965 }
966
967 let choice = replacements.choose(rng).unwrap_or_else(|| {
968 panic!(
969 "There is no node with output {output} and arguments {}",
970 arguments
971 .into_iter()
972 .map(|x| x.to_string())
973 .collect::<Vec<_>>()
974 .join(", ")
975 )
976 });
977
978 let mut new_expr = choice.clone().into_owned();
979 new_expr.change_children(self.pool.get(position).get_children());
980 self.pool.0[position.0 as usize] = new_expr;
981
982 Ok(())
983 }
984
985 pub fn swap_subtree(&mut self, rng: &mut impl Rng) -> Result<(), TypeError> {
988 let position = LambdaExprRef((0..self.len()).choose(rng).unwrap() as u32);
989 let (context, _) = Context::from_pos(self, position);
990 let alt = context.find_compatible(self, position)?;
991 if let Some(new_pos) = alt.choose(rng).copied() {
992 let offset = self.pool.0.len();
993 let mut lookup = HashMap::default();
994 let mut new_pool: Vec<LambdaExpr<'src, Expr<'src>>> = self
995 .pool
996 .bfs_from(new_pos)
997 .map(|(x, _)| {
998 let n = lookup.len();
999 lookup
1000 .entry(x)
1001 .or_insert(LambdaExprRef((n + offset) as u32));
1002 self.get(x).clone()
1003 })
1004 .collect();
1005
1006 for x in &mut new_pool {
1007 let child: Vec<_> = x.get_children().map(|x| *lookup.get(&x).unwrap()).collect();
1008 x.change_children(child.into_iter());
1009 }
1010
1011 self.pool.0.extend(new_pool);
1012 self.pool.0.swap(position.0 as usize, offset);
1013 self.cleanup();
1014 }
1015 Ok(())
1016 }
1017}
1018
1019#[cfg(test)]
1020mod test {
1021
1022 use std::collections::HashSet;
1023
1024 use super::*;
1025 use crate::lambda::{LambdaPool, LambdaSummaryStats};
1026 use rand::SeedableRng;
1027 use rand_chacha::ChaCha8Rng;
1028
1029 #[test]
1030 fn prune_quantifier_test() -> anyhow::Result<()> {
1031 let mut pool =
1032 RootedLambdaPool::parse("some_e(x,all_e,AgentOf(a_2,e_1) & PatientOf(a_0,e_0))")?;
1033
1034 pool.prune_quantifiers();
1035 assert_eq!(pool.to_string(), "AgentOf(a_2, e_1) & PatientOf(a_0, e_0)");
1036
1037 let mut pool = RootedLambdaPool::parse(
1038 "some_e(x0, all_e, some(z, all_a, AgentOf(z, e_1) & PatientOf(a_0, e_0)))",
1039 )?;
1040
1041 pool.prune_quantifiers();
1042 assert_eq!(
1043 pool.to_string(),
1044 "some(x, all_a, AgentOf(x, e_1) & PatientOf(a_0, e_0))"
1045 );
1046
1047 let mut pool = RootedLambdaPool::parse("~every_e(z, pe_1, pa_2(a_0))")?;
1048
1049 pool.prune_quantifiers();
1050
1051 assert_eq!(pool.to_string(), "~pa_2(a_0)");
1052 let mut pool = RootedLambdaPool::new(
1053 LambdaPool(vec![
1054 LambdaExpr::Lambda(LambdaExprRef(1), LambdaType::E),
1055 LambdaExpr::Lambda(LambdaExprRef(2), LambdaType::T),
1056 LambdaExpr::LanguageOfThoughtExpr(Expr::Quantifier {
1057 quantifier: Quantifier::Universal,
1058 var_type: ActorOrEvent::Actor,
1059 restrictor: ExprRef(3),
1060 subformula: ExprRef(4),
1061 }),
1062 LambdaExpr::LanguageOfThoughtExpr(Expr::Constant(Constant::Property(
1063 "1",
1064 ActorOrEvent::Actor,
1065 ))),
1066 LambdaExpr::LanguageOfThoughtExpr(Expr::Quantifier {
1067 quantifier: Quantifier::Existential,
1068 var_type: ActorOrEvent::Actor,
1069 restrictor: ExprRef(5),
1070 subformula: ExprRef(6),
1071 }),
1072 LambdaExpr::LanguageOfThoughtExpr(Expr::Constant(Constant::Property(
1073 "0",
1074 ActorOrEvent::Actor,
1075 ))),
1076 LambdaExpr::LanguageOfThoughtExpr(Expr::Unary(
1077 MonOp::Property("3", ActorOrEvent::Event),
1078 ExprRef(7),
1079 )),
1080 LambdaExpr::BoundVariable(3, LambdaType::E),
1081 ]),
1082 LambdaExprRef(0),
1083 );
1084
1085 assert_eq!(
1086 pool.to_string(),
1087 "lambda e x lambda t phi every(y, pa_1, some(z, pa_0, pe_3(x)))"
1088 );
1089
1090 let mut parsed_pool = RootedLambdaPool::parse(
1091 "lambda e x lambda t phi every(y, pa_1, some(z, pa_0, pe_3(x)))",
1092 )?;
1093 parsed_pool.prune_quantifiers();
1094 pool.prune_quantifiers();
1095 assert_eq!(pool.to_string(), "lambda e x lambda t phi pe_3(x)");
1096
1097 Ok(())
1098 }
1099
1100 #[test]
1101 fn random_swap_tree() -> anyhow::Result<()> {
1102 let mut rng = ChaCha8Rng::seed_from_u64(2);
1103 let x = RootedLambdaPool::parse("lambda a x pa_cool(a_John) | pa_cool(x)")?;
1104 let mut h = HashSet::new();
1105 for _ in 0..100 {
1106 let mut z = x.clone();
1107 z.swap_subtree(&mut rng)?;
1108 h.insert(z.to_string());
1109 }
1110 for x in h.iter() {
1111 println!("{x}");
1112 }
1113 assert!(h.contains("lambda a x pa_cool(x)"));
1114 assert!(h.contains("lambda a x pa_cool(a_John)"));
1115 Ok(())
1116 }
1117
1118 #[test]
1119 fn randomn_swap() -> anyhow::Result<()> {
1120 let mut rng = ChaCha8Rng::seed_from_u64(2);
1121 let actors = ["0", "1"];
1122 let available_event_properties = ["2", "3", "4"];
1123 let possible_expressions = PossibleExpressions::new(
1124 &actors,
1125 &available_event_properties,
1126 &available_event_properties,
1127 );
1128 for _ in 0..200 {
1129 let t = LambdaType::random(&mut rng);
1130 println!("{t}");
1131 let mut pool = RootedLambdaPool::random_expr(&t, &possible_expressions, &mut rng);
1132 println!("{t}: {pool}");
1133 assert_eq!(t, pool.get_type()?);
1134 println!("{pool:?}");
1135 pool.swap_expr(&possible_expressions, &mut rng)?;
1136 println!("{pool:?}");
1137 println!("{t}: {pool}");
1138 assert_eq!(t, pool.get_type()?);
1139 }
1140 let p =
1141 RootedLambdaPool::parse("lambda <a,t> P some(x, P(x), some_e(y, pe_2(y), pe_4(y)))")?;
1142 let t = p.get_type()?;
1143 for _ in 0..1000 {
1144 let mut pool = p.clone();
1145 assert_eq!(t, pool.get_type()?);
1146
1147 pool.swap_expr(&possible_expressions, &mut rng)?;
1148 dbg!(&pool);
1149 println!("{t}: {pool}");
1150 assert_eq!(t, pool.get_type()?);
1151 }
1152
1153 Ok(())
1154 }
1155
1156 #[test]
1157 fn enumerate() -> anyhow::Result<()> {
1158 let actors = ["john"];
1159 let actor_properties = ["a"];
1160 let event_properties = ["e"];
1161 let possibles = PossibleExpressions::new(&actors, &actor_properties, &event_properties);
1162
1163 let t = LambdaType::from_string("<<a,t>,t>")?;
1164
1165 let p = RootedLambdaPool::enumerator(&t, &possibles);
1166 let pools: HashSet<_> = p
1167 .filter_map(|(p, x)| {
1168 if !x.has_constant_function() {
1169 Some(p.to_string())
1170 } else {
1171 None
1172 }
1173 })
1174 .take(20)
1175 .collect();
1176 for p in pools.iter() {
1177 println!("{p}");
1178 }
1179 assert!(pools.contains("lambda <a,t> P P(a_john)"));
1180 for p in pools {
1181 println!("{p}");
1182 }
1183
1184 Ok(())
1185 }
1186
1187 #[test]
1215 fn enumerate_weirds() -> anyhow::Result<()> {
1216 let actors = &["1"];
1217 let actor_properties = &["1"];
1218 let event_properties = &["1"];
1219 let poss = PossibleExpressions::new(actors, actor_properties, event_properties);
1220
1221 let t = "<a,<e,e>>";
1222
1223 for (p, d) in RootedLambdaPool::enumerator(&LambdaType::from_string(t)?, &poss)
1224 .filter(|(_, e)| !e.has_constant_function())
1225 .take(5)
1226 {
1227 println!("{p} {d:?} {}", d.has_constant_function());
1228 }
1229 Ok(())
1230 }
1231
1232 #[test]
1233 fn random_expr() -> anyhow::Result<()> {
1234 let actors = ["john"];
1235 let actor_properties = ["a"];
1236 let event_properties = ["e"];
1237 let possibles = PossibleExpressions::new(&actors, &actor_properties, &event_properties);
1238 let mut rng = ChaCha8Rng::seed_from_u64(0);
1239
1240 let map = [(LambdaType::A, vec![RootedLambdaPool::parse("a_john")?])]
1241 .into_iter()
1242 .collect();
1243
1244 for _ in 0..100 {
1245 let t = LambdaType::random(&mut rng);
1246 println!("sampling: {t}");
1247 let mut pool = RootedLambdaPool::random_expr(&t, &possibles, &mut rng);
1248 assert_eq!(t, pool.get_type()?);
1249 let s = pool.to_string();
1250 println!("{s}");
1251 let mut pool2 = RootedLambdaPool::parse(s.as_str())?;
1252 assert_eq!(s, pool2.to_string());
1253 println!("{pool}");
1254 pool2.resample_from_expr(&possibles, None, &mut rng)?;
1255 assert_eq!(pool2.get_type()?, t);
1256 pool2.resample_from_expr(&possibles, Some(&map), &mut rng)?;
1257 assert_eq!(pool2.get_type()?, t);
1258
1259 pool.swap_subtree(&mut rng)?;
1260 assert_eq!(pool.get_type()?, t);
1261 }
1262 let t = LambdaType::from_string("<a,<a,t>>")?;
1263 for _ in 0..100 {
1264 let pool = RootedLambdaPool::random_expr(&t, &possibles, &mut rng);
1265 println!("{pool}");
1266 }
1267 Ok(())
1268 }
1269
1270 #[test]
1271 fn constant_exprs() -> anyhow::Result<()> {
1272 let actors = ["john", "mary", "phil", "sue"];
1273 let actor_properties = ["a"];
1274 let event_properties = ["e"];
1275 let possibles = PossibleExpressions::new(&actors, &actor_properties, &event_properties);
1276 let mut rng = ChaCha8Rng::seed_from_u64(0);
1277
1278 let v: HashSet<_> = RootedLambdaPool::enumerator(LambdaType::a(), &possibles)
1279 .map(|(x, _)| x.to_string())
1280 .take(4)
1281 .collect();
1282
1283 assert_eq!(v, HashSet::from(actors.map(|x| format!("a_{x}"))));
1284 println!("{v:?}");
1285
1286 let mut constants = 0;
1287 for _ in 0..100 {
1288 let t = LambdaType::from_string("<a, <a,t>>")?;
1289 println!("sampling: {t}");
1290 let pool = RootedLambdaPool::random_expr(&t, &possibles, &mut rng);
1291 let x = pool.stats();
1292 match x {
1293 LambdaSummaryStats::WellFormed {
1294 constant_function, ..
1295 } => {
1296 if constant_function {
1297 constants += 1;
1298 }
1299 }
1300 LambdaSummaryStats::Malformed => todo!(),
1301 }
1302 }
1303 println!("{constants}");
1304
1305 Ok(())
1306 }
1307
1308 #[test]
1309 fn random_expr_counts() -> anyhow::Result<()> {
1310 let actors = ["john", "mary", "phil", "sue"];
1311 let actor_properties = ["a"];
1312 let event_properties = ["e"];
1313 let possibles = PossibleExpressions::new(&actors, &actor_properties, &event_properties);
1314 let mut rng = ChaCha8Rng::seed_from_u64(1);
1315
1316 let mut counts: HashMap<_, usize> = HashMap::default();
1317 for _ in 0..1000 {
1318 let t = LambdaType::A;
1319 let pool = RootedLambdaPool::random_expr(&t, &possibles, &mut rng);
1320 assert_eq!(t, pool.get_type()?);
1321 let s = pool.to_string();
1322 *counts.entry(s).or_default() += 1;
1323 }
1324 assert_eq!(counts.len(), 4);
1325 dbg!(&counts);
1326 for (_, v) in counts.iter() {
1327 assert!(200 <= *v && *v <= 300);
1328 }
1329
1330 counts.clear();
1331 for _ in 0..1000 {
1332 let t = LambdaType::at().clone();
1333 let pool = RootedLambdaPool::random_expr(&t, &possibles, &mut rng);
1334 assert_eq!(t, pool.get_type()?);
1335 let s = pool.to_string();
1336 *counts.entry(s).or_default() += 1;
1337 }
1338 dbg!(counts);
1339 Ok(())
1340 }
1341}