minimalist_grammar_parser/lexicon/
mdl.rs

1//! Module which defines helper functions to calculate the MDL of MGs according to Ermolaeva, 2021
2//!
3//! - Ermolaeva, M. (2021). Learning Syntax via Decomposition [The University of Chicago]. <https://doi.org/10.6082/uchicago.3015>
4
5use super::{Feature, FeatureOrLemma, Lexicon, LexiconError};
6use ahash::AHashSet;
7use std::hash::Hash;
8
9///Defines the cost of a type of lemma
10pub trait SymbolCost: Sized {
11    ///This is the length of a member where each sub-unit has `n_phonemes` possible encodings.
12    /// # Example
13    /// Let Phon = $\{a,b,c\}$ so, `n_phonemes` should be 3 (passed to [``Lexicon::mdl_score``].
14    /// The string, abcabc should have symbol_cost 6.
15    fn symbol_cost(x: &Option<Self>) -> u16;
16}
17
18impl SymbolCost for &str {
19    fn symbol_cost(x: &Option<Self>) -> u16 {
20        match x {
21            Some(x) => x.len().try_into().unwrap(),
22            None => 0,
23        }
24    }
25}
26
27impl SymbolCost for String {
28    fn symbol_cost(x: &Option<Self>) -> u16 {
29        match x {
30            Some(x) => x.len().try_into().unwrap(),
31            None => 0,
32        }
33    }
34}
35
36impl SymbolCost for char {
37    fn symbol_cost(x: &Option<Self>) -> u16 {
38        match x {
39            Some(_) => 1,
40            None => 0,
41        }
42    }
43}
44
45impl SymbolCost for u8 {
46    fn symbol_cost(x: &Option<Self>) -> u16 {
47        match x {
48            Some(_) => 1,
49            None => 0,
50        }
51    }
52}
53
54///Number of types of features, e.g. the space of possible features in [``Feature``] enum.
55///Here it is six to account for left and right attachment.
56const MG_TYPES: u16 = 6;
57
58impl<T: Eq + std::fmt::Debug + Clone + SymbolCost, Category: Eq + std::fmt::Debug + Clone + Hash>
59    Lexicon<T, Category>
60{
61    /// Returns the MDL score of a lexicon assuming the number of phonemes is fixed.
62    pub fn mdl_score_fixed_category_size(
63        &self,
64        n_phonemes: u16,
65        n_categories: u16,
66    ) -> Result<f64, LexiconError> {
67        self.mdl_inner(n_phonemes, Some(n_categories))
68    }
69
70    ///Returns the MDL Score of the lexicon
71    ///
72    /// # Arguments
73    /// * `n_phonemes` - The size of required to encode a symbol of the phonology. e.g in English orthography, it would be 26.
74    pub fn mdl_score(&self, n_phonemes: u16) -> Result<f64, LexiconError> {
75        self.mdl_inner(n_phonemes, None)
76    }
77
78    fn mdl_inner(&self, n_phonemes: u16, n_categories: Option<u16>) -> Result<f64, LexiconError> {
79        let mut category_symbols = AHashSet::new();
80        let mut lexemes: Vec<(f64, f64)> = Vec::with_capacity(self.leaves.len());
81
82        for leaf in self.leaves.iter() {
83            if let FeatureOrLemma::Lemma(lemma) = &self.graph[leaf.0] {
84                let n_phonemes = T::symbol_cost(lemma);
85
86                let mut nx = leaf.0;
87                let mut n_features = 0;
88                while let Some(parent) = self.parent_of(nx) {
89                    if parent == self.root {
90                        break;
91                    } else if let FeatureOrLemma::Feature(f) = &self.graph[parent] {
92                        category_symbols.insert(match f {
93                            Feature::Category(c)
94                            | Feature::Licensor(c)
95                            | Feature::Licensee(c)
96                            | Feature::Affix(c, _)
97                            | Feature::Selector(c, _) => c,
98                        });
99                        n_features += 1;
100                    } else if let FeatureOrLemma::Complement(c, _d) = &self.graph[parent] {
101                        category_symbols.insert(c);
102                        n_features += 1;
103                    }
104                    nx = parent;
105                }
106                lexemes.push((n_phonemes.into(), n_features.into()));
107            } else {
108                return Err(LexiconError::MissingLexeme(*leaf));
109            }
110        }
111        let n_categories: u16 =
112            n_categories.unwrap_or_else(|| category_symbols.len().try_into().unwrap());
113
114        let bits_per_feature: f64 = (MG_TYPES * n_categories).into();
115        let bits_per_feature = bits_per_feature.ln();
116        let bits_per_phoneme: f64 = (Into::<f64>::into(n_phonemes)).ln();
117
118        Ok(lexemes
119            .into_iter()
120            .map(|(n_phonemes, n_categories)| {
121                n_phonemes * bits_per_phoneme + bits_per_feature * n_categories
122            })
123            .sum())
124    }
125}
126
127#[cfg(test)]
128mod test {
129    use super::*;
130    use approx::assert_relative_eq;
131
132    #[test]
133    fn mdl_score_test() -> anyhow::Result<()> {
134        let ga: &str = "mary::d -k
135laughs::=d +k t
136laughed::=d +k t
137jumps::=d +k t
138jumped::=d +k t";
139        let gb: &str = "mary::d -k
140laugh::=d v
141jump::=d v
142s::=v +k t
143ed::=v +k t";
144        for (g, n_categories, string_size, feature_size) in
145            [(ga, 3, 28_f64, 14_f64), (gb, 4, 16_f64, 12_f64)]
146        {
147            let lex = Lexicon::from_string(g)?;
148            for alphabet_size in [26, 32, 37] {
149                let bits_per_symbol: f64 = (MG_TYPES * n_categories).into();
150                let bits_per_phoneme: f64 = alphabet_size.into();
151                assert_relative_eq!(
152                    lex.mdl_score(alphabet_size)?,
153                    string_size * bits_per_phoneme.ln() + feature_size * bits_per_symbol.ln()
154                );
155            }
156        }
157        Ok(())
158    }
159}