1use super::{
2 ActorOrEvent, BinOp, Constant, Display, Expr, ExprRef, MonOp, Quantifier, RootedLambdaPool,
3 thiserror,
4};
5use crate::lambda::{
6 LambdaError, LambdaExpr, LambdaExprRef, LambdaLanguageOfThought, LambdaPool,
7 types::{LambdaType, TypeError},
8};
9use ahash::HashMap;
10use rand::{
11 Rng, RngExt,
12 distr::{Distribution, weighted::WeightedIndex},
13 seq::{IndexedRandom, IteratorRandom},
14};
15use std::{
16 cmp::Reverse,
17 collections::{BinaryHeap, VecDeque},
18 fmt::Debug,
19};
20use thiserror::Error;
21
22mod context;
23mod samplers;
24pub use context::Context;
25pub(crate) use samplers::PossibleExpr;
26pub use samplers::PossibleExpressions;
27
28#[derive(Debug, Error, Clone)]
29pub struct ExprOrTypeError();
30
31impl Display for ExprOrTypeError {
32 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
33 write!(f, "This ExprOrType is not an Expr!")
34 }
35}
36
37#[derive(Debug, Clone, Eq, PartialEq)]
38enum ExprOrType<'src, T: LambdaLanguageOfThought> {
39 Type {
40 lambda_type: LambdaType,
41 parent: Option<usize>,
42 is_app_subformula: bool,
43 },
44 Expr {
45 lambda_expr: LambdaExpr<'src, T>,
46 parent: Option<usize>,
47 },
48}
49
50impl<'src, T: LambdaLanguageOfThought> TryFrom<ExprOrType<'src, T>> for LambdaExpr<'src, T> {
51 type Error = ExprOrTypeError;
52
53 fn try_from(value: ExprOrType<'src, T>) -> Result<Self, Self::Error> {
54 match value {
55 ExprOrType::Type { .. } => Err(ExprOrTypeError()),
56 ExprOrType::Expr { lambda_expr, .. } => Ok(lambda_expr),
57 }
58 }
59}
60
61impl<T: LambdaLanguageOfThought> ExprOrType<'_, T> {
62 fn parent(&self) -> Option<usize> {
63 match self {
64 ExprOrType::Type { parent, .. } | ExprOrType::Expr { parent, .. } => *parent,
65 }
66 }
67
68 fn is_type(&self) -> bool {
69 matches!(self, ExprOrType::Type { .. })
70 }
71}
72
73#[derive(Debug, Clone, Eq, PartialEq)]
74struct UnfinishedLambdaPool<'src, T: LambdaLanguageOfThought> {
75 pool: Vec<ExprOrType<'src, T>>,
76}
77
78impl<T: LambdaLanguageOfThought> Default for UnfinishedLambdaPool<'_, T> {
79 fn default() -> Self {
80 Self { pool: vec![] }
81 }
82}
83
84impl<'src, T: LambdaLanguageOfThought + Clone> UnfinishedLambdaPool<'src, T> {
85 fn add_expr<'a>(&mut self, expr: PossibleExpr<'a, 'src, T>, c: &mut Context, t: &LambdaType) {
86 let (mut lambda_expr, app_details) = expr.into_expr();
87 c.depth += 1;
88 c.open_nodes += lambda_expr.n_children();
89 c.open_nodes -= 1;
90 let parent = self.pool[c.position].parent();
91 match &mut lambda_expr {
92 LambdaExpr::Lambda(body, arg) => {
93 c.add_lambda(arg);
94 *body = LambdaExprRef::new(self.pool.len());
95 self.pool.push(ExprOrType::Type {
96 lambda_type: t.rhs().unwrap().clone(),
97 parent: Some(c.position),
98 is_app_subformula: false,
99 });
100 }
101 LambdaExpr::BoundVariable(b, _) => {
102 c.use_bvar(*b);
103 }
104 LambdaExpr::FreeVariable(..) => (),
105 LambdaExpr::Application {
106 subformula,
107 argument,
108 } => {
109 *subformula = LambdaExprRef::new(self.pool.len());
110 *argument = LambdaExprRef::new(self.pool.len() + 1);
111 let (subformula, argument) = app_details.unwrap();
112 self.pool.push(ExprOrType::Type {
113 lambda_type: subformula,
114 parent: Some(c.position),
115 is_app_subformula: true,
116 });
117 self.pool.push(ExprOrType::Type {
118 lambda_type: argument,
119 parent: Some(c.position),
120 is_app_subformula: false,
121 });
122 }
123 LambdaExpr::LanguageOfThoughtExpr(e) => {
124 let children_start = self.pool.len();
125 if let Some(t) = e.var_type() {
126 c.add_lambda(t);
127 }
128 self.pool
129 .extend(e.get_arguments().map(|lambda_type| ExprOrType::Type {
130 lambda_type,
131 parent: Some(c.position),
132 is_app_subformula: false,
133 }));
134 e.change_children((children_start..self.pool.len()).map(LambdaExprRef::new));
135 }
136 }
137 self.pool[c.position] = ExprOrType::Expr {
138 lambda_expr,
139 parent,
140 };
141 }
142}
143
144#[derive(Debug)]
145pub struct NormalEnumeration(BinaryHeap<Reverse<Context>>, VecDeque<ExprDetails>);
146
147impl EnumerationType for NormalEnumeration {
148 fn pop(&mut self) -> Option<Context> {
149 self.0.pop().map(|x| x.0)
150 }
151
152 fn push(&mut self, context: Context, _: bool) {
153 self.0.push(Reverse(context));
154 }
155
156 fn get_yield(&mut self) -> Option<ExprDetails> {
157 self.1.pop_front()
158 }
159
160 fn push_yield(&mut self, e: ExprDetails) {
161 self.1.push_back(e);
162 }
163
164 fn include(&mut self, n: usize) -> impl Iterator<Item = bool> + 'static {
165 std::iter::repeat_n(true, n)
166 }
167}
168
169impl ExprDetails {
170 #[allow(clippy::cast_precision_loss)]
171 fn score(&self) -> f64 {
172 (1.0 / (self.size as f64)) + if self.constant_function { 0.0 } else { 10.0 }
173 }
174
175 pub fn has_constant_function(&self) -> bool {
176 self.constant_function
177 }
178}
179
180#[derive(Debug, Clone, Copy, PartialEq)]
181struct KeyedExprDetails {
182 expr_details: ExprDetails,
183 k: f64,
184}
185
186impl Eq for KeyedExprDetails {}
187
188impl PartialOrd for KeyedExprDetails {
189 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
190 Some(self.cmp(other))
191 }
192}
193
194impl Ord for KeyedExprDetails {
195 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
196 other.k.partial_cmp(&self.k).unwrap()
198 }
199}
200impl KeyedExprDetails {
201 fn new(expr_details: ExprDetails, rng: &mut impl Rng) -> Self {
202 let u: f64 = rng.random();
203 KeyedExprDetails {
204 expr_details,
205 k: u.powf(1.0 / expr_details.score()),
206 }
207 }
208}
209
210#[derive(Debug, Clone, PartialEq)]
211struct RandomPQ(Context, f64);
212
213impl Eq for RandomPQ {}
214
215impl RandomPQ {
216 fn new(c: Context, rng: &mut impl Rng) -> Self {
217 RandomPQ(c, rng.random())
218 }
219}
220
221#[derive(Debug)]
222struct ProbabilisticEnumeration<'a, R: Rng, F>
223where
224 F: Fn(&ExprDetails) -> bool,
225{
226 rng: &'a mut R,
227 reservoir_size: usize,
228 reservoir: BinaryHeap<KeyedExprDetails>,
229 backups: Vec<Context>,
230 pq: BinaryHeap<RandomPQ>,
231 filter: F,
232 n_seen: usize,
233 done: bool,
234}
235impl<R: Rng, F> ProbabilisticEnumeration<'_, R, F>
236where
237 F: Fn(&ExprDetails) -> bool,
238{
239 fn threshold(&self) -> Option<f64> {
240 self.reservoir.peek().map(|x| x.k)
241 }
242
243 fn new<'a, 'src, T: LambdaLanguageOfThought, E: Fn(&Context) -> bool>(
244 reservoir_size: usize,
245 t: &LambdaType,
246 possible_expressions: &'a PossibleExpressions<'src, T>,
247 eager_filter: E,
248 filter: F,
249 rng: &'a mut R,
250 ) -> LambdaEnumerator<'a, 'src, T, E, ProbabilisticEnumeration<'a, R, F>> {
251 let context = Context::new(0, vec![]);
252 let mut pq = BinaryHeap::default();
253 pq.push(RandomPQ::new(context, rng));
254 let pools = vec![UnfinishedLambdaPool {
255 pool: vec![ExprOrType::Type {
256 lambda_type: t.clone(),
257 parent: None,
258 is_app_subformula: false,
259 }],
260 }];
261
262 LambdaEnumerator {
263 pools,
264 possible_expressions,
265 eager_filter,
266 pq: ProbabilisticEnumeration {
267 rng,
268 reservoir_size,
269 reservoir: BinaryHeap::default(),
270 backups: vec![],
271 filter,
272 pq,
273 n_seen: 0,
274 done: false,
275 },
276 }
277 }
278}
279
280impl<R: Rng, F> EnumerationType for ProbabilisticEnumeration<'_, R, F>
281where
282 F: Fn(&ExprDetails) -> bool,
283{
284 fn pop(&mut self) -> Option<Context> {
285 self.pq.pop().map(|x| x.0).or_else(|| {
287 (0..self.backups.len()).choose(self.rng).and_then(|index| {
288 let last_item = self.backups.len() - 1;
289 self.backups.swap(index, last_item);
290 self.backups.pop()
291 })
292 })
293 }
294
295 fn push(&mut self, context: Context, included: bool) {
296 if included {
297 self.pq.push(RandomPQ::new(context, &mut self.rng));
298 } else {
299 self.backups.push(context);
300 }
301 }
302
303 fn get_yield(&mut self) -> Option<ExprDetails> {
304 if (self.done || self.pq.is_empty())
305 && let Some(x) = self.reservoir.pop()
306 {
307 Some(x.expr_details)
308 } else {
309 None
310 }
311 }
312
313 fn push_yield(&mut self, e: ExprDetails) {
314 let e = KeyedExprDetails::new(e, &mut self.rng);
315 if (self.filter)(&e.expr_details) {
316 self.n_seen += 1;
317 if self.reservoir_size > self.reservoir.len() {
318 self.reservoir.push(e);
319 } else if let Some(t) = self.threshold()
320 && e.k > t
321 {
322 self.reservoir.pop();
323 self.reservoir.push(e);
324 }
325 if self.n_seen >= self.reservoir_size * 20 {
326 self.pq.clear();
327 self.done = true;
328 }
329 }
330 }
331
332 fn include(&mut self, n: usize) -> impl Iterator<Item = bool> + 'static {
333 let x = (0..n).sample(self.rng, (n / 2).max(1));
334 let mut v = vec![false; n];
335 for i in x {
336 v[i] = true;
337 }
338 v.into_iter()
339 }
340}
341
342#[derive(Debug)]
343pub struct LambdaEnumerator<'a, 'src, T: LambdaLanguageOfThought, F, E = NormalEnumeration> {
345 pools: Vec<UnfinishedLambdaPool<'src, T>>,
346 possible_expressions: &'a PossibleExpressions<'src, T>,
347 eager_filter: F,
348 pq: E,
349}
350
351#[derive(Debug, Clone, Copy, Eq, PartialEq)]
353pub struct ExprDetails {
354 id: usize,
355 constant_function: bool,
356 root: LambdaExprRef,
357 size: usize,
358}
359
360impl ExprDetails {
361 pub fn size(&self) -> usize {
363 self.size
364 }
365}
366
367#[derive(Debug, Clone, Eq, PartialEq)]
368pub struct TypeAgnosticSampler<'src, T: LambdaLanguageOfThought> {
370 type_to_sampler: HashMap<LambdaType, (usize, LambdaSampler<'src, T>)>,
371 max_expr: usize,
372 max_types: usize,
373 possible_expressions: PossibleExpressions<'src, T>,
374}
375
376impl<'src> TypeAgnosticSampler<'src, Expr<'src>> {
377 #[allow(clippy::missing_panics_doc)]
378 pub fn sample(
380 &mut self,
381 lambda_type: LambdaType,
382 rng: &mut impl Rng,
383 ) -> RootedLambdaPool<'src, Expr<'src>> {
384 let (counts, exprs) = self
385 .type_to_sampler
386 .entry(lambda_type)
387 .or_insert_with_key(|t| {
388 (
389 1,
390 RootedLambdaPool::sampler(t, &self.possible_expressions, self.max_expr),
391 )
392 });
393 *counts += 1;
394 let sample = exprs.sample(rng);
395
396 if self.type_to_sampler.len() > self.max_types {
397 let (_, k) = self
398 .type_to_sampler
399 .iter()
400 .map(|(k, (n_visits, _))| (n_visits, k))
401 .min_by_key(|x| x.0)
402 .unwrap();
403
404 let t = k.clone();
405 self.type_to_sampler.remove(&t);
406 }
407
408 sample
409 }
410
411 #[must_use]
413 pub fn possibles(&self) -> &PossibleExpressions<'src, Expr<'src>> {
414 &self.possible_expressions
415 }
416}
417
418impl<'src, T: LambdaLanguageOfThought + Clone> RootedLambdaPool<'src, T> {
419 #[must_use]
424 pub fn typeless_sampler(
425 possible_expressions: PossibleExpressions<'src, T>,
426 max_expr: usize,
427 max_types: usize,
428 ) -> TypeAgnosticSampler<'src, T> {
429 assert!(max_types >= 1);
430 assert!(max_expr >= 1);
431 TypeAgnosticSampler {
432 possible_expressions,
433 max_expr,
434 max_types,
435 type_to_sampler: HashMap::default(),
436 }
437 }
438}
439
440#[derive(Debug, Clone, Eq, PartialEq)]
442pub struct LambdaSampler<'src, T: LambdaLanguageOfThought> {
443 lambdas: Vec<RootedLambdaPool<'src, T>>,
444 expr_details: Vec<ExprDetails>,
445}
446
447impl<'src, T: LambdaLanguageOfThought + Clone> Distribution<RootedLambdaPool<'src, T>>
448 for LambdaSampler<'src, T>
449{
450 fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> RootedLambdaPool<'src, T> {
451 let w = WeightedIndex::new(self.expr_details.iter().map(ExprDetails::score)).unwrap();
452 let i = w.sample(rng);
453 self.lambdas
454 .get(i)
455 .expect("The Lambda Sampler has no lambdas to sample :(")
456 .clone()
457 }
458}
459
460pub trait EnumerationType {
463 fn pop(&mut self) -> Option<Context>;
464 fn push(&mut self, context: Context, included: bool);
465 fn get_yield(&mut self) -> Option<ExprDetails>;
466 fn push_yield(&mut self, e: ExprDetails);
467 fn include(&mut self, n: usize) -> impl Iterator<Item = bool> + 'static;
468}
469
470fn try_yield<'src, T, F, E>(
471 x: &mut LambdaEnumerator<'_, 'src, T, F, E>,
472) -> Option<(RootedLambdaPool<'src, T>, ExprDetails)>
473where
474 T: LambdaLanguageOfThought,
475 E: EnumerationType,
476{
477 if let Some(item) = x.pq.get_yield() {
478 let p = std::mem::take(&mut x.pools[item.id]);
479 return Some((
480 RootedLambdaPool {
481 pool: LambdaPool(
482 p.pool
483 .into_iter()
484 .map(|x| LambdaExpr::try_from(x).unwrap())
485 .collect(),
486 ),
487 root: item.root,
488 },
489 item,
490 ));
491 }
492 None
493}
494
495impl<'a, 'src, T, F, E> LambdaEnumerator<'a, 'src, T, F, E>
496where
497 T: LambdaLanguageOfThought + Clone + Debug,
498 F: Fn(&Context) -> bool,
499 E: EnumerationType,
500{
501 fn push(&mut self, c: Context, included: bool) {
502 if (self.eager_filter)(&c) {
503 self.pq.push(c, included);
504 } else {
505 self.pools[c.pool_index] = UnfinishedLambdaPool::default();
506 }
507 }
508
509 pub fn eager_filter<F2>(self, eager_filter: F2) -> LambdaEnumerator<'a, 'src, T, F2, E> {
511 let LambdaEnumerator {
512 pools,
513 possible_expressions,
514 eager_filter: _,
515 pq,
516 } = self;
517
518 LambdaEnumerator {
519 pools,
520 possible_expressions,
521 eager_filter,
522 pq,
523 }
524 }
525}
526
527impl<'src, F, E> Iterator for LambdaEnumerator<'_, 'src, Expr<'src>, F, E>
528where
529 F: Fn(&Context) -> bool,
530 E: EnumerationType,
531{
532 type Item = (RootedLambdaPool<'src, Expr<'src>>, ExprDetails);
533
534 #[allow(clippy::too_many_lines)]
535 fn next(&mut self) -> Option<Self::Item> {
536 if let Some(x) = try_yield(self) {
537 return Some(x);
538 }
539
540 while let Some(mut c) = self.pq.pop() {
541 if let Some(x) = try_yield(self) {
542 self.push(c, true);
543 return Some(x);
544 }
545 let (possibles, lambda_type) = match &self.pools[c.pool_index].pool[c.position] {
546 ExprOrType::Type {
547 lambda_type,
548 is_app_subformula,
549 parent,
550 } => {
551 let mut possibles = self.possible_expressions.possibilities(
552 lambda_type,
553 *is_app_subformula,
554 &c,
555 );
556
557 if let Some(p) = parent
560 && let ExprOrType::Expr {
561 lambda_expr:
562 LambdaExpr::LanguageOfThoughtExpr(Expr::Quantifier {
563 var_type,
564 restrictor,
565 ..
566 }),
567 ..
568 } = self.pools[c.pool_index].pool[*p]
569 && restrictor.0 == u32::try_from(c.position).unwrap()
570 {
571 possibles.push(PossibleExpr::new_borrowed(match var_type {
572 ActorOrEvent::Actor => &LambdaExpr::LanguageOfThoughtExpr(
573 Expr::Constant(Constant::Everyone),
574 ),
575 ActorOrEvent::Event => &LambdaExpr::LanguageOfThoughtExpr(
576 Expr::Constant(Constant::EveryEvent),
577 ),
578 }));
579 }
580
581 (possibles, lambda_type.clone())
582 }
583 ExprOrType::Expr {
584 lambda_expr,
585 parent,
586 } => {
587 if let Some(child) = lambda_expr
591 .get_children()
592 .map(|x| x.0 as usize)
593 .find(|child| self.pools[c.pool_index].pool[*child].is_type())
594 {
595 c.position = child;
596 self.pq.push(c, true);
597 continue;
598 }
599
600 if lambda_expr.inc_depth() {
601 c.pop_lambda();
602 }
603
604 if let Some(p) = parent {
605 c.position = *p;
606 self.pq.push(c, true);
607 continue;
608 }
609 self.pq.push_yield(ExprDetails {
611 id: c.pool_index,
612 root: LambdaExprRef(u32::try_from(c.position).unwrap()),
613 size: c.depth,
614 constant_function: c.is_constant(),
615 });
616 continue;
617 }
618 };
619
620 let n = possibles.len();
621 let included = self.pq.include(n);
622 if n == 0 {
623 continue;
624 }
625 let n_pools = self.pools.len();
626 if n_pools.is_multiple_of(10_000) {
627 self.pools.shrink_to_fit();
628 }
629 for _ in 0..n.saturating_sub(1) {
630 self.pools.push(self.pools[c.pool_index].clone());
631 }
632
633 let positions =
634 std::iter::once(c.pool_index).chain(n_pools..n_pools + n.saturating_sub(1));
635
636 for (((expr, pool_id), mut c), included) in possibles
637 .into_iter()
638 .zip(positions)
639 .zip(std::iter::repeat_n(c, n))
640 .zip(included)
641 {
642 c.pool_index = pool_id;
643 let pool = self.pools.get_mut(pool_id).unwrap();
644 pool.add_expr(expr, &mut c, &lambda_type);
645 self.push(c, included);
646 }
647
648 if let Some(x) = try_yield(self) {
649 return Some(x);
650 }
651 }
652
653 if let Some(x) = try_yield(self) {
655 return Some(x);
656 }
657 None
658 }
659}
660
661impl<'src> RootedLambdaPool<'src, Expr<'src>> {
662 pub fn resample_from_expr<'a>(
670 &mut self,
671 possible_expressions: &'a PossibleExpressions<'src, Expr<'src>>,
672 helpers: Option<&HashMap<LambdaType, Vec<RootedLambdaPool<'src, Expr<'src>>>>>,
673 rng: &mut impl Rng,
674 ) -> Result<(), LambdaError> {
675 let position = LambdaExprRef(u32::try_from((0..self.len()).choose(rng).unwrap()).unwrap());
676 let t = self.pool.get_type(position)?;
677
678 let pool = if let Some(helpers) = helpers
679 && rng.random_bool(0.2)
680 && let Some(v) = helpers.get(&t)
681 && !v.is_empty()
682 {
683 let pool = v.choose(rng).unwrap();
684 pool.clone()
685 } else {
686 let (pool, _) = self
687 .probabilistic_enumerate_from_expr(
688 position,
689 possible_expressions,
690 |_| true,
691 |_| true,
692 rng,
693 )?
694 .next()
695 .unwrap();
696 pool
697 };
698
699 let offset = u32::try_from(self.len()).unwrap();
700 let new_root = pool.root.0 + offset;
701 self.pool.0.extend(pool.pool.0.into_iter().map(|mut x| {
702 let children: Vec<_> = x
703 .get_children()
704 .map(|x| LambdaExprRef(x.0 + offset))
705 .collect();
706 x.change_children(children.into_iter());
707 x
708 }));
709 self.pool.0.swap(position.0 as usize, new_root as usize);
710 self.cleanup();
711 Ok(())
712 }
713
714 fn probabilistic_enumerate_from_expr<'a, R, E, F>(
715 &self,
716 position: LambdaExprRef,
717 possible_expressions: &'a PossibleExpressions<'src, Expr<'src>>,
718 eager_filter: E,
719 filter: F,
720 rng: &'a mut R,
721 ) -> Result<
722 LambdaEnumerator<'a, 'src, Expr<'src>, E, ProbabilisticEnumeration<'a, R, F>>,
723 TypeError,
724 >
725 where
726 R: Rng,
727 F: Fn(&ExprDetails) -> bool,
728 E: Fn(&Context) -> bool,
729 {
730 let (context, is_app_subformula) = Context::from_pos(self, position);
731 let output = self.pool.get_type(position)?;
732 let mut pq = BinaryHeap::default();
733 pq.push(RandomPQ::new(context, rng));
734 let pools = vec![UnfinishedLambdaPool {
735 pool: vec![ExprOrType::Type {
736 lambda_type: output,
737 parent: None,
738 is_app_subformula,
739 }],
740 }];
741 let enumerator = LambdaEnumerator {
742 pools,
743 possible_expressions,
744 eager_filter,
745 pq: ProbabilisticEnumeration {
746 rng,
747 reservoir_size: 1,
748 reservoir: BinaryHeap::default(),
749 done: false,
750 n_seen: 0,
751 filter,
752 backups: vec![],
753 pq,
754 },
755 };
756
757 Ok(enumerator)
758 }
759
760 pub fn enumerator_filter<'a, F: Fn(&Context) -> bool>(
762 t: &LambdaType,
763 filter: F,
764 possible_expressions: &'a PossibleExpressions<'src, Expr<'src>>,
765 ) -> LambdaEnumerator<'a, 'src, Expr<'src>, F> {
766 let context = Context::new(0, vec![]);
767 let mut pq = BinaryHeap::default();
768 pq.push(Reverse(context));
769 let pools = vec![UnfinishedLambdaPool {
770 pool: vec![ExprOrType::Type {
771 lambda_type: t.clone(),
772 parent: None,
773 is_app_subformula: false,
774 }],
775 }];
776
777 LambdaEnumerator {
778 pools,
779 possible_expressions,
780 eager_filter: filter,
781 pq: NormalEnumeration(pq, VecDeque::default()),
782 }
783 }
784
785 #[must_use]
787 pub fn enumerator<'a>(
788 t: &LambdaType,
789 possible_expressions: &'a PossibleExpressions<'src, Expr<'src>>,
790 ) -> LambdaEnumerator<'a, 'src, Expr<'src>, impl Fn(&'_ Context) -> bool> {
791 let context = Context::new(0, vec![]);
792 let mut pq = BinaryHeap::default();
793 pq.push(Reverse(context));
794 let pools = vec![UnfinishedLambdaPool {
795 pool: vec![ExprOrType::Type {
796 lambda_type: t.clone(),
797 parent: None,
798 is_app_subformula: false,
799 }],
800 }];
801
802 LambdaEnumerator {
803 pools,
804 possible_expressions,
805 eager_filter: |_: &Context| true,
806 pq: NormalEnumeration(pq, VecDeque::default()),
807 }
808 }
809
810 #[must_use]
812 pub fn sampler(
813 t: &LambdaType,
814 possible_expressions: &PossibleExpressions<'src, Expr<'src>>,
815 max_expr: usize,
816 ) -> LambdaSampler<'src, Expr<'src>> {
817 let enumerator = RootedLambdaPool::enumerator(t, possible_expressions);
818 let mut lambdas = Vec::with_capacity(max_expr);
819 let mut expr_details = Vec::with_capacity(max_expr);
820 for (lambda, expr_detail) in enumerator.take(max_expr) {
821 lambdas.push(lambda);
822 expr_details.push(expr_detail);
823 }
824
825 LambdaSampler {
826 lambdas,
827 expr_details,
828 }
829 }
830
831 pub fn random_expr(
836 t: &LambdaType,
837 possible_expressions: &PossibleExpressions<'src, Expr<'src>>,
838 rng: &mut impl Rng,
839 ) -> RootedLambdaPool<'src, Expr<'src>> {
840 ProbabilisticEnumeration::new(
841 1,
842 t,
843 possible_expressions,
844 |_: &Context| true,
845 |_| true,
846 rng,
847 )
848 .next()
849 .unwrap()
850 .0
851 }
852
853 pub fn random_expr_no_constant(
855 t: &LambdaType,
856 possible_expressions: &PossibleExpressions<'src, Expr<'src>>,
857 rng: &mut impl Rng,
858 ) -> Option<RootedLambdaPool<'src, Expr<'src>>> {
859 ProbabilisticEnumeration::new(
860 1,
861 t,
862 possible_expressions,
863 |_: &Context| true,
864 |e| !e.constant_function,
865 rng,
866 )
867 .next()
868 .map(|x| x.0)
869 }
870}
871
872impl<'src> RootedLambdaPool<'src, Expr<'src>> {
873 pub fn prune_quantifiers(&mut self) {
875 let quantifiers = self
876 .pool
877 .bfs_from(self.root)
878 .filter_map(|(i, _)| match self.get(i) {
879 LambdaExpr::LanguageOfThoughtExpr(Expr::Quantifier { subformula, .. }) => {
880 if self
881 .pool
882 .bfs_from(LambdaExprRef(subformula.0))
883 .any(|(x, d)| {
884 if let LambdaExpr::BoundVariable(v, _) = self.get(x) {
885 *v == d
886 } else {
887 false
888 }
889 })
890 {
891 None
892 } else {
893 Some((i, LambdaExprRef(subformula.0)))
894 }
895 }
896 _ => None,
897 })
898 .collect::<Vec<_>>();
899
900 for (quantifier, subformula) in quantifiers.into_iter().rev() {
902 *self.pool.get_mut(quantifier) = self.pool.get(subformula).clone();
903 self.pool.bfs_from_mut(quantifier).for_each(|(x, d, _)| {
904 if let LambdaExpr::BoundVariable(b_d, _) = x
905 && *b_d > d
906 {
907 *b_d -= 1;
908 }
909 });
910 }
911 self.root = self.pool.cleanup(self.root);
912 }
913
914 pub fn swap_expr(
922 &mut self,
923 possible_expressions: &PossibleExpressions<'src, Expr<'src>>,
924 rng: &mut impl Rng,
925 ) -> Result<(), TypeError> {
926 let position = LambdaExprRef(u32::try_from((0..self.len()).choose(rng).unwrap()).unwrap());
927
928 let (context, _) = Context::from_pos(self, position);
929
930 let output = self.pool.get_type(position)?;
931 let expr = self.pool.get(position);
932
933 let arguments: Vec<_> = expr
934 .get_children()
935 .map(|x| self.pool.get_type(x).unwrap())
936 .collect();
937
938 let mut replacements = possible_expressions.possiblities_fixed_children(
939 &output,
940 &arguments,
941 expr.var_type(),
942 &context,
943 );
944
945 if replacements.is_empty()
947 && let LambdaExpr::LanguageOfThoughtExpr(Expr::Quantifier {
948 var_type,
949 restrictor,
950 subformula,
951 ..
952 }) = expr
953 {
954 for quantifier in [Quantifier::Universal, Quantifier::Existential] {
955 replacements.push(std::borrow::Cow::Owned(LambdaExpr::LanguageOfThoughtExpr(
956 Expr::Quantifier {
957 quantifier,
958 var_type: *var_type,
959 restrictor: *restrictor,
960 subformula: *subformula,
961 },
962 )));
963 }
964 }
965 for c in [Constant::EveryEvent, Constant::Everyone] {
966 if replacements.is_empty()
967 && expr == &LambdaExpr::LanguageOfThoughtExpr(Expr::Constant(c))
968 {
969 replacements = possible_expressions.possiblities_fixed_children(
970 LambdaType::t(),
971 &arguments,
972 expr.var_type(),
973 &context,
974 );
975 replacements.push(std::borrow::Cow::Owned(LambdaExpr::LanguageOfThoughtExpr(
976 Expr::Constant(c),
977 )));
978 }
979 }
980
981 let choice = replacements.choose(rng).unwrap_or_else(|| {
982 panic!(
983 "There is no node with output {output} and arguments {}",
984 arguments
985 .into_iter()
986 .map(|x| x.to_string())
987 .collect::<Vec<_>>()
988 .join(", ")
989 )
990 });
991
992 let mut new_expr = choice.clone().into_owned();
993 new_expr.change_children(self.pool.get(position).get_children());
994 self.pool.0[position.0 as usize] = new_expr;
995
996 Ok(())
997 }
998
999 pub fn swap_subtree(&mut self, rng: &mut impl Rng) -> Result<(), TypeError> {
1008 let position = LambdaExprRef(u32::try_from((0..self.len()).choose(rng).unwrap()).unwrap());
1009 let (context, _) = Context::from_pos(self, position);
1010 let alt = context.find_compatible(self, position)?;
1011 if let Some(new_pos) = alt.choose(rng).copied() {
1012 let offset = self.pool.0.len();
1013 let mut lookup = HashMap::default();
1014 let mut new_pool: Vec<LambdaExpr<'src, Expr<'src>>> = self
1015 .pool
1016 .bfs_from(new_pos)
1017 .map(|(x, _)| {
1018 let n = lookup.len();
1019 lookup
1020 .entry(x)
1021 .or_insert(LambdaExprRef(u32::try_from(n + offset).unwrap()));
1022 self.get(x).clone()
1023 })
1024 .collect();
1025
1026 for x in &mut new_pool {
1027 let child: Vec<_> = x.get_children().map(|x| *lookup.get(&x).unwrap()).collect();
1028 x.change_children(child.into_iter());
1029 }
1030
1031 self.pool.0.extend(new_pool);
1032 self.pool.0.swap(position.0 as usize, offset);
1033 self.cleanup();
1034 }
1035 Ok(())
1036 }
1037}
1038
1039#[cfg(test)]
1040mod test {
1041
1042 use std::collections::HashSet;
1043
1044 use super::*;
1045 use crate::lambda::{LambdaPool, LambdaSummaryStats};
1046 use rand::SeedableRng;
1047 use rand_chacha::ChaCha8Rng;
1048
1049 #[test]
1050 fn prune_quantifier_test() -> anyhow::Result<()> {
1051 let mut pool =
1052 RootedLambdaPool::parse("some_e(x,all_e,AgentOf(a_2,e_1) & PatientOf(a_0,e_0))")?;
1053
1054 pool.prune_quantifiers();
1055 assert_eq!(pool.to_string(), "AgentOf(a_2, e_1) & PatientOf(a_0, e_0)");
1056
1057 let mut pool = RootedLambdaPool::parse(
1058 "some_e(x0, all_e, some(z, all_a, AgentOf(z, e_1) & PatientOf(a_0, e_0)))",
1059 )?;
1060
1061 pool.prune_quantifiers();
1062 assert_eq!(
1063 pool.to_string(),
1064 "some(x, all_a, AgentOf(x, e_1) & PatientOf(a_0, e_0))"
1065 );
1066
1067 let mut pool = RootedLambdaPool::parse("~every_e(z, pe_1, pa_2(a_0))")?;
1068
1069 pool.prune_quantifiers();
1070
1071 assert_eq!(pool.to_string(), "~pa_2(a_0)");
1072 let mut pool = RootedLambdaPool::new(
1073 LambdaPool(vec![
1074 LambdaExpr::Lambda(LambdaExprRef(1), LambdaType::E),
1075 LambdaExpr::Lambda(LambdaExprRef(2), LambdaType::T),
1076 LambdaExpr::LanguageOfThoughtExpr(Expr::Quantifier {
1077 quantifier: Quantifier::Universal,
1078 var_type: ActorOrEvent::Actor,
1079 restrictor: ExprRef(3),
1080 subformula: ExprRef(4),
1081 }),
1082 LambdaExpr::LanguageOfThoughtExpr(Expr::Constant(Constant::Property(
1083 "1",
1084 ActorOrEvent::Actor,
1085 ))),
1086 LambdaExpr::LanguageOfThoughtExpr(Expr::Quantifier {
1087 quantifier: Quantifier::Existential,
1088 var_type: ActorOrEvent::Actor,
1089 restrictor: ExprRef(5),
1090 subformula: ExprRef(6),
1091 }),
1092 LambdaExpr::LanguageOfThoughtExpr(Expr::Constant(Constant::Property(
1093 "0",
1094 ActorOrEvent::Actor,
1095 ))),
1096 LambdaExpr::LanguageOfThoughtExpr(Expr::Unary(
1097 MonOp::Property("3", ActorOrEvent::Event),
1098 ExprRef(7),
1099 )),
1100 LambdaExpr::BoundVariable(3, LambdaType::E),
1101 ]),
1102 LambdaExprRef(0),
1103 );
1104
1105 assert_eq!(
1106 pool.to_string(),
1107 "lambda e x lambda t phi every(y, pa_1, some(z, pa_0, pe_3(x)))"
1108 );
1109
1110 let mut parsed_pool = RootedLambdaPool::parse(
1111 "lambda e x lambda t phi every(y, pa_1, some(z, pa_0, pe_3(x)))",
1112 )?;
1113 parsed_pool.prune_quantifiers();
1114 pool.prune_quantifiers();
1115 assert_eq!(pool.to_string(), "lambda e x lambda t phi pe_3(x)");
1116
1117 Ok(())
1118 }
1119
1120 #[test]
1121 fn random_swap_tree() -> anyhow::Result<()> {
1122 let mut rng = ChaCha8Rng::seed_from_u64(2);
1123 let x = RootedLambdaPool::parse("lambda a x pa_cool(a_John) | pa_cool(x)")?;
1124 let mut h = HashSet::new();
1125 for _ in 0..100 {
1126 let mut z = x.clone();
1127 z.swap_subtree(&mut rng)?;
1128 h.insert(z.to_string());
1129 }
1130 for x in h.iter() {
1131 println!("{x}");
1132 }
1133 assert!(h.contains("lambda a x pa_cool(x)"));
1134 assert!(h.contains("lambda a x pa_cool(a_John)"));
1135 Ok(())
1136 }
1137
1138 #[test]
1139 fn randomn_swap() -> anyhow::Result<()> {
1140 let mut rng = ChaCha8Rng::seed_from_u64(2);
1141 let actors = ["0", "1"];
1142 let available_event_properties = ["2", "3", "4"];
1143 let possible_expressions = PossibleExpressions::new(
1144 &actors,
1145 &available_event_properties,
1146 &available_event_properties,
1147 );
1148 for _ in 0..200 {
1149 let t = LambdaType::random(&mut rng);
1150 println!("{t}");
1151 let mut pool = RootedLambdaPool::random_expr(&t, &possible_expressions, &mut rng);
1152 println!("{t}: {pool}");
1153 assert_eq!(t, pool.get_type()?);
1154 println!("{pool:?}");
1155 pool.swap_expr(&possible_expressions, &mut rng)?;
1156 println!("{pool:?}");
1157 println!("{t}: {pool}");
1158 assert_eq!(t, pool.get_type()?);
1159 }
1160 let p =
1161 RootedLambdaPool::parse("lambda <a,t> P some(x, P(x), some_e(y, pe_2(y), pe_4(y)))")?;
1162 let t = p.get_type()?;
1163 for _ in 0..1000 {
1164 let mut pool = p.clone();
1165 assert_eq!(t, pool.get_type()?);
1166
1167 pool.swap_expr(&possible_expressions, &mut rng)?;
1168 dbg!(&pool);
1169 println!("{t}: {pool}");
1170 assert_eq!(t, pool.get_type()?);
1171 }
1172
1173 Ok(())
1174 }
1175
1176 #[test]
1177 fn enumerate() -> anyhow::Result<()> {
1178 let actors = ["john"];
1179 let actor_properties = ["a"];
1180 let event_properties = ["e"];
1181 let possibles = PossibleExpressions::new(&actors, &actor_properties, &event_properties);
1182
1183 let t = LambdaType::from_string("<<a,t>,t>")?;
1184
1185 let p = RootedLambdaPool::enumerator(&t, &possibles);
1186 let pools: HashSet<_> = p
1187 .filter_map(|(p, x)| {
1188 if !x.has_constant_function() {
1189 Some(p.to_string())
1190 } else {
1191 None
1192 }
1193 })
1194 .take(20)
1195 .collect();
1196 for p in pools.iter() {
1197 println!("{p}");
1198 }
1199 assert!(pools.contains("lambda <a,t> P P(a_john)"));
1200 for p in pools {
1201 println!("{p}");
1202 }
1203
1204 Ok(())
1205 }
1206
1207 #[test]
1235 fn enumerate_weirds() -> anyhow::Result<()> {
1236 let actors = &["1"];
1237 let actor_properties = &["1"];
1238 let event_properties = &["1"];
1239 let poss = PossibleExpressions::new(actors, actor_properties, event_properties);
1240
1241 let t = "<a,<e,e>>";
1242
1243 for (p, d) in RootedLambdaPool::enumerator(&LambdaType::from_string(t)?, &poss)
1244 .filter(|(_, e)| !e.has_constant_function())
1245 .take(5)
1246 {
1247 println!("{p} {d:?} {}", d.has_constant_function());
1248 }
1249 Ok(())
1250 }
1251
1252 #[test]
1253 fn random_expr() -> anyhow::Result<()> {
1254 let actors = ["john"];
1255 let actor_properties = ["a"];
1256 let event_properties = ["e"];
1257 let possibles = PossibleExpressions::new(&actors, &actor_properties, &event_properties);
1258 let mut rng = ChaCha8Rng::seed_from_u64(0);
1259
1260 let map = [(LambdaType::A, vec![RootedLambdaPool::parse("a_john")?])]
1261 .into_iter()
1262 .collect();
1263
1264 for _ in 0..100 {
1265 let t = LambdaType::random(&mut rng);
1266 println!("sampling: {t}");
1267 let mut pool = RootedLambdaPool::random_expr(&t, &possibles, &mut rng);
1268 assert_eq!(t, pool.get_type()?);
1269 let s = pool.to_string();
1270 println!("{s}");
1271 let mut pool2 = RootedLambdaPool::parse(s.as_str())?;
1272 assert_eq!(s, pool2.to_string());
1273 println!("{pool}");
1274 pool2.resample_from_expr(&possibles, None, &mut rng)?;
1275 assert_eq!(pool2.get_type()?, t);
1276 pool2.resample_from_expr(&possibles, Some(&map), &mut rng)?;
1277 assert_eq!(pool2.get_type()?, t);
1278
1279 pool.swap_subtree(&mut rng)?;
1280 assert_eq!(pool.get_type()?, t);
1281 }
1282 let t = LambdaType::from_string("<a,<a,t>>")?;
1283 for _ in 0..100 {
1284 let pool = RootedLambdaPool::random_expr(&t, &possibles, &mut rng);
1285 println!("{pool}");
1286 }
1287 Ok(())
1288 }
1289
1290 #[test]
1291 fn constant_exprs() -> anyhow::Result<()> {
1292 let actors = ["john", "mary", "phil", "sue"];
1293 let actor_properties = ["a"];
1294 let event_properties = ["e"];
1295 let possibles = PossibleExpressions::new(&actors, &actor_properties, &event_properties);
1296 let mut rng = ChaCha8Rng::seed_from_u64(0);
1297
1298 let v: HashSet<_> = RootedLambdaPool::enumerator(LambdaType::a(), &possibles)
1299 .map(|(x, _)| x.to_string())
1300 .take(4)
1301 .collect();
1302
1303 assert_eq!(v, HashSet::from(actors.map(|x| format!("a_{x}"))));
1304 println!("{v:?}");
1305
1306 let mut constants = 0;
1307 for _ in 0..100 {
1308 let t = LambdaType::from_string("<a, <a,t>>")?;
1309 println!("sampling: {t}");
1310 let pool = RootedLambdaPool::random_expr(&t, &possibles, &mut rng);
1311 let x = pool.stats();
1312 match x {
1313 LambdaSummaryStats::WellFormed {
1314 constant_function, ..
1315 } => {
1316 if constant_function {
1317 constants += 1;
1318 }
1319 }
1320 LambdaSummaryStats::Malformed => todo!(),
1321 }
1322 }
1323 println!("{constants}");
1324
1325 Ok(())
1326 }
1327
1328 #[test]
1329 fn random_expr_counts() -> anyhow::Result<()> {
1330 let actors = ["john", "mary", "phil", "sue"];
1331 let actor_properties = ["a"];
1332 let event_properties = ["e"];
1333 let possibles = PossibleExpressions::new(&actors, &actor_properties, &event_properties);
1334 let mut rng = ChaCha8Rng::seed_from_u64(1);
1335
1336 let mut counts: HashMap<_, usize> = HashMap::default();
1337 for _ in 0..1000 {
1338 let t = LambdaType::A;
1339 let pool = RootedLambdaPool::random_expr(&t, &possibles, &mut rng);
1340 assert_eq!(t, pool.get_type()?);
1341 let s = pool.to_string();
1342 *counts.entry(s).or_default() += 1;
1343 }
1344 assert_eq!(counts.len(), 4);
1345 dbg!(&counts);
1346 for (_, v) in counts.iter() {
1347 assert!(200 <= *v && *v <= 300);
1348 }
1349
1350 counts.clear();
1351 for _ in 0..1000 {
1352 let t = LambdaType::at().clone();
1353 let pool = RootedLambdaPool::random_expr(&t, &possibles, &mut rng);
1354 assert_eq!(t, pool.get_type()?);
1355 let s = pool.to_string();
1356 *counts.entry(s).or_default() += 1;
1357 }
1358 dbg!(counts);
1359 Ok(())
1360 }
1361}