1use chumsky::{
3 extra::ParserExtra,
4 label::LabelError,
5 prelude::*,
6 text::{TextExpected, inline_whitespace},
7};
8#[cfg(feature = "sampling")]
9use rand::{Rng, RngExt, seq::IteratorRandom};
10use std::{fmt::Display, sync::LazyLock};
11use thiserror::Error;
12
13#[derive(Debug, Clone, Error, PartialEq, Eq)]
15pub struct TypeParsingError(String);
16
17impl From<Vec<Rich<'_, char>>> for TypeParsingError {
18 fn from(value: Vec<Rich<'_, char>>) -> Self {
19 TypeParsingError(
20 value
21 .iter()
22 .map(std::string::ToString::to_string)
23 .collect::<Vec<_>>()
24 .join("\n"),
25 )
26 }
27}
28
29impl Display for TypeParsingError {
30 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
31 writeln!(f, "{}", self.0)
32 }
33}
34
35#[derive(Debug, Clone, Error, PartialEq, Eq)]
36pub enum TypeError {
38 #[error("Cannot split a primitive type")]
40 NotAFunction,
41 #[error("Cannot apply {0} to {1}!")]
43 CantApply(LambdaType, LambdaType),
44}
45
46#[derive(Debug, Clone, Eq, PartialEq, Default, Hash, PartialOrd, Ord)]
47pub enum LambdaType {
49 #[default]
50 A,
52 E,
54 T,
56 Composition(Box<LambdaType>, Box<LambdaType>),
58}
59
60pub(crate) fn core_type_parser<'src, E>()
61-> impl Parser<'src, &'src str, LambdaType, E> + Clone + 'src
62where
63 E: ParserExtra<'src, &'src str> + 'src,
64 E::Error: LabelError<'src, &'src str, TextExpected<&'src str>>,
65 E::Error: LabelError<'src, &'src str, TextExpected<()>>,
66{
67 let atom = choice((
68 just('e').to(LambdaType::e().clone()),
69 just('t').to(LambdaType::t().clone()),
70 just('a').to(LambdaType::a().clone()),
71 ));
72 recursive(|expr| {
73 atom.or((expr.clone().then_ignore(just(',').padded()).then(expr))
74 .map(|(x, y)| LambdaType::compose(x, y))
75 .delimited_by(
76 just('<').then(inline_whitespace()),
77 inline_whitespace().then(just('>')),
78 ))
79 })
80}
81
82fn type_parser<'a>() -> impl Parser<'a, &'a str, LambdaType, extra::Err<Rich<'a, char>>> + Clone {
83 core_type_parser().padded().then_ignore(end())
84}
85
86#[derive(Debug, Clone, PartialEq, Eq)]
87pub struct RetrievableTypeIterator<'a>(&'a LambdaType);
89
90impl<'a> Iterator for RetrievableTypeIterator<'a> {
91 type Item = (&'a LambdaType, &'a LambdaType);
92
93 fn next(&mut self) -> Option<Self::Item> {
94 match self.0.split() {
95 Ok((lhs, rhs)) => {
96 self.0 = rhs;
97 Some((lhs, rhs))
98 }
99 Err(_) => None,
100 }
101 }
102}
103
104impl LambdaType {
105 #[must_use]
107 pub fn lift_type(self) -> LambdaType {
108 let t = LambdaType::compose(self, LambdaType::T);
109
110 LambdaType::compose(t, LambdaType::T)
111 }
112
113 #[must_use]
132 pub fn retrievable_types(&self) -> RetrievableTypeIterator<'_> {
133 RetrievableTypeIterator(self)
134 }
135
136 #[must_use]
138 pub fn is_lifted_type_of(&self, t: &LambdaType) -> bool {
139 let Ok((a, LambdaType::T)) = self.split() else {
140 return false;
141 };
142 let Ok((a, LambdaType::T)) = a.split() else {
143 return false;
144 };
145
146 a == t
147 }
148
149 pub fn add_right_argument(&mut self, other: LambdaType) {
151 let mut t = LambdaType::A;
152 std::mem::swap(&mut t, self);
153 *self = LambdaType::Composition(Box::new(t), Box::new(other));
154 }
155 pub fn add_left_argument(&mut self, other: LambdaType) {
157 let mut t = LambdaType::A;
158 std::mem::swap(&mut t, self);
159 *self = LambdaType::Composition(Box::new(other), Box::new(t));
160 }
161
162 pub fn from_string(s: &str) -> Result<Self, TypeParsingError> {
174 type_parser()
175 .parse(s)
176 .into_result()
177 .map_err(std::convert::Into::into)
178 }
179
180 #[must_use]
182 pub fn a() -> &'static Self {
183 &LambdaType::A
184 }
185
186 #[must_use]
188 pub fn e() -> &'static Self {
189 &LambdaType::E
190 }
191
192 #[must_use]
194 pub fn t() -> &'static Self {
195 &LambdaType::T
196 }
197
198 #[must_use]
200 pub fn compose(a: Self, b: Self) -> Self {
201 LambdaType::Composition(Box::new(a), Box::new(b))
202 }
203
204 #[must_use]
206 pub fn at() -> &'static Self {
207 static VAL: LazyLock<LambdaType> =
208 LazyLock::new(|| LambdaType::compose(LambdaType::a().clone(), LambdaType::t().clone()));
209 &VAL
210 }
211
212 #[must_use]
214 pub fn et() -> &'static Self {
215 static VAL: LazyLock<LambdaType> =
216 LazyLock::new(|| LambdaType::compose(LambdaType::e().clone(), LambdaType::t().clone()));
217 &VAL
218 }
219
220 #[must_use]
222 pub fn eet() -> &'static Self {
223 static VAL: LazyLock<LambdaType> = LazyLock::new(|| {
224 LambdaType::compose(
225 LambdaType::e().clone(),
226 LambdaType::compose(LambdaType::e().clone(), LambdaType::t().clone()),
227 )
228 });
229 &VAL
230 }
231 #[must_use]
233 pub fn ett() -> &'static Self {
234 static VAL: LazyLock<LambdaType> = LazyLock::new(|| {
235 LambdaType::compose(
236 LambdaType::compose(LambdaType::e().clone(), LambdaType::t().clone()),
237 LambdaType::t().clone(),
238 )
239 });
240 &VAL
241 }
242
243 #[must_use]
245 pub fn size(&self) -> usize {
246 match self {
247 LambdaType::A | LambdaType::E | LambdaType::T => 1,
248 LambdaType::Composition(a, b) => a.size() + b.size(),
249 }
250 }
251
252 #[must_use]
254 pub fn can_apply(&self, other: &Self) -> bool {
255 matches!(&self, LambdaType::Composition(a, _) if a.as_ref() == other)
256 }
257
258 pub fn split(&self) -> Result<(&LambdaType, &LambdaType), TypeError> {
263 match &self {
264 LambdaType::Composition(a, b) => Ok((a, b)),
265 _ => Err(TypeError::NotAFunction),
266 }
267 }
268
269 pub fn apply(&self, other: &Self) -> Result<&Self, TypeError> {
275 if !self.can_apply(other) {
276 return Err(TypeError::CantApply(other.clone(), self.clone()));
277 }
278 self.rhs()
279 }
280
281 #[must_use]
283 pub fn is_function(&self) -> bool {
284 matches!(&self, LambdaType::Composition(_, _))
285 }
286
287 pub fn lhs(&self) -> Result<&Self, TypeError> {
292 Ok(self.split()?.0)
293 }
294
295 #[must_use]
297 pub fn is_one_place_function(&self) -> bool {
298 if let Ok((lhs, rhs)) = self.split() {
299 !lhs.is_function() && !rhs.is_function()
300 } else {
301 false
302 }
303 }
304
305 pub fn rhs(&self) -> Result<&Self, TypeError> {
310 Ok(self.split()?.1)
311 }
312}
313
314impl Display for LambdaType {
315 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
316 match &self {
317 LambdaType::E => write!(f, "e"),
318 LambdaType::T => write!(f, "t"),
319 LambdaType::A => write!(f, "a"),
320 LambdaType::Composition(lhs, rhs) => write!(f, "<{lhs},{rhs}>"),
321 }
322 }
323}
324
325#[cfg(feature = "sampling")]
326const RECURSE_PROB: f64 = 0.2;
327#[cfg(feature = "sampling")]
328const MAX_DEPTH: u8 = 64;
329
330#[cfg(feature = "sampling")]
331impl LambdaType {
332 fn random_inner(r: &mut impl Rng, depth: u8, no_e: bool) -> Self {
333 if depth < MAX_DEPTH && r.random_bool(RECURSE_PROB) {
334 LambdaType::compose(
335 LambdaType::random_inner(r, depth + 1, false),
336 LambdaType::random_inner(r, depth + 1, no_e),
337 )
338 } else if no_e {
339 if r.random_bool(0.5) {
340 LambdaType::t().clone()
341 } else {
342 LambdaType::a().clone()
343 }
344 } else {
345 let i = (0..3).choose(r).unwrap();
346 [LambdaType::t(), LambdaType::a(), LambdaType::e()][i].clone()
347 }
348 }
349
350 pub fn random(r: &mut impl Rng) -> Self {
352 LambdaType::random_inner(r, 0, false)
353 }
354
355 pub fn random_no_e(r: &mut impl Rng) -> Self {
357 LambdaType::random_inner(r, 0, true)
358 }
359}
360
361#[cfg(test)]
362mod test {
363
364 #[cfg(feature = "sampling")]
365 use rand::SeedableRng;
366 #[cfg(feature = "sampling")]
367 use rand_chacha::ChaCha8Rng;
368
369 use super::*;
370
371 #[cfg(feature = "sampling")]
372 #[test]
373 fn random_lambdas() -> anyhow::Result<()> {
374 let mut r = ChaCha8Rng::seed_from_u64(32);
375 for _ in 0..10_000 {
376 let t = LambdaType::random(&mut r);
377 println!("{t}");
378 }
379 Ok(())
380 }
381
382 #[test]
383 fn check_application() -> anyhow::Result<()> {
384 let et = LambdaType::et();
385 let et_to_et = LambdaType::compose(et.clone(), et.clone());
386 let et_squared_to_et_squared = LambdaType::compose(et_to_et.clone(), et_to_et.clone());
387 assert!(et.can_apply(LambdaType::e()));
388 assert!(et_to_et.can_apply(et));
389 assert!(et_squared_to_et_squared.can_apply(&et_to_et));
390 assert!(!et.can_apply(LambdaType::t()));
391 assert!(!et_to_et.can_apply(&et_squared_to_et_squared));
392 assert!(!et_squared_to_et_squared.can_apply(&et_squared_to_et_squared));
393
394 assert_eq!(&et_to_et, et_squared_to_et_squared.rhs()?);
395
396 assert_eq!(et, et_to_et.rhs()?);
397
398 assert_eq!(LambdaType::t(), et.rhs()?);
399 Ok(())
400 }
401
402 #[test]
403 fn parse_types() {
404 let parser = type_parser();
405 assert_eq!(&parser.parse("e").unwrap(), LambdaType::e());
406 assert_eq!(&parser.parse(" e ").unwrap(), LambdaType::e());
407 assert_eq!(&parser.parse("e ").unwrap(), LambdaType::e());
408 assert!(parser.parse("e z").has_errors());
409
410 assert_eq!(&parser.parse("t").unwrap(), LambdaType::t());
411
412 let et = LambdaType::et();
413 assert_eq!(&parser.parse("<e, t>").unwrap(), et);
414
415 let et_to_et = LambdaType::compose(et.clone(), et.clone());
416
417 assert_eq!(parser.parse("<<e, t>, <e,t>>").unwrap(), et_to_et);
418
419 let et_squared_to_et_squared = LambdaType::compose(et_to_et.clone(), et_to_et);
420 assert_eq!(
421 parser.parse("<< <e, t>, <e,t>>, <<e,t>, <e,t>>>").unwrap(),
422 et_squared_to_et_squared
423 );
424 }
425
426 #[test]
427 fn check_printing() {
428 let parser = type_parser();
429 for s in ["e", "t", "<e,t>", "<e,<e,t>>", "<t,<<t,t>,<e,t>>>"] {
430 assert_eq!(parser.parse(s).unwrap().to_string(), s);
431 }
432 }
433
434 #[test]
435 fn check_lifting() -> anyhow::Result<()> {
436 for s in ["e", "t", "<e,t>", "<e,<e,t>>", "<t,<<t,t>,<e,t>>>"] {
437 let lifted_str = format!("<<{s},t>,t>");
438 let lifted = LambdaType::from_string(&lifted_str)?;
439 let base_type = LambdaType::from_string(s)?;
440 assert!(lifted.is_lifted_type_of(&base_type));
441 assert_eq!(base_type.lift_type(), lifted);
442 }
443
444 Ok(())
445 }
446}