1use super::{
2 Debug, LambdaExpr, LambdaExprRef, LambdaLanguageOfThought, LambdaPool, LambdaType, RandomPQ,
3 RootedLambdaPool, TypeError,
4};
5use ahash::{HashMap, HashSet};
6use itertools::Either;
7
8#[derive(Debug, Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)]
9pub(super) enum ConstantFunctionState {
10 Constant,
11 PotentiallyConstant,
12 NonConstant,
13}
14
15impl ConstantFunctionState {
16 fn update(&mut self, new: Self) {
17 match self {
18 ConstantFunctionState::Constant => (),
19 ConstantFunctionState::PotentiallyConstant => match new {
20 ConstantFunctionState::Constant => *self = ConstantFunctionState::Constant,
21 ConstantFunctionState::PotentiallyConstant | ConstantFunctionState::NonConstant => {
22 }
23 },
24 ConstantFunctionState::NonConstant => match new {
25 ConstantFunctionState::Constant => *self = ConstantFunctionState::Constant,
26 ConstantFunctionState::PotentiallyConstant => {
27 *self = ConstantFunctionState::PotentiallyConstant;
28 }
29 ConstantFunctionState::NonConstant => *self = ConstantFunctionState::NonConstant,
30 },
31 }
32 }
33
34 fn use_var(&mut self) {
35 match self {
36 ConstantFunctionState::PotentiallyConstant => {
37 *self = ConstantFunctionState::NonConstant;
38 }
39 ConstantFunctionState::Constant | ConstantFunctionState::NonConstant => (),
40 }
41 }
42
43 fn done(&mut self) {
44 match self {
45 ConstantFunctionState::PotentiallyConstant => *self = ConstantFunctionState::Constant,
46 ConstantFunctionState::Constant | ConstantFunctionState::NonConstant => (),
47 }
48 }
49}
50
51#[derive(Debug, Clone, PartialEq, Eq)]
54pub struct Context {
55 lambdas: Vec<(LambdaType, ConstantFunctionState)>,
56 possible_types: HashMap<LambdaType, HashSet<LambdaType>>,
57 pub(super) pool_index: usize,
58 pub(super) position: usize,
59 pub(super) depth: usize,
60 done: bool,
61 pub(super) open_nodes: usize,
62 constant_function: ConstantFunctionState,
63}
64
65impl PartialOrd for RandomPQ {
66 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
67 Some(self.cmp(other))
68 }
69}
70
71impl Ord for RandomPQ {
73 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
74 let c = &self.0;
75 let o = &other.0;
76
77 c.done
78 .cmp(&o.done)
79 .then(o.open_depth_score().cmp(&c.open_depth_score()))
80 .then(o.lambdas.len().cmp(&c.lambdas.len()))
81 .then(o.constant_function.cmp(&c.constant_function))
82 .then(self.1.partial_cmp(&other.1).unwrap())
83 }
84}
85
86impl PartialOrd for Context {
87 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
88 Some(self.cmp(other))
89 }
90}
91
92impl Ord for Context {
93 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
94 other
95 .done
96 .cmp(&self.done)
97 .then(self.open_depth_score().cmp(&other.open_depth_score()))
98 .then(self.constant_function.cmp(&other.constant_function))
99 .then(self.pool_index.cmp(&other.pool_index))
100 }
101}
102impl Context {
103 fn open_depth_score(&self) -> usize {
104 self.depth + self.open_nodes.pow(2) + self.lambdas.len()
105 }
106}
107
108impl<T: LambdaLanguageOfThought> LambdaPool<'_, T> {
109 fn compatible_with(
111 &self,
112 pos: LambdaExprRef,
113 new_context: &Context,
114 old_context: &Context,
115 ) -> bool {
116 for (x, d) in self.bfs_from(pos) {
117 if let LambdaExpr::BoundVariable(b, _) = self.get(x)
118 && b + 1 > d
119 {
120 let old_lambda_pos = old_context.lambdas.len() + d - b - 1;
122
123 if b + 1 > d + new_context.lambdas.len() {
124 return false;
128 }
129 let new_lambda_pos = new_context.lambdas.len() + d - b - 1;
130 if new_context.lambdas.get(new_lambda_pos)
131 != old_context.lambdas.get(old_lambda_pos)
132 {
133 return false;
134 }
135 }
136 }
137 true
138 }
139}
140
141impl Context {
142 #[allow(clippy::len_without_is_empty)]
143 #[must_use]
145 pub fn len(&self) -> usize {
146 self.depth
147 }
148
149 #[must_use]
151 pub fn n_vars(&self) -> usize {
152 self.lambdas.len()
153 }
154
155 pub(super) fn from_pos<T: LambdaLanguageOfThought>(
156 pool: &RootedLambdaPool<'_, T>,
157 pos: LambdaExprRef,
158 ) -> (Context, bool) {
159 let mut context = Context::new(0, vec![]);
160 let mut stack = vec![(pool.root, 0, false)];
161 let mut return_is_subformula = false;
162
163 while let Some((x, n_lambdas, is_subformula)) = stack.pop() {
164 context.depth += 1;
165 let e = pool.get(x);
166 if context.lambdas.len() != n_lambdas {
167 for _ in 0..(context.lambdas.len() - n_lambdas) {
168 context.pop_lambda();
169 }
170 }
171
172 if pos == x {
173 return_is_subformula = is_subformula;
174 break;
175 }
176
177 if let Some(v) = e.var_type() {
178 context.add_lambda(v);
179 } else if let LambdaExpr::BoundVariable(n, _) = e {
180 context.use_bvar(*n);
181 }
182
183 if let LambdaExpr::Application {
184 subformula,
185 argument,
186 } = e
187 {
188 stack.push((*subformula, context.lambdas.len(), true));
189 stack.push((*argument, context.lambdas.len(), false));
190 } else {
191 stack.extend(e.get_children().map(|x| (x, context.lambdas.len(), false)));
192 }
193 }
194 (context, return_is_subformula)
195 }
196
197 pub(super) fn find_compatible<T: LambdaLanguageOfThought>(
198 &self,
199 pool: &RootedLambdaPool<'_, T>,
200 pos: LambdaExprRef,
201 ) -> Result<Vec<LambdaExprRef>, TypeError> {
202 let t = pool.pool.get_type(pos)?;
203
204 let mut this_context = Context::new(0, vec![]);
205 let mut stack = vec![(pool.root, 0)];
206 let mut options: Vec<_> = vec![];
207 while let Some((x, n_lambdas)) = stack.pop() {
208 this_context.depth += 1;
209 let e = pool.get(x);
210 if this_context.lambdas.len() != n_lambdas {
211 for _ in 0..(this_context.lambdas.len() - n_lambdas) {
212 this_context.pop_lambda();
213 }
214 }
215 if pos != x
216 && t == pool.pool.get_type(x)?
217 && pool.pool.compatible_with(x, self, &this_context)
218 {
219 options.push(x);
220 }
221
222 if let Some(v) = e.var_type() {
223 this_context.add_lambda(v);
224 } else if let LambdaExpr::BoundVariable(n, _) = e {
225 this_context.use_bvar(*n);
226 }
227
228 stack.extend(e.get_children().map(|x| (x, this_context.lambdas.len())));
229 }
230 Ok(options)
231 }
232
233 fn update_possible_types(&mut self) {
234 self.possible_types.clear();
235
236 let mut new_types: HashSet<(&LambdaType, &LambdaType)> = HashSet::default();
237 let mut base_types: HashSet<_> = self.lambdas.iter().map(|(x, _)| x).collect();
238 base_types.insert(LambdaType::a());
239 base_types.insert(LambdaType::e());
240 base_types.insert(LambdaType::t());
241 base_types.insert(LambdaType::at());
242 base_types.insert(LambdaType::et());
243
244 loop {
245 for subformula in &base_types {
246 if let Ok((argument, result_type)) = subformula.split() {
247 let already_has_type = self
248 .possible_types
249 .get(result_type)
250 .is_some_and(|x| x.contains(argument));
251
252 if base_types.contains(argument) && !already_has_type {
253 new_types.insert((result_type, argument));
254 }
255 }
256 }
257 if new_types.is_empty() {
258 break;
259 }
260 for (result, argument) in &new_types {
261 self.possible_types
262 .entry((*result).clone())
263 .or_default()
264 .insert((*argument).clone());
265 }
266 base_types.extend(new_types.drain().map(|(result, _arg)| result));
267 }
268 }
269
270 pub(super) fn new(position: usize, lambdas: Vec<(LambdaType, ConstantFunctionState)>) -> Self {
271 let mut c = Context {
272 pool_index: 0,
273 position,
274 done: false,
275 depth: 0,
276 possible_types: HashMap::default(),
277 open_nodes: 1,
278 constant_function: if lambdas.is_empty() {
279 ConstantFunctionState::NonConstant
280 } else {
281 ConstantFunctionState::PotentiallyConstant
282 },
283 lambdas,
284 };
285 c.update_possible_types();
286 c
287 }
288
289 pub(super) fn add_lambda(&mut self, t: &LambdaType) {
290 self.constant_function
291 .update(ConstantFunctionState::PotentiallyConstant);
292 self.lambdas
293 .push((t.clone(), ConstantFunctionState::PotentiallyConstant));
294 self.update_possible_types();
295 }
296
297 pub(super) fn pop_lambda(&mut self) {
298 let (_, mut function_state) = self.lambdas.pop().unwrap();
299 function_state.done();
300 self.constant_function.update(function_state);
301 if self.lambdas.is_empty() {
302 self.constant_function.use_var();
303 }
304 self.update_possible_types();
305 }
306
307 pub(super) fn use_bvar(&mut self, b: usize) {
308 let n = self.lambdas.len() - b - 1;
309 self.lambdas.get_mut(n).unwrap().1.use_var();
310 }
311
312 #[must_use]
314 pub fn is_constant(&self) -> bool {
315 self.constant_function == ConstantFunctionState::Constant
316 }
317
318 pub fn current_variable_types(&self) -> impl Iterator<Item = &LambdaType> {
320 self.lambdas.iter().map(|(x, _)| x)
321 }
322
323 pub fn applications<'a, 'b: 'a>(
325 &'a self,
326 lambda_type: &'b LambdaType,
327 ) -> impl Iterator<Item = (LambdaType, LambdaType)> + 'a {
328 match self.possible_types.get(lambda_type) {
329 Some(x) => Either::Left(x.iter().map(|x| {
330 (
331 LambdaType::compose(x.clone(), lambda_type.clone()),
332 x.clone(),
333 )
334 })),
335 None => Either::Right(std::iter::empty()),
336 }
337 }
338
339 pub fn variables<'src, T: LambdaLanguageOfThought>(
341 &self,
342 lambda_type: &LambdaType,
343 ) -> impl Iterator<Item = LambdaExpr<'src, T>> {
344 let n = self.lambdas.len();
345 self.lambdas
346 .iter()
347 .enumerate()
348 .filter_map(move |(i, (lambda, _))| {
349 if lambda == lambda_type {
350 Some(LambdaExpr::BoundVariable(n - i - 1, lambda.clone()))
351 } else {
352 None
353 }
354 })
355 }
356}
357
358#[cfg(test)]
359mod test {
360 use super::*;
361 #[test]
362 fn test_context() -> anyhow::Result<()> {
363 let a = Context {
364 depth: 1,
365 done: false,
366 lambdas: vec![],
367 pool_index: 0,
368 position: 0,
369 possible_types: HashMap::default(),
370 open_nodes: 0,
371 constant_function: ConstantFunctionState::NonConstant,
372 };
373 let b = Context {
374 depth: 2,
375 done: false,
376 lambdas: vec![],
377 possible_types: HashMap::default(),
378 pool_index: 0,
379 position: 0,
380 open_nodes: 0,
381 constant_function: ConstantFunctionState::NonConstant,
382 };
383 let c = Context {
384 depth: 5,
385 done: true,
386 lambdas: vec![],
387 possible_types: HashMap::default(),
388 pool_index: 0,
389 position: 0,
390 open_nodes: 0,
391 constant_function: ConstantFunctionState::NonConstant,
392 };
393 let d = Context {
394 depth: 5,
395 done: true,
396 lambdas: vec![],
397 possible_types: HashMap::default(),
398 pool_index: 0,
399 position: 0,
400 open_nodes: 54,
401 constant_function: ConstantFunctionState::NonConstant,
402 };
403
404 assert!(a < b);
405 assert!(c < b);
406 assert!(c < a);
407 assert!(c < d);
408
409 Ok(())
410 }
411
412 #[test]
413 fn possible_type_check() -> anyhow::Result<()> {
414 let mut c = Context {
415 depth: 0,
416 done: false,
417 lambdas: vec![
418 (
419 LambdaType::from_string("<a,t>")?,
420 ConstantFunctionState::PotentiallyConstant,
421 ),
422 (
423 LambdaType::from_string("<<a,t>, <a,t>>")?,
424 ConstantFunctionState::PotentiallyConstant,
425 ),
426 (
427 LambdaType::from_string("<<a,t>, <<a,t>, <e,t>>>")?,
428 ConstantFunctionState::PotentiallyConstant,
429 ),
430 (
431 LambdaType::from_string("<<a,t>, e>")?,
432 ConstantFunctionState::PotentiallyConstant,
433 ),
434 (
435 LambdaType::from_string("<e, <a,<a,t>>>")?,
436 ConstantFunctionState::PotentiallyConstant,
437 ),
438 ],
439 possible_types: HashMap::default(),
440 pool_index: 0,
441 position: 0,
442 open_nodes: 54,
443 constant_function: ConstantFunctionState::PotentiallyConstant,
444 };
445
446 c.update_possible_types();
447 let mut z = c
448 .possible_types
449 .iter()
450 .map(|(x, y)| {
451 let mut v = y.iter().map(|y| y.to_string()).collect::<Vec<_>>();
452 v.sort();
453 (x.to_string(), v)
454 })
455 .collect::<Vec<_>>();
456 z.sort();
457
458 assert_eq!(
459 z,
460 vec![
461 ("<<a,t>,<e,t>>".to_string(), vec!["<a,t>".to_string()]),
462 ("<a,<a,t>>".to_string(), vec!["e".to_string()]),
463 (
464 "<a,t>".to_string(),
465 vec!["<a,t>".to_string(), "a".to_string()]
466 ),
467 ("<e,t>".to_string(), vec!["<a,t>".to_string()]),
468 ("e".to_string(), vec!["<a,t>".to_string()]),
469 ("t".to_string(), vec!["a".to_string(), "e".to_string()]),
470 ]
471 );
472
473 Ok(())
474 }
475}