minimalist_grammar_parser/lexicon/
mdl.rs1use super::{Feature, FeatureOrLemma, Lexicon, LexiconError};
6use ahash::AHashSet;
7use std::hash::Hash;
8
9pub trait SymbolCost: Sized {
11 fn symbol_cost(x: &Option<Self>) -> u16;
16}
17
18impl SymbolCost for &str {
19 fn symbol_cost(x: &Option<Self>) -> u16 {
20 match x {
21 Some(x) => x.len().try_into().unwrap(),
22 None => 0,
23 }
24 }
25}
26
27impl SymbolCost for String {
28 fn symbol_cost(x: &Option<Self>) -> u16 {
29 match x {
30 Some(x) => x.len().try_into().unwrap(),
31 None => 0,
32 }
33 }
34}
35
36impl SymbolCost for char {
37 fn symbol_cost(x: &Option<Self>) -> u16 {
38 match x {
39 Some(_) => 1,
40 None => 0,
41 }
42 }
43}
44
45impl SymbolCost for u8 {
46 fn symbol_cost(x: &Option<Self>) -> u16 {
47 match x {
48 Some(_) => 1,
49 None => 0,
50 }
51 }
52}
53
54const MG_TYPES: u16 = 6;
57
58impl<T: Eq + std::fmt::Debug + Clone + SymbolCost, Category: Eq + std::fmt::Debug + Clone + Hash>
59 Lexicon<T, Category>
60{
61 pub fn mdl_score_fixed_category_size(
63 &self,
64 n_phonemes: u16,
65 n_categories: u16,
66 ) -> Result<f64, LexiconError> {
67 self.mdl_inner(n_phonemes, Some(n_categories))
68 }
69
70 pub fn mdl_score(&self, n_phonemes: u16) -> Result<f64, LexiconError> {
75 self.mdl_inner(n_phonemes, None)
76 }
77
78 fn mdl_inner(&self, n_phonemes: u16, n_categories: Option<u16>) -> Result<f64, LexiconError> {
79 let mut category_symbols = AHashSet::new();
80 let mut lexemes: Vec<(f64, f64)> = Vec::with_capacity(self.leaves.len());
81
82 for leaf in self.leaves.iter() {
83 if let FeatureOrLemma::Lemma(lemma) = &self.graph[leaf.0] {
84 let n_phonemes = T::symbol_cost(lemma);
85
86 let mut nx = leaf.0;
87 let mut n_features = 0;
88 while let Some(parent) = self.parent_of(nx) {
89 if parent == self.root {
90 break;
91 } else if let FeatureOrLemma::Feature(f) = &self.graph[parent] {
92 category_symbols.insert(match f {
93 Feature::Category(c)
94 | Feature::Licensor(c)
95 | Feature::Licensee(c)
96 | Feature::Affix(c, _)
97 | Feature::Selector(c, _) => c,
98 });
99 n_features += 1;
100 } else if let FeatureOrLemma::Complement(c, _d) = &self.graph[parent] {
101 category_symbols.insert(c);
102 n_features += 1;
103 }
104 nx = parent;
105 }
106 lexemes.push((n_phonemes.into(), n_features.into()));
107 } else {
108 return Err(LexiconError::MissingLexeme(*leaf));
109 }
110 }
111 let n_categories: u16 =
112 n_categories.unwrap_or_else(|| category_symbols.len().try_into().unwrap());
113
114 let bits_per_feature: f64 = (MG_TYPES * n_categories).into();
115 let bits_per_feature = bits_per_feature.ln();
116 let bits_per_phoneme: f64 = (Into::<f64>::into(n_phonemes)).ln();
117
118 Ok(lexemes
119 .into_iter()
120 .map(|(n_phonemes, n_categories)| {
121 n_phonemes * bits_per_phoneme + bits_per_feature * n_categories
122 })
123 .sum())
124 }
125}
126
127#[cfg(test)]
128mod test {
129 use super::*;
130 use approx::assert_relative_eq;
131
132 #[test]
133 fn mdl_score_test() -> anyhow::Result<()> {
134 let ga: &str = "mary::d -k
135laughs::=d +k t
136laughed::=d +k t
137jumps::=d +k t
138jumped::=d +k t";
139 let gb: &str = "mary::d -k
140laugh::=d v
141jump::=d v
142s::=v +k t
143ed::=v +k t";
144 for (g, n_categories, string_size, feature_size) in
145 [(ga, 3, 28_f64, 14_f64), (gb, 4, 16_f64, 12_f64)]
146 {
147 let lex = Lexicon::from_string(g)?;
148 for alphabet_size in [26, 32, 37] {
149 let bits_per_symbol: f64 = (MG_TYPES * n_categories).into();
150 let bits_per_phoneme: f64 = alphabet_size.into();
151 assert_relative_eq!(
152 lex.mdl_score(alphabet_size)?,
153 string_size * bits_per_phoneme.ln() + feature_size * bits_per_symbol.ln()
154 );
155 }
156 }
157 Ok(())
158 }
159}