minimalist_grammar_parser/parsing/
beam.rs

1//!Module which defines how the different beams used by [`Lexicon::parse`] or [`Lexicon::generate`]
2//!work.
3
4use crate::{
5    ParseHeap, ParsingConfig, PhonContent, expand,
6    lexicon::{Lexicon, ParsingError},
7};
8
9use super::{BeamWrapper, RuleHolder, rules::RulePool};
10use ahash::HashSet;
11use logprob::LogProb;
12use std::{fmt::Debug, hash::Hash};
13
14///A trait which allows a struct to be used as by the parsing algorithm by defining how scanning
15///works. Parsing checks the next string corresponds to a parse, whereas generation uses scan to
16///iteratively build strings.
17pub(crate) trait Scanner<T>: Sized {
18    fn scan(&mut self, s: &Option<T>) -> bool;
19
20    fn multiscan(&mut self, heads: Vec<&T>) -> bool;
21}
22
23#[derive(Debug, Clone, PartialEq, Eq)]
24pub(crate) struct ParseScan<'a, T> {
25    pub sentence: Vec<(&'a [PhonContent<T>], usize)>,
26}
27
28impl<T> Scanner<T> for ParseScan<'_, T>
29where
30    T: std::cmp::Eq + std::fmt::Debug,
31{
32    fn scan(&mut self, s: &Option<T>) -> bool {
33        self.sentence.retain_mut(|(sentence, position)| match s {
34            Some(s) => {
35                if let Some(PhonContent::Normal(string)) = sentence.get(*position) {
36                    if s == string {
37                        *position += 1;
38                        true
39                    } else {
40                        false
41                    }
42                } else {
43                    false
44                }
45            }
46            None => true,
47        });
48        !self.sentence.is_empty()
49    }
50
51    fn multiscan(&mut self, heads: Vec<&T>) -> bool {
52        if heads.is_empty() {
53            return true;
54        }
55
56        self.sentence.retain_mut(|(sentence, position)| {
57            if let Some(s) = sentence.get(*position) {
58                match s {
59                    PhonContent::Normal(s) => {
60                        if heads.len() == 1 && heads.first().unwrap() == &s {
61                            *position += 1;
62                            true
63                        } else {
64                            false
65                        }
66                    }
67                    PhonContent::Affixed(string) => {
68                        if heads.len() == string.len()
69                            && heads.iter().zip(string.iter()).all(|(a, b)| *a == b)
70                        {
71                            *position += 1;
72                            true
73                        } else {
74                            false
75                        }
76                    }
77                }
78            } else {
79                false
80            }
81        });
82        !self.sentence.is_empty()
83    }
84}
85
86impl<'a, T: Eq + std::fmt::Debug + Clone> ParseScan<'a, T> {
87    #[allow(clippy::complexity)]
88    pub(crate) fn yield_good_parse(
89        b: BeamWrapper<T, Self>,
90        rules: &[RuleHolder],
91    ) -> Option<(
92        impl Iterator<Item = &'a [PhonContent<T>]> + 'a,
93        LogProb<f64>,
94        RulePool,
95    )> {
96        if b.is_empty() {
97            Some((
98                b.beam
99                    .sentence
100                    .into_iter()
101                    .filter(|(s, pos)| s.len() == *pos)
102                    .map(|(s, _)| s),
103                b.log_prob,
104                b.rules.into_rule_pool(rules),
105            ))
106        } else {
107            None
108        }
109    }
110}
111
112#[derive(Debug, Clone, PartialEq, Eq)]
113pub(crate) struct FuzzyScan<'b, T> {
114    pub generated_sentences: Vec<PhonContent<T>>,
115    pub sentence_guides: Vec<(&'b [PhonContent<T>], usize)>,
116}
117
118impl<T: Eq + std::fmt::Debug + Clone> FuzzyScan<'_, T> {
119    pub fn yield_good_parse(
120        b: BeamWrapper<T, Self>,
121        rules: &[RuleHolder],
122    ) -> Option<(LogProb<f64>, Vec<PhonContent<T>>, RulePool)> {
123        if b.is_empty() {
124            Some((
125                b.log_prob,
126                b.beam.generated_sentences.clone(),
127                b.rules.into_rule_pool(rules),
128            ))
129        } else {
130            None
131        }
132    }
133}
134
135impl<T> Scanner<T> for FuzzyScan<'_, T>
136where
137    T: std::cmp::Eq + std::fmt::Debug + Clone,
138{
139    fn scan(&mut self, s: &Option<T>) -> bool {
140        if let Some(s) = s {
141            self.generated_sentences
142                .push(PhonContent::Normal(s.clone()));
143        }
144        self.sentence_guides
145            .retain_mut(|(sentence, position)| match s {
146                Some(s) => {
147                    if let Some(PhonContent::Normal(string)) = sentence.get(*position) {
148                        if s == string {
149                            *position += 1;
150                            true
151                        } else {
152                            false
153                        }
154                    } else {
155                        false
156                    }
157                }
158                None => true,
159            });
160        true
161    }
162
163    fn multiscan(&mut self, mut heads: Vec<&T>) -> bool {
164        self.sentence_guides.retain_mut(|(sentence, position)| {
165            if let Some(s) = sentence.get(*position) {
166                match s {
167                    PhonContent::Normal(s) => {
168                        if heads.len() == 1 && heads.first().unwrap() == &s {
169                            *position += 1;
170                            true
171                        } else {
172                            false
173                        }
174                    }
175                    PhonContent::Affixed(string) => {
176                        if heads.len() == string.len()
177                            && heads.iter().zip(string.iter()).all(|(a, b)| *a == b)
178                        {
179                            *position += 1;
180                            true
181                        } else {
182                            false
183                        }
184                    }
185                }
186            } else {
187                false
188            }
189        });
190        if !heads.is_empty() {
191            if heads.len() == 1 {
192                self.generated_sentences
193                    .push(PhonContent::Normal(heads.pop().unwrap().clone()));
194            } else {
195                self.generated_sentences
196                    .push(PhonContent::Affixed(heads.into_iter().cloned().collect()));
197            }
198        }
199        true
200    }
201}
202
203#[derive(Debug, Clone, Eq, PartialEq)]
204pub(crate) struct GeneratorScan<T> {
205    pub sentence: Vec<PhonContent<T>>,
206}
207
208impl<T: Clone> Scanner<T> for GeneratorScan<T>
209where
210    T: std::cmp::Eq + std::fmt::Debug,
211{
212    fn scan(&mut self, s: &Option<T>) -> bool {
213        if let Some(s) = s {
214            //If the word was None then adding it does nothing
215            self.sentence.push(PhonContent::Normal(s.clone()));
216        }
217        true
218    }
219
220    fn multiscan(&mut self, mut heads: Vec<&T>) -> bool {
221        if !heads.is_empty() {
222            if heads.len() == 1 {
223                self.sentence
224                    .push(PhonContent::Normal(heads.pop().unwrap().clone()));
225            } else {
226                self.sentence
227                    .push(PhonContent::Affixed(heads.into_iter().cloned().collect()));
228            }
229        }
230        true
231    }
232}
233
234impl<T: Eq + std::fmt::Debug + Clone> GeneratorScan<T> {
235    pub(crate) fn yield_good_parse(
236        b: BeamWrapper<T, Self>,
237        rules: &[RuleHolder],
238    ) -> Option<(LogProb<f64>, Vec<PhonContent<T>>, RulePool)> {
239        if b.is_empty() {
240            Some((
241                b.log_prob,
242                b.beam.sentence.clone(),
243                b.rules.into_rule_pool(rules),
244            ))
245        } else {
246            None
247        }
248    }
249}
250
251#[derive(Debug, PartialEq, Eq, Clone)]
252struct ContinuationScan<'a, T> {
253    prefix: &'a [PhonContent<T>],
254    position: usize,
255    final_char: Option<Continuation<T>>,
256}
257
258impl<T> Scanner<T> for ContinuationScan<'_, T>
259where
260    T: std::cmp::Eq + std::fmt::Debug + Clone,
261{
262    fn scan(&mut self, word: &Option<T>) -> bool {
263        match word {
264            Some(word) => {
265                if let Some(string) = self.prefix.get(self.position) {
266                    if let PhonContent::Normal(string) = string
267                        && string == word
268                    {
269                        self.position += 1;
270                        true
271                    } else {
272                        false
273                    }
274                } else if self.position == self.prefix.len() {
275                    self.final_char = Some(Continuation::Word(word.clone()));
276                    self.position += 1;
277                    true
278                } else {
279                    self.position += 1;
280                    true
281                }
282            }
283            None => true,
284        }
285    }
286
287    fn multiscan(&mut self, heads: Vec<&T>) -> bool {
288        if heads.is_empty() {
289            return true;
290        }
291
292        if let Some(s) = self.prefix.get(self.position) {
293            match s {
294                PhonContent::Normal(s) => {
295                    if heads.len() == 1 && heads.first().unwrap() == &s {
296                        self.position += 1;
297                        true
298                    } else {
299                        false
300                    }
301                }
302                PhonContent::Affixed(string) => {
303                    if heads.len() == string.len()
304                        && heads.iter().zip(string.iter()).all(|(a, b)| *a == b)
305                    {
306                        self.position += 1;
307                        true
308                    } else {
309                        false
310                    }
311                }
312            }
313        } else if self.position == self.prefix.len() {
314            self.final_char = Some(Continuation::AffixedWord(
315                heads.into_iter().cloned().collect(),
316            ));
317            self.position += 1;
318            true
319        } else {
320            self.position += 1;
321            true
322        }
323    }
324}
325
326impl<T: Eq + Debug + Clone> ContinuationScan<'_, T> {
327    pub fn yield_good_parse(b: BeamWrapper<T, Self>) -> Option<Continuation<T>> {
328        if b.is_empty() {
329            match b.beam.final_char {
330                Some(x) => Some(x),
331                None if b.beam.position == b.beam.prefix.len() => Some(Continuation::EndOfSentence),
332                None => None,
333            }
334        } else {
335            None
336        }
337    }
338}
339
340///Enum that describes a possible token of a grammar
341#[derive(Debug, Clone, Eq, PartialEq, Hash)]
342pub enum Continuation<T> {
343    ///The following word is a valid token given the prefix in [`Lexicon::valid_continuations`].
344    Word(T),
345    ///The following affxied word is a valid token given the prefix in [`Lexicon::valid_continuations`].
346    AffixedWord(Vec<T>),
347    ///Has the sentence ended
348    EndOfSentence,
349}
350
351impl<T, C> Lexicon<T, C>
352where
353    T: Eq + std::fmt::Debug + Clone + Hash,
354    C: Eq + Clone + std::fmt::Debug + Hash,
355{
356    ///Given a grammar and a prefix string, return a [`HashSet`] of the possible [`Continuation`]s (i.e. next words) that are valid
357    ///in the grammar.
358    ///Returns an [`ParsingError`] if there is no node with the category of `initial_category`.
359    ///
360    ///```
361    ///# use minimalist_grammar_parser::{ParsingConfig, Lexicon, PhonContent};
362    ///# use ahash::HashSet;
363    ///# use minimalist_grammar_parser::parsing::beam::Continuation;
364    ///
365    ///let lex = Lexicon::from_string("a::S= b= S\n::S\nb::b")?;
366    ///let continuations = lex.valid_continuations("S", &PhonContent::from(["a"]), &ParsingConfig::default())?;
367    ///assert_eq!(continuations, HashSet::from_iter([Continuation::Word("b"), Continuation::Word("a")].into_iter()));
368    ///let continuations = lex.valid_continuations("S", &PhonContent::from(["a", "b"]), &ParsingConfig::default())?;
369    ///assert_eq!(continuations, HashSet::from_iter([Continuation::EndOfSentence]));
370    ///# Ok::<(), anyhow::Error>(())
371    /// ```
372    pub fn valid_continuations(
373        &self,
374        initial_category: C,
375        prefix: &[PhonContent<T>],
376        config: &ParsingConfig,
377    ) -> Result<HashSet<Continuation<T>>, ParsingError<C>> {
378        let cat = self.find_category(&initial_category)?;
379
380        let cont = ContinuationScan {
381            prefix,
382            position: 0,
383            final_char: None,
384        };
385
386        let mut valid_chars: HashSet<Continuation<T>> = HashSet::default();
387
388        let mut parse_heap = ParseHeap::new(BeamWrapper::new(cont, cat), config, cat);
389
390        while let Some(mut beam) = parse_heap.pop() {
391            if let Some(word) = beam.beam.final_char.as_ref()
392                && valid_chars.contains(word)
393            {
394                //We don't care since there's already a successful parse with that character.
395                continue;
396            }
397
398            if let Some(moment) = beam.pop_moment() {
399                expand(&mut parse_heap, moment, beam, self, config);
400            } else if let Some(cont) = ContinuationScan::yield_good_parse(beam) {
401                valid_chars.insert(cont);
402            }
403        }
404        Ok(valid_chars)
405    }
406}
407
408#[cfg(test)]
409mod test {
410    use crate::{
411        ParsingConfig, PhonContent,
412        grammars::{DYCK_LANGUAGE, STABLER2011},
413        lexicon::Lexicon,
414        parsing::beam::Continuation,
415    };
416
417    #[test]
418    fn continuations() -> anyhow::Result<()> {
419        let lex = Lexicon::from_string(STABLER2011)?;
420
421        let strings = [
422            "the",
423            "the king",
424            "which",
425            "which king",
426            "the king knows",
427            "the king drinks the beer",
428        ]
429        .map(|x| x.split(" ").collect::<Vec<_>>());
430
431        let continuations = [
432            vec![
433                Continuation::Word("king"),
434                Continuation::Word("beer"),
435                Continuation::Word("wine"),
436                Continuation::Word("queen"),
437            ],
438            vec![
439                Continuation::Word("knows"),
440                Continuation::Word("says"),
441                Continuation::Word("drinks"),
442                Continuation::Word("prefers"),
443            ],
444            vec![
445                Continuation::Word("wine"),
446                Continuation::Word("king"),
447                Continuation::Word("beer"),
448                Continuation::Word("queen"),
449            ],
450            vec![
451                Continuation::Word("drinks"),
452                Continuation::Word("knows"),
453                Continuation::Word("the"),
454                Continuation::Word("says"),
455                Continuation::Word("prefers"),
456            ],
457            vec![Continuation::Word("which"), Continuation::Word("the")],
458            vec![Continuation::EndOfSentence],
459        ]
460        .into_iter()
461        .map(|x| x.into_iter().collect());
462
463        for (s, valid) in strings.into_iter().map(PhonContent::new).zip(continuations) {
464            let cont = lex.valid_continuations("C", &s, &ParsingConfig::default())?;
465            assert_eq!(cont, valid);
466        }
467        let lex = Lexicon::from_string(DYCK_LANGUAGE)?;
468
469        let strings = ["(", "( )", "( ( )", "( ( ) )", "( ) ( )", "( ( ( ) )"]
470            .map(|x| x.split(" ").collect::<Vec<_>>());
471
472        let continuations = [
473            vec![Continuation::Word(")"), Continuation::Word("(")],
474            vec![Continuation::Word("("), Continuation::EndOfSentence],
475            vec![Continuation::Word(")"), Continuation::Word("(")],
476            vec![Continuation::Word("("), Continuation::EndOfSentence],
477            vec![Continuation::Word("("), Continuation::EndOfSentence],
478            vec![Continuation::Word(")"), Continuation::Word("(")],
479        ]
480        .into_iter()
481        .map(|x| x.into_iter().collect());
482
483        for (s, valid) in strings.into_iter().map(PhonContent::new).zip(continuations) {
484            let cont = lex.valid_continuations("S", &s, &ParsingConfig::default())?;
485            assert_eq!(cont, valid);
486        }
487
488        let lex = Lexicon::from_string("a::S= b= S\n::S\nb::b")?;
489
490        let mut strings: Vec<_> = ["a", "a b", "a a b", "a a b b"]
491            .iter()
492            .map(|x| x.split(" ").collect::<Vec<_>>())
493            .collect();
494        strings.push(vec![]);
495
496        let continuations = [
497            vec![Continuation::Word("b"), Continuation::Word("a")],
498            vec![Continuation::EndOfSentence],
499            vec![Continuation::Word("b")],
500            vec![Continuation::EndOfSentence],
501            vec![Continuation::Word("a"), Continuation::EndOfSentence],
502        ]
503        .into_iter()
504        .map(|x| x.into_iter().collect());
505
506        for (s, valid) in strings.into_iter().map(PhonContent::new).zip(continuations) {
507            let cont = lex.valid_continuations("S", &s, &ParsingConfig::default())?;
508            assert_eq!(cont, valid);
509        }
510
511        let lexicon = "::T<= +q Q
512what::d[in] -subj3 -q -wh
513what::d[in] -acc -wh
514who::d[an] -subj3 -q -wh
515who::d[an] -acc -wh
516::T<= +q +wh Q
517::q -q
518does::V= q= +subj3 T
519do::V= q= +subj2 T
520do::V= q= +subj1 T
521did::V= q= +subj3 T
522did::V= q= +subj2 T
523did::V= q= +subj1 T
524::q -q
525to::theme[an]= p
526talk::p= v
527see::d[an]= +acc v
528see::d[in]= +acc v
529devour::d[in]= +acc v
530want::d[in]= +acc v
531run::v
532you::d[an] -subj2
533you::d[an] -acc
534I::d[an] -subj1
535me::d[an] -acc
536he::d[an] -subj3
537him::d[an] -acc
538she::d[an] -subj3
539her::d[an] -acc
540::d[an]= +theme theme[an]
541that::C= +r +rel[in] d[in] -acc
542that::C= +r +rel[in] d[in] -subj3
543who::C= +r +rel[an] d[an] -acc
544who::C= +r +rel[an] d[an] -subj3
545::=>v =d[an] V
546man::N[an]
547woman::N[an]
548cake::N[in]
549John::d[an] -subj3
550John::d[an] -acc
551Mary::d[an] -subj3
552Mary::d[an] -acc
553the::N[in]= d[in] -theme
554the::N[in]= d[in] -subj3
555the::N[in]= d[in] -acc
556the::N[in]= d[in] -acc -rel[in]
557the::N[in]= d[in] -subj3 -rel[in]
558the::N[an]= d[an] -theme
559the::N[an]= d[an] -subj3
560the::N[an]= d[an] -acc
561the::N[an]= d[an] -acc -rel[an]
562the::N[an]= d[an] -subj3 -rel[an]
563a::N[in]= d[in] -theme
564a::N[in]= d[in] -subj3
565a::N[in]= d[in] -acc
566a::N[in]= d[in] -acc -rel[in]
567a::N[in]= d[in] -subj3 -rel[in]
568a::N[an]= d[an] -theme
569a::N[an]= d[an] -subj3
570a::N[an]= d[an] -acc
571a::N[an]= d[an] -acc -rel[an]
572a::N[an]= d[an] -subj3 -rel[an]
573can::V= +subj3 T
574can::V= +subj2 T
575can::V= +subj1 T
576can::V= q= +subj3 T
577can::V= q= +subj2 T
578can::V= q= +subj1 T
579can::V= r= +subj3 T
580can::V= r= +subj2 T
581can::V= r= +subj1 T
582am::prog= +subj1 T
583are::prog= +subj2 T
584is::prog= +subj3 T
585am::prog= q= +subj1 T
586are::prog= q= +subj2 T
587is::prog= q= +subj3 T
588am::prog= r= +subj1 T
589are::prog= r= +subj2 T
590is::prog= r= +subj3 T
591ing::=>V prog
592PAST::=>V +subj3 t
593PAST::=>V +subj2 t
594PAST::=>V +subj1 t
595::T= C
596::t= T
597::t= r= T
598::r -r
5993PRES::=>V +subj3 t
6002PRES::=>V +subj2 t
6011PRES::=>V +subj1 t
602";
603
604        let lexicon = Lexicon::from_string(lexicon)?;
605
606        assert!(
607            lexicon
608                .parse(
609                    &[
610                        PhonContent::Normal("I"),
611                        PhonContent::Normal("can"),
612                        PhonContent::Normal("see"),
613                        PhonContent::Normal("a"),
614                        PhonContent::Normal("woman"),
615                        PhonContent::Normal("who"),
616                        PhonContent::Normal("a"),
617                        PhonContent::Normal("man"),
618                        PhonContent::Affixed(vec!["see", "3PRES"]),
619                    ],
620                    "C",
621                    &ParsingConfig::empty().with_max_steps(50),
622                )?
623                .next()
624                .is_some()
625        );
626        assert_eq!(
627            lexicon.valid_continuations(
628                "C",
629                &[
630                    PhonContent::Normal("I"),
631                    PhonContent::Normal("can"),
632                    PhonContent::Normal("see"),
633                    PhonContent::Normal("a"),
634                    PhonContent::Normal("woman"),
635                    PhonContent::Normal("who"),
636                    PhonContent::Normal("a"),
637                    PhonContent::Normal("man"),
638                    PhonContent::Affixed(vec!["see", "3PRES"]),
639                ],
640                &ParsingConfig::empty().with_max_steps(50)
641            )?,
642            [Continuation::EndOfSentence].into_iter().collect()
643        );
644        Ok(())
645    }
646}