1use chumsky::{
3 extra::ParserExtra,
4 label::LabelError,
5 prelude::*,
6 text::{TextExpected, inline_whitespace},
7};
8#[cfg(feature = "sampling")]
9use rand::{Rng, RngExt, seq::IteratorRandom};
10use std::{fmt::Display, sync::LazyLock};
11use thiserror::Error;
12
13#[derive(Debug, Clone, Error, PartialEq, Eq)]
15pub struct TypeParsingError(String);
16
17impl From<Vec<Rich<'_, char>>> for TypeParsingError {
18 fn from(value: Vec<Rich<'_, char>>) -> Self {
19 TypeParsingError(
20 value
21 .iter()
22 .map(std::string::ToString::to_string)
23 .collect::<Vec<_>>()
24 .join("\n"),
25 )
26 }
27}
28
29impl Display for TypeParsingError {
30 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
31 writeln!(f, "{}", self.0)
32 }
33}
34
35#[derive(Debug, Clone, Error, PartialEq, Eq)]
36pub enum TypeError {
38 #[error("Cannot split a primitive type")]
40 NotAFunction,
41 #[error("Cannot apply {0} to {1}!")]
43 CantApply(LambdaType, LambdaType),
44}
45
46#[derive(Debug, Clone, Eq, PartialEq, Default, Hash, PartialOrd, Ord)]
47pub enum LambdaType {
49 #[default]
50 A,
52 E,
54 T,
56 Composition(Box<LambdaType>, Box<LambdaType>),
58}
59
60pub(crate) fn core_type_parser<'src, E>()
61-> impl Parser<'src, &'src str, LambdaType, E> + Clone + 'src
62where
63 E: ParserExtra<'src, &'src str> + 'src,
64 E::Error: LabelError<'src, &'src str, TextExpected<&'src str>>,
65 E::Error: LabelError<'src, &'src str, TextExpected<()>>,
66{
67 let atom = choice((
68 just('e').to(LambdaType::e().clone()),
69 just('t').to(LambdaType::t().clone()),
70 just('a').to(LambdaType::a().clone()),
71 ));
72 recursive(|expr| {
73 atom.or((expr.clone().then_ignore(just(',').padded()).then(expr))
74 .map(|(x, y)| LambdaType::compose(x, y))
75 .delimited_by(
76 just('<').then(inline_whitespace()),
77 inline_whitespace().then(just('>')),
78 ))
79 })
80}
81
82fn type_parser<'a>() -> impl Parser<'a, &'a str, LambdaType, extra::Err<Rich<'a, char>>> + Clone {
83 core_type_parser().padded().then_ignore(end())
84}
85
86#[derive(Debug, Clone, Copy, PartialEq, Eq)]
87pub struct RetrievableTypeIterator<'a>(&'a LambdaType);
89
90impl<'a> Iterator for RetrievableTypeIterator<'a> {
91 type Item = (&'a LambdaType, &'a LambdaType);
92
93 fn next(&mut self) -> Option<Self::Item> {
94 match self.0.split() {
95 Ok((lhs, rhs)) => {
96 self.0 = rhs;
97 Some((lhs, rhs))
98 }
99 Err(_) => None,
100 }
101 }
102}
103
104impl LambdaType {
105 #[must_use]
107 pub fn lift_type(self) -> LambdaType {
108 let t = LambdaType::compose(self, LambdaType::T);
109
110 LambdaType::compose(t, LambdaType::T)
111 }
112
113 #[must_use]
132 pub fn retrievable_types(&self) -> RetrievableTypeIterator<'_> {
133 RetrievableTypeIterator(self)
134 }
135
136 #[must_use]
138 pub fn is_lifted_type_of(&self, t: &LambdaType) -> bool {
139 let Ok((a, LambdaType::T)) = self.split() else {
140 return false;
141 };
142 let Ok((a, LambdaType::T)) = a.split() else {
143 return false;
144 };
145
146 a == t
147 }
148
149 pub fn add_right_argument(&mut self, other: LambdaType) {
151 let mut t = LambdaType::A;
152 std::mem::swap(&mut t, self);
153 *self = LambdaType::Composition(Box::new(t), Box::new(other));
154 }
155 pub fn add_left_argument(&mut self, other: LambdaType) {
157 let mut t = LambdaType::A;
158 std::mem::swap(&mut t, self);
159 *self = LambdaType::Composition(Box::new(other), Box::new(t));
160 }
161
162 pub fn from_string(s: &str) -> Result<Self, TypeParsingError> {
171 type_parser()
172 .parse(s)
173 .into_result()
174 .map_err(std::convert::Into::into)
175 }
176
177 #[must_use]
179 pub fn a() -> &'static Self {
180 &LambdaType::A
181 }
182
183 #[must_use]
185 pub fn e() -> &'static Self {
186 &LambdaType::E
187 }
188
189 #[must_use]
191 pub fn t() -> &'static Self {
192 &LambdaType::T
193 }
194
195 #[must_use]
197 pub fn compose(a: Self, b: Self) -> Self {
198 LambdaType::Composition(Box::new(a), Box::new(b))
199 }
200
201 #[must_use]
203 pub fn at() -> &'static Self {
204 static VAL: LazyLock<LambdaType> =
205 LazyLock::new(|| LambdaType::compose(LambdaType::a().clone(), LambdaType::t().clone()));
206 &VAL
207 }
208
209 #[must_use]
211 pub fn et() -> &'static Self {
212 static VAL: LazyLock<LambdaType> =
213 LazyLock::new(|| LambdaType::compose(LambdaType::e().clone(), LambdaType::t().clone()));
214 &VAL
215 }
216
217 #[must_use]
219 pub fn eet() -> &'static Self {
220 static VAL: LazyLock<LambdaType> = LazyLock::new(|| {
221 LambdaType::compose(
222 LambdaType::e().clone(),
223 LambdaType::compose(LambdaType::e().clone(), LambdaType::t().clone()),
224 )
225 });
226 &VAL
227 }
228 #[must_use]
230 pub fn ett() -> &'static Self {
231 static VAL: LazyLock<LambdaType> = LazyLock::new(|| {
232 LambdaType::compose(
233 LambdaType::compose(LambdaType::e().clone(), LambdaType::t().clone()),
234 LambdaType::t().clone(),
235 )
236 });
237 &VAL
238 }
239
240 #[must_use]
242 pub fn size(&self) -> usize {
243 match self {
244 LambdaType::A | LambdaType::E | LambdaType::T => 1,
245 LambdaType::Composition(a, b) => a.size() + b.size(),
246 }
247 }
248
249 #[must_use]
251 pub fn can_apply(&self, other: &Self) -> bool {
252 matches!(&self, LambdaType::Composition(a, _) if a.as_ref() == other)
253 }
254
255 pub fn split(&self) -> Result<(&LambdaType, &LambdaType), TypeError> {
258 match &self {
259 LambdaType::Composition(a, b) => Ok((a, b)),
260 _ => Err(TypeError::NotAFunction),
261 }
262 }
263
264 pub fn apply(&self, other: &Self) -> Result<&Self, TypeError> {
267 if !self.can_apply(other) {
268 return Err(TypeError::CantApply(other.clone(), self.clone()));
269 }
270 self.rhs()
271 }
272
273 #[must_use]
275 pub fn is_function(&self) -> bool {
276 matches!(&self, LambdaType::Composition(_, _))
277 }
278
279 pub fn lhs(&self) -> Result<&Self, TypeError> {
282 Ok(self.split()?.0)
283 }
284
285 #[must_use]
287 pub fn is_one_place_function(&self) -> bool {
288 if let Ok((lhs, rhs)) = self.split() {
289 !lhs.is_function() && !rhs.is_function()
290 } else {
291 false
292 }
293 }
294
295 pub fn rhs(&self) -> Result<&Self, TypeError> {
298 Ok(self.split()?.1)
299 }
300}
301
302impl Display for LambdaType {
303 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
304 match &self {
305 LambdaType::E => write!(f, "e"),
306 LambdaType::T => write!(f, "t"),
307 LambdaType::A => write!(f, "a"),
308 LambdaType::Composition(lhs, rhs) => write!(f, "<{lhs},{rhs}>"),
309 }
310 }
311}
312
313#[cfg(feature = "sampling")]
314const RECURSE_PROB: f64 = 0.2;
315#[cfg(feature = "sampling")]
316const MAX_DEPTH: u8 = 64;
317
318#[cfg(feature = "sampling")]
319impl LambdaType {
320 fn random_inner(r: &mut impl Rng, depth: u8, no_e: bool) -> Self {
321 if depth < MAX_DEPTH && r.random_bool(RECURSE_PROB) {
322 LambdaType::compose(
323 LambdaType::random_inner(r, depth + 1, false),
324 LambdaType::random_inner(r, depth + 1, no_e),
325 )
326 } else if no_e {
327 if r.random_bool(0.5) {
328 LambdaType::t().clone()
329 } else {
330 LambdaType::a().clone()
331 }
332 } else {
333 let i = (0..3).choose(r).unwrap();
334 [LambdaType::t(), LambdaType::a(), LambdaType::e()][i].clone()
335 }
336 }
337
338 pub fn random(r: &mut impl Rng) -> Self {
340 LambdaType::random_inner(r, 0, false)
341 }
342
343 pub fn random_no_e(r: &mut impl Rng) -> Self {
345 LambdaType::random_inner(r, 0, true)
346 }
347}
348
349#[cfg(test)]
350mod test {
351
352 #[cfg(feature = "sampling")]
353 use rand::SeedableRng;
354 #[cfg(feature = "sampling")]
355 use rand_chacha::ChaCha8Rng;
356
357 use super::*;
358
359 #[cfg(feature = "sampling")]
360 #[test]
361 fn random_lambdas() -> anyhow::Result<()> {
362 let mut r = ChaCha8Rng::seed_from_u64(32);
363 for _ in 0..10_000 {
364 let t = LambdaType::random(&mut r);
365 println!("{t}");
366 }
367 Ok(())
368 }
369
370 #[test]
371 fn check_application() -> anyhow::Result<()> {
372 let et = LambdaType::et();
373 let et_to_et = LambdaType::compose(et.clone(), et.clone());
374 let et_squared_to_et_squared = LambdaType::compose(et_to_et.clone(), et_to_et.clone());
375 assert!(et.can_apply(LambdaType::e()));
376 assert!(et_to_et.can_apply(et));
377 assert!(et_squared_to_et_squared.can_apply(&et_to_et));
378 assert!(!et.can_apply(LambdaType::t()));
379 assert!(!et_to_et.can_apply(&et_squared_to_et_squared));
380 assert!(!et_squared_to_et_squared.can_apply(&et_squared_to_et_squared));
381
382 assert_eq!(&et_to_et, et_squared_to_et_squared.rhs()?);
383
384 assert_eq!(et, et_to_et.rhs()?);
385
386 assert_eq!(LambdaType::t(), et.rhs()?);
387 Ok(())
388 }
389
390 #[test]
391 fn parse_types() {
392 let parser = type_parser();
393 assert_eq!(&parser.parse("e").unwrap(), LambdaType::e());
394 assert_eq!(&parser.parse(" e ").unwrap(), LambdaType::e());
395 assert_eq!(&parser.parse("e ").unwrap(), LambdaType::e());
396 assert!(parser.parse("e z").has_errors());
397
398 assert_eq!(&parser.parse("t").unwrap(), LambdaType::t());
399
400 let et = LambdaType::et();
401 assert_eq!(&parser.parse("<e, t>").unwrap(), et);
402
403 let et_to_et = LambdaType::compose(et.clone(), et.clone());
404
405 assert_eq!(parser.parse("<<e, t>, <e,t>>").unwrap(), et_to_et);
406
407 let et_squared_to_et_squared = LambdaType::compose(et_to_et.clone(), et_to_et);
408 assert_eq!(
409 parser.parse("<< <e, t>, <e,t>>, <<e,t>, <e,t>>>").unwrap(),
410 et_squared_to_et_squared
411 );
412 }
413
414 #[test]
415 fn check_printing() {
416 let parser = type_parser();
417 for s in ["e", "t", "<e,t>", "<e,<e,t>>", "<t,<<t,t>,<e,t>>>"] {
418 assert_eq!(parser.parse(s).unwrap().to_string(), s);
419 }
420 }
421
422 #[test]
423 fn check_lifting() -> anyhow::Result<()> {
424 for s in ["e", "t", "<e,t>", "<e,<e,t>>", "<t,<<t,t>,<e,t>>>"] {
425 let lifted_str = format!("<<{s},t>,t>");
426 let lifted = LambdaType::from_string(&lifted_str)?;
427 let base_type = LambdaType::from_string(s)?;
428 assert!(lifted.is_lifted_type_of(&base_type));
429 assert_eq!(base_type.lift_type(), lifted);
430 }
431
432 Ok(())
433 }
434}