1use std::{fmt::Display, sync::LazyLock};
3
4use chumsky::{
5 extra::ParserExtra,
6 label::LabelError,
7 prelude::*,
8 text::{TextExpected, inline_whitespace},
9};
10
11#[cfg(feature = "sampling")]
12use rand::{Rng, seq::IteratorRandom};
13
14use thiserror::Error;
15
16#[derive(Debug, Clone, Error, PartialEq, Eq)]
18pub struct TypeParsingError(String);
19
20impl From<Vec<Rich<'_, char>>> for TypeParsingError {
21 fn from(value: Vec<Rich<'_, char>>) -> Self {
22 TypeParsingError(
23 value
24 .iter()
25 .map(|e| e.to_string())
26 .collect::<Vec<_>>()
27 .join("\n"),
28 )
29 }
30}
31
32impl Display for TypeParsingError {
33 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
34 writeln!(f, "{}", self.0)
35 }
36}
37
38#[derive(Debug, Clone, Error, PartialEq, Eq)]
39pub enum TypeError {
41 #[error("Cannot split a primitive type")]
43 NotAFunction,
44 #[error("Cannot apply {0} to {1}!")]
46 CantApply(LambdaType, LambdaType),
47}
48
49#[derive(Debug, Clone, Eq, PartialEq, Default, Hash, PartialOrd, Ord)]
50pub enum LambdaType {
52 #[default]
53 A,
55 E,
57 T,
59 Composition(Box<LambdaType>, Box<LambdaType>),
61}
62
63pub(crate) fn core_type_parser<'src, E>()
64-> impl Parser<'src, &'src str, LambdaType, E> + Clone + 'src
65where
66 E: ParserExtra<'src, &'src str> + 'src,
67 E::Error: LabelError<'src, &'src str, TextExpected<'src, &'src str>>,
68{
69 let atom = choice((
70 just('e').to(LambdaType::e().clone()),
71 just('t').to(LambdaType::t().clone()),
72 just('a').to(LambdaType::a().clone()),
73 ));
74 recursive(|expr| {
75 atom.or((expr.clone().then_ignore(just(',').padded()).then(expr))
76 .map(|(x, y)| LambdaType::compose(x, y))
77 .delimited_by(
78 just('<').then(inline_whitespace()),
79 inline_whitespace().then(just('>')),
80 ))
81 })
82}
83
84fn type_parser<'a>() -> impl Parser<'a, &'a str, LambdaType, extra::Err<Rich<'a, char>>> + Clone {
85 core_type_parser().padded().then_ignore(end())
86}
87
88#[derive(Debug, Clone, Copy, PartialEq, Eq)]
89pub struct RetrievableTypeIterator<'a>(&'a LambdaType);
91
92impl<'a> Iterator for RetrievableTypeIterator<'a> {
93 type Item = (&'a LambdaType, &'a LambdaType);
94
95 fn next(&mut self) -> Option<Self::Item> {
96 match self.0.split() {
97 Ok((lhs, rhs)) => {
98 self.0 = rhs;
99 Some((lhs, rhs))
100 }
101 Err(_) => None,
102 }
103 }
104}
105
106impl LambdaType {
107 pub fn lift_type(self) -> LambdaType {
109 let t = LambdaType::compose(self, LambdaType::T);
110
111 LambdaType::compose(t, LambdaType::T)
112 }
113
114 pub fn retrievable_types<'a>(&'a self) -> RetrievableTypeIterator<'a> {
133 RetrievableTypeIterator(self)
134 }
135
136 pub fn is_lifted_type_of(&self, t: &LambdaType) -> bool {
138 let Ok((a, LambdaType::T)) = self.split() else {
139 return false;
140 };
141 let Ok((a, LambdaType::T)) = a.split() else {
142 return false;
143 };
144
145 a == t
146 }
147
148 pub fn add_right_argument(&mut self, other: LambdaType) {
150 let mut t = LambdaType::A;
151 std::mem::swap(&mut t, self);
152 *self = LambdaType::Composition(Box::new(t), Box::new(other));
153 }
154 pub fn add_left_argument(&mut self, other: LambdaType) {
156 let mut t = LambdaType::A;
157 std::mem::swap(&mut t, self);
158 *self = LambdaType::Composition(Box::new(other), Box::new(t));
159 }
160
161 pub fn from_string(s: &str) -> Result<Self, TypeParsingError> {
170 type_parser().parse(s).into_result().map_err(|x| x.into())
171 }
172
173 pub fn a() -> &'static Self {
175 &LambdaType::A
176 }
177
178 pub fn e() -> &'static Self {
180 &LambdaType::E
181 }
182
183 pub fn t() -> &'static Self {
185 &LambdaType::T
186 }
187
188 pub fn compose(a: Self, b: Self) -> Self {
190 LambdaType::Composition(Box::new(a), Box::new(b))
191 }
192
193 pub fn at() -> &'static Self {
195 static VAL: LazyLock<LambdaType> =
196 LazyLock::new(|| LambdaType::compose(LambdaType::a().clone(), LambdaType::t().clone()));
197 &VAL
198 }
199
200 pub fn et() -> &'static Self {
202 static VAL: LazyLock<LambdaType> =
203 LazyLock::new(|| LambdaType::compose(LambdaType::e().clone(), LambdaType::t().clone()));
204 &VAL
205 }
206
207 pub fn eet() -> &'static Self {
209 static VAL: LazyLock<LambdaType> = LazyLock::new(|| {
210 LambdaType::compose(
211 LambdaType::e().clone(),
212 LambdaType::compose(LambdaType::e().clone(), LambdaType::t().clone()),
213 )
214 });
215 &VAL
216 }
217 pub fn ett() -> &'static Self {
219 static VAL: LazyLock<LambdaType> = LazyLock::new(|| {
220 LambdaType::compose(
221 LambdaType::compose(LambdaType::e().clone(), LambdaType::t().clone()),
222 LambdaType::t().clone(),
223 )
224 });
225 &VAL
226 }
227
228 pub fn size(&self) -> usize {
230 match self {
231 LambdaType::A | LambdaType::E | LambdaType::T => 1,
232 LambdaType::Composition(a, b) => a.size() + b.size(),
233 }
234 }
235
236 pub fn can_apply(&self, other: &Self) -> bool {
238 matches!(&self, LambdaType::Composition(a, _) if a.as_ref() == other)
239 }
240
241 pub fn split(&self) -> Result<(&LambdaType, &LambdaType), TypeError> {
244 match &self {
245 LambdaType::Composition(a, b) => Ok((a, b)),
246 _ => Err(TypeError::NotAFunction),
247 }
248 }
249
250 pub fn apply(&self, other: &Self) -> Result<&Self, TypeError> {
253 if !self.can_apply(other) {
254 return Err(TypeError::CantApply(other.clone(), self.clone()));
255 }
256 self.rhs()
257 }
258
259 pub fn is_function(&self) -> bool {
261 matches!(&self, LambdaType::Composition(_, _))
262 }
263
264 pub fn lhs(&self) -> Result<&Self, TypeError> {
267 Ok(self.split()?.0)
268 }
269
270 pub fn is_one_place_function(&self) -> bool {
272 if let Ok((lhs, rhs)) = self.split() {
273 !lhs.is_function() && !rhs.is_function()
274 } else {
275 false
276 }
277 }
278
279 pub fn rhs(&self) -> Result<&Self, TypeError> {
282 Ok(self.split()?.1)
283 }
284}
285
286impl Display for LambdaType {
287 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
288 match &self {
289 LambdaType::E => write!(f, "e"),
290 LambdaType::T => write!(f, "t"),
291 LambdaType::A => write!(f, "a"),
292 LambdaType::Composition(lhs, rhs) => write!(f, "<{lhs},{rhs}>"),
293 }
294 }
295}
296
297#[cfg(feature = "sampling")]
298const RECURSE_PROB: f64 = 0.2;
299#[cfg(feature = "sampling")]
300const MAX_DEPTH: u8 = 64;
301
302#[cfg(feature = "sampling")]
303impl LambdaType {
304 fn random_inner(r: &mut impl Rng, depth: u8, no_e: bool) -> Self {
305 if depth < MAX_DEPTH && r.random_bool(RECURSE_PROB) {
306 LambdaType::compose(
307 LambdaType::random_inner(r, depth + 1, false),
308 LambdaType::random_inner(r, depth + 1, no_e),
309 )
310 } else if no_e {
311 if r.random_bool(0.5) {
312 LambdaType::t().clone()
313 } else {
314 LambdaType::a().clone()
315 }
316 } else {
317 let i = (0..3).choose(r).unwrap();
318 [LambdaType::t(), LambdaType::a(), LambdaType::e()][i].clone()
319 }
320 }
321
322 pub fn random(r: &mut impl Rng) -> Self {
324 LambdaType::random_inner(r, 0, false)
325 }
326
327 pub fn random_no_e(r: &mut impl Rng) -> Self {
329 LambdaType::random_inner(r, 0, true)
330 }
331}
332
333#[cfg(test)]
334mod test {
335
336 #[cfg(feature = "sampling")]
337 use rand::SeedableRng;
338 #[cfg(feature = "sampling")]
339 use rand_chacha::ChaCha8Rng;
340
341 use super::*;
342
343 #[cfg(feature = "sampling")]
344 #[test]
345 fn random_lambdas() -> anyhow::Result<()> {
346 let mut r = ChaCha8Rng::seed_from_u64(32);
347 for _ in 0..10_000 {
348 let t = LambdaType::random(&mut r);
349 println!("{t}");
350 }
351 Ok(())
352 }
353
354 #[test]
355 fn check_application() -> anyhow::Result<()> {
356 let et = LambdaType::et();
357 let et_to_et = LambdaType::compose(et.clone(), et.clone());
358 let et_squared_to_et_squared = LambdaType::compose(et_to_et.clone(), et_to_et.clone());
359 assert!(et.can_apply(LambdaType::e()));
360 assert!(et_to_et.can_apply(et));
361 assert!(et_squared_to_et_squared.can_apply(&et_to_et));
362 assert!(!et.can_apply(LambdaType::t()));
363 assert!(!et_to_et.can_apply(&et_squared_to_et_squared));
364 assert!(!et_squared_to_et_squared.can_apply(&et_squared_to_et_squared));
365
366 assert_eq!(&et_to_et, et_squared_to_et_squared.rhs()?);
367
368 assert_eq!(et, et_to_et.rhs()?);
369
370 assert_eq!(LambdaType::t(), et.rhs()?);
371 Ok(())
372 }
373
374 #[test]
375 fn parse_types() {
376 let parser = type_parser();
377 assert_eq!(&parser.parse("e").unwrap(), LambdaType::e());
378 assert_eq!(&parser.parse(" e ").unwrap(), LambdaType::e());
379 assert_eq!(&parser.parse("e ").unwrap(), LambdaType::e());
380 assert!(parser.parse("e z").has_errors());
381
382 assert_eq!(&parser.parse("t").unwrap(), LambdaType::t());
383
384 let et = LambdaType::et();
385 assert_eq!(&parser.parse("<e, t>").unwrap(), et);
386
387 let et_to_et = LambdaType::compose(et.clone(), et.clone());
388
389 assert_eq!(parser.parse("<<e, t>, <e,t>>").unwrap(), et_to_et);
390
391 let et_squared_to_et_squared = LambdaType::compose(et_to_et.clone(), et_to_et);
392 assert_eq!(
393 parser.parse("<< <e, t>, <e,t>>, <<e,t>, <e,t>>>").unwrap(),
394 et_squared_to_et_squared
395 );
396 }
397
398 #[test]
399 fn check_printing() {
400 let parser = type_parser();
401 for s in ["e", "t", "<e,t>", "<e,<e,t>>", "<t,<<t,t>,<e,t>>>"] {
402 assert_eq!(parser.parse(s).unwrap().to_string(), s);
403 }
404 }
405
406 #[test]
407 fn check_lifting() -> anyhow::Result<()> {
408 for s in ["e", "t", "<e,t>", "<e,<e,t>>", "<t,<<t,t>,<e,t>>>"] {
409 let lifted_str = format!("<<{s},t>,t>");
410 let lifted = LambdaType::from_string(&lifted_str)?;
411 let base_type = LambdaType::from_string(s)?;
412 assert!(lifted.is_lifted_type_of(&base_type));
413 assert_eq!(base_type.lift_type(), lifted);
414 }
415
416 Ok(())
417 }
418}