minimalist_grammar_parser/parsing/
beam.rs

1//!Module which defines how the different beams used by [`Lexicon::parse`] or [`Lexicon::generate`]
2//!work.
3
4use crate::{
5    ParseHeap, ParsingConfig, PhonContent, expand,
6    lexicon::{Lexicon, ParsingError},
7    parsing::FilledHeadTree,
8};
9
10use super::{BeamWrapper, RuleHolder, rules::RulePool};
11use ahash::HashSet;
12use logprob::LogProb;
13use std::{fmt::Debug, hash::Hash};
14
15///A trait which allows a struct to be used as by the parsing algorithm by defining how scanning
16///works. Parsing checks the next string corresponds to a parse, whereas generation uses scan to
17///iteratively build strings.
18pub(crate) trait Scanner<T>: Sized {
19    fn scan(&mut self, s: &Option<T>) -> bool;
20
21    fn multiscan(&mut self, heads: &FilledHeadTree<T>) -> bool;
22}
23
24#[derive(Debug, Clone, PartialEq, Eq)]
25pub(crate) struct ParseScan<'a, T> {
26    pub sentence: Vec<(&'a [PhonContent<T>], usize)>,
27}
28
29impl<'a, T> Scanner<T> for ParseScan<'a, T>
30where
31    T: std::cmp::Eq + std::fmt::Debug,
32{
33    fn scan(&mut self, s: &Option<T>) -> bool {
34        self.sentence.retain_mut(|(sentence, position)| match s {
35            Some(s) => {
36                if let Some(PhonContent::Normal(string)) = sentence.get(*position) {
37                    if s == string {
38                        *position += 1;
39                        true
40                    } else {
41                        false
42                    }
43                } else {
44                    false
45                }
46            }
47            None => true,
48        });
49        !self.sentence.is_empty()
50    }
51
52    fn multiscan(&mut self, heads: &FilledHeadTree<T>) -> bool {
53        let heads = heads.inorder();
54        if heads.is_empty() {
55            return true;
56        }
57
58        self.sentence.retain_mut(|(sentence, position)| {
59            if let Some(s) = sentence.get(*position) {
60                match s {
61                    PhonContent::Normal(s) => {
62                        if heads.len() == 1 && heads.first().unwrap() == &s {
63                            *position += 1;
64                            true
65                        } else {
66                            false
67                        }
68                    }
69                    PhonContent::Affixed(string) => {
70                        if heads.len() == string.len()
71                            && heads.iter().zip(string.iter()).all(|(a, b)| *a == b)
72                        {
73                            *position += 1;
74                            true
75                        } else {
76                            false
77                        }
78                    }
79                }
80            } else {
81                false
82            }
83        });
84        !self.sentence.is_empty()
85    }
86}
87
88impl<'a, T: Eq + std::fmt::Debug + Clone> ParseScan<'a, T> {
89    #[allow(clippy::complexity)]
90    pub(crate) fn yield_good_parse(
91        b: BeamWrapper<T, Self>,
92        rules: &[RuleHolder],
93    ) -> Option<(
94        impl Iterator<Item = &'a [PhonContent<T>]> + 'a,
95        LogProb<f64>,
96        RulePool,
97    )> {
98        if b.is_empty() {
99            Some((
100                b.beam
101                    .sentence
102                    .into_iter()
103                    .filter(|(s, pos)| s.len() == *pos)
104                    .map(|(s, _)| s),
105                b.log_prob,
106                b.rules.into_rule_pool(rules),
107            ))
108        } else {
109            None
110        }
111    }
112}
113
114#[derive(Debug, Clone, PartialEq, Eq)]
115pub(crate) struct FuzzyScan<'b, T> {
116    pub generated_sentences: Vec<PhonContent<T>>,
117    pub sentence_guides: Vec<(&'b [PhonContent<T>], usize)>,
118}
119
120impl<'a, T: Eq + std::fmt::Debug + Clone> FuzzyScan<'a, T> {
121    pub fn yield_good_parse(
122        b: BeamWrapper<T, Self>,
123        rules: &[RuleHolder],
124    ) -> Option<(LogProb<f64>, Vec<PhonContent<T>>, RulePool)> {
125        if b.is_empty() {
126            Some((
127                b.log_prob,
128                b.beam.generated_sentences.to_vec(),
129                b.rules.into_rule_pool(rules),
130            ))
131        } else {
132            None
133        }
134    }
135}
136
137impl<T> Scanner<T> for FuzzyScan<'_, T>
138where
139    T: std::cmp::Eq + std::fmt::Debug + Clone,
140{
141    fn scan(&mut self, s: &Option<T>) -> bool {
142        if let Some(s) = s {
143            self.generated_sentences
144                .push(PhonContent::Normal(s.clone()));
145        }
146        self.sentence_guides
147            .retain_mut(|(sentence, position)| match s {
148                Some(s) => {
149                    if let Some(PhonContent::Normal(string)) = sentence.get(*position) {
150                        if s == string {
151                            *position += 1;
152                            true
153                        } else {
154                            false
155                        }
156                    } else {
157                        false
158                    }
159                }
160                None => true,
161            });
162        true
163    }
164
165    fn multiscan(&mut self, heads: &FilledHeadTree<T>) -> bool {
166        let mut heads: Vec<_> = heads.inorder().into_iter().cloned().collect();
167        self.sentence_guides.retain_mut(|(sentence, position)| {
168            if let Some(s) = sentence.get(*position) {
169                match s {
170                    PhonContent::Normal(s) => {
171                        if heads.len() == 1 && heads.first().unwrap() == s {
172                            *position += 1;
173                            true
174                        } else {
175                            false
176                        }
177                    }
178                    PhonContent::Affixed(string) => {
179                        if heads.len() == string.len()
180                            && heads.iter().zip(string.iter()).all(|(a, b)| a == b)
181                        {
182                            *position += 1;
183                            true
184                        } else {
185                            false
186                        }
187                    }
188                }
189            } else {
190                false
191            }
192        });
193        if !heads.is_empty() {
194            if heads.len() == 1 {
195                self.generated_sentences
196                    .push(PhonContent::Normal(heads.pop().unwrap()));
197            } else {
198                self.generated_sentences.push(PhonContent::Affixed(heads));
199            }
200        }
201        true
202    }
203}
204
205#[derive(Debug, Clone, Eq, PartialEq)]
206pub(crate) struct GeneratorScan<T> {
207    pub sentence: Vec<PhonContent<T>>,
208}
209
210impl<T: Clone> Scanner<T> for GeneratorScan<T>
211where
212    T: std::cmp::Eq + std::fmt::Debug,
213{
214    fn scan(&mut self, s: &Option<T>) -> bool {
215        if let Some(s) = s {
216            //If the word was None then adding it does nothing
217            self.sentence.push(PhonContent::Normal(s.clone()));
218        }
219        true
220    }
221
222    fn multiscan(&mut self, heads: &FilledHeadTree<T>) -> bool {
223        let mut heads: Vec<_> = heads.inorder().into_iter().cloned().collect();
224        if !heads.is_empty() {
225            if heads.len() == 1 {
226                self.sentence
227                    .push(PhonContent::Normal(heads.pop().unwrap()));
228            } else {
229                self.sentence.push(PhonContent::Affixed(heads));
230            }
231        }
232        true
233    }
234}
235
236impl<T: Eq + std::fmt::Debug + Clone> GeneratorScan<T> {
237    pub(crate) fn yield_good_parse(
238        b: BeamWrapper<T, Self>,
239        rules: &[RuleHolder],
240    ) -> Option<(LogProb<f64>, Vec<PhonContent<T>>, RulePool)> {
241        if b.is_empty() {
242            Some((
243                b.log_prob,
244                b.beam.sentence.to_vec(),
245                b.rules.into_rule_pool(rules),
246            ))
247        } else {
248            None
249        }
250    }
251}
252
253#[derive(Debug, PartialEq, Eq, Clone)]
254struct ContinuationScan<'a, T> {
255    prefix: &'a [PhonContent<T>],
256    position: usize,
257    final_char: Option<Continuation<T>>,
258}
259
260impl<T> Scanner<T> for ContinuationScan<'_, T>
261where
262    T: std::cmp::Eq + std::fmt::Debug + Clone,
263{
264    fn scan(&mut self, word: &Option<T>) -> bool {
265        match word {
266            Some(word) => {
267                if let Some(string) = self.prefix.get(self.position) {
268                    if let PhonContent::Normal(string) = string
269                        && string == word
270                    {
271                        self.position += 1;
272                        true
273                    } else {
274                        false
275                    }
276                } else if self.position == self.prefix.len() {
277                    self.final_char = Some(Continuation::Word(word.clone()));
278                    self.position += 1;
279                    true
280                } else {
281                    self.position += 1;
282                    true
283                }
284            }
285            None => true,
286        }
287    }
288
289    fn multiscan(&mut self, heads: &FilledHeadTree<T>) -> bool {
290        let heads = heads.inorder();
291        if heads.is_empty() {
292            return true;
293        }
294
295        if let Some(s) = self.prefix.get(self.position) {
296            match s {
297                PhonContent::Normal(s) => {
298                    if heads.len() == 1 && heads.first().unwrap() == &s {
299                        self.position += 1;
300                        true
301                    } else {
302                        false
303                    }
304                }
305                PhonContent::Affixed(string) => {
306                    if heads.len() == string.len()
307                        && heads.iter().zip(string.iter()).all(|(a, b)| *a == b)
308                    {
309                        self.position += 1;
310                        true
311                    } else {
312                        false
313                    }
314                }
315            }
316        } else if self.position == self.prefix.len() {
317            self.final_char = Some(Continuation::AffixedWord(
318                heads.into_iter().cloned().collect(),
319            ));
320            self.position += 1;
321            true
322        } else {
323            self.position += 1;
324            true
325        }
326    }
327}
328
329impl<'a, T: Eq + Debug + Clone> ContinuationScan<'a, T> {
330    pub fn yield_good_parse(b: BeamWrapper<T, Self>) -> Option<Continuation<T>> {
331        if b.is_empty() {
332            match b.beam.final_char {
333                Some(x) => Some(x),
334                None if b.beam.position == b.beam.prefix.len() => Some(Continuation::EndOfSentence),
335                None => None,
336            }
337        } else {
338            None
339        }
340    }
341}
342
343///Enum that describes a possible token of a grammar
344#[derive(Debug, Clone, Eq, PartialEq, Hash)]
345pub enum Continuation<T> {
346    ///The following word is a valid token given the prefix in [`Lexicon::valid_continuations`].
347    Word(T),
348    ///The following affxied word is a valid token given the prefix in [`Lexicon::valid_continuations`].
349    AffixedWord(Vec<T>),
350    ///Has the sentence ended
351    EndOfSentence,
352}
353
354impl<T, C> Lexicon<T, C>
355where
356    T: Eq + std::fmt::Debug + Clone + Hash,
357    C: Eq + Clone + std::fmt::Debug + Hash,
358{
359    ///Given a grammar and a prefix string, return a [`HashSet`] of the possible [`Continuation`]s (i.e. next words) that are valid
360    ///in the grammar.
361    ///Returns an [`ParsingError`] if there is no node with the category of `initial_category`.
362    ///
363    ///```
364    ///# use minimalist_grammar_parser::{ParsingConfig, Lexicon, PhonContent};
365    ///# use ahash::HashSet;
366    ///# use minimalist_grammar_parser::parsing::beam::Continuation;
367    ///
368    ///let lex = Lexicon::from_string("a::S= b= S\n::S\nb::b")?;
369    ///let continuations = lex.valid_continuations("S", &PhonContent::from(["a"]), &ParsingConfig::default())?;
370    ///assert_eq!(continuations, HashSet::from_iter([Continuation::Word("b"), Continuation::Word("a")].into_iter()));
371    ///let continuations = lex.valid_continuations("S", &PhonContent::from(["a", "b"]), &ParsingConfig::default())?;
372    ///assert_eq!(continuations, HashSet::from_iter([Continuation::EndOfSentence]));
373    ///# Ok::<(), anyhow::Error>(())
374    /// ```
375    pub fn valid_continuations(
376        &self,
377        initial_category: C,
378        prefix: &[PhonContent<T>],
379        config: &ParsingConfig,
380    ) -> Result<HashSet<Continuation<T>>, ParsingError<C>> {
381        let cat = self.find_category(&initial_category)?;
382
383        let cont = ContinuationScan {
384            prefix,
385            position: 0,
386            final_char: None,
387        };
388
389        let mut valid_chars: HashSet<Continuation<T>> = HashSet::default();
390
391        let mut parse_heap = ParseHeap::new(BeamWrapper::new(cont, cat), config, cat);
392
393        while let Some(mut beam) = parse_heap.pop() {
394            if let Some(word) = beam.beam.final_char.as_ref()
395                && valid_chars.contains(word)
396            {
397                //We don't care since there's already a successful parse with that character.
398                continue;
399            }
400
401            if let Some(moment) = beam.pop_moment() {
402                expand(&mut parse_heap, moment, beam, self, config);
403            } else if let Some(cont) = ContinuationScan::yield_good_parse(beam) {
404                valid_chars.insert(cont);
405            }
406        }
407        Ok(valid_chars)
408    }
409}
410
411#[cfg(test)]
412mod test {
413    use crate::{
414        ParsingConfig, PhonContent,
415        grammars::{DYCK_LANGUAGE, STABLER2011},
416        lexicon::Lexicon,
417        parsing::beam::Continuation,
418    };
419
420    #[test]
421    fn continuations() -> anyhow::Result<()> {
422        let lex = Lexicon::from_string(STABLER2011)?;
423
424        let strings = [
425            "the",
426            "the king",
427            "which",
428            "which king",
429            "the king knows",
430            "the king drinks the beer",
431        ]
432        .map(|x| x.split(" ").collect::<Vec<_>>());
433
434        let continuations = [
435            vec![
436                Continuation::Word("king"),
437                Continuation::Word("beer"),
438                Continuation::Word("wine"),
439                Continuation::Word("queen"),
440            ],
441            vec![
442                Continuation::Word("knows"),
443                Continuation::Word("says"),
444                Continuation::Word("drinks"),
445                Continuation::Word("prefers"),
446            ],
447            vec![
448                Continuation::Word("wine"),
449                Continuation::Word("king"),
450                Continuation::Word("beer"),
451                Continuation::Word("queen"),
452            ],
453            vec![
454                Continuation::Word("drinks"),
455                Continuation::Word("knows"),
456                Continuation::Word("the"),
457                Continuation::Word("says"),
458                Continuation::Word("prefers"),
459            ],
460            vec![Continuation::Word("which"), Continuation::Word("the")],
461            vec![Continuation::EndOfSentence],
462        ]
463        .into_iter()
464        .map(|x| x.into_iter().collect());
465
466        for (s, valid) in strings.into_iter().map(PhonContent::new).zip(continuations) {
467            let cont = lex.valid_continuations("C", &s, &ParsingConfig::default())?;
468            assert_eq!(cont, valid);
469        }
470        let lex = Lexicon::from_string(DYCK_LANGUAGE)?;
471
472        let strings = ["(", "( )", "( ( )", "( ( ) )", "( ) ( )", "( ( ( ) )"]
473            .map(|x| x.split(" ").collect::<Vec<_>>());
474
475        let continuations = [
476            vec![Continuation::Word(")"), Continuation::Word("(")],
477            vec![Continuation::Word("("), Continuation::EndOfSentence],
478            vec![Continuation::Word(")"), Continuation::Word("(")],
479            vec![Continuation::Word("("), Continuation::EndOfSentence],
480            vec![Continuation::Word("("), Continuation::EndOfSentence],
481            vec![Continuation::Word(")"), Continuation::Word("(")],
482        ]
483        .into_iter()
484        .map(|x| x.into_iter().collect());
485
486        for (s, valid) in strings.into_iter().map(PhonContent::new).zip(continuations) {
487            let cont = lex.valid_continuations("S", &s, &ParsingConfig::default())?;
488            assert_eq!(cont, valid);
489        }
490
491        let lex = Lexicon::from_string("a::S= b= S\n::S\nb::b")?;
492
493        let mut strings: Vec<_> = ["a", "a b", "a a b", "a a b b"]
494            .iter()
495            .map(|x| x.split(" ").collect::<Vec<_>>())
496            .collect();
497        strings.push(vec![]);
498
499        let continuations = [
500            vec![Continuation::Word("b"), Continuation::Word("a")],
501            vec![Continuation::EndOfSentence],
502            vec![Continuation::Word("b")],
503            vec![Continuation::EndOfSentence],
504            vec![Continuation::Word("a"), Continuation::EndOfSentence],
505        ]
506        .into_iter()
507        .map(|x| x.into_iter().collect());
508
509        for (s, valid) in strings.into_iter().map(PhonContent::new).zip(continuations) {
510            let cont = lex.valid_continuations("S", &s, &ParsingConfig::default())?;
511            assert_eq!(cont, valid);
512        }
513
514        let lexicon = "::T<= +q Q
515what::d[in] -subj3 -q -wh
516what::d[in] -acc -wh
517who::d[an] -subj3 -q -wh
518who::d[an] -acc -wh
519::T<= +q +wh Q
520::q -q
521does::V= q= +subj3 T
522do::V= q= +subj2 T
523do::V= q= +subj1 T
524did::V= q= +subj3 T
525did::V= q= +subj2 T
526did::V= q= +subj1 T
527::q -q
528to::theme[an]= p
529talk::p= v
530see::d[an]= +acc v
531see::d[in]= +acc v
532devour::d[in]= +acc v
533want::d[in]= +acc v
534run::v
535you::d[an] -subj2
536you::d[an] -acc
537I::d[an] -subj1
538me::d[an] -acc
539he::d[an] -subj3
540him::d[an] -acc
541she::d[an] -subj3
542her::d[an] -acc
543::d[an]= +theme theme[an]
544that::C= +r +rel[in] d[in] -acc
545that::C= +r +rel[in] d[in] -subj3
546who::C= +r +rel[an] d[an] -acc
547who::C= +r +rel[an] d[an] -subj3
548::=>v =d[an] V
549man::N[an]
550woman::N[an]
551cake::N[in]
552John::d[an] -subj3
553John::d[an] -acc
554Mary::d[an] -subj3
555Mary::d[an] -acc
556the::N[in]= d[in] -theme
557the::N[in]= d[in] -subj3
558the::N[in]= d[in] -acc
559the::N[in]= d[in] -acc -rel[in]
560the::N[in]= d[in] -subj3 -rel[in]
561the::N[an]= d[an] -theme
562the::N[an]= d[an] -subj3
563the::N[an]= d[an] -acc
564the::N[an]= d[an] -acc -rel[an]
565the::N[an]= d[an] -subj3 -rel[an]
566a::N[in]= d[in] -theme
567a::N[in]= d[in] -subj3
568a::N[in]= d[in] -acc
569a::N[in]= d[in] -acc -rel[in]
570a::N[in]= d[in] -subj3 -rel[in]
571a::N[an]= d[an] -theme
572a::N[an]= d[an] -subj3
573a::N[an]= d[an] -acc
574a::N[an]= d[an] -acc -rel[an]
575a::N[an]= d[an] -subj3 -rel[an]
576can::V= +subj3 T
577can::V= +subj2 T
578can::V= +subj1 T
579can::V= q= +subj3 T
580can::V= q= +subj2 T
581can::V= q= +subj1 T
582can::V= r= +subj3 T
583can::V= r= +subj2 T
584can::V= r= +subj1 T
585am::prog= +subj1 T
586are::prog= +subj2 T
587is::prog= +subj3 T
588am::prog= q= +subj1 T
589are::prog= q= +subj2 T
590is::prog= q= +subj3 T
591am::prog= r= +subj1 T
592are::prog= r= +subj2 T
593is::prog= r= +subj3 T
594ing::=>V prog
595PAST::=>V +subj3 t
596PAST::=>V +subj2 t
597PAST::=>V +subj1 t
598::T= C
599::t= T
600::t= r= T
601::r -r
6023PRES::=>V +subj3 t
6032PRES::=>V +subj2 t
6041PRES::=>V +subj1 t
605";
606
607        let lexicon = Lexicon::from_string(lexicon)?;
608        assert_eq!(
609            lexicon.valid_continuations(
610                "C",
611                &[
612                    PhonContent::Normal("I"),
613                    PhonContent::Normal("can"),
614                    PhonContent::Normal("see"),
615                    PhonContent::Normal("a"),
616                    PhonContent::Normal("woman"),
617                    PhonContent::Normal("who"),
618                    PhonContent::Normal("a"),
619                    PhonContent::Normal("man"),
620                    PhonContent::Affixed(vec!["see", "3PRES"]),
621                ],
622                &ParsingConfig::empty().with_max_steps(50)
623            )?,
624            [Continuation::EndOfSentence].into_iter().collect()
625        );
626        Ok(())
627    }
628}