simple_semantics/language/mutations/
context.rs

1use super::{
2    Debug, LambdaExpr, LambdaExprRef, LambdaLanguageOfThought, LambdaPool, LambdaType, RandomPQ,
3    RootedLambdaPool, TypeError,
4};
5use ahash::{HashMap, HashSet};
6use itertools::Either;
7
8#[derive(Debug, Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)]
9pub(super) enum ConstantFunctionState {
10    Constant,
11    PotentiallyConstant,
12    NonConstant,
13}
14
15impl ConstantFunctionState {
16    fn update(&mut self, new: Self) {
17        match self {
18            ConstantFunctionState::Constant => (),
19            ConstantFunctionState::PotentiallyConstant => match new {
20                ConstantFunctionState::Constant => *self = ConstantFunctionState::Constant,
21                ConstantFunctionState::PotentiallyConstant | ConstantFunctionState::NonConstant => {
22                }
23            },
24            ConstantFunctionState::NonConstant => match new {
25                ConstantFunctionState::Constant => *self = ConstantFunctionState::Constant,
26                ConstantFunctionState::PotentiallyConstant => {
27                    *self = ConstantFunctionState::PotentiallyConstant;
28                }
29                ConstantFunctionState::NonConstant => *self = ConstantFunctionState::NonConstant,
30            },
31        }
32    }
33
34    fn use_var(&mut self) {
35        match self {
36            ConstantFunctionState::PotentiallyConstant => {
37                *self = ConstantFunctionState::NonConstant;
38            }
39            ConstantFunctionState::Constant | ConstantFunctionState::NonConstant => (),
40        }
41    }
42
43    fn done(&mut self) {
44        match self {
45            ConstantFunctionState::PotentiallyConstant => *self = ConstantFunctionState::Constant,
46            ConstantFunctionState::Constant | ConstantFunctionState::NonConstant => (),
47        }
48    }
49}
50
51///A struct which keeps track of the context leading up to some expression, e.g. its depth, what
52///variables are accessible, and whether the context has a constant function.
53#[derive(Debug, Clone, PartialEq, Eq)]
54pub struct Context {
55    lambdas: Vec<(LambdaType, ConstantFunctionState)>,
56    possible_types: HashMap<LambdaType, HashSet<LambdaType>>,
57    pub(super) pool_index: usize,
58    pub(super) position: usize,
59    pub(super) depth: usize,
60    done: bool,
61    pub(super) open_nodes: usize,
62    constant_function: ConstantFunctionState,
63}
64
65impl PartialOrd for RandomPQ {
66    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
67        Some(self.cmp(other))
68    }
69}
70
71///Reversed to deal with pq
72impl Ord for RandomPQ {
73    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
74        let c = &self.0;
75        let o = &other.0;
76
77        c.done
78            .cmp(&o.done)
79            .then(o.open_depth_score().cmp(&c.open_depth_score()))
80            .then(o.lambdas.len().cmp(&c.lambdas.len()))
81            .then(o.constant_function.cmp(&c.constant_function))
82            .then(self.1.partial_cmp(&other.1).unwrap())
83    }
84}
85
86impl PartialOrd for Context {
87    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
88        Some(self.cmp(other))
89    }
90}
91
92impl Ord for Context {
93    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
94        other
95            .done
96            .cmp(&self.done)
97            .then(self.open_depth_score().cmp(&other.open_depth_score()))
98            .then(self.constant_function.cmp(&other.constant_function))
99            .then(self.pool_index.cmp(&other.pool_index))
100    }
101}
102impl Context {
103    fn open_depth_score(&self) -> usize {
104        self.depth + self.open_nodes.pow(2) + self.lambdas.len()
105    }
106}
107
108impl<T: LambdaLanguageOfThought> LambdaPool<'_, T> {
109    //If something at position `pos` can be moved from `old_context` to `new_context`
110    fn compatible_with(
111        &self,
112        pos: LambdaExprRef,
113        new_context: &Context,
114        old_context: &Context,
115    ) -> bool {
116        for (x, d) in self.bfs_from(pos) {
117            if let LambdaExpr::BoundVariable(b, _) = self.get(x)
118                && b + 1 > d
119            {
120                //this involves the outside context;
121                let old_lambda_pos = old_context.lambdas.len() + d - b - 1;
122
123                if b + 1 > d + new_context.lambdas.len() {
124                    //Impossible to access
125                    //TODO: Maybe some remapping system if old_context is contained by
126                    //new_context
127                    return false;
128                }
129                let new_lambda_pos = new_context.lambdas.len() + d - b - 1;
130                if new_context.lambdas.get(new_lambda_pos)
131                    != old_context.lambdas.get(old_lambda_pos)
132                {
133                    return false;
134                }
135            }
136        }
137        true
138    }
139}
140
141impl Context {
142    #[allow(clippy::len_without_is_empty)]
143    ///The length of the context thus far.
144    #[must_use]
145    pub fn len(&self) -> usize {
146        self.depth
147    }
148
149    ///The number of variables in the current [`Context`]
150    #[must_use]
151    pub fn n_vars(&self) -> usize {
152        self.lambdas.len()
153    }
154
155    pub(super) fn from_pos<T: LambdaLanguageOfThought>(
156        pool: &RootedLambdaPool<'_, T>,
157        pos: LambdaExprRef,
158    ) -> (Context, bool) {
159        let mut context = Context::new(0, vec![]);
160        let mut stack = vec![(pool.root, 0, false)];
161        let mut return_is_subformula = false;
162
163        while let Some((x, n_lambdas, is_subformula)) = stack.pop() {
164            context.depth += 1;
165            let e = pool.get(x);
166            if context.lambdas.len() != n_lambdas {
167                for _ in 0..(context.lambdas.len() - n_lambdas) {
168                    context.pop_lambda();
169                }
170            }
171
172            if pos == x {
173                return_is_subformula = is_subformula;
174                break;
175            }
176
177            if let Some(v) = e.var_type() {
178                context.add_lambda(v);
179            } else if let LambdaExpr::BoundVariable(n, _) = e {
180                context.use_bvar(*n);
181            }
182
183            if let LambdaExpr::Application {
184                subformula,
185                argument,
186            } = e
187            {
188                stack.push((*subformula, context.lambdas.len(), true));
189                stack.push((*argument, context.lambdas.len(), false));
190            } else {
191                stack.extend(e.get_children().map(|x| (x, context.lambdas.len(), false)));
192            }
193        }
194        (context, return_is_subformula)
195    }
196
197    pub(super) fn find_compatible<T: LambdaLanguageOfThought>(
198        &self,
199        pool: &RootedLambdaPool<'_, T>,
200        pos: LambdaExprRef,
201    ) -> Result<Vec<LambdaExprRef>, TypeError> {
202        let t = pool.pool.get_type(pos)?;
203
204        let mut this_context = Context::new(0, vec![]);
205        let mut stack = vec![(pool.root, 0)];
206        let mut options: Vec<_> = vec![];
207        while let Some((x, n_lambdas)) = stack.pop() {
208            this_context.depth += 1;
209            let e = pool.get(x);
210            if this_context.lambdas.len() != n_lambdas {
211                for _ in 0..(this_context.lambdas.len() - n_lambdas) {
212                    this_context.pop_lambda();
213                }
214            }
215            if pos != x
216                && t == pool.pool.get_type(x)?
217                && pool.pool.compatible_with(x, self, &this_context)
218            {
219                options.push(x);
220            }
221
222            if let Some(v) = e.var_type() {
223                this_context.add_lambda(v);
224            } else if let LambdaExpr::BoundVariable(n, _) = e {
225                this_context.use_bvar(*n);
226            }
227
228            stack.extend(e.get_children().map(|x| (x, this_context.lambdas.len())));
229        }
230        Ok(options)
231    }
232
233    fn update_possible_types(&mut self) {
234        self.possible_types.clear();
235
236        let mut new_types: HashSet<(&LambdaType, &LambdaType)> = HashSet::default();
237        let mut base_types: HashSet<_> = self.lambdas.iter().map(|(x, _)| x).collect();
238        base_types.insert(LambdaType::a());
239        base_types.insert(LambdaType::e());
240        base_types.insert(LambdaType::t());
241        base_types.insert(LambdaType::at());
242        base_types.insert(LambdaType::et());
243
244        loop {
245            for subformula in &base_types {
246                if let Ok((argument, result_type)) = subformula.split() {
247                    let already_has_type = self
248                        .possible_types
249                        .get(result_type)
250                        .is_some_and(|x| x.contains(argument));
251
252                    if base_types.contains(argument) && !already_has_type {
253                        new_types.insert((result_type, argument));
254                    }
255                }
256            }
257            if new_types.is_empty() {
258                break;
259            }
260            for (result, argument) in &new_types {
261                self.possible_types
262                    .entry((*result).clone())
263                    .or_default()
264                    .insert((*argument).clone());
265            }
266            base_types.extend(new_types.drain().map(|(result, _arg)| result));
267        }
268    }
269
270    pub(super) fn new(position: usize, lambdas: Vec<(LambdaType, ConstantFunctionState)>) -> Self {
271        let mut c = Context {
272            pool_index: 0,
273            position,
274            done: false,
275            depth: 0,
276            possible_types: HashMap::default(),
277            open_nodes: 1,
278            constant_function: if lambdas.is_empty() {
279                ConstantFunctionState::NonConstant
280            } else {
281                ConstantFunctionState::PotentiallyConstant
282            },
283            lambdas,
284        };
285        c.update_possible_types();
286        c
287    }
288
289    pub(super) fn add_lambda(&mut self, t: &LambdaType) {
290        self.constant_function
291            .update(ConstantFunctionState::PotentiallyConstant);
292        self.lambdas
293            .push((t.clone(), ConstantFunctionState::PotentiallyConstant));
294        self.update_possible_types();
295    }
296
297    pub(super) fn pop_lambda(&mut self) {
298        let (_, mut function_state) = self.lambdas.pop().unwrap();
299        function_state.done();
300        self.constant_function.update(function_state);
301        if self.lambdas.is_empty() {
302            self.constant_function.use_var();
303        }
304        self.update_possible_types();
305    }
306
307    pub(super) fn use_bvar(&mut self, b: usize) {
308        let n = self.lambdas.len() - b - 1;
309        self.lambdas.get_mut(n).unwrap().1.use_var();
310    }
311
312    ///Does the context have any constant functions preceding it?
313    #[must_use]
314    pub fn is_constant(&self) -> bool {
315        self.constant_function == ConstantFunctionState::Constant
316    }
317
318    ///What are the *current* possible variable types
319    pub fn current_variable_types(&self) -> impl Iterator<Item = &LambdaType> {
320        self.lambdas.iter().map(|(x, _)| x)
321    }
322
323    ///What possible applications can be created at this point?
324    pub fn applications<'a, 'b: 'a>(
325        &'a self,
326        lambda_type: &'b LambdaType,
327    ) -> impl Iterator<Item = (LambdaType, LambdaType)> + 'a {
328        match self.possible_types.get(lambda_type) {
329            Some(x) => Either::Left(x.iter().map(|x| {
330                (
331                    LambdaType::compose(x.clone(), lambda_type.clone()),
332                    x.clone(),
333                )
334            })),
335            None => Either::Right(std::iter::empty()),
336        }
337    }
338
339    ///What are the current variables?
340    pub fn variables<'src, T: LambdaLanguageOfThought>(
341        &self,
342        lambda_type: &LambdaType,
343    ) -> impl Iterator<Item = LambdaExpr<'src, T>> {
344        let n = self.lambdas.len();
345        self.lambdas
346            .iter()
347            .enumerate()
348            .filter_map(move |(i, (lambda, _))| {
349                if lambda == lambda_type {
350                    Some(LambdaExpr::BoundVariable(n - i - 1, lambda.clone()))
351                } else {
352                    None
353                }
354            })
355    }
356}
357
358#[cfg(test)]
359mod test {
360    use super::*;
361    #[test]
362    fn test_context() -> anyhow::Result<()> {
363        let a = Context {
364            depth: 1,
365            done: false,
366            lambdas: vec![],
367            pool_index: 0,
368            position: 0,
369            possible_types: HashMap::default(),
370            open_nodes: 0,
371            constant_function: ConstantFunctionState::NonConstant,
372        };
373        let b = Context {
374            depth: 2,
375            done: false,
376            lambdas: vec![],
377            possible_types: HashMap::default(),
378            pool_index: 0,
379            position: 0,
380            open_nodes: 0,
381            constant_function: ConstantFunctionState::NonConstant,
382        };
383        let c = Context {
384            depth: 5,
385            done: true,
386            lambdas: vec![],
387            possible_types: HashMap::default(),
388            pool_index: 0,
389            position: 0,
390            open_nodes: 0,
391            constant_function: ConstantFunctionState::NonConstant,
392        };
393        let d = Context {
394            depth: 5,
395            done: true,
396            lambdas: vec![],
397            possible_types: HashMap::default(),
398            pool_index: 0,
399            position: 0,
400            open_nodes: 54,
401            constant_function: ConstantFunctionState::NonConstant,
402        };
403
404        assert!(a < b);
405        assert!(c < b);
406        assert!(c < a);
407        assert!(c < d);
408
409        Ok(())
410    }
411
412    #[test]
413    fn possible_type_check() -> anyhow::Result<()> {
414        let mut c = Context {
415            depth: 0,
416            done: false,
417            lambdas: vec![
418                (
419                    LambdaType::from_string("<a,t>")?,
420                    ConstantFunctionState::PotentiallyConstant,
421                ),
422                (
423                    LambdaType::from_string("<<a,t>, <a,t>>")?,
424                    ConstantFunctionState::PotentiallyConstant,
425                ),
426                (
427                    LambdaType::from_string("<<a,t>, <<a,t>, <e,t>>>")?,
428                    ConstantFunctionState::PotentiallyConstant,
429                ),
430                (
431                    LambdaType::from_string("<<a,t>, e>")?,
432                    ConstantFunctionState::PotentiallyConstant,
433                ),
434                (
435                    LambdaType::from_string("<e, <a,<a,t>>>")?,
436                    ConstantFunctionState::PotentiallyConstant,
437                ),
438            ],
439            possible_types: HashMap::default(),
440            pool_index: 0,
441            position: 0,
442            open_nodes: 54,
443            constant_function: ConstantFunctionState::PotentiallyConstant,
444        };
445
446        c.update_possible_types();
447        let mut z = c
448            .possible_types
449            .iter()
450            .map(|(x, y)| {
451                let mut v = y.iter().map(|y| y.to_string()).collect::<Vec<_>>();
452                v.sort();
453                (x.to_string(), v)
454            })
455            .collect::<Vec<_>>();
456        z.sort();
457
458        assert_eq!(
459            z,
460            vec![
461                ("<<a,t>,<e,t>>".to_string(), vec!["<a,t>".to_string()]),
462                ("<a,<a,t>>".to_string(), vec!["e".to_string()]),
463                (
464                    "<a,t>".to_string(),
465                    vec!["<a,t>".to_string(), "a".to_string()]
466                ),
467                ("<e,t>".to_string(), vec!["<a,t>".to_string()]),
468                ("e".to_string(), vec!["<a,t>".to_string()]),
469                ("t".to_string(), vec!["a".to_string(), "e".to_string()]),
470            ]
471        );
472
473        Ok(())
474    }
475}