1use ahash::{HashMap, HashSet};
2use itertools::Either;
3
4use super::*;
5
6#[derive(Debug, Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)]
7pub(super) enum ConstantFunctionState {
8 Constant,
9 PotentiallyConstant,
10 NonConstant,
11}
12
13impl ConstantFunctionState {
14 fn update(&mut self, new: Self) {
15 match self {
16 ConstantFunctionState::Constant => (),
17 ConstantFunctionState::PotentiallyConstant => match new {
18 ConstantFunctionState::Constant => *self = ConstantFunctionState::Constant,
19 ConstantFunctionState::PotentiallyConstant => (),
20 ConstantFunctionState::NonConstant => (),
21 },
22 ConstantFunctionState::NonConstant => match new {
23 ConstantFunctionState::Constant => *self = ConstantFunctionState::Constant,
24 ConstantFunctionState::PotentiallyConstant => {
25 *self = ConstantFunctionState::PotentiallyConstant
26 }
27 ConstantFunctionState::NonConstant => *self = ConstantFunctionState::NonConstant,
28 },
29 }
30 }
31
32 fn use_var(&mut self) {
33 match self {
34 ConstantFunctionState::Constant => (),
35 ConstantFunctionState::PotentiallyConstant => {
36 *self = ConstantFunctionState::NonConstant
37 }
38 ConstantFunctionState::NonConstant => (),
39 }
40 }
41
42 fn done(&mut self) {
43 match self {
44 ConstantFunctionState::Constant => (),
45 ConstantFunctionState::PotentiallyConstant => *self = ConstantFunctionState::Constant,
46 ConstantFunctionState::NonConstant => (),
47 }
48 }
49}
50
51#[derive(Debug, Clone, PartialEq, Eq)]
54pub struct Context {
55 lambdas: Vec<(LambdaType, ConstantFunctionState)>,
56 possible_types: HashMap<LambdaType, HashSet<LambdaType>>,
57 pub(super) pool_index: usize,
58 pub(super) position: usize,
59 pub(super) depth: usize,
60 done: bool,
61 pub(super) open_nodes: usize,
62 constant_function: ConstantFunctionState,
63}
64
65impl PartialOrd for RandomPQ {
66 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
67 Some(self.cmp(other))
68 }
69}
70
71impl Ord for RandomPQ {
73 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
74 let c = &self.0;
75 let o = &other.0;
76
77 c.done
78 .cmp(&o.done)
79 .then(o.open_depth_score().cmp(&c.open_depth_score()))
80 .then(o.lambdas.len().cmp(&c.lambdas.len()))
81 .then(o.constant_function.cmp(&c.constant_function))
82 .then(self.1.partial_cmp(&other.1).unwrap())
83 }
84}
85
86impl PartialOrd for Context {
87 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
88 Some(self.cmp(other))
89 }
90}
91
92impl Ord for Context {
93 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
94 other
95 .done
96 .cmp(&self.done)
97 .then(self.open_depth_score().cmp(&other.open_depth_score()))
98 .then(self.constant_function.cmp(&other.constant_function))
99 .then(self.pool_index.cmp(&other.pool_index))
100 }
101}
102impl Context {
103 fn open_depth_score(&self) -> usize {
104 self.depth + self.open_nodes.pow(2) + self.lambdas.len()
105 }
106}
107
108impl<'src, T: LambdaLanguageOfThought> LambdaPool<'src, T> {
109 fn compatible_with(
111 &self,
112 pos: LambdaExprRef,
113 new_context: &Context,
114 old_context: &Context,
115 ) -> bool {
116 for (x, d) in self.bfs_from(pos) {
117 if let LambdaExpr::BoundVariable(b, _) = self.get(x)
118 && b + 1 > d
119 {
120 let old_lambda_pos = old_context.lambdas.len() + d - b - 1;
122
123 if b + 1 > d + new_context.lambdas.len() {
124 return false;
128 }
129 let new_lambda_pos = new_context.lambdas.len() + d - b - 1;
130 if new_context.lambdas.get(new_lambda_pos)
131 != old_context.lambdas.get(old_lambda_pos)
132 {
133 return false;
134 }
135 }
136 }
137 true
138 }
139}
140
141impl Context {
142 #[allow(clippy::len_without_is_empty)]
143 pub fn len(&self) -> usize {
145 self.depth
146 }
147
148 pub fn n_vars(&self) -> usize {
150 self.lambdas.len()
151 }
152
153 pub(super) fn from_pos<'src, T: LambdaLanguageOfThought>(
154 pool: &RootedLambdaPool<'src, T>,
155 pos: LambdaExprRef,
156 ) -> (Context, bool) {
157 let mut context = Context::new(0, vec![]);
158 let mut stack = vec![(pool.root, 0, false)];
159 let mut return_is_subformula = false;
160
161 while let Some((x, n_lambdas, is_subformula)) = stack.pop() {
162 context.depth += 1;
163 let e = pool.get(x);
164 if context.lambdas.len() != n_lambdas {
165 for _ in 0..(context.lambdas.len() - n_lambdas) {
166 context.pop_lambda();
167 }
168 }
169
170 if pos == x {
171 return_is_subformula = is_subformula;
172 break;
173 }
174
175 if let Some(v) = e.var_type() {
176 context.add_lambda(v);
177 } else if let LambdaExpr::BoundVariable(n, _) = e {
178 context.use_bvar(*n);
179 }
180
181 if let LambdaExpr::Application {
182 subformula,
183 argument,
184 } = e
185 {
186 stack.push((*subformula, context.lambdas.len(), true));
187 stack.push((*argument, context.lambdas.len(), false));
188 } else {
189 stack.extend(e.get_children().map(|x| (x, context.lambdas.len(), false)));
190 }
191 }
192 (context, return_is_subformula)
193 }
194
195 pub(super) fn find_compatible<'src, T: LambdaLanguageOfThought>(
196 &self,
197 pool: &RootedLambdaPool<'src, T>,
198 pos: LambdaExprRef,
199 ) -> Result<Vec<LambdaExprRef>, TypeError> {
200 let t = pool.pool.get_type(pos)?;
201
202 let mut this_context = Context::new(0, vec![]);
203 let mut stack = vec![(pool.root, 0)];
204 let mut options: Vec<_> = vec![];
205 while let Some((x, n_lambdas)) = stack.pop() {
206 this_context.depth += 1;
207 let e = pool.get(x);
208 if this_context.lambdas.len() != n_lambdas {
209 for _ in 0..(this_context.lambdas.len() - n_lambdas) {
210 this_context.pop_lambda();
211 }
212 }
213 if pos != x
214 && t == pool.pool.get_type(x)?
215 && pool.pool.compatible_with(x, self, &this_context)
216 {
217 options.push(x);
218 }
219
220 if let Some(v) = e.var_type() {
221 this_context.add_lambda(v);
222 } else if let LambdaExpr::BoundVariable(n, _) = e {
223 this_context.use_bvar(*n);
224 }
225
226 stack.extend(e.get_children().map(|x| (x, this_context.lambdas.len())));
227 }
228 Ok(options)
229 }
230
231 fn update_possible_types(&mut self) {
232 self.possible_types.clear();
233
234 let mut new_types: HashSet<(&LambdaType, &LambdaType)> = HashSet::default();
235 let mut base_types: HashSet<_> = self.lambdas.iter().map(|(x, _)| x).collect();
236 base_types.insert(LambdaType::a());
237 base_types.insert(LambdaType::e());
238 base_types.insert(LambdaType::t());
239 base_types.insert(LambdaType::at());
240 base_types.insert(LambdaType::et());
241
242 loop {
243 for subformula in base_types.iter() {
244 if let Ok((argument, result_type)) = subformula.split() {
245 let already_has_type = self
246 .possible_types
247 .get(result_type)
248 .map(|x| x.contains(argument))
249 .unwrap_or(false);
250
251 if base_types.contains(argument) && !already_has_type {
252 new_types.insert((result_type, argument));
253 }
254 }
255 }
256 if new_types.is_empty() {
257 break;
258 } else {
259 for (result, argument) in new_types.iter() {
260 self.possible_types
261 .entry((*result).clone())
262 .or_default()
263 .insert((*argument).clone());
264 }
265 base_types.extend(new_types.drain().map(|(result, _arg)| result));
266 }
267 }
268 }
269
270 pub(super) fn new(position: usize, lambdas: Vec<(LambdaType, ConstantFunctionState)>) -> Self {
271 let mut c = Context {
272 pool_index: 0,
273 position,
274 done: false,
275 depth: 0,
276 possible_types: HashMap::default(),
277 open_nodes: 1,
278 constant_function: if lambdas.is_empty() {
279 ConstantFunctionState::NonConstant
280 } else {
281 ConstantFunctionState::PotentiallyConstant
282 },
283 lambdas,
284 };
285 c.update_possible_types();
286 c
287 }
288
289 pub(super) fn add_lambda(&mut self, t: &LambdaType) {
290 self.constant_function
291 .update(ConstantFunctionState::PotentiallyConstant);
292 self.lambdas
293 .push((t.clone(), ConstantFunctionState::PotentiallyConstant));
294 self.update_possible_types();
295 }
296
297 pub(super) fn pop_lambda(&mut self) {
298 let (_, mut function_state) = self.lambdas.pop().unwrap();
299 function_state.done();
300 self.constant_function.update(function_state);
301 if self.lambdas.is_empty() {
302 self.constant_function.use_var();
303 }
304 self.update_possible_types();
305 }
306
307 pub(super) fn use_bvar(&mut self, b: usize) {
308 let n = self.lambdas.len() - b - 1;
309 self.lambdas.get_mut(n).unwrap().1.use_var();
310 }
311
312 pub fn is_constant(&self) -> bool {
314 self.constant_function == ConstantFunctionState::Constant
315 }
316
317 pub fn current_variable_types(&self) -> impl Iterator<Item = &LambdaType> {
319 self.lambdas.iter().map(|(x, _)| x)
320 }
321
322 pub fn applications<'a, 'b: 'a>(
324 &'a self,
325 lambda_type: &'b LambdaType,
326 ) -> impl Iterator<Item = (LambdaType, LambdaType)> + 'a {
327 match self.possible_types.get(lambda_type) {
328 Some(x) => Either::Left(x.iter().map(|x| {
329 (
330 LambdaType::compose(x.clone(), lambda_type.clone()),
331 x.clone(),
332 )
333 })),
334 None => Either::Right(std::iter::empty()),
335 }
336 }
337
338 pub fn variables<'src, T: LambdaLanguageOfThought>(
340 &self,
341 lambda_type: &LambdaType,
342 ) -> impl Iterator<Item = LambdaExpr<'src, T>> {
343 let n = self.lambdas.len();
344 self.lambdas
345 .iter()
346 .enumerate()
347 .filter_map(move |(i, (lambda, _))| {
348 if lambda == lambda_type {
349 Some(LambdaExpr::BoundVariable(n - i - 1, lambda.clone()))
350 } else {
351 None
352 }
353 })
354 }
355}
356
357#[cfg(test)]
358mod test {
359 use super::*;
360 #[test]
361 fn test_context() -> anyhow::Result<()> {
362 let a = Context {
363 depth: 1,
364 done: false,
365 lambdas: vec![],
366 pool_index: 0,
367 position: 0,
368 possible_types: HashMap::default(),
369 open_nodes: 0,
370 constant_function: ConstantFunctionState::NonConstant,
371 };
372 let b = Context {
373 depth: 2,
374 done: false,
375 lambdas: vec![],
376 possible_types: HashMap::default(),
377 pool_index: 0,
378 position: 0,
379 open_nodes: 0,
380 constant_function: ConstantFunctionState::NonConstant,
381 };
382 let c = Context {
383 depth: 5,
384 done: true,
385 lambdas: vec![],
386 possible_types: HashMap::default(),
387 pool_index: 0,
388 position: 0,
389 open_nodes: 0,
390 constant_function: ConstantFunctionState::NonConstant,
391 };
392 let d = Context {
393 depth: 5,
394 done: true,
395 lambdas: vec![],
396 possible_types: HashMap::default(),
397 pool_index: 0,
398 position: 0,
399 open_nodes: 54,
400 constant_function: ConstantFunctionState::NonConstant,
401 };
402
403 assert!(a < b);
404 assert!(c < b);
405 assert!(c < a);
406 assert!(c < d);
407
408 Ok(())
409 }
410
411 #[test]
412 fn possible_type_check() -> anyhow::Result<()> {
413 let mut c = Context {
414 depth: 0,
415 done: false,
416 lambdas: vec![
417 (
418 LambdaType::from_string("<a,t>")?,
419 ConstantFunctionState::PotentiallyConstant,
420 ),
421 (
422 LambdaType::from_string("<<a,t>, <a,t>>")?,
423 ConstantFunctionState::PotentiallyConstant,
424 ),
425 (
426 LambdaType::from_string("<<a,t>, <<a,t>, <e,t>>>")?,
427 ConstantFunctionState::PotentiallyConstant,
428 ),
429 (
430 LambdaType::from_string("<<a,t>, e>")?,
431 ConstantFunctionState::PotentiallyConstant,
432 ),
433 (
434 LambdaType::from_string("<e, <a,<a,t>>>")?,
435 ConstantFunctionState::PotentiallyConstant,
436 ),
437 ],
438 possible_types: HashMap::default(),
439 pool_index: 0,
440 position: 0,
441 open_nodes: 54,
442 constant_function: ConstantFunctionState::PotentiallyConstant,
443 };
444
445 c.update_possible_types();
446 let mut z = c
447 .possible_types
448 .iter()
449 .map(|(x, y)| {
450 let mut v = y.iter().map(|y| y.to_string()).collect::<Vec<_>>();
451 v.sort();
452 (x.to_string(), v)
453 })
454 .collect::<Vec<_>>();
455 z.sort();
456
457 assert_eq!(
458 z,
459 vec![
460 ("<<a,t>,<e,t>>".to_string(), vec!["<a,t>".to_string()]),
461 ("<a,<a,t>>".to_string(), vec!["e".to_string()]),
462 (
463 "<a,t>".to_string(),
464 vec!["<a,t>".to_string(), "a".to_string()]
465 ),
466 ("<e,t>".to_string(), vec!["<a,t>".to_string()]),
467 ("e".to_string(), vec!["<a,t>".to_string()]),
468 ("t".to_string(), vec!["a".to_string(), "e".to_string()]),
469 ]
470 );
471
472 Ok(())
473 }
474}