1use std::{
2 cmp::Reverse,
3 collections::{BinaryHeap, VecDeque},
4 fmt::Debug,
5};
6
7use ahash::HashMap;
8use chumsky::container::Container;
9use rand::{
10 Rng,
11 distr::{Distribution, weighted::WeightedIndex},
12 seq::{IndexedRandom, IteratorRandom},
13};
14use thiserror::Error;
15
16use super::*;
17use crate::lambda::{
18 LambdaError, LambdaExpr, LambdaExprRef, LambdaLanguageOfThought, LambdaPool,
19 types::{LambdaType, TypeError},
20};
21
22mod context;
23mod samplers;
24pub use context::Context;
25use samplers::PossibleExpr;
26pub use samplers::PossibleExpressions;
27
28#[derive(Debug, Error, Clone)]
29pub struct ExprOrTypeError();
30
31impl Display for ExprOrTypeError {
32 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
33 write!(f, "This ExprOrType is not an Expr!")
34 }
35}
36
37#[derive(Debug, Clone, Eq, PartialEq)]
38enum ExprOrType<'src, T: LambdaLanguageOfThought> {
39 Type {
40 lambda_type: LambdaType,
41 parent: Option<usize>,
42 is_app_subformula: bool,
43 },
44 Expr {
45 lambda_expr: LambdaExpr<'src, T>,
46 parent: Option<usize>,
47 },
48}
49
50impl<'src, T: LambdaLanguageOfThought> TryFrom<ExprOrType<'src, T>> for LambdaExpr<'src, T> {
51 type Error = ExprOrTypeError;
52
53 fn try_from(value: ExprOrType<'src, T>) -> Result<Self, Self::Error> {
54 match value {
55 ExprOrType::Type { .. } => Err(ExprOrTypeError()),
56 ExprOrType::Expr { lambda_expr, .. } => Ok(lambda_expr),
57 }
58 }
59}
60
61impl<'src, T: LambdaLanguageOfThought> ExprOrType<'src, T> {
62 fn parent(&self) -> Option<usize> {
63 match self {
64 ExprOrType::Type { parent, .. } | ExprOrType::Expr { parent, .. } => *parent,
65 }
66 }
67
68 fn is_type(&self) -> bool {
69 matches!(self, ExprOrType::Type { .. })
70 }
71}
72
73#[derive(Debug, Clone, Eq, PartialEq)]
74struct UnfinishedLambdaPool<'src, T: LambdaLanguageOfThought> {
75 pool: Vec<ExprOrType<'src, T>>,
76}
77
78impl<'src, T: LambdaLanguageOfThought> Default for UnfinishedLambdaPool<'src, T> {
79 fn default() -> Self {
80 Self { pool: vec![] }
81 }
82}
83
84impl<'src, T: LambdaLanguageOfThought + Clone> UnfinishedLambdaPool<'src, T> {
85 fn add_expr<'a>(&mut self, expr: PossibleExpr<'a, 'src, T>, c: &mut Context, t: &LambdaType) {
86 let (mut lambda_expr, app_details) = expr.into_expr();
87 c.depth += 1;
88 c.open_nodes += lambda_expr.n_children();
89 c.open_nodes -= 1;
90 let parent = self.pool[c.position].parent();
91 match &mut lambda_expr {
92 LambdaExpr::Lambda(body, arg) => {
93 c.add_lambda(arg);
94 *body = LambdaExprRef(self.pool.len() as u32);
95 self.pool.push(ExprOrType::Type {
96 lambda_type: t.rhs().unwrap().clone(),
97 parent: Some(c.position),
98 is_app_subformula: false,
99 })
100 }
101 LambdaExpr::BoundVariable(b, _) => {
102 c.use_bvar(*b);
103 }
104 LambdaExpr::FreeVariable(..) => (),
105 LambdaExpr::Application {
106 subformula,
107 argument,
108 } => {
109 *subformula = LambdaExprRef(self.pool.len() as u32);
110 *argument = LambdaExprRef((self.pool.len() + 1) as u32);
111 let (subformula, argument) = app_details.unwrap();
112 self.pool.push(ExprOrType::Type {
113 lambda_type: subformula,
114 parent: Some(c.position),
115 is_app_subformula: true,
116 });
117 self.pool.push(ExprOrType::Type {
118 lambda_type: argument,
119 parent: Some(c.position),
120 is_app_subformula: false,
121 });
122 }
123 LambdaExpr::LanguageOfThoughtExpr(e) => {
124 let children_start = self.pool.len();
125 if let Some(t) = e.var_type() {
126 c.add_lambda(t);
127 }
128 self.pool
129 .extend(e.get_arguments().map(|lambda_type| ExprOrType::Type {
130 lambda_type,
131 parent: Some(c.position),
132 is_app_subformula: false,
133 }));
134 e.change_children(
135 (children_start..self.pool.len()).map(|x| LambdaExprRef(x as u32)),
136 );
137 }
138 }
139 self.pool[c.position] = ExprOrType::Expr {
140 lambda_expr,
141 parent,
142 };
143 }
144}
145
146#[derive(Debug)]
147pub struct NormalEnumeration(BinaryHeap<Reverse<Context>>, VecDeque<ExprDetails>);
148
149impl EnumerationType for NormalEnumeration {
150 fn pop(&mut self) -> Option<Context> {
151 self.0.pop().map(|x| x.0)
152 }
153
154 fn push(&mut self, context: Context, _: bool) {
155 self.0.push(Reverse(context))
156 }
157
158 fn get_yield(&mut self) -> Option<ExprDetails> {
159 self.1.pop_front()
160 }
161
162 fn push_yield(&mut self, e: ExprDetails) {
163 self.1.push(e);
164 }
165
166 fn include(&mut self, n: usize) -> impl Iterator<Item = bool> + 'static {
167 std::iter::repeat_n(true, n)
168 }
169}
170
171impl ExprDetails {
172 fn score(&self) -> f64 {
173 (1.0 / (self.size as f64))
174 + match self.constant_function {
175 true => 0.0,
176 false => 10.0,
177 }
178 }
179
180 pub fn has_constant_function(&self) -> bool {
181 self.constant_function
182 }
183}
184
185#[derive(Debug, Clone, Copy, PartialEq)]
186struct KeyedExprDetails {
187 expr_details: ExprDetails,
188 k: f64,
189}
190
191impl Eq for KeyedExprDetails {}
192
193impl PartialOrd for KeyedExprDetails {
194 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
195 Some(self.cmp(other))
196 }
197}
198
199impl Ord for KeyedExprDetails {
200 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
201 other.k.partial_cmp(&self.k).unwrap()
203 }
204}
205impl KeyedExprDetails {
206 fn new(expr_details: ExprDetails, rng: &mut impl Rng) -> Self {
207 let u: f64 = rng.random();
208 KeyedExprDetails {
209 expr_details,
210 k: u.powf(1.0 / expr_details.score()),
211 }
212 }
213}
214
215#[derive(Debug, Clone, PartialEq)]
216struct RandomPQ(Context, f64);
217
218impl Eq for RandomPQ {}
219
220impl RandomPQ {
221 fn new(c: Context, rng: &mut impl Rng) -> Self {
222 RandomPQ(c, rng.random())
223 }
224}
225
226#[derive(Debug)]
227struct ProbabilisticEnumeration<'a, R: Rng, F>
228where
229 F: Fn(&ExprDetails) -> bool,
230{
231 rng: &'a mut R,
232 reservoir_size: usize,
233 reservoir: BinaryHeap<KeyedExprDetails>,
234 backups: Vec<Context>,
235 pq: BinaryHeap<RandomPQ>,
236 filter: F,
237 n_seen: usize,
238 done: bool,
239}
240impl<R: Rng, F> ProbabilisticEnumeration<'_, R, F>
241where
242 F: Fn(&ExprDetails) -> bool,
243{
244 fn threshold(&self) -> Option<f64> {
245 self.reservoir.peek().map(|x| x.k)
246 }
247
248 fn new<'a, 'src, T: LambdaLanguageOfThought, E: Fn(&Context) -> bool>(
249 reservoir_size: usize,
250 t: &LambdaType,
251 possible_expressions: &'a PossibleExpressions<'src, T>,
252 eager_filter: E,
253 filter: F,
254 rng: &'a mut R,
255 ) -> LambdaEnumerator<'a, 'src, T, E, ProbabilisticEnumeration<'a, R, F>> {
256 let context = Context::new(0, vec![]);
257 let mut pq = BinaryHeap::default();
258 pq.push(RandomPQ::new(context, rng));
259 let pools = vec![UnfinishedLambdaPool {
260 pool: vec![ExprOrType::Type {
261 lambda_type: t.clone(),
262 parent: None,
263 is_app_subformula: false,
264 }],
265 }];
266
267 LambdaEnumerator {
268 pools,
269 possible_expressions,
270 eager_filter,
271 pq: ProbabilisticEnumeration {
272 rng,
273 reservoir_size,
274 reservoir: BinaryHeap::default(),
275 backups: vec![],
276 filter,
277 pq,
278 n_seen: 0,
279 done: false,
280 },
281 }
282 }
283}
284
285impl<R: Rng, F> EnumerationType for ProbabilisticEnumeration<'_, R, F>
286where
287 F: Fn(&ExprDetails) -> bool,
288{
289 fn pop(&mut self) -> Option<Context> {
290 self.pq.pop().map(|x| x.0).or_else(|| {
292 (0..self.backups.len()).choose(self.rng).and_then(|index| {
293 let last_item = self.backups.len() - 1;
294 self.backups.swap(index, last_item);
295 self.backups.pop()
296 })
297 })
298 }
299
300 fn push(&mut self, context: Context, included: bool) {
301 if included {
302 self.pq.push(RandomPQ::new(context, &mut self.rng));
303 } else {
304 self.backups.push(context);
305 }
306 }
307
308 fn get_yield(&mut self) -> Option<ExprDetails> {
309 if (self.done || self.pq.is_empty())
310 && let Some(x) = self.reservoir.pop()
311 {
312 Some(x.expr_details)
313 } else {
314 None
315 }
316 }
317
318 fn push_yield(&mut self, e: ExprDetails) {
319 let e = KeyedExprDetails::new(e, &mut self.rng);
320 if (self.filter)(&e.expr_details) {
321 self.n_seen += 1;
322 if self.reservoir_size > self.reservoir.len() {
323 self.reservoir.push(e)
324 } else if let Some(t) = self.threshold()
325 && e.k > t
326 {
327 self.reservoir.pop();
328 self.reservoir.push(e)
329 }
330 if self.n_seen >= self.reservoir_size * 20 {
331 self.pq.clear();
332 self.done = true;
333 }
334 }
335 }
336
337 fn include(&mut self, n: usize) -> impl Iterator<Item = bool> + 'static {
338 let x = (0..n).choose_multiple(self.rng, (n / 2).max(1));
339 let mut v = vec![false; n];
340 for i in x {
341 v[i] = true;
342 }
343 v.into_iter()
344 }
345}
346
347#[derive(Debug)]
348pub struct LambdaEnumerator<'a, 'src, T: LambdaLanguageOfThought, F, E = NormalEnumeration> {
350 pools: Vec<UnfinishedLambdaPool<'src, T>>,
351 possible_expressions: &'a PossibleExpressions<'src, T>,
352 eager_filter: F,
353 pq: E,
354}
355
356#[derive(Debug, Clone, Copy, Eq, PartialEq)]
358pub struct ExprDetails {
359 id: usize,
360 constant_function: bool,
361 root: LambdaExprRef,
362 size: usize,
363}
364
365impl ExprDetails {
366 pub fn size(&self) -> usize {
368 self.size
369 }
370}
371
372#[derive(Debug, Clone, Eq, PartialEq)]
373pub struct TypeAgnosticSampler<'src, T: LambdaLanguageOfThought> {
375 type_to_sampler: HashMap<LambdaType, (usize, LambdaSampler<'src, T>)>,
376 max_expr: usize,
377 max_types: usize,
378 possible_expressions: PossibleExpressions<'src, T>,
379}
380
381impl<'src> TypeAgnosticSampler<'src, Expr<'src>> {
382 pub fn sample(
384 &mut self,
385 lambda_type: LambdaType,
386 rng: &mut impl Rng,
387 ) -> RootedLambdaPool<'src, Expr<'src>> {
388 let (counts, exprs) = self
389 .type_to_sampler
390 .entry(lambda_type)
391 .or_insert_with_key(|t| {
392 (
393 1,
394 RootedLambdaPool::sampler(t, &self.possible_expressions, self.max_expr),
395 )
396 });
397 *counts += 1;
398 let sample = exprs.sample(rng);
399
400 if self.type_to_sampler.len() > self.max_types {
401 let (_, k) = self
402 .type_to_sampler
403 .iter()
404 .map(|(k, (n_visits, _))| (n_visits, k))
405 .min_by_key(|x| x.0)
406 .unwrap();
407
408 let t = k.clone();
409 self.type_to_sampler.remove(&t);
410 }
411
412 sample
413 }
414
415 pub fn possibles(&self) -> &PossibleExpressions<'src, Expr<'src>> {
417 &self.possible_expressions
418 }
419}
420
421impl<'src, T: LambdaLanguageOfThought + Clone> RootedLambdaPool<'src, T> {
422 pub fn typeless_sampler(
424 possible_expressions: PossibleExpressions<'src, T>,
425 max_expr: usize,
426 max_types: usize,
427 ) -> TypeAgnosticSampler<'src, T> {
428 assert!(max_types >= 1);
429 assert!(max_expr >= 1);
430 TypeAgnosticSampler {
431 possible_expressions,
432 max_expr,
433 max_types,
434 type_to_sampler: HashMap::default(),
435 }
436 }
437}
438
439#[derive(Debug, Clone, Eq, PartialEq)]
441pub struct LambdaSampler<'src, T: LambdaLanguageOfThought> {
442 lambdas: Vec<RootedLambdaPool<'src, T>>,
443 expr_details: Vec<ExprDetails>,
444}
445
446impl<'src, T: LambdaLanguageOfThought + Clone> Distribution<RootedLambdaPool<'src, T>>
447 for LambdaSampler<'src, T>
448{
449 fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> RootedLambdaPool<'src, T> {
450 let w = WeightedIndex::new(self.expr_details.iter().map(|x| x.score())).unwrap();
451 let i = w.sample(rng);
452 self.lambdas
453 .get(i)
454 .expect("The Lambda Sampler has no lambdas to sample :(")
455 .clone()
456 }
457}
458
459pub trait EnumerationType {
462 fn pop(&mut self) -> Option<Context>;
463 fn push(&mut self, context: Context, included: bool);
464 fn get_yield(&mut self) -> Option<ExprDetails>;
465 fn push_yield(&mut self, e: ExprDetails);
466 fn include(&mut self, n: usize) -> impl Iterator<Item = bool> + 'static;
467}
468
469fn try_yield<'a, 'src, T, F, E>(
470 x: &mut LambdaEnumerator<'a, 'src, T, F, E>,
471) -> Option<(RootedLambdaPool<'src, T>, ExprDetails)>
472where
473 T: LambdaLanguageOfThought,
474 E: EnumerationType,
475{
476 if let Some(item) = x.pq.get_yield() {
477 let p = std::mem::take(&mut x.pools[item.id]);
478 return Some((
479 RootedLambdaPool {
480 pool: LambdaPool(
481 p.pool
482 .into_iter()
483 .map(|x| LambdaExpr::try_from(x).unwrap())
484 .collect(),
485 ),
486 root: item.root,
487 },
488 item,
489 ));
490 }
491 None
492}
493
494impl<'a, 'src, T, F, E> LambdaEnumerator<'a, 'src, T, F, E>
495where
496 T: LambdaLanguageOfThought + Clone + Debug,
497 F: Fn(&Context) -> bool,
498 E: EnumerationType,
499{
500 fn push(&mut self, c: Context, included: bool) {
501 if (self.eager_filter)(&c) {
502 self.pq.push(c, included);
503 } else {
504 self.pools[c.pool_index] = UnfinishedLambdaPool::default();
505 }
506 }
507
508 pub fn eager_filter<F2>(self, eager_filter: F2) -> LambdaEnumerator<'a, 'src, T, F2, E> {
510 let LambdaEnumerator {
511 pools,
512 possible_expressions,
513 eager_filter: _,
514 pq,
515 } = self;
516
517 LambdaEnumerator {
518 pools,
519 possible_expressions,
520 eager_filter,
521 pq,
522 }
523 }
524}
525
526impl<'a, 'src, F, E> Iterator for LambdaEnumerator<'a, 'src, Expr<'src>, F, E>
527where
528 F: Fn(&Context) -> bool,
529 E: EnumerationType,
530{
531 type Item = (RootedLambdaPool<'src, Expr<'src>>, ExprDetails);
532
533 fn next(&mut self) -> Option<Self::Item> {
534 if let Some(x) = try_yield(self) {
535 return Some(x);
536 }
537
538 while let Some(mut c) = self.pq.pop() {
539 if let Some(x) = try_yield(self) {
540 self.push(c, true);
541 return Some(x);
542 }
543 let (possibles, lambda_type) = match &self.pools[c.pool_index].pool[c.position] {
544 ExprOrType::Type {
545 lambda_type,
546 is_app_subformula,
547 parent,
548 } => {
549 let mut possibles = self.possible_expressions.possibilities(
550 lambda_type,
551 *is_app_subformula,
552 &c,
553 );
554
555 if let Some(p) = parent
558 && let ExprOrType::Expr {
559 lambda_expr:
560 LambdaExpr::LanguageOfThoughtExpr(Expr::Quantifier {
561 var_type,
562 restrictor,
563 ..
564 }),
565 ..
566 } = self.pools[c.pool_index].pool[*p]
567 && restrictor.0 == c.position as u32
568 {
569 possibles.push(PossibleExpr::new_borrowed(match var_type {
570 ActorOrEvent::Actor => &LambdaExpr::LanguageOfThoughtExpr(
571 Expr::Constant(Constant::Everyone),
572 ),
573 ActorOrEvent::Event => &LambdaExpr::LanguageOfThoughtExpr(
574 Expr::Constant(Constant::EveryEvent),
575 ),
576 }));
577 };
578
579 (possibles, lambda_type.clone())
580 }
581 ExprOrType::Expr {
582 lambda_expr,
583 parent,
584 } => {
585 if let Some(child) = lambda_expr
589 .get_children()
590 .map(|x| x.0 as usize)
591 .find(|child| self.pools[c.pool_index].pool[*child].is_type())
592 {
593 c.position = child;
594 self.pq.push(c, true);
595 continue;
596 }
597
598 if lambda_expr.inc_depth() {
599 c.pop_lambda();
600 }
601
602 if let Some(p) = parent {
603 c.position = *p;
604 self.pq.push(c, true);
605 continue;
606 } else {
607 self.pq.push_yield(ExprDetails {
609 id: c.pool_index,
610 root: LambdaExprRef(c.position as u32),
611 size: c.depth,
612 constant_function: c.is_constant(),
613 });
614 continue;
615 }
616 }
617 };
618
619 let n = possibles.len();
620 let included = self.pq.include(n);
621 if n == 0 {
622 continue;
626 }
627 let n_pools = self.pools.len();
628 if n_pools.is_multiple_of(10_000) {
629 self.pools.shrink_to_fit();
630 }
631
632 for _ in 0..n.saturating_sub(1) {
633 self.pools.push(self.pools[c.pool_index].clone());
634 }
635
636 let positions =
637 std::iter::once(c.pool_index).chain(n_pools..n_pools + n.saturating_sub(1));
638
639 for (((expr, pool_id), mut c), included) in possibles
640 .into_iter()
641 .zip(positions)
642 .zip(std::iter::repeat_n(c, n))
643 .zip(included)
644 {
645 c.pool_index = pool_id;
646 let pool = self.pools.get_mut(pool_id).unwrap();
647 pool.add_expr(expr, &mut c, &lambda_type);
648 self.push(c, included);
649 }
650
651 if let Some(x) = try_yield(self) {
652 return Some(x);
653 }
654 }
655
656 if let Some(x) = try_yield(self) {
658 return Some(x);
659 }
660 None
661 }
662}
663
664impl<'src> RootedLambdaPool<'src, Expr<'src>> {
665 pub fn resample_from_expr<'a>(
667 &mut self,
668 possible_expressions: &'a PossibleExpressions<'src, Expr<'src>>,
669 helpers: Option<&HashMap<LambdaType, Vec<RootedLambdaPool<'src, Expr<'src>>>>>,
670 rng: &mut impl Rng,
671 ) -> Result<(), LambdaError> {
672 let position = LambdaExprRef((0..self.len()).choose(rng).unwrap() as u32);
673 let t = self.pool.get_type(position)?;
674
675 let pool = if let Some(helpers) = helpers
676 && rng.random_bool(0.2)
677 && let Some(v) = helpers.get(&t)
678 && !v.is_empty()
679 {
680 let pool = v.choose(rng).unwrap();
681 pool.clone()
682 } else {
683 let (pool, _) = self
684 .probabilistic_enumerate_from_expr(
685 position,
686 possible_expressions,
687 |_| true,
688 |_| true,
689 rng,
690 )?
691 .next()
692 .unwrap();
693 pool
694 };
695
696 let offset = self.len() as u32;
697 let new_root = pool.root.0 + offset;
698 self.pool.0.extend(pool.pool.0.into_iter().map(|mut x| {
699 let children: Vec<_> = x
700 .get_children()
701 .map(|x| LambdaExprRef(x.0 + offset))
702 .collect();
703 x.change_children(children.into_iter());
704 x
705 }));
706 self.pool.0.swap(position.0 as usize, new_root as usize);
707 self.cleanup();
708 Ok(())
709 }
710
711 fn probabilistic_enumerate_from_expr<'a, R, E, F>(
712 &self,
713 position: LambdaExprRef,
714 possible_expressions: &'a PossibleExpressions<'src, Expr<'src>>,
715 eager_filter: E,
716 filter: F,
717 rng: &'a mut R,
718 ) -> Result<
719 LambdaEnumerator<'a, 'src, Expr<'src>, E, ProbabilisticEnumeration<'a, R, F>>,
720 TypeError,
721 >
722 where
723 R: Rng,
724 F: Fn(&ExprDetails) -> bool,
725 E: Fn(&Context) -> bool,
726 {
727 let (context, is_app_subformula) = Context::from_pos(self, position);
728 let output = self.pool.get_type(position)?;
729 let mut pq = BinaryHeap::default();
730 pq.push(RandomPQ::new(context, rng));
731 let pools = vec![UnfinishedLambdaPool {
732 pool: vec![ExprOrType::Type {
733 lambda_type: output,
734 parent: None,
735 is_app_subformula,
736 }],
737 }];
738 let enumerator = LambdaEnumerator {
739 pools,
740 possible_expressions,
741 eager_filter,
742 pq: ProbabilisticEnumeration {
743 rng,
744 reservoir_size: 1,
745 reservoir: BinaryHeap::default(),
746 done: false,
747 n_seen: 0,
748 filter,
749 backups: vec![],
750 pq,
751 },
752 };
753
754 Ok(enumerator)
755 }
756
757 pub fn enumerator_filter<'a, F: Fn(&Context) -> bool>(
759 t: &LambdaType,
760 filter: F,
761 possible_expressions: &'a PossibleExpressions<'src, Expr<'src>>,
762 ) -> LambdaEnumerator<'a, 'src, Expr<'src>, F> {
763 let context = Context::new(0, vec![]);
764 let mut pq = BinaryHeap::default();
765 pq.push(Reverse(context));
766 let pools = vec![UnfinishedLambdaPool {
767 pool: vec![ExprOrType::Type {
768 lambda_type: t.clone(),
769 parent: None,
770 is_app_subformula: false,
771 }],
772 }];
773
774 LambdaEnumerator {
775 pools,
776 possible_expressions,
777 eager_filter: filter,
778 pq: NormalEnumeration(pq, VecDeque::default()),
779 }
780 }
781
782 pub fn enumerator<'a>(
784 t: &LambdaType,
785 possible_expressions: &'a PossibleExpressions<'src, Expr<'src>>,
786 ) -> LambdaEnumerator<'a, 'src, Expr<'src>, impl Fn(&'_ Context) -> bool> {
787 let context = Context::new(0, vec![]);
788 let mut pq = BinaryHeap::default();
789 pq.push(Reverse(context));
790 let pools = vec![UnfinishedLambdaPool {
791 pool: vec![ExprOrType::Type {
792 lambda_type: t.clone(),
793 parent: None,
794 is_app_subformula: false,
795 }],
796 }];
797
798 LambdaEnumerator {
799 pools,
800 possible_expressions,
801 eager_filter: |_: &Context| true,
802 pq: NormalEnumeration(pq, VecDeque::default()),
803 }
804 }
805
806 pub fn sampler(
808 t: &LambdaType,
809 possible_expressions: &PossibleExpressions<'src, Expr<'src>>,
810 max_expr: usize,
811 ) -> LambdaSampler<'src, Expr<'src>> {
812 let enumerator = RootedLambdaPool::enumerator(t, possible_expressions);
813 let mut lambdas = Vec::with_capacity(max_expr);
814 let mut expr_details = Vec::with_capacity(max_expr);
815 for (lambda, expr_detail) in enumerator.take(max_expr) {
816 lambdas.push(lambda);
817 expr_details.push(expr_detail);
818 }
819
820 LambdaSampler {
821 lambdas,
822 expr_details,
823 }
824 }
825
826 pub fn random_expr(
828 t: &LambdaType,
829 possible_expressions: &PossibleExpressions<'src, Expr<'src>>,
830 rng: &mut impl Rng,
831 ) -> RootedLambdaPool<'src, Expr<'src>> {
832 ProbabilisticEnumeration::new(
833 1,
834 t,
835 possible_expressions,
836 |_: &Context| true,
837 |_| true,
838 rng,
839 )
840 .next()
841 .unwrap()
842 .0
843 }
844
845 pub fn random_expr_no_constant(
847 t: &LambdaType,
848 possible_expressions: &PossibleExpressions<'src, Expr<'src>>,
849 rng: &mut impl Rng,
850 ) -> Option<RootedLambdaPool<'src, Expr<'src>>> {
851 ProbabilisticEnumeration::new(
852 1,
853 t,
854 possible_expressions,
855 |_: &Context| true,
856 |e| !e.constant_function,
857 rng,
858 )
859 .next()
860 .map(|x| x.0)
861 }
862}
863
864impl<'src> RootedLambdaPool<'src, Expr<'src>> {
865 pub fn prune_quantifiers(&mut self) {
867 let quantifiers = self
868 .pool
869 .bfs_from(self.root)
870 .filter_map(|(i, _)| match self.get(i) {
871 LambdaExpr::LanguageOfThoughtExpr(Expr::Quantifier { subformula, .. }) => {
872 if !self
873 .pool
874 .bfs_from(LambdaExprRef(subformula.0))
875 .any(|(x, d)| {
876 if let LambdaExpr::BoundVariable(v, _) = self.get(x) {
877 *v == d
878 } else {
879 false
880 }
881 })
882 {
883 Some((i, LambdaExprRef(subformula.0)))
884 } else {
885 None
886 }
887 }
888 _ => None,
889 })
890 .collect::<Vec<_>>();
891
892 for (quantifier, subformula) in quantifiers.into_iter().rev() {
894 *self.pool.get_mut(quantifier) = self.pool.get(subformula).clone();
895 self.pool.bfs_from_mut(quantifier).for_each(|(x, d, _)| {
896 if let LambdaExpr::BoundVariable(b_d, _) = x
897 && *b_d > d
898 {
899 *b_d -= 1;
900 }
901 });
902 }
903 self.root = self.pool.cleanup(self.root);
904 }
905
906 pub fn swap_expr(
908 &mut self,
909 possible_expressions: &PossibleExpressions<'src, Expr<'src>>,
910 rng: &mut impl Rng,
911 ) -> Result<(), TypeError> {
912 let position = LambdaExprRef((0..self.len()).choose(rng).unwrap() as u32);
913
914 let (context, _) = Context::from_pos(self, position);
915
916 let output = self.pool.get_type(position)?;
917 let expr = self.pool.get(position);
918
919 let arguments: Vec<_> = expr
920 .get_children()
921 .map(|x| self.pool.get_type(x).unwrap())
922 .collect();
923
924 let mut replacements = possible_expressions.possiblities_fixed_children(
925 &output,
926 &arguments,
927 expr.var_type(),
928 &context,
929 );
930
931 if replacements.is_empty()
933 && let LambdaExpr::LanguageOfThoughtExpr(Expr::Quantifier {
934 var_type,
935 restrictor,
936 subformula,
937 ..
938 }) = expr
939 {
940 for quantifier in [Quantifier::Universal, Quantifier::Existential] {
941 replacements.push(std::borrow::Cow::Owned(LambdaExpr::LanguageOfThoughtExpr(
942 Expr::Quantifier {
943 quantifier,
944 var_type: *var_type,
945 restrictor: *restrictor,
946 subformula: *subformula,
947 },
948 )));
949 }
950 }
951 for c in [Constant::EveryEvent, Constant::Everyone] {
952 if replacements.is_empty()
953 && expr == &LambdaExpr::LanguageOfThoughtExpr(Expr::Constant(c))
954 {
955 replacements = possible_expressions.possiblities_fixed_children(
956 LambdaType::t(),
957 &arguments,
958 expr.var_type(),
959 &context,
960 );
961 replacements.push(std::borrow::Cow::Owned(LambdaExpr::LanguageOfThoughtExpr(
962 Expr::Constant(c),
963 )));
964 }
965 }
966
967 let choice = replacements.choose(rng).unwrap_or_else(|| {
968 panic!(
969 "There is no node with output {output} and arguments {}",
970 arguments
971 .into_iter()
972 .map(|x| x.to_string())
973 .collect::<Vec<_>>()
974 .join(", ")
975 )
976 });
977
978 let mut new_expr = choice.clone().into_owned();
979 new_expr.change_children(self.pool.get(position).get_children());
980 self.pool.0[position.0 as usize] = new_expr;
981
982 Ok(())
983 }
984
985 pub fn swap_subtree(&mut self, rng: &mut impl Rng) -> Result<(), TypeError> {
988 let position = LambdaExprRef((0..self.len()).choose(rng).unwrap() as u32);
989 let (context, _) = Context::from_pos(self, position);
990 let alt = context.find_compatible(self, position)?;
991 if let Some(new_pos) = alt.choose(rng).copied() {
992 let offset = self.pool.0.len();
993 let mut lookup = HashMap::default();
994 let mut new_pool: Vec<LambdaExpr<'src, Expr<'src>>> = self
995 .pool
996 .bfs_from(new_pos)
997 .map(|(x, _)| {
998 let n = lookup.len();
999 lookup
1000 .entry(x)
1001 .or_insert(LambdaExprRef((n + offset) as u32));
1002 self.get(x).clone()
1003 })
1004 .collect();
1005
1006 new_pool.iter_mut().for_each(|x| {
1007 let child: Vec<_> = x.get_children().map(|x| *lookup.get(&x).unwrap()).collect();
1008 x.change_children(child.into_iter());
1009 });
1010
1011 self.pool.0.extend(new_pool);
1012 self.pool.0.swap(position.0 as usize, offset);
1013 self.cleanup();
1014 }
1015 Ok(())
1016 }
1017}
1018
1019#[cfg(test)]
1020mod test {
1021
1022 use std::collections::HashSet;
1023
1024 use super::*;
1025 use crate::lambda::{LambdaPool, LambdaSummaryStats};
1026 use rand::SeedableRng;
1027 use rand_chacha::ChaCha8Rng;
1028
1029 #[test]
1030 fn prune_quantifier_test() -> anyhow::Result<()> {
1031 let mut pool =
1032 RootedLambdaPool::parse("some_e(x,all_e,AgentOf(a_2,e_1) & PatientOf(a_0,e_0))")?;
1033
1034 pool.prune_quantifiers();
1035 assert_eq!(pool.to_string(), "AgentOf(a_2, e_1) & PatientOf(a_0, e_0)");
1036
1037 let mut pool = RootedLambdaPool::parse(
1038 "some_e(x0, all_e, some(z, all_a, AgentOf(z, e_1) & PatientOf(a_0, e_0)))",
1039 )?;
1040
1041 pool.prune_quantifiers();
1042 assert_eq!(
1043 pool.to_string(),
1044 "some(x, all_a, AgentOf(x, e_1) & PatientOf(a_0, e_0))"
1045 );
1046
1047 let mut pool = RootedLambdaPool::parse("~every_e(z, pe_1, pa_2(a_0))")?;
1048
1049 pool.prune_quantifiers();
1050
1051 assert_eq!(pool.to_string(), "~pa_2(a_0)");
1052 let mut pool = RootedLambdaPool::new(
1053 LambdaPool(vec![
1054 LambdaExpr::Lambda(LambdaExprRef(1), LambdaType::E),
1055 LambdaExpr::Lambda(LambdaExprRef(2), LambdaType::T),
1056 LambdaExpr::LanguageOfThoughtExpr(Expr::Quantifier {
1057 quantifier: Quantifier::Universal,
1058 var_type: ActorOrEvent::Actor,
1059 restrictor: ExprRef(3),
1060 subformula: ExprRef(4),
1061 }),
1062 LambdaExpr::LanguageOfThoughtExpr(Expr::Constant(Constant::Property(
1063 "1",
1064 ActorOrEvent::Actor,
1065 ))),
1066 LambdaExpr::LanguageOfThoughtExpr(Expr::Quantifier {
1067 quantifier: Quantifier::Existential,
1068 var_type: ActorOrEvent::Actor,
1069 restrictor: ExprRef(5),
1070 subformula: ExprRef(6),
1071 }),
1072 LambdaExpr::LanguageOfThoughtExpr(Expr::Constant(Constant::Property(
1073 "0",
1074 ActorOrEvent::Actor,
1075 ))),
1076 LambdaExpr::LanguageOfThoughtExpr(Expr::Unary(
1077 MonOp::Property("3", ActorOrEvent::Event),
1078 ExprRef(7),
1079 )),
1080 LambdaExpr::BoundVariable(3, LambdaType::E),
1081 ]),
1082 LambdaExprRef(0),
1083 );
1084
1085 assert_eq!(
1086 pool.to_string(),
1087 "lambda e x lambda t phi every(y, pa_1, some(z, pa_0, pe_3(x)))"
1088 );
1089
1090 let mut parsed_pool = RootedLambdaPool::parse(
1091 "lambda e x lambda t phi every(y, pa_1, some(z, pa_0, pe_3(x)))",
1092 )?;
1093 parsed_pool.prune_quantifiers();
1094 pool.prune_quantifiers();
1095 assert_eq!(pool.to_string(), "lambda e x lambda t phi pe_3(x)");
1096
1097 Ok(())
1098 }
1099
1100 #[test]
1101 fn random_swap_tree() -> anyhow::Result<()> {
1102 let mut rng = ChaCha8Rng::seed_from_u64(2);
1103 let x = RootedLambdaPool::parse("lambda a x pa_cool(a_John) | pa_cool(x)")?;
1104 let mut h = HashSet::new();
1105 for _ in 0..100 {
1106 let mut z = x.clone();
1107 z.swap_subtree(&mut rng)?;
1108 h.insert(z.to_string());
1109 }
1110 for x in h.iter() {
1111 println!("{x}");
1112 }
1113 assert!(h.contains("lambda a x pa_cool(x)"));
1114 assert!(h.contains("lambda a x pa_cool(a_John)"));
1115 Ok(())
1116 }
1117
1118 #[test]
1119 fn randomn_swap() -> anyhow::Result<()> {
1120 let mut rng = ChaCha8Rng::seed_from_u64(2);
1121 let actors = ["0", "1"];
1122 let available_event_properties = ["2", "3", "4"];
1123 let possible_expressions = PossibleExpressions::new(
1124 &actors,
1125 &available_event_properties,
1126 &available_event_properties,
1127 );
1128 for _ in 0..200 {
1129 let t = LambdaType::random(&mut rng);
1130 println!("{t}");
1131 let mut pool = RootedLambdaPool::random_expr(&t, &possible_expressions, &mut rng);
1132 println!("{t}: {pool}");
1133 assert_eq!(t, pool.get_type()?);
1134 println!("{pool:?}");
1135 pool.swap_expr(&possible_expressions, &mut rng)?;
1136 println!("{pool:?}");
1137 println!("{t}: {pool}");
1138 assert_eq!(t, pool.get_type()?);
1139 }
1140 let p =
1141 RootedLambdaPool::parse("lambda <a,t> P some(x, P(x), some_e(y, pe_2(y), pe_4(y)))")?;
1142 let t = p.get_type()?;
1143 for _ in 0..1000 {
1144 let mut pool = p.clone();
1145 assert_eq!(t, pool.get_type()?);
1146
1147 pool.swap_expr(&possible_expressions, &mut rng)?;
1148 dbg!(&pool);
1149 println!("{t}: {pool}");
1150 assert_eq!(t, pool.get_type()?);
1151 }
1152
1153 Ok(())
1154 }
1155
1156 #[test]
1157 fn enumerate() -> anyhow::Result<()> {
1158 let actors = ["john"];
1159 let actor_properties = ["a"];
1160 let event_properties = ["e"];
1161 let possibles = PossibleExpressions::new(&actors, &actor_properties, &event_properties);
1162
1163 let t = LambdaType::from_string("<<a,t>,t>")?;
1164
1165 let p = RootedLambdaPool::enumerator(&t, &possibles);
1166 let pools: HashSet<_> = p
1167 .filter_map(|(p, x)| {
1168 if !x.has_constant_function() {
1169 Some(p.to_string())
1170 } else {
1171 None
1172 }
1173 })
1174 .take(20)
1175 .collect();
1176 for p in pools.iter() {
1177 println!("{p}");
1178 }
1179 assert!(pools.contains("lambda <a,t> P P(a_john)"));
1180 for p in pools {
1181 println!("{p}");
1182 }
1183
1184 Ok(())
1185 }
1186
1187 #[test]
1215 fn enumerate_weirds() -> anyhow::Result<()> {
1216 let actors = &["1"];
1217 let actor_properties = &["1"];
1218 let event_properties = &["1"];
1219 let poss = PossibleExpressions::new(actors, actor_properties, event_properties);
1220
1221 let t = "<a,<e,e>>";
1222
1223 for (p, d) in RootedLambdaPool::enumerator(&LambdaType::from_string(t)?, &poss)
1224 .filter(|(_, e)| !e.has_constant_function())
1225 .take(5)
1226 {
1227 println!("{p} {d:?} {}", d.has_constant_function());
1228 }
1229 Ok(())
1230 }
1231
1232 #[test]
1233 fn random_expr() -> anyhow::Result<()> {
1234 let actors = ["john"];
1235 let actor_properties = ["a"];
1236 let event_properties = ["e"];
1237 let possibles = PossibleExpressions::new(&actors, &actor_properties, &event_properties);
1238 let mut rng = ChaCha8Rng::seed_from_u64(0);
1239
1240 let map = [(LambdaType::A, vec![RootedLambdaPool::parse("a_john")?])]
1241 .into_iter()
1242 .collect();
1243
1244 for _ in 0..100 {
1245 let t = LambdaType::random(&mut rng);
1246 println!("sampling: {t}");
1247 let mut pool = RootedLambdaPool::random_expr(&t, &possibles, &mut rng);
1248 assert_eq!(t, pool.get_type()?);
1249 let s = pool.to_string();
1250 println!("{s}");
1251 let mut pool2 = RootedLambdaPool::parse(s.as_str())?;
1252 assert_eq!(s, pool2.to_string());
1253 println!("{pool}");
1254 pool2.resample_from_expr(&possibles, None, &mut rng)?;
1255 assert_eq!(pool2.get_type()?, t);
1256 pool2.resample_from_expr(&possibles, Some(&map), &mut rng)?;
1257 assert_eq!(pool2.get_type()?, t);
1258
1259 pool.swap_subtree(&mut rng)?;
1260 assert_eq!(pool.get_type()?, t);
1261 }
1262 let t = LambdaType::from_string("<a,<a,t>>")?;
1263 for _ in 0..100 {
1264 let pool = RootedLambdaPool::random_expr(&t, &possibles, &mut rng);
1265 println!("{pool}");
1266 }
1267 Ok(())
1268 }
1269
1270 #[test]
1271 fn constant_exprs() -> anyhow::Result<()> {
1272 let actors = ["john", "mary", "phil", "sue"];
1273 let actor_properties = ["a"];
1274 let event_properties = ["e"];
1275 let possibles = PossibleExpressions::new(&actors, &actor_properties, &event_properties);
1276 let mut rng = ChaCha8Rng::seed_from_u64(0);
1277
1278 let v: HashSet<_> = RootedLambdaPool::enumerator(LambdaType::a(), &possibles)
1279 .map(|(x, _)| x.to_string())
1280 .take(4)
1281 .collect();
1282
1283 assert_eq!(v, HashSet::from(actors.map(|x| format!("a_{x}"))));
1284 println!("{v:?}");
1285
1286 let mut constants = 0;
1287 for _ in 0..100 {
1288 let t = LambdaType::from_string("<a, <a,t>>")?;
1289 println!("sampling: {t}");
1290 let pool = RootedLambdaPool::random_expr(&t, &possibles, &mut rng);
1291 let x = pool.stats();
1292 match x {
1293 LambdaSummaryStats::WellFormed {
1294 constant_function, ..
1295 } => {
1296 if constant_function {
1297 constants += 1;
1298 }
1299 }
1300 LambdaSummaryStats::Malformed => todo!(),
1301 }
1302 }
1303 println!("{constants}");
1304
1305 Ok(())
1306 }
1307
1308 #[test]
1309 fn random_expr_counts() -> anyhow::Result<()> {
1310 let actors = ["john", "mary", "phil", "sue"];
1311 let actor_properties = ["a"];
1312 let event_properties = ["e"];
1313 let possibles = PossibleExpressions::new(&actors, &actor_properties, &event_properties);
1314 let mut rng = ChaCha8Rng::seed_from_u64(1);
1315
1316 let mut counts: HashMap<_, usize> = HashMap::default();
1317 for _ in 0..1000 {
1318 let t = LambdaType::A;
1319 let pool = RootedLambdaPool::random_expr(&t, &possibles, &mut rng);
1320 assert_eq!(t, pool.get_type()?);
1321 let s = pool.to_string();
1322 *counts.entry(s).or_default() += 1;
1323 }
1324 assert_eq!(counts.len(), 4);
1325 dbg!(&counts);
1326 for (_, v) in counts.iter() {
1327 assert!(200 <= *v && *v <= 300);
1328 }
1329
1330 counts.clear();
1331 for _ in 0..1000 {
1332 let t = LambdaType::at().clone();
1333 let pool = RootedLambdaPool::random_expr(&t, &possibles, &mut rng);
1334 assert_eq!(t, pool.get_type()?);
1335 let s = pool.to_string();
1336 *counts.entry(s).or_default() += 1;
1337 }
1338 dbg!(counts);
1339 Ok(())
1340 }
1341}