1#![warn(missing_docs)]
34
35use std::borrow::Borrow;
36use std::fmt::{Debug, Display};
37use std::hash::Hash;
38use std::marker::PhantomData;
39
40#[cfg(not(target_arch = "wasm32"))]
41use std::time::{Duration, Instant};
42
43pub use lexicon::{Lexicon, ParsingError};
44
45use logprob::LogProb;
46use min_max_heap::MinMaxHeap;
47use parsing::RuleHolder;
48
49pub use parsing::RulePool;
50use parsing::beam::{FuzzyScan, GeneratorScan, ParseScan, Scanner};
51use parsing::{BeamWrapper, PartialRulePool, expand};
52use petgraph::graph::NodeIndex;
53
54#[cfg(feature = "sampling")]
55use rand::Rng;
56#[cfg(feature = "sampling")]
57use rand_distr::Distribution;
58#[cfg(feature = "sampling")]
59use rand_distr::weighted::WeightedIndex;
60
61use serde::{Deserialize, Serialize};
62use thiserror::Error;
63
64#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
66pub enum PhonContent<T> {
67 Normal(T),
69 Affixed(Vec<T>),
71}
72
73#[derive(Debug, Copy, Clone, PartialEq, Eq, Error)]
74pub struct FlattenError {}
76impl Display for FlattenError {
77 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
78 write!(f, "This Input is not a Normal variant")
79 }
80}
81
82impl<T> PhonContent<T> {
83 pub fn try_inner(self) -> Result<T, FlattenError> {
85 match self {
86 PhonContent::Normal(x) => Ok(x),
87 PhonContent::Affixed(_) => Err(FlattenError {}),
88 }
89 }
90
91 pub fn new(x: Vec<T>) -> Vec<PhonContent<T>> {
93 x.into_iter().map(PhonContent::Normal).collect()
94 }
95
96 pub fn from<const N: usize>(x: [T; N]) -> [PhonContent<T>; N] {
98 x.map(PhonContent::Normal)
99 }
100
101 pub fn try_flatten(x: Vec<PhonContent<T>>) -> Result<Vec<T>, FlattenError> {
104 x.into_iter().map(PhonContent::try_inner).collect()
105 }
106}
107impl PhonContent<&str> {
108 #[must_use]
110 pub fn flatten(x: Vec<PhonContent<&str>>) -> Vec<String> {
111 let mut v = vec![];
112 for content in x {
113 match content {
114 PhonContent::Normal(val) => v.push(val.to_string()),
115 PhonContent::Affixed(items) => v.push(items.join("")),
116 }
117 }
118 v
119 }
120}
121
122#[allow(missing_docs)]
125#[derive(
126 Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Default, Serialize, Deserialize,
127)]
128pub enum Direction {
129 #[default]
130 Left,
131 Right,
132}
133
134impl Direction {
135 #[must_use]
137 pub fn flip(&self) -> Self {
138 match self {
139 Direction::Left => Direction::Right,
140 Direction::Right => Direction::Left,
141 }
142 }
143}
144
145impl From<Direction> for bool {
146 fn from(value: Direction) -> Self {
147 match value {
148 Direction::Left => false,
149 Direction::Right => true,
150 }
151 }
152}
153
154impl From<bool> for Direction {
155 fn from(value: bool) -> Self {
156 if value { Direction::Right } else { Direction::Left }
157 }
158}
159
160#[derive(Debug, Copy, Clone, PartialEq, Eq)]
169pub struct ParsingConfig {
170 min_log_prob: Option<LogProb<f64>>,
171 move_prob: LogProb<f64>,
172 dont_move_prob: LogProb<f64>,
173 max_steps: Option<usize>,
174 max_beams: Option<usize>,
175 max_consecutive_empty: Option<usize>,
176
177 #[cfg(not(target_arch = "wasm32"))]
178 max_time: Option<Duration>,
179}
180
181impl ParsingConfig {
182 #[must_use]
186 pub fn empty() -> ParsingConfig {
187 let move_prob = LogProb::from_raw_prob(0.5).unwrap();
188 let dont_move_prob = move_prob.opposite_prob();
189
190 ParsingConfig {
191 min_log_prob: None,
192 move_prob,
193 dont_move_prob,
194 max_consecutive_empty: None,
195 max_steps: None,
196 max_beams: None,
197 #[cfg(not(target_arch = "wasm32"))]
198 max_time: None,
199 }
200 }
201
202 #[must_use]
204 pub fn new(
205 min_log_prob: LogProb<f64>,
206 move_prob: LogProb<f64>,
207 max_steps: usize,
208 max_beams: usize,
209 ) -> ParsingConfig {
210 let max_steps = usize::min(parsing::MAX_STEPS, max_steps);
211 let merge_prob = move_prob.opposite_prob();
212 ParsingConfig {
213 min_log_prob: Some(min_log_prob),
214 move_prob,
215 dont_move_prob: merge_prob,
216 max_consecutive_empty: None,
217 max_steps: Some(max_steps),
218 max_beams: Some(max_beams),
219 #[cfg(not(target_arch = "wasm32"))]
220 max_time: None,
221 }
222 }
223
224 #[cfg(not(target_arch = "wasm32"))]
226 #[must_use]
227 pub fn with_max_time(mut self, duration: Duration) -> Self {
228 self.max_time = Some(duration);
229 self
230 }
231
232 #[must_use]
234 pub fn with_max_consecutive_empty(mut self, n: usize) -> Self {
235 self.max_consecutive_empty = Some(n);
236 self
237 }
238
239 #[must_use]
241 pub fn with_min_log_prob(mut self, min_log_prob: LogProb<f64>) -> Self {
242 self.min_log_prob = Some(min_log_prob);
243 self
244 }
245
246 #[must_use]
248 pub fn with_max_steps(mut self, max_steps: usize) -> Self {
249 self.max_steps = Some(max_steps);
250 self
251 }
252
253 #[must_use]
255 pub fn with_max_beams(mut self, max_beams: usize) -> Self {
256 self.max_beams = Some(max_beams);
257 self
258 }
259
260 #[must_use]
262 pub fn with_move_prob(mut self, move_prob: LogProb<f64>) -> Self {
263 self.move_prob = move_prob;
264 self.dont_move_prob = self.move_prob.opposite_prob();
265 self
266 }
267}
268
269impl Default for ParsingConfig {
270 fn default() -> Self {
271 ParsingConfig::new(
272 LogProb::new(-256.0).unwrap(),
273 LogProb::from_raw_prob(0.5).unwrap(),
274 128,
275 256,
276 )
277 }
278}
279
280#[derive(Debug, Copy, Clone, Eq, PartialEq)]
281struct BeamKey(LogProb<f64>, usize);
282
283impl PartialOrd for BeamKey {
284 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
285 Some(self.cmp(other))
286 }
287}
288
289impl Ord for BeamKey {
290 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
291 self.0.cmp(&other.0)
292 }
293}
294
295#[derive(Debug, Clone)]
296struct ParseHeap<T, B: Scanner<T>> {
297 parse_heap: MinMaxHeap<BeamKey>,
298 phantom: PhantomData<T>,
299 config: ParsingConfig,
300 rule_arena: Vec<RuleHolder>,
301 beam_arena: Vec<Option<BeamWrapper<T, B>>>,
302 #[cfg(feature = "sampling")]
303 random_buffer: Vec<BeamWrapper<T, B>>,
304 #[cfg(feature = "sampling")]
305 head: Option<BeamWrapper<T, B>>,
306 #[cfg(feature = "sampling")]
307 random_order: bool,
308}
309
310impl<T: Eq + std::fmt::Debug + Clone, B: Scanner<T> + Eq + Clone> ParseHeap<T, B> {
311 fn retain_map<F: FnMut(BeamWrapper<T, B>) -> Option<BeamWrapper<T, B>>>(&mut self, mut f: F) {
313 let mut heap = MinMaxHeap::new();
314 std::mem::swap(&mut heap, &mut self.parse_heap);
315 self.parse_heap.extend(heap.into_iter().filter(|x| {
316 let v = self.beam_arena.get_mut(x.1).unwrap().take().unwrap();
317 if let Some(v) = f(v) {
318 self.beam_arena[x.1] = Some(v);
319 true
320 } else {
321 false
322 }
323 }));
324 }
325
326 #[cfg(feature = "sampling")]
327 fn pop(&mut self) -> Option<BeamWrapper<T, B>> {
328 self.head.take().or_else(|| {
329 self.parse_heap
330 .pop_max()
331 .map(|x| self.beam_arena[x.1].take().unwrap())
332 })
333 }
334 #[cfg(not(feature = "sampling"))]
335 fn pop(&mut self) -> Option<BeamWrapper<T, B>> {
336 self.parse_heap
337 .pop_max()
338 .map(|x| self.beam_arena[x.1].take().unwrap())
339 }
340
341 fn can_push(&self, v: &BeamWrapper<T, B>) -> bool {
342 let is_probable_enough = self
343 .config
344 .min_log_prob
345 .is_none_or(|p| v.log_prob() > p);
346 let is_short_enough = self
347 .config
348 .max_steps
349 .is_none_or(|max_steps| v.n_steps() < max_steps);
350 let is_not_fake_structure = self
351 .config
352 .max_consecutive_empty
353 .is_none_or(|n_empty| v.n_consecutive_empty() <= n_empty);
354 is_short_enough && is_probable_enough && is_not_fake_structure
355 }
356
357 #[cfg(feature = "sampling")]
358 fn process_randoms(&mut self, rng: &mut impl Rng) {
359 let weights = self
360 .random_buffer
361 .iter()
362 .map(|x| -x.log_prob().into_inner())
363 .collect::<Vec<_>>();
364 if !weights.is_empty() {
365 let head_id = if weights.len() > 1 {
366 let weights = WeightedIndex::new(weights).unwrap();
367 weights.sample(rng)
368 } else {
369 0
370 };
371
372 let buffer = std::mem::take(&mut self.random_buffer);
374 for (i, beam) in buffer.into_iter().enumerate() {
375 if i == head_id {
376 self.head = Some(beam);
377 } else {
378 self.add_to_heap(beam);
379 }
380 }
381 }
382 }
383
384 fn add_to_heap(&mut self, v: BeamWrapper<T, B>) {
385 let key = BeamKey(v.log_prob(), self.beam_arena.len());
386 if let Some(max_beams) = self.config.max_beams
387 && self.parse_heap.len() > max_beams
388 {
389 let x = self.parse_heap.push_pop_min(key);
390 if x.1 != key.1 {
391 self.beam_arena[x.1] = None;
393 }
394 } else {
395 self.parse_heap.push(key);
396 }
397 self.beam_arena.push(Some(v));
398 }
399
400 fn push(&mut self, v: BeamWrapper<T, B>) {
401 if self.can_push(&v) {
402 #[cfg(feature = "sampling")]
403 if self.random_order {
404 self.random_buffer.push(v);
405 } else {
406 self.add_to_heap(v);
407 }
408
409 #[cfg(not(feature = "sampling"))]
410 self.add_to_heap(v);
411 }
412 }
413
414 fn new(start: BeamWrapper<T, B>, config: &ParsingConfig, cat: NodeIndex) -> Self {
415 let mut parse_heap = MinMaxHeap::with_capacity(config.max_beams.unwrap_or(50));
416 let key = BeamKey(start.log_prob(), 0);
417 parse_heap.push(key);
418 let beam_arena = vec![Some(start)];
419 ParseHeap {
420 parse_heap,
421 beam_arena,
422 phantom: PhantomData,
423 config: *config,
424 #[cfg(feature = "sampling")]
425 random_order: false,
426 #[cfg(feature = "sampling")]
427 random_buffer: vec![],
428 #[cfg(feature = "sampling")]
429 head: None,
430 rule_arena: PartialRulePool::default_pool(cat),
431 }
432 }
433
434 fn rules_mut(&mut self) -> &mut Vec<RuleHolder> {
435 &mut self.rule_arena
436 }
437}
438
439type ParserOutput<'a, T> = (LogProb<f64>, &'a [PhonContent<T>], RulePool);
440type GeneratorOutput<T> = (LogProb<f64>, Vec<PhonContent<T>>, RulePool);
441
442pub struct FuzzyParser<
444 'a,
445 'b,
446 T: Eq + std::fmt::Debug + Clone,
447 Category: Eq + Clone + std::fmt::Debug,
448> {
449 lexicon: &'a Lexicon<T, Category>,
450 parse_heap: ParseHeap<T, FuzzyScan<'b, T>>,
451 config: &'a ParsingConfig,
452}
453
454impl<T, Category> Iterator for FuzzyParser<'_, '_, T, Category>
455where
456 T: Eq + std::fmt::Debug + Clone,
457 Category: Eq + Clone + std::fmt::Debug + Hash,
458{
459 type Item = GeneratorOutput<T>;
460
461 fn next(&mut self) -> Option<Self::Item> {
462 while let Some(mut beam) = self.parse_heap.pop() {
463 if let Some(moment) = beam.pop_moment() {
464 expand(
465 &mut self.parse_heap,
466 moment,
467 beam,
468 self.lexicon,
469 self.config,
470 );
471 } else if let Some(x) = FuzzyScan::yield_good_parse(beam, &self.parse_heap.rule_arena) {
472 return Some(x);
473 }
474 }
475
476 None
477 }
478}
479
480pub struct Parser<'a, 'b, T: Eq + std::fmt::Debug + Clone, Category: Eq + Clone + std::fmt::Debug> {
482 lexicon: &'a Lexicon<T, Category>,
483 parse_heap: ParseHeap<T, ParseScan<'b, T>>,
484
485 #[cfg(not(target_arch = "wasm32"))]
486 start_time: Option<Instant>,
487 config: &'a ParsingConfig,
488 buffer: Vec<ParserOutput<'b, T>>,
489}
490
491impl<'b, T, Category> Iterator for Parser<'_, 'b, T, Category>
492where
493 T: Eq + std::fmt::Debug + Clone,
494 Category: Eq + Clone + std::fmt::Debug,
495{
496 type Item = ParserOutput<'b, T>;
497
498 fn next(&mut self) -> Option<Self::Item> {
499 #[cfg(not(target_arch = "wasm32"))]
500 if self.start_time.is_none() {
501 self.start_time = Some(Instant::now());
502 }
503
504 if self.buffer.is_empty() {
505 while let Some(mut beam) = self.parse_heap.pop() {
506 #[cfg(not(target_arch = "wasm32"))]
507 if let Some(max_time) = self.config.max_time
508 && max_time < self.start_time.unwrap().elapsed()
509 {
510 return None;
511 }
512
513 if let Some(moment) = beam.pop_moment() {
514 expand(
515 &mut self.parse_heap,
516 moment,
517 beam,
518 self.lexicon,
519 self.config,
520 );
521 } else if let Some((mut good_parses, p, rules)) =
522 ParseScan::yield_good_parse(beam, &self.parse_heap.rule_arena)
523 && let Some(next_sentence) = good_parses.next()
524 {
525 self.buffer
526 .extend(good_parses.map(|x| (p, x, rules.clone())));
527 let next = Some((p, next_sentence, rules));
528 return next;
529 }
530 }
531 } else {
532 return self.buffer.pop();
533 }
534
535 None
536 }
537}
538
539#[cfg(feature = "sampling")]
541pub struct RandomParser<
542 'a,
543 'b,
544 T: Eq + std::fmt::Debug + Clone,
545 Category: Eq + Clone + std::fmt::Debug,
546 R: Rng,
547> {
548 lexicon: &'a Lexicon<T, Category>,
549 parse_heap: ParseHeap<T, ParseScan<'b, T>>,
550
551 #[cfg(not(target_arch = "wasm32"))]
552 start_time: Option<Instant>,
553 config: &'a ParsingConfig,
554 buffer: Vec<ParserOutput<'b, T>>,
555 rng: &'a mut R,
556}
557
558#[cfg(feature = "sampling")]
559impl<'b, T, Category, R> Iterator for RandomParser<'_, 'b, T, Category, R>
560where
561 T: Eq + std::fmt::Debug + Clone,
562 Category: Eq + Clone + std::fmt::Debug,
563 R: Rng,
564{
565 type Item = ParserOutput<'b, T>;
566
567 fn next(&mut self) -> Option<Self::Item> {
568 #[cfg(not(target_arch = "wasm32"))]
569 if self.start_time.is_none() {
570 self.start_time = Some(Instant::now());
571 }
572
573 if self.buffer.is_empty() {
574 while let Some(mut beam) = self.parse_heap.pop() {
575 #[cfg(not(target_arch = "wasm32"))]
576 if let Some(max_time) = self.config.max_time
577 && max_time < self.start_time.unwrap().elapsed()
578 {
579 return None;
580 }
581
582 if let Some(moment) = beam.pop_moment() {
583 expand(
584 &mut self.parse_heap,
585 moment,
586 beam,
587 self.lexicon,
588 self.config,
589 );
590 self.parse_heap.process_randoms(self.rng);
591 } else if let Some((mut good_parses, p, rules)) =
592 ParseScan::yield_good_parse(beam, &self.parse_heap.rule_arena)
593 && let Some(next_sentence) = good_parses.next()
594 {
595 self.parse_heap.retain_map(|mut x| {
597 x.beam.sentence.retain(|(s, _)| s != &next_sentence);
598 if x.beam.sentence.is_empty() {
599 None
600 } else {
601 Some(x)
602 }
603 });
604 self.buffer
605 .extend(good_parses.map(|x| (p, x, rules.clone())));
606 let next = Some((p, next_sentence, rules));
607 return next;
608 }
609 }
610 } else {
611 return self.buffer.pop();
612 }
613
614 None
615 }
616}
617
618impl<T, Category> Lexicon<T, Category>
619where
620 T: Eq + std::fmt::Debug + Clone,
621 Category: Eq + Clone + std::fmt::Debug,
622{
623 pub fn generate(
630 &self,
631 category: Category,
632 config: &ParsingConfig,
633 ) -> Result<Generator<&Self, T, Category>, ParsingError<Category>> {
634 let cat = self.find_category(&category)?;
635 let beam = BeamWrapper::new(GeneratorScan { sentence: vec![] }, cat);
636 let parse_heap = ParseHeap::new(beam, config, cat);
637 Ok(Generator {
638 lexicon: self,
639 config: *config,
640 parse_heap,
641 phantom: PhantomData,
642 })
643 }
644
645 pub fn into_generate(
647 self,
648 category: Category,
649 config: &ParsingConfig,
650 ) -> Result<Generator<Self, T, Category>, ParsingError<Category>> {
651 let cat = self.find_category(&category)?;
652 let beam = BeamWrapper::new(GeneratorScan { sentence: vec![] }, cat);
653 let parse_heap = ParseHeap::new(beam, config, cat);
654 Ok(Generator {
655 lexicon: self,
656 config: *config,
657 parse_heap,
658 phantom: PhantomData,
659 })
660 }
661 #[cfg(feature = "sampling")]
666 pub fn random_parse<'a, 'b, R: Rng>(
667 &'a self,
668 s: &'b [PhonContent<T>],
669 category: Category,
670 config: &'a ParsingConfig,
671 rng: &'a mut R,
672 ) -> Result<Option<ParserOutput<'b, T>>, ParsingError<Category>> {
673 let cat = self.find_category(&category)?;
674
675 let beam = BeamWrapper::new(
676 ParseScan {
677 sentence: vec![(s, 0)],
678 },
679 cat,
680 );
681 let mut parse_heap = ParseHeap::new(beam, config, cat);
682 parse_heap.random_order = true;
683 Ok(RandomParser {
684 lexicon: self,
685 config,
686 #[cfg(not(target_arch = "wasm32"))]
687 start_time: None,
688 buffer: vec![],
689 parse_heap,
690 rng,
691 }
692 .next())
693 }
694
695 #[cfg(feature = "sampling")]
701 pub fn random_parse_multiple<'a, 'b, U, R: Rng>(
702 &'a self,
703 sentences: &'b [U],
704 category: Category,
705 config: &'a ParsingConfig,
706 rng: &'a mut R,
707 ) -> Result<RandomParser<'a, 'b, T, Category, R>, ParsingError<Category>>
708 where
709 U: AsRef<[PhonContent<T>]>,
710 {
711 let cat = self.find_category(&category)?;
712
713 let beam = BeamWrapper::new(
714 ParseScan {
715 sentence: sentences.iter().map(|x| (x.as_ref(), 0)).collect(),
716 },
717 cat,
718 );
719 let mut parse_heap = ParseHeap::new(beam, config, cat);
720 parse_heap.random_order = true;
721 Ok(RandomParser {
722 lexicon: self,
723 config,
724 #[cfg(not(target_arch = "wasm32"))]
725 start_time: None,
726 buffer: vec![],
727 parse_heap,
728 rng,
729 })
730 }
731
732 pub fn parse<'a, 'b>(
737 &'a self,
738 s: &'b [PhonContent<T>],
739 category: Category,
740 config: &'a ParsingConfig,
741 ) -> Result<Parser<'a, 'b, T, Category>, ParsingError<Category>> {
742 let cat = self.find_category(&category)?;
743
744 let beam = BeamWrapper::new(
745 ParseScan {
746 sentence: vec![(s, 0)],
747 },
748 cat,
749 );
750 let parse_heap = ParseHeap::new(beam, config, cat);
751 Ok(Parser {
752 lexicon: self,
753 config,
754 #[cfg(not(target_arch = "wasm32"))]
755 start_time: None,
756 buffer: vec![],
757 parse_heap,
758 })
759 }
760
761 pub fn parse_multiple<'a, 'b, U>(
763 &'a self,
764 sentences: &'b [U],
765 category: Category,
766 config: &'a ParsingConfig,
767 ) -> Result<Parser<'a, 'b, T, Category>, ParsingError<Category>>
768 where
769 U: AsRef<[PhonContent<T>]>,
770 {
771 let cat = self.find_category(&category)?;
772 let beams = BeamWrapper::new(
773 ParseScan {
774 sentence: sentences.iter().map(|x| (x.as_ref(), 0)).collect(),
775 },
776 cat,
777 );
778 let parse_heap = ParseHeap::new(beams, config, cat);
779 Ok(Parser {
780 lexicon: self,
781 buffer: vec![],
782 #[cfg(not(target_arch = "wasm32"))]
783 start_time: None,
784 config,
785 parse_heap,
786 })
787 }
788
789 pub fn fuzzy_parse<'a, 'b, U>(
792 &'a self,
793 sentences: &'b [U],
794 category: Category,
795 config: &'a ParsingConfig,
796 ) -> Result<FuzzyParser<'a, 'b, T, Category>, ParsingError<Category>>
797 where
798 U: AsRef<[PhonContent<T>]>,
799 {
800 let cat = self.find_category(&category)?;
801
802 let beams = BeamWrapper::new(
803 FuzzyScan {
804 sentence_guides: sentences.iter().map(|x| (x.as_ref(), 0)).collect(),
805 generated_sentences: vec![],
806 },
807 cat,
808 );
809
810 let parse_heap = ParseHeap::new(beams, config, cat);
811
812 Ok(FuzzyParser {
813 lexicon: self,
814 config,
815 parse_heap,
816 })
817 }
818}
819
820#[derive(Debug, Clone)]
821pub struct Generator<L, T: Eq + std::fmt::Debug + Clone, Category: Eq + Clone + std::fmt::Debug> {
823 lexicon: L,
824 phantom: PhantomData<Category>,
825 parse_heap: ParseHeap<T, GeneratorScan<T>>,
826 config: ParsingConfig,
827}
828
829impl<L, T, Category> Iterator for Generator<L, T, Category>
830where
831 L: Borrow<Lexicon<T, Category>>,
832 T: Eq + std::fmt::Debug + Clone,
833 Category: Eq + Clone + std::fmt::Debug + Hash,
834{
835 type Item = GeneratorOutput<T>;
836
837 fn next(&mut self) -> Option<Self::Item> {
838 while let Some(mut beam) = self.parse_heap.pop() {
839 if let Some(moment) = beam.pop_moment() {
840 expand(
841 &mut self.parse_heap,
842 moment,
843 beam,
844 self.lexicon.borrow(),
845 &self.config,
846 );
847 } else if let Some(x) =
848 GeneratorScan::yield_good_parse(beam, &self.parse_heap.rule_arena)
849 {
850 return Some(x);
851 }
852 }
853 None
854 }
855}
856
857pub mod grammars;
858pub mod lexicon;
859pub mod parsing;
860
861#[cfg(test)]
862mod tests;