1use std::{collections::hash_map::Entry, f64::consts::LN_2, fmt::Debug, hash::Hash};
4
5use thiserror::Error;
6
7use crate::{
8 Direction,
9 lexicon::{LexemeId, LexicalEntry, fix_weights, fix_weights_per_node},
10};
11
12use super::{Feature, FeatureOrLemma, Lexicon};
13use ahash::{AHashMap, AHashSet, HashMap};
14use logprob::{LogProb, LogSumExp};
15use petgraph::{
16 Direction::{Incoming, Outgoing},
17 graph::NodeIndex,
18 prelude::StableDiGraph,
19 visit::EdgeRef,
20};
21use rand::{
22 Rng,
23 seq::{IndexedRandom, IteratorRandom},
24};
25use rand_distr::{Distribution, Geometric};
26
27#[derive(Debug, Clone, Copy, PartialEq, Eq)]
28enum Position {
29 PreCategory,
30 PostCategory,
31 Done,
32}
33
34impl Position {
35 fn sample<C: Eq + Clone>(
36 &mut self,
37 categories: &[C],
38 licensors: &[C],
39 config: LexicalProbConfig,
40 rng: &mut impl Rng,
41 ) -> Feature<C> {
42 match self {
43 Position::PreCategory => {
44 if licensors.is_empty() || !rng.random_bool(config.licensee_prob) {
45 *self = Position::PostCategory;
46 Feature::Category(categories.choose(rng).unwrap().clone())
47 } else {
48 Feature::Licensee(licensors.choose(rng).unwrap().clone())
49 }
50 }
51 Position::PostCategory => {
52 if licensors.is_empty() || !rng.random_bool(config.mover_prob) {
53 if rng.random_bool(config.lemma_prob) {
54 *self = Position::Done;
55 }
56 let dir = config.direction(rng);
57 let category = categories.choose(rng).unwrap().clone();
58
59 if *self == Position::Done && rng.random_bool(config.affix_prob) {
60 Feature::Affix(category, dir)
61 } else {
62 *self = Position::Done;
63 Feature::Selector(category, dir)
64 }
65 } else {
66 Feature::Licensor(licensors.choose(rng).unwrap().clone())
67 }
68 }
69 Position::Done => panic!(),
70 }
71 }
72}
73
74impl<T: Eq, Category: Eq + Clone> LexicalEntry<T, Category> {
75 fn sample(
76 categories: &[Category],
77 licensors: &[Category],
78 lemma: Option<T>,
79 config: LexicalProbConfig,
80 rng: &mut impl Rng,
81 ) -> Self {
82 let mut pos = Position::PreCategory;
83 let mut features = vec![];
84 while pos != Position::Done {
85 features.push(pos.sample(categories, licensors, config, rng));
86 }
87 features.reverse();
88
89 LexicalEntry { lemma, features }
90 }
91}
92
93#[derive(Debug)]
94struct AccessibilityChecker<'a, T: Eq, Category: Eq> {
95 stack: Vec<NodeIndex>,
96 seen: AHashSet<NodeIndex>,
97 unsatisfiable: AHashSet<NodeIndex>,
98 lex: &'a Lexicon<T, Category>,
99}
100
101impl<'a, T, C> AccessibilityChecker<'a, T, C>
102where
103 T: Eq + Debug + Clone,
104 C: Eq + Debug + Clone,
105{
106 fn new(node: NodeIndex, lex: &'a Lexicon<T, C>) -> Self {
107 Self {
108 stack: lex.graph.neighbors_directed(node, Outgoing).collect(),
109 seen: [lex.root, node].into_iter().collect(),
110 unsatisfiable: AHashSet::default(),
111 lex,
112 }
113 }
114
115 fn pop(&mut self) -> Option<NodeIndex> {
116 match self.stack.pop() {
117 Some(x) => {
118 self.seen.insert(x);
119 Some(x)
120 }
121 None => None,
122 }
123 }
124
125 fn add_direct_children(&mut self, node: NodeIndex) {
126 self.stack.extend(
127 self.lex
128 .graph
129 .neighbors_directed(node, Outgoing)
130 .filter(|x| !self.seen.contains(x)),
131 );
132 }
133
134 fn mark_unsatisfiable(&mut self, mut node: NodeIndex) {
136 self.unsatisfiable.insert(node);
137 let get_parent = |node| self.lex.graph.neighbors_directed(node, Incoming).next();
138 while let Some(parent) = get_parent(node) {
139 if self.lex.graph.neighbors_directed(node, Outgoing).count() > 1 {
141 break;
142 } else if !matches!(
143 self.lex.graph.node_weight(parent).unwrap(),
144 FeatureOrLemma::Root
145 ) {
146 self.unsatisfiable.insert(parent);
147 }
148 node = parent;
149 }
150 }
151
152 fn add_indirect_children(&mut self, node: NodeIndex) {
153 match self.lex.graph.node_weight(node).unwrap() {
154 FeatureOrLemma::Feature(Feature::Selector(c, _) | Feature::Affix(c, _)) |
155FeatureOrLemma::Complement(c, _) => match self.lex.find_category(c) {
156 Ok(x) => {
157 if !self.seen.contains(&x) {
158 self.stack.push(x);
159 }
160 }
161 Err(_) if !self.lex.has_moving_category(c) => self.mark_unsatisfiable(node),
162 Err(_) => (),
163 },
164 FeatureOrLemma::Feature(Feature::Licensor(c)) => {
165 match self.lex.find_licensee(c) {
166 Ok(x) => {
170 if !self.seen.contains(&x) {
171 self.stack.push(x);
172 }
173 }
174 Err(_) => self.mark_unsatisfiable(node),
175 }
176 }
177 FeatureOrLemma::Root | FeatureOrLemma::Lemma(_) |
178FeatureOrLemma::Feature(Feature::Licensee(_) | Feature::Category(_)) => (),
179 }
180 }
181}
182
183impl<T, C> Lexicon<T, C>
184where
185 T: Eq + Debug + Clone,
186 C: Eq + Debug + Clone,
187{
188 fn has_moving_category(&self, cat: &C) -> bool {
189 let mut stack: Vec<_> = self
190 .graph
191 .neighbors_directed(self.root, Outgoing)
192 .filter(|a| {
193 matches!(
194 self.graph.node_weight(*a).unwrap(),
195 FeatureOrLemma::Feature(Feature::Licensee(_))
196 )
197 })
198 .collect();
199 while let Some(x) = stack.pop() {
200 for x in self.graph.neighbors_directed(x, Outgoing) {
201 match &self.graph[x] {
202 FeatureOrLemma::Feature(Feature::Licensee(_)) => stack.push(x),
203 FeatureOrLemma::Feature(Feature::Category(c)) if c == cat => return true,
204 _ => (),
205 }
206 }
207 }
208 false
209 }
210
211 pub fn prune(&mut self, start: &C) {
214 loop {
215 let start = if let Ok(x) = self.find_category(start) { x } else {
216 self.graph.retain_nodes(|g, n| {
217 matches!(g.node_weight(n).unwrap(), FeatureOrLemma::Root)
218 });
219 self.leaves.clear();
220 return;
221 };
222 let mut checker = AccessibilityChecker::new(start, self);
223
224 while let Some(node) = checker.pop() {
225 checker.add_direct_children(node);
226 checker.add_indirect_children(node);
227 }
228
229 if checker.unsatisfiable.is_empty() && checker.seen.len() == self.graph.node_count()
230 {
232 break;
233 }
234 self.graph.retain_nodes(|_, n| {
235 checker.seen.contains(&n) & !checker.unsatisfiable.contains(&n)
236 });
237 }
238
239 self.leaves = self
240 .graph
241 .node_indices()
242 .filter_map(|x| {
243 if matches!(self.graph.node_weight(x).unwrap(), FeatureOrLemma::Lemma(_)) {
244 Some(LexemeId(x))
245 } else {
246 None
247 }
248 })
249 .collect();
250 }
251}
252
253pub trait FreshCategory: Sized {
256 fn fresh(categories: &[Self]) -> Self;
258}
259
260impl FreshCategory for usize {
261 fn fresh(categories: &[Self]) -> Self {
262 match categories.iter().max() {
263 Some(x) => x + 1,
264 None => 0,
265 }
266 }
267}
268
269impl FreshCategory for u8 {
270 fn fresh(categories: &[Self]) -> Self {
271 match categories.iter().max() {
272 Some(x) => x + 1,
273 None => 0,
274 }
275 }
276}
277impl FreshCategory for u16 {
278 fn fresh(categories: &[Self]) -> Self {
279 match categories.iter().max() {
280 Some(x) => x + 1,
281 None => 0,
282 }
283 }
284}
285impl FreshCategory for u32 {
286 fn fresh(categories: &[Self]) -> Self {
287 match categories.iter().max() {
288 Some(x) => x + 1,
289 None => 0,
290 }
291 }
292}
293
294impl FreshCategory for u64 {
295 fn fresh(categories: &[Self]) -> Self {
296 match categories.iter().max() {
297 Some(x) => x + 1,
298 None => 0,
299 }
300 }
301}
302
303#[derive(Error, Debug, Clone)]
304pub enum MutationError {
306 #[error("Node {0:?} is not a leaf so we can't delete it.")]
308 CantDeleteNonLeaf(LexemeId),
309
310 #[error("Node {0:?} is the only leaf so we can't delete it.")]
312 LastLeaf(LexemeId),
313}
314
315#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
316pub struct NewLexeme {
319 pub new_lexeme: LexemeId,
321 pub sibling: LexemeId,
323}
324
325pub struct LexemeDetails {
327 pub other_node: LexemeId,
329
330 pub this_node: LexemeId,
332}
333
334struct NodeDetails {
335 other_node: NodeIndex,
336 this_node: NodeIndex,
337}
338
339impl<T, C> Lexicon<T, C>
340where
341 T: Eq + Debug + Clone + Hash,
342 C: Eq + Debug + Clone + FreshCategory + Hash,
343{
344 pub fn uniform_distribution(&mut self) {
346 for e in self.graph.edge_weights_mut() {
347 *e = LogProb::prob_of_one();
348 }
349 fix_weights(&mut self.graph);
350 }
351
352 pub fn unify(&mut self, other: &Self) -> Vec<LexemeDetails> {
355 let mut new_leaves = vec![];
356 let mut stack = vec![NodeDetails {
357 other_node: other.root,
358 this_node: self.root,
359 }];
360
361 while let Some(NodeDetails {
362 other_node,
363 this_node,
364 }) = stack.pop()
365 {
366 let mut this_children: HashMap<FeatureOrLemma<T, C>, _> = self
367 .children_of(this_node)
368 .map(|v| (self.graph.node_weight(v).unwrap().clone(), v))
369 .collect();
370 for (other_child, child_prob) in other
371 .graph
372 .edges_directed(other_node, Outgoing)
373 .map(|x| (x.target(), x.weight()))
374 {
375 let weight = other.graph.node_weight(other_child).unwrap();
376
377 let this_child = match this_children.entry(weight.clone()) {
378 Entry::Occupied(occupied_entry) => *occupied_entry.get(),
379 Entry::Vacant(vacant_entry) => {
380 let this_child = self.graph.add_node(weight.clone());
381 self.graph.add_edge(this_node, this_child, *child_prob);
382 vacant_entry.insert(this_child);
383 if matches!(weight, FeatureOrLemma::Lemma(_)) {
384 self.leaves.push(LexemeId(this_child));
385 new_leaves.push(LexemeDetails {
386 other_node: LexemeId(other_child),
387 this_node: LexemeId(this_child),
388 });
389 }
390
391 this_child
392 }
393 };
394
395 stack.push(NodeDetails {
396 other_node: other_child,
397 this_node: this_child,
398 });
399 }
400 }
401 fix_weights(&mut self.graph);
402 new_leaves
403 }
404
405 pub fn add_new_lexeme_from_sibling(
408 &mut self,
409 lemma: T,
410 rng: &mut impl Rng,
411 ) -> Option<NewLexeme> {
412 if let Some(&leaf) = self
413 .leaves
414 .iter()
415 .filter(|&&x| matches!(self.graph.node_weight(x.0).unwrap(), FeatureOrLemma::Lemma(Some(s)) if s!=&lemma))
416 .choose(rng)
417 {
418 let parent = self.parent_of(leaf.0).unwrap();
419 let node = self.graph.add_node(FeatureOrLemma::Lemma(Some(lemma)));
420 self.graph.add_edge(parent, node, LogProb::prob_of_one());
421 fix_weights_per_node(&mut self.graph, parent);
422 self.leaves.push(LexemeId(node));
423 Some(NewLexeme { new_lexeme: LexemeId(node), sibling: leaf})
424 }else{
425 None
426 }
427 }
428
429 pub fn add_new_lexeme(
431 &mut self,
432 lemma: Option<T>,
433 config: Option<LexicalProbConfig>,
434 rng: &mut impl Rng,
435 ) -> Option<LexemeId> {
436 let categories: Vec<_> = self.categories().cloned().collect();
437 let licensors: Vec<_> = self.licensor_types().cloned().collect();
438 let config = config.unwrap_or_default();
439
440 let x = LexicalEntry::sample(&categories, &licensors, lemma, config, rng);
441 self.add_lexical_entry(x)
442 }
443
444 pub fn delete_lexeme(&mut self, lexeme: LexemeId) -> Result<(), MutationError> {
447 if !self.leaves.contains(&lexeme) {
448 return Err(MutationError::CantDeleteNonLeaf(lexeme));
449 }
450 if self.leaves.len() == 1 {
451 return Err(MutationError::LastLeaf(lexeme));
452 }
453
454 let mut next_node = Some(lexeme.0);
455 while let Some(node) = next_node {
456 if self.n_children(node) == 0 {
457 next_node = self.parent_of(node);
458 self.graph.remove_node(node);
459 } else {
460 fix_weights_per_node(&mut self.graph, node);
461 next_node = None;
462 }
463 }
464 self.leaves.retain(|x| x != &lexeme);
465
466 Ok(())
467 }
468
469 pub fn delete_from_node(&mut self, rng: &mut impl Rng) {
471 if let Some(&node) = self
472 .graph
473 .node_indices()
474 .filter(|&nx| if self.graph.node_weight(nx).unwrap() == &FeatureOrLemma::Root { false } else {
475 let parent = self.parent_of(nx).unwrap();
476 self.children_of(parent).count() > 1
478 })
479 .collect::<Vec<_>>()
480 .choose(rng)
481 {
482 let parent = self.parent_of(node).unwrap();
483 let mut stack = vec![node];
484 while let Some(nx) = stack.pop() {
485 stack.extend(self.children_of(nx));
486 self.graph.remove_node(nx);
487 }
488 fix_weights_per_node(&mut self.graph, parent);
489 self.leaves.retain(|&x| self.graph.contains_node(x.0));
490 }
491 }
492
493 pub fn delete_node(&mut self, rng: &mut impl Rng) -> Option<LexemeId> {
495 if let Some(&node) = self
496 .graph
497 .node_indices()
498 .filter(|&nx| match self.graph.node_weight(nx).unwrap() {
499 FeatureOrLemma::Root | FeatureOrLemma::Feature(Feature::Category(_)) => false,
500 FeatureOrLemma::Lemma(_) => {
501 let parent = self.parent_of(nx).unwrap();
502 self.graph.edges_directed(parent, Outgoing).count() > 1
504 }
505 FeatureOrLemma::Complement(..) | FeatureOrLemma::Feature(Feature::Affix(..)) => {
506 let parent = self.parent_of(nx).unwrap();
507 !matches!(
509 self.graph[parent],
510 FeatureOrLemma::Feature(Feature::Licensor(_))
511 )
512 }
513 _ => true,
514 })
515 .collect::<Vec<_>>()
516 .choose(rng)
517 {
518 let e = self.graph.edges_directed(node, Incoming).next().unwrap();
519 let parent = e.source();
520 let w = *e.weight();
521
522 let edges = self
523 .graph
524 .edges_directed(node, Outgoing)
525 .map(|e| (e.target(), w + e.weight()))
526 .collect::<Vec<_>>();
527
528 if !edges.is_empty() {
529 if matches!(
530 self.graph.node_weight(parent).unwrap(),
531 FeatureOrLemma::Feature(Feature::Selector(_, _))
532 ) {
533 let mut complement_edges = vec![];
535 let mut selector_edges = vec![];
536
537 for (child, w) in edges {
538 if matches!(
539 self.graph.node_weight(child).unwrap(),
540 FeatureOrLemma::Lemma(_)
541 ) {
542 complement_edges.push((child, w));
543 } else {
544 selector_edges.push((child, w));
545 }
546 }
547 if complement_edges.is_empty() {
548 for (child, w) in selector_edges {
549 self.graph.add_edge(parent, child, w);
550 }
551 } else if selector_edges.is_empty()
552 && self.graph.edges_directed(parent, Outgoing).count() == 1
553 {
554 for (child, w) in complement_edges {
555 self.graph.add_edge(parent, child, w);
556 }
557 self.graph[parent].to_complement_mut();
558 } else {
559 let mut f = self.graph[parent].clone();
560 f.to_complement_mut();
561 let alt_parent = self.graph.add_node(f);
562 let grand_parent = self.parent_of(parent).unwrap();
563 let parent_e = self
564 .graph
565 .edges_directed(grand_parent, Incoming)
566 .next()
567 .unwrap()
568 .id();
569
570 self.graph[parent_e] += LogProb::new(-LN_2).unwrap();
571 self.graph
572 .add_edge(grand_parent, alt_parent, self.graph[parent_e]);
573
574 for (child, w) in selector_edges {
575 self.graph.add_edge(parent, child, w);
576 }
577 for (child, w) in complement_edges {
578 self.graph.add_edge(alt_parent, child, w);
579 }
580 }
581 } else {
582 for (child, w) in edges {
583 self.graph.add_edge(parent, child, w);
584 }
585 }
586 }
587
588 let x = if matches!(self.graph[node], FeatureOrLemma::Lemma(_)) {
589 self.leaves.retain(|&a| a.0 != node);
590 Some(LexemeId(node))
591 } else {
592 None
593 };
594 self.graph.remove_node(node);
595 self.clean_up();
596 x
597 } else {
598 None
599 }
600 }
601
602 pub fn random(
604 base_category: &C,
605 lemmas: &[T],
606 config: Option<LexicalProbConfig>,
607 rng: &mut impl Rng,
608 ) -> Self {
609 let mut graph = StableDiGraph::new();
610 let root = graph.add_node(FeatureOrLemma::Root);
611 let node = graph.add_node(FeatureOrLemma::Feature(Feature::Category(
612 base_category.clone(),
613 )));
614 graph.add_edge(root, node, LogProb::prob_of_one());
615
616 let mut lexicon = Lexicon {
617 graph,
618 root,
619 leaves: vec![],
620 };
621 let config = config.unwrap_or_default();
622 let mut probs = LexicalProbs::from_lexicon(&mut lexicon, lemmas, &config);
623 probs.descend_from(node, rng);
624 probs.add_novel_branches(rng);
625 lexicon.clean_up();
626 lexicon
627 }
628
629 pub fn change_feature(
631 &mut self,
632 lemmas: &[T],
633 config: Option<LexicalProbConfig>,
634 rng: &mut impl Rng,
635 ) {
636 let config = config.unwrap_or_default();
637 if let Some(&node) = self
638 .graph
639 .node_indices()
640 .filter(|nx| !matches!(self.graph.node_weight(*nx).unwrap(), FeatureOrLemma::Root))
641 .collect::<Vec<_>>()
642 .choose(rng)
643 {
644 let mut probs = LexicalProbs::from_lexicon(self, lemmas, &config);
645 probs.set_node(node, rng);
646 self.clean_up();
647 }
648 }
649
650 pub fn resample_below_node(
652 &mut self,
653 lemmas: &[T],
654 config: Option<LexicalProbConfig>,
655 rng: &mut impl Rng,
656 ) {
657 let config = config.unwrap_or_default();
658 if let Some(&node) = self
659 .graph
660 .node_indices()
661 .filter(|nx| {
662 !matches!(
663 self.graph.node_weight(*nx).unwrap(),
664 FeatureOrLemma::Root | FeatureOrLemma::Lemma(_)
665 )
666 })
667 .collect::<Vec<_>>()
668 .choose(rng)
669 {
670 let mut children: Vec<_> = self.children_of(node).collect();
671
672 let mut probs = LexicalProbs::from_lexicon(self, lemmas, &config);
673 probs.descend_from(node, rng);
674 probs.add_novel_branches(rng);
675 while let Some(child) = children.pop() {
676 children.extend(
677 self.graph
678 .edges_directed(child, Outgoing)
679 .map(|x| x.target()),
680 );
681 if matches!(
682 self.graph.node_weight(child).unwrap(),
683 FeatureOrLemma::Lemma(_)
684 ) {
685 self.leaves.retain(|&x| x != LexemeId(child));
686 }
687 self.graph.remove_node(child);
688 }
689 self.clean_up();
690 }
691 }
692
693 fn clean_up(&mut self) {
694 fix_weights(&mut self.graph);
695
696 let mut stack = vec![self.root];
697
698 while let Some(n) = stack.pop() {
699 let mut features: AHashMap<_, Vec<_>> = AHashMap::default();
700
701 for child in self.children_of(n) {
702 let feature = self.graph.node_weight(child).unwrap().clone();
703 features.entry(feature).or_default().push(child);
704 }
705
706 for (key, mut nodes_to_merge) in features {
707 if nodes_to_merge.len() == 1 {
708 stack.push(nodes_to_merge.pop().unwrap());
709 } else if matches!(key, FeatureOrLemma::Lemma(_)) {
710 stack.extend(nodes_to_merge);
711 } else {
712 let sum = nodes_to_merge
713 .iter()
714 .flat_map(|&a| self.graph.edges_directed(a, Outgoing))
715 .map(|x| x.weight())
716 .log_sum_exp_float_no_alloc();
717
718 let incoming_weight = nodes_to_merge
719 .iter()
720 .flat_map(|&a| self.graph.edges_directed(a, Incoming))
721 .map(|x| x.weight())
722 .log_sum_exp_clamped_no_alloc();
723
724 let node_to_keep = nodes_to_merge.pop().unwrap();
725
726 let new_edges: Vec<_> = nodes_to_merge
727 .iter()
728 .flat_map(|&a| self.graph.edges_directed(a, Outgoing))
729 .map(|e| {
730 (
731 e.target(),
732 LogProb::new(e.weight().into_inner() - sum).unwrap(),
733 )
734 })
735 .collect();
736
737 for (e, p) in self
738 .graph
739 .edges_directed(node_to_keep, Outgoing)
740 .map(|e| (e.id(), LogProb::new(e.weight().into_inner() - sum).unwrap()))
741 .collect::<Vec<_>>()
742 {
743 self.graph[e] = p;
744 }
745
746 new_edges.into_iter().for_each(|(tgt, weight)| {
747 self.graph.add_edge(node_to_keep, tgt, weight);
748 });
749
750 if let Some(e) = self
751 .graph
752 .edges_directed(node_to_keep, Incoming)
753 .next()
754 .map(|e| e.id())
755 {
756 self.graph[e] = incoming_weight;
757 }
758 nodes_to_merge.into_iter().for_each(|x| {
759 self.graph.remove_node(x);
760 });
761 stack.push(node_to_keep);
762 }
763 }
764 }
765 }
766}
767
768#[derive(Debug, Clone)]
769enum MoverOrSelector<C> {
770 Selector(C),
771 Mover(C),
772}
773
774#[derive(Debug, Clone, Copy)]
775pub struct LexicalProbConfig {
777 children_width: f64,
778 lemma_prob: f64,
779 empty_prob: f64,
780 left_prob: f64,
781 add_cat_prob: f64,
782 mover_prob: f64,
783 licensee_prob: f64,
784 affix_prob: f64,
785}
786
787impl Default for LexicalProbConfig {
788 fn default() -> Self {
789 Self {
790 children_width: 0.8,
791 lemma_prob: 0.75,
792 empty_prob: 0.25,
793 left_prob: 0.5,
794 add_cat_prob: 0.25,
795 mover_prob: 0.2,
796 licensee_prob: 0.05,
797 affix_prob: 0.25,
798 }
799 }
800}
801
802impl LexicalProbConfig {
803 fn direction(&self, rng: &mut impl Rng) -> Direction {
804 if rng.random_bool(self.left_prob) {
805 Direction::Left
806 } else {
807 Direction::Right
808 }
809 }
810}
811
812impl<'a, 'b, 'c, T: Eq + Clone + Debug, C: Eq + FreshCategory + Clone + Debug>
813 LexicalProbs<'a, 'b, 'c, T, C>
814{
815 fn from_lexicon(
816 lexicon: &'b mut Lexicon<T, C>,
817 lemmas: &'c [T],
818 config: &'a LexicalProbConfig,
819 ) -> Self {
820 LexicalProbs {
821 children_distr: Geometric::new(config.children_width).unwrap(),
822 categories: lexicon.categories().cloned().collect(),
823 licensee_features: lexicon.licensor_types().cloned().collect(),
824 to_branch: vec![],
825 config,
826 lemmas,
827 lexicon,
828 }
829 }
830 fn n_children(&self, rng: &mut impl Rng) -> u64 {
831 self.children_distr.sample(rng) + 1
832 }
833
834 fn is_lemma(&self, rng: &mut impl Rng) -> bool {
835 rng.random_bool(self.config.lemma_prob)
836 }
837
838 fn get_feature(&mut self, rng: &mut impl Rng) -> FeatureOrLemma<T, C> {
839 if rng.random_bool(self.config.mover_prob) {
840 let c = self.choose_category_for_licensor(rng);
841 FeatureOrLemma::Feature(Feature::Licensor(c))
842 } else {
843 let c = self.choose_category_for_feature(rng);
844 if self.is_lemma(rng) {
845 if rng.random_bool(self.config.affix_prob) {
846 FeatureOrLemma::Feature(Feature::Affix(c, self.config.direction(rng)))
847 } else {
848 FeatureOrLemma::Complement(c, self.config.direction(rng))
849 }
850 } else {
851 FeatureOrLemma::Feature(Feature::Selector(c, self.config.direction(rng)))
852 }
853 }
854 }
855
856 fn get_licensee_or_category(&mut self, rng: &mut impl Rng) -> Feature<C> {
857 if rng.random_bool(self.config.licensee_prob) {
858 let c = self.choose_category_for_licensee(rng);
859 Feature::Licensee(c)
860 } else {
861 let c = self.choose_category_for_category(rng);
862 Feature::Category(c)
863 }
864 }
865
866 fn add_novel_branches(&mut self, rng: &mut impl Rng) {
867 while let Some(p) = self.to_branch.pop() {
868 let f = FeatureOrLemma::Feature(match p {
869 MoverOrSelector::Selector(c) => Feature::Category(c),
870 MoverOrSelector::Mover(c) => Feature::Licensee(c),
871 });
872 let node = self.lexicon.graph.add_node(f);
873 self.lexicon
874 .graph
875 .add_edge(self.lexicon.root, node, LogProb::prob_of_one());
876 self.descend_from(node, rng);
877 }
878 }
879
880 fn descend_from(&mut self, node: NodeIndex, rng: &mut impl Rng) {
881 let mut stack = vec![node];
882
883 while let Some(node) = stack.pop() {
884 let feature = self.lexicon.graph.node_weight(node).unwrap();
885 match feature {
886 FeatureOrLemma::Root => unimplemented!(),
887 FeatureOrLemma::Lemma(_) => (),
888 FeatureOrLemma::Feature(Feature::Licensee(_)) => {
889 let n_children = self.n_children(rng);
890 for _ in 0..n_children {
891 let feature = FeatureOrLemma::Feature(self.get_licensee_or_category(rng));
892 let child = self.lexicon.graph.add_node(feature);
893 self.lexicon
894 .graph
895 .add_edge(node, child, LogProb::prob_of_one());
896 stack.push(child);
897 }
898 }
899 FeatureOrLemma::Feature(Feature::Licensor(_) | Feature::Selector(_, _)) => {
900 let n_children = self.n_children(rng);
901 for _ in 0..n_children {
902 let feature = self.get_feature(rng);
903 let child = self.lexicon.graph.add_node(feature);
904 self.lexicon
905 .graph
906 .add_edge(node, child, LogProb::prob_of_one());
907 stack.push(child);
908 }
909 }
910 FeatureOrLemma::Feature(Feature::Category(_)) => {
911 let n_children = self.n_children(rng);
912 for _ in 0..n_children {
913 let is_lemma = self.is_lemma(rng);
914 let feature = if is_lemma {
915 FeatureOrLemma::Lemma(self.get_lemma(rng))
916 } else {
917 self.get_feature(rng)
918 };
919 let child = self.lexicon.graph.add_node(feature);
920 self.lexicon
921 .graph
922 .add_edge(node, child, LogProb::prob_of_one());
923 if is_lemma {
924 self.lexicon.leaves.push(LexemeId(child));
925 } else {
926 stack.push(child);
927 }
928 }
929 }
930 FeatureOrLemma::Complement(_, _)
931 | FeatureOrLemma::Feature(Feature::Affix(_, _)) => {
932 let n_children = self.n_children(rng);
933 for _ in 0..n_children {
934 let child = self
935 .lexicon
936 .graph
937 .add_node(FeatureOrLemma::Lemma(self.get_lemma(rng)));
938 self.lexicon
939 .graph
940 .add_edge(node, child, LogProb::prob_of_one());
941 self.lexicon.leaves.push(LexemeId(child));
942 }
943 }
944 }
945 }
946 }
947
948 fn get_lemma(&self, rng: &mut impl Rng) -> Option<T> {
949 if rng.random_bool(self.config.empty_prob) {
950 None
951 } else {
952 self.lemmas.choose(rng).cloned()
953 }
954 }
955
956 fn set_node(&mut self, node: NodeIndex, rng: &mut impl Rng) {
959 let n = self.lexicon.graph.node_weight(node).unwrap();
960
961 let new_feature = match n {
962 FeatureOrLemma::Root => FeatureOrLemma::Root,
963 FeatureOrLemma::Lemma(_) => FeatureOrLemma::Lemma(self.get_lemma(rng)),
964 FeatureOrLemma::Feature(f) => match f {
965 Feature::Category(_) => FeatureOrLemma::Feature(Feature::Category(
966 self.choose_category_for_category(rng),
967 )),
968 Feature::Licensee(_) => FeatureOrLemma::Feature(Feature::Licensee(
969 self.choose_category_for_licensee(rng),
970 )),
971 Feature::Selector(_, _) | Feature::Licensor(_) => {
972 let lemma_children =
973 self.lexicon
974 .graph
975 .neighbors_directed(node, Outgoing)
976 .any(|x| {
977 matches!(
978 self.lexicon.graph.node_weight(x).unwrap(),
979 FeatureOrLemma::Lemma(_)
980 )
981 });
982
983 if rng.random_bool(self.config.mover_prob) && !lemma_children {
984 let feature = self.choose_category_for_licensor(rng);
985 FeatureOrLemma::Feature(Feature::Licensor(feature))
986 } else {
987 let feature = self.choose_category_for_feature(rng);
988 FeatureOrLemma::Feature(Feature::Selector(
989 feature,
990 self.config.direction(rng),
991 ))
992 }
993 }
994 Feature::Affix(_, _) => {
995 if rng.random_bool(self.config.affix_prob) {
996 FeatureOrLemma::Feature(Feature::Affix(
997 self.choose_category_for_feature(rng),
998 self.config.direction(rng),
999 ))
1000 } else {
1001 FeatureOrLemma::Complement(
1002 self.choose_category_for_feature(rng),
1003 self.config.direction(rng),
1004 )
1005 }
1006 }
1007 },
1008 FeatureOrLemma::Complement(..) => {
1009 if rng.random_bool(self.config.affix_prob) {
1010 FeatureOrLemma::Feature(Feature::Affix(
1011 self.choose_category_for_feature(rng),
1012 self.config.direction(rng),
1013 ))
1014 } else {
1015 FeatureOrLemma::Complement(
1016 self.choose_category_for_feature(rng),
1017 self.config.direction(rng),
1018 )
1019 }
1020 }
1021 };
1022 *self.lexicon.graph.node_weight_mut(node).unwrap() = new_feature;
1023 }
1024
1025 fn choose_category_for_category(&mut self, rng: &mut impl Rng) -> C {
1026 if self.categories.is_empty() {
1027 self.categories = vec![C::fresh(&self.licensee_features)];
1028 }
1029 self.categories.choose(rng).cloned().unwrap()
1030 }
1031
1032 fn choose_category_for_licensee(&mut self, rng: &mut impl Rng) -> C {
1033 if self.licensee_features.is_empty() {
1034 self.licensee_features = vec![C::fresh(&self.categories)];
1035 }
1036 self.licensee_features.choose(rng).cloned().unwrap()
1037 }
1038
1039 fn choose_category_for_licensor(&mut self, rng: &mut impl Rng) -> C {
1040 if self.licensee_features.is_empty() || rng.random_bool(self.config.add_cat_prob) {
1041 let new_cat = C::fresh(
1042 &[
1043 self.categories.as_slice(),
1044 self.licensee_features.as_slice(),
1045 ]
1046 .concat(),
1047 );
1048 self.licensee_features.push(new_cat.clone());
1049 self.to_branch.push(MoverOrSelector::Mover(new_cat.clone()));
1050 new_cat
1051 } else {
1052 self.licensee_features.choose(rng).cloned().unwrap()
1053 }
1054 }
1055
1056 fn choose_category_for_feature(&mut self, rng: &mut impl Rng) -> C {
1057 if rng.random_bool(self.config.add_cat_prob) {
1058 let new_cat = C::fresh(
1059 &[
1060 self.categories.as_slice(),
1061 self.licensee_features.as_slice(),
1062 ]
1063 .concat(),
1064 );
1065 self.categories.push(new_cat.clone());
1066 self.to_branch
1067 .push(MoverOrSelector::Selector(new_cat.clone()));
1068 new_cat
1069 } else {
1070 self.categories.choose(rng).cloned().unwrap()
1071 }
1072 }
1073}
1074
1075struct LexicalProbs<'a, 'b, 'c, T: Eq, C: Eq> {
1076 children_distr: Geometric,
1077 categories: Vec<C>,
1078 to_branch: Vec<MoverOrSelector<C>>,
1079 licensee_features: Vec<C>,
1080 config: &'a LexicalProbConfig,
1081 lemmas: &'c [T],
1082 lexicon: &'b mut Lexicon<T, C>,
1083}
1084
1085#[cfg(test)]
1086mod test {
1087 use ahash::HashSet;
1088 use anyhow::bail;
1089 use itertools::Itertools;
1090 use rand::SeedableRng;
1091 use rand_chacha::ChaCha8Rng;
1092
1093 use super::*;
1094
1095 fn total_prob<T: Eq, C: Eq + Debug>(lex: &Lexicon<T, C>, node: NodeIndex) -> f64 {
1096 lex.graph
1097 .edges_directed(node, Outgoing)
1098 .map(|x| x.weight())
1099 .log_sum_exp_float()
1100 }
1101
1102 #[derive(Debug, Clone, Copy, PartialEq, Eq)]
1103 enum Position {
1104 Root,
1105 PreCategory,
1106 Category,
1107 PostCategory,
1108 NextIsLemma,
1109 Done,
1110 }
1111
1112 impl Position {
1113 fn next_pos<T: Eq + Debug, C: Eq + Debug>(
1114 &self,
1115 f: &FeatureOrLemma<T, C>,
1116 ) -> anyhow::Result<Self> {
1117 match self {
1118 Position::Root | Position::PreCategory => match f {
1119 FeatureOrLemma::Feature(Feature::Category(_)) => Ok(Position::Category),
1120 FeatureOrLemma::Feature(Feature::Licensee(_)) => Ok(Position::PreCategory),
1121 _ => bail!("can't go from {:?} to {:?} !", self, f),
1122 },
1123 Position::Category => match f {
1124 FeatureOrLemma::Feature(Feature::Selector(..))
1125 | FeatureOrLemma::Feature(Feature::Licensor(_)) => Ok(Position::PostCategory),
1126 FeatureOrLemma::Complement(..)
1127 | FeatureOrLemma::Feature(Feature::Affix(_, _)) => Ok(Position::NextIsLemma),
1128 FeatureOrLemma::Lemma(_) => Ok(Position::Done),
1129 _ => bail!("can't go from {:?} to {:?} !", self, f),
1130 },
1131 Position::PostCategory => match f {
1132 FeatureOrLemma::Feature(Feature::Selector(..))
1133 | FeatureOrLemma::Feature(Feature::Licensor(_)) => Ok(Position::PostCategory),
1134 FeatureOrLemma::Complement(..)
1135 | FeatureOrLemma::Feature(Feature::Affix(_, _)) => Ok(Position::NextIsLemma),
1136 _ => bail!("can't go from {:?} to {:?} !", self, f),
1137 },
1138 Position::NextIsLemma => match f {
1139 FeatureOrLemma::Lemma(_) => Ok(Position::Done),
1140 _ => bail!("can't go from {:?} to {:?} !", self, f),
1141 },
1142 Position::Done => bail!("Done can't be continued"),
1143 }
1144 }
1145 }
1146
1147 fn validate_lexicon<T: Eq + Debug + Clone, C: Eq + Debug + Clone>(
1148 lex: &Lexicon<T, C>,
1149 ) -> anyhow::Result<()> {
1150 let mut at_least_one_category = false;
1151 let mut found_leaves = AHashSet::default();
1152 let mut found_root = None;
1153 let mut stack = vec![(lex.root, Position::Root)];
1154
1155 while let Some((nx, pos)) = stack.pop() {
1156 let children = lex.children_of(nx).collect::<Vec<_>>();
1157
1158 if pos == Position::Done {
1159 assert!(children.is_empty())
1160 } else {
1161 assert!(!children.is_empty())
1162 }
1163 for child in children {
1164 let f = lex.graph.node_weight(child).unwrap();
1165 if matches!(f, FeatureOrLemma::Feature(Feature::Category(_))) {
1166 at_least_one_category = true;
1167 }
1168 let next_pos = pos.next_pos(f)?;
1169 stack.push((child, next_pos))
1170 }
1171 }
1172 assert!(at_least_one_category);
1173
1174 for node in lex.graph.node_indices() {
1175 let mut parent_iter = lex.graph.neighbors_directed(node, Incoming);
1176 let parent = parent_iter.next();
1177 assert!(parent_iter.next().is_none());
1178 let children: Vec<_> = lex.graph.neighbors_directed(node, Outgoing).collect();
1179 let de_duped: Vec<_> = children
1180 .iter()
1181 .map(|&x| lex.graph.node_weight(x))
1182 .filter(|x| !matches!(x.unwrap(), FeatureOrLemma::Lemma(_)))
1183 .dedup()
1184 .collect();
1185
1186 assert_eq!(
1187 de_duped.len(),
1188 children
1189 .iter()
1190 .filter(|&&a| !matches!(
1191 lex.graph.node_weight(a).unwrap(),
1192 FeatureOrLemma::Lemma(_)
1193 ))
1194 .count()
1195 );
1196
1197 match lex.graph.node_weight(node).unwrap() {
1198 FeatureOrLemma::Root => {
1199 assert!(parent.is_none());
1200 if found_root.is_some() {
1201 panic!("Multiple roots!");
1202 }
1203 found_root = Some(node);
1204 }
1205 FeatureOrLemma::Lemma(_) => {
1206 assert!(parent.is_some());
1207 found_leaves.insert(LexemeId(node));
1208
1209 assert!(matches!(
1210 lex.graph.node_weight(parent.unwrap()).unwrap(),
1211 FeatureOrLemma::Complement(_, _)
1212 | FeatureOrLemma::Feature(Feature::Category(_))
1213 | FeatureOrLemma::Feature(Feature::Affix(_, _))
1214 ));
1215 assert!(children.is_empty());
1216 }
1217 FeatureOrLemma::Feature(_) => {
1218 assert!(parent.is_some());
1219 assert!(!children.is_empty());
1220 approx::assert_relative_eq!(total_prob(lex, node), 0.0, epsilon = 1e-10);
1221 }
1222 FeatureOrLemma::Complement(_, _) => {
1223 assert!(parent.is_some());
1224 assert!(!children.is_empty());
1225 assert!(children.into_iter().all(|x| matches!(
1226 lex.graph.node_weight(x).unwrap(),
1227 FeatureOrLemma::Lemma(_)
1228 )));
1229 approx::assert_relative_eq!(total_prob(lex, node), 0.0, epsilon = 1e-10);
1230 }
1231 }
1232 }
1233
1234 let leaves: AHashSet<_> = lex.leaves.iter().copied().collect();
1235 assert_eq!(leaves, found_leaves);
1236 assert_eq!(leaves.len(), lex.leaves.len());
1237
1238 Ok(())
1239 }
1240
1241 #[test]
1242 fn sample_lexeme() -> anyhow::Result<()> {
1243 let config = LexicalProbConfig::default();
1244 let mut rng = ChaCha8Rng::seed_from_u64(0);
1245 LexicalEntry::sample(&["a", "b"], &["c", "d"], Some("john"), config, &mut rng);
1246 Ok(())
1247 }
1248
1249 #[test]
1250 fn pruning() -> anyhow::Result<()> {
1251 let mut lex = Lexicon::from_string("A::c= s\nB::d\nC::c")?;
1252 lex.prune(&"s");
1253 assert_eq!(lex.to_string(), "A::c= s\nC::c");
1254
1255 let mut lex = Lexicon::from_string("A::z= c= s\nB::d\nC::c")?;
1256 lex.prune(&"s");
1257 assert_eq!(lex.to_string(), "");
1258
1259 let mut lex = Lexicon::from_string("A::z= c= s\nB::d\nC::d= c\nD::z")?;
1260 lex.prune(&"s");
1261 assert_eq!(lex.to_string(), "A::z= c= s\nB::d\nC::d= c\nD::z");
1262 let mut lex = Lexicon::from_string("A::z= +w s\nD::z -w")?;
1263 lex.prune(&"s");
1264 assert_eq!(lex.to_string(), "A::z= +w s\nD::z -w");
1265 Ok(())
1266 }
1267
1268 #[test]
1269 fn random_lexicon() -> anyhow::Result<()> {
1270 let mut rng = ChaCha8Rng::seed_from_u64(0);
1271 let mut at_least_one_affix = false;
1272 for _ in 0..1000 {
1273 let x = Lexicon::<_, usize>::random(&0, &["the", "dog", "runs"], None, &mut rng);
1274 for leaf in x.leaves() {
1275 let nx = x.parent_of(leaf.0).unwrap();
1276 let (f, _) = x.get(nx).unwrap();
1277 match f {
1278 FeatureOrLemma::Feature(Feature::Affix(..)) => at_least_one_affix = true,
1279 FeatureOrLemma::Complement(..)
1280 | FeatureOrLemma::Feature(Feature::Category(_)) => (),
1281 _ => panic!("Invalid lexicon"),
1282 }
1283 }
1284 validate_lexicon(&x)?;
1285 }
1286 assert!(at_least_one_affix);
1287 Ok(())
1288 }
1289
1290 #[test]
1291 fn random_redone_lexicon() -> anyhow::Result<()> {
1292 let mut main_rng = ChaCha8Rng::seed_from_u64(0);
1293 let lemmas = &["the", "dog", "runs"];
1294 for _ in 0..1000 {
1295 let mut rng = ChaCha8Rng::seed_from_u64(497);
1296 let mut lex = Lexicon::<_, usize>::random(&0, lemmas, None, &mut rng);
1297 validate_lexicon(&lex)?;
1298 lex.resample_below_node(lemmas, None, &mut main_rng);
1299 validate_lexicon(&lex)?;
1300 }
1301 Ok(())
1302 }
1303
1304 #[test]
1305 fn random_delete_branch() -> anyhow::Result<()> {
1306 let mut rng = ChaCha8Rng::seed_from_u64(0);
1307 let lemmas = &["the", "dog", "runs"];
1308 for _ in 0..10000 {
1309 let mut lex = Lexicon::<_, usize>::random(&0, lemmas, None, &mut rng);
1310 validate_lexicon(&lex)?;
1311 lex.delete_from_node(&mut rng);
1312 validate_lexicon(&lex)?;
1313 }
1314 Ok(())
1315 }
1316
1317 #[test]
1318 fn random_delete_feat() -> anyhow::Result<()> {
1319 let mut rng = ChaCha8Rng::seed_from_u64(0);
1320 let lemmas = &["the", "dog", "runs"];
1321 for _ in 0..10000 {
1322 let mut lex = Lexicon::<_, usize>::random(&0, lemmas, None, &mut rng);
1323 validate_lexicon(&lex)?;
1324 lex.delete_node(&mut rng);
1325 validate_lexicon(&lex)?;
1326 }
1327 Ok(())
1328 }
1329 #[test]
1330 fn random_add_lexeme() -> anyhow::Result<()> {
1331 let mut rng = ChaCha8Rng::seed_from_u64(0);
1332 let lemmas = &["the", "dog", "runs"];
1333 for _ in 0..10000 {
1334 let mut lex = Lexicon::<_, usize>::random(&0, lemmas, None, &mut rng);
1335 validate_lexicon(&lex)?;
1336 lex.add_new_lexeme_from_sibling("lbarg", &mut rng);
1337 validate_lexicon(&lex)?;
1338 let mut lex = Lexicon::<_, usize>::random(&0, lemmas, None, &mut rng);
1339 lex.add_new_lexeme(Some("lbarg"), None, &mut rng);
1340 validate_lexicon(&lex)?;
1341 }
1342 Ok(())
1343 }
1344
1345 #[test]
1346 fn random_change_feat() -> anyhow::Result<()> {
1347 let mut rng = ChaCha8Rng::seed_from_u64(0);
1348 let lemmas = &["the", "dog", "runs"];
1349 for _ in 0..10000 {
1350 let mut lex = Lexicon::<_, usize>::random(&0, lemmas, None, &mut rng);
1351 println!("{lex}");
1352 validate_lexicon(&lex)?;
1353 println!("NEW VERSION");
1354 lex.change_feature(lemmas, None, &mut rng);
1355 validate_lexicon(&lex)?;
1356 println!("{lex}");
1357 println!("_______________________________________________");
1358 }
1359 Ok(())
1360 }
1361
1362 #[test]
1363 fn random_unify() -> anyhow::Result<()> {
1364 let mut rng = ChaCha8Rng::seed_from_u64(0);
1365 let lemmas = &["the", "dog", "runs"];
1366 for _ in 0..100 {
1367 let mut a = Lexicon::<_, usize>::random(&0, lemmas, None, &mut rng);
1368 let b = Lexicon::<_, usize>::random(&0, lemmas, None, &mut rng);
1369
1370 let lexemes: HashSet<_> = a
1371 .lexemes()?
1372 .into_iter()
1373 .chain(b.lexemes()?.into_iter())
1374 .collect();
1375
1376 validate_lexicon(&a)?;
1377 validate_lexicon(&b)?;
1378 a.unify(&b);
1379 validate_lexicon(&a)?;
1380
1381 let unified_lexemes: HashSet<_> = a.lexemes()?.into_iter().collect();
1382
1383 assert_eq!(lexemes, unified_lexemes);
1384 }
1385 Ok(())
1386 }
1387
1388 #[test]
1389 fn uniform_distribution() -> anyhow::Result<()> {
1390 let mut rng = ChaCha8Rng::seed_from_u64(0);
1391 let lemmas = &["the", "dog", "runs"];
1392 for _ in 0..100 {
1393 let mut a = Lexicon::<_, usize>::random(&0, lemmas, None, &mut rng);
1394 let b = Lexicon::<_, usize>::random(&0, lemmas, None, &mut rng);
1395
1396 a.unify(&b);
1397 a.uniform_distribution();
1398 validate_lexicon(&a)?;
1399 }
1400 Ok(())
1401 }
1402
1403 #[test]
1404 fn delete_lexeme() -> anyhow::Result<()> {
1405 let mut rng = ChaCha8Rng::seed_from_u64(0);
1406 let lemmas = &["the", "dog", "runs"];
1407 for _ in 0..100 {
1408 let a = Lexicon::<_, usize>::random(&0, lemmas, None, &mut rng);
1409 let lexemes_and_leaves: HashSet<_> = a
1410 .lexemes()?
1411 .into_iter()
1412 .zip(a.leaves.iter().copied())
1413 .collect();
1414
1415 if lexemes_and_leaves.len() > 1 {
1416 validate_lexicon(&a)?;
1417 for (lexeme, leaf) in lexemes_and_leaves.clone() {
1418 println!("Deleting {lexeme} from ({a})");
1419 let mut a = a.clone();
1420 a.delete_lexeme(leaf)?;
1421 validate_lexicon(&a)?;
1422
1423 let new_lexemes: HashSet<_> = a
1424 .lexemes()?
1425 .into_iter()
1426 .zip(a.leaves.iter().copied())
1427 .collect();
1428 let mut old_lexemes: HashSet<_> = lexemes_and_leaves.clone();
1429 old_lexemes.remove(&(lexeme, leaf));
1430 assert_eq!(new_lexemes, old_lexemes);
1431 }
1432 }
1433 }
1434 Ok(())
1435 }
1436}