1use crate::{
2 Direction,
3 lexicon::Feature,
4 parsing::{
5 RuleIndex,
6 rules::{
7 TraceId,
8 printing::{Lemma, MgNode, Storage},
9 },
10 trees::GornIndex,
11 },
12};
13use ahash::HashMap;
14use itertools::Itertools;
15use petgraph::graph::{DiGraph, NodeIndex};
16use serde::{
17 Serialize,
18 ser::{SerializeSeq, SerializeStruct, SerializeStructVariant},
19};
20use std::collections::HashSet;
21use std::fmt::{Debug, Display};
22use std::{collections::VecDeque, hash::Hash};
23
24#[cfg(not(feature = "semantics"))]
25use std::marker::PhantomData;
26
27#[cfg(feature = "semantics")]
28use regex::Regex;
29
30#[cfg(feature = "semantics")]
31use super::semantics::SemanticNode;
32
33#[derive(Debug, Clone, Eq, PartialEq, Hash)]
35pub struct Tree<'src, T, C: Eq + Display> {
36 node: TreeNode<'src, T, C>,
37 children: Vec<Tree<'src, T, C>>,
38}
39
40#[derive(Debug, Clone, Eq, PartialEq, Hash)]
41struct MovementTrace(Vec<(GornIndex, GornIndex)>);
42
43impl Serialize for MovementTrace {
44 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
45 where
46 S: serde::Serializer,
47 {
48 let mut seq = serializer.serialize_seq(Some(self.0.len()))?;
49 for (source, tgt) in &self.0 {
50 seq.serialize_element(&(source.to_string(), tgt.to_string()))?;
51 }
52 seq.end()
53 }
54}
55
56#[derive(Debug, Clone, Eq, PartialEq, Hash)]
58pub struct TreeWithMovement<'src, T, C: Eq + Display> {
59 tree: Tree<'src, T, C>,
60 head_movement: MovementTrace,
61 phrasal_movement: MovementTrace,
62}
63
64impl<T: Serialize, C: Eq + Display + Clone> Serialize for TreeWithMovement<'_, T, C> {
65 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
66 where
67 S: serde::Serializer,
68 {
69 let mut seq = serializer.serialize_struct("Tree", 2)?;
70
71 seq.serialize_field("tree", &self.tree)?;
72 seq.serialize_field("head_movement", &self.head_movement)?;
73 seq.serialize_field("phrasal_movement", &self.phrasal_movement)?;
74
75 seq.end()
76 }
77}
78impl<T: Display, C: Eq + Display + Clone> TreeWithMovement<'_, T, C> {
79 pub fn latex(&self) -> String {
81 format!(
82 "\\begin{{forest}}{}\\end{{forest}}",
83 self.tree.latex_inner()
84 )
85 }
86}
87
88impl<'src, T, C: Eq + Display> TreeWithMovement<'src, T, C> {
89 pub fn tree(&self) -> &Tree<'src, T, C> {
91 &self.tree
92 }
93
94 pub fn head_movement(&self) -> &[(GornIndex, GornIndex)] {
96 &self.head_movement.0
97 }
98
99 pub fn phrasal_movement(&self) -> &[(GornIndex, GornIndex)] {
101 &self.phrasal_movement.0
102 }
103
104 pub(crate) fn new(
105 tree: Tree<'src, T, C>,
106 head_movement: impl Iterator<Item = (RuleIndex, RuleIndex)>,
107 phrasal_movement: impl Iterator<Item = (RuleIndex, RuleIndex)>,
108 ) -> Self {
109 let look_up = tree.gorn_address();
110 TreeWithMovement {
111 tree,
112 head_movement: MovementTrace(
113 head_movement
114 .map(|(a, b)| {
115 (
116 look_up.get(&a).copied().unwrap(),
117 look_up.get(&b).copied().unwrap(),
118 )
119 })
120 .collect(),
121 ),
122 phrasal_movement: MovementTrace(
123 phrasal_movement
124 .map(|(a, b)| {
125 (
126 look_up.get(&a).copied().unwrap(),
127 look_up.get(&b).copied().unwrap(),
128 )
129 })
130 .collect(),
131 ),
132 }
133 }
134}
135
136#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash, PartialOrd, Ord)]
138pub enum TreeEdge {
139 Merge(Direction),
142 Move,
144 MoveHead,
146}
147
148impl Display for TreeEdge {
149 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
150 match self {
151 TreeEdge::Merge(Direction::Left) => write!(f, "Merge(Left)"),
152 TreeEdge::Merge(Direction::Right) => write!(f, "Merge(Right)"),
153 TreeEdge::Move => write!(f, "Move"),
154 TreeEdge::MoveHead => write!(f, "MoveHead"),
155 }
156 }
157}
158
159impl<'src, T: Debug + Clone, C: Debug + Clone + Eq + Display> TreeWithMovement<'src, T, C> {
160 pub fn petgraph(&self) -> (DiGraph<TreeNode<'src, T, C>, TreeEdge>, NodeIndex) {
162 let mut g = DiGraph::new();
163 let root = g.add_node(self.tree.node.clone());
164 let mut h = HashMap::default();
165 h.insert(GornIndex::default(), root);
166
167 let mut stack: VecDeque<_> = self
168 .tree
169 .children
170 .iter()
171 .enumerate()
172 .map(|(i, x)| {
173 let dir = match i {
174 0 => Direction::Left,
175 1 => Direction::Right,
176 _ => panic!("The library should only have binary branching!"),
177 };
178 (x, root, dir, GornIndex::new(dir))
179 })
180 .collect();
181
182 while let Some((tree, par, dir, gorn)) = stack.pop_front() {
183 let node = g.add_node(tree.node.clone());
184 h.insert(gorn, node);
185 g.add_edge(par, node, TreeEdge::Merge(dir));
186
187 stack.extend(tree.children.iter().enumerate().map(|(i, x)| {
188 let dir = match i {
189 0 => Direction::Left,
190 1 => Direction::Right,
191 _ => panic!("The library should only have binary branching!"),
192 };
193 (x, node, dir, gorn.clone_push(dir))
194 }));
195 }
196
197 for (a, b) in &self.head_movement.0 {
198 g.add_edge(*h.get(a).unwrap(), *h.get(b).unwrap(), TreeEdge::MoveHead);
199 }
200 for (a, b) in &self.phrasal_movement.0 {
201 g.add_edge(*h.get(a).unwrap(), *h.get(b).unwrap(), TreeEdge::Move);
202 }
203
204 (g, root)
205 }
206}
207
208impl<'src, T, C: Eq + Display> Tree<'src, T, C> {
209 pub(crate) fn gorn_address(&self) -> HashMap<RuleIndex, GornIndex> {
210 let mut h = HashMap::default();
211
212 let mut stack = vec![(self, GornIndex::default())];
213
214 while let Some((tree, gorn)) = stack.pop() {
215 h.insert(tree.node.rule, gorn);
216 stack.extend(tree.children.iter().enumerate().map(|(x, child)| {
217 (
218 child,
219 gorn.clone_push(match x {
220 0 => Direction::Left,
221 1 => Direction::Right,
222 _ => panic!("Trees should always be binary!"),
223 }),
224 )
225 }));
226 }
227
228 h
229 }
230
231 pub(crate) fn new(
232 node: MgNode<T, C>,
233 storage: Storage<C>,
234 children: Vec<Tree<'src, T, C>>,
235 rule: RuleIndex,
236 ) -> Self {
237 Tree {
238 node: TreeNode::new(node, storage, rule),
239 children,
240 }
241 }
242
243 #[cfg(feature = "semantics")]
244 pub(crate) fn new_with_semantics(
245 node: MgNode<T, C>,
246 semantic_node: Option<SemanticNode<'src>>,
247 storage: Storage<C>,
248 children: Vec<Tree<'src, T, C>>,
249 rule: RuleIndex,
250 ) -> Self {
251 Tree {
252 node: TreeNode::new_semantics(node, storage, semantic_node, rule),
253 children,
254 }
255 }
256
257 pub fn storage(&self) -> &Storage<C> {
259 &self.node.storage
260 }
261}
262
263impl<T: Display, C: Eq + Display> Tree<'_, T, C> {
264 fn latex_inner(&self) -> String {
265 let node = self.node.latex();
266
267 let children: Vec<_> = self.children.iter().map(Tree::latex_inner).collect();
268 if children.is_empty() {
269 format!("[{node}]")
270 } else {
271 format!("[{node} {}]", children.join(" "))
272 }
273 }
274}
275
276#[derive(Debug, Clone, Eq, PartialEq, Hash)]
280pub struct TreeNode<'src, T, C: Eq + Display> {
281 node: MgNode<T, C>,
282 rule: RuleIndex,
283
284 storage: Storage<C>,
285
286 #[cfg(feature = "semantics")]
287 semantics: Option<SemanticNode<'src>>,
288
289 #[cfg(not(feature = "semantics"))]
290 semantics: PhantomData<&'src ()>,
291}
292
293impl<T: Display, C: Eq + Display> Display for TreeNode<'_, T, C> {
294 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
295 match &self.node {
296 MgNode::Start => {
297 write!(f, "Start")?;
298 }
299 MgNode::Node { features } => {
300 write!(f, "{}", features.iter().join(" "))?;
301 #[cfg(feature = "semantics")]
302 if let Some(SemanticNode::Rich(_, Some(state))) = &self.semantics {
303 write!(f, "::{state}")?;
304 }
305 }
306
307 MgNode::Leaf { lemma, features } => {
308 write!(f, "{}::{}", lemma, features.iter().join(" "))?;
309
310 #[cfg(feature = "semantics")]
311 if let Some(SemanticNode::Rich(_, Some(state))) = &self.semantics {
312 write!(f, "::{state}")?;
313 }
314 }
315 MgNode::Trace { trace } => {
316 write!(f, "{trace}")?;
317 }
318 }
319 Ok(())
320 }
321}
322
323impl<T: Display> Lemma<T> {
324 fn to_string(&self, empty_string: &str, join: &str) -> String {
325 match self {
326 Lemma::Single(Some(x)) => x.to_string(),
327 Lemma::Single(None) => empty_string.to_string(),
328 Lemma::Multi { heads, .. } => heads
329 .iter()
330 .map(|x| {
331 x.as_ref().map_or_else(
332 || empty_string.to_string(),
333 std::string::ToString::to_string,
334 )
335 })
336 .collect::<Vec<_>>()
337 .join(join),
338 }
339 }
340}
341impl<T, C: Eq + Display> TreeNode<'_, T, C> {
342 pub fn is_trace(&self) -> bool {
344 matches!(self.node, MgNode::Trace { .. })
345 }
346
347 pub fn trace_id(&self) -> Option<TraceId> {
349 let MgNode::Trace { trace } = self.node else {
350 return None;
351 };
352 Some(trace)
353 }
354
355 pub fn lemma(&self) -> Option<&Lemma<T>> {
357 let MgNode::Leaf { lemma, .. } = &self.node else {
358 return None;
359 };
360 Some(lemma)
361 }
362}
363
364impl<T: Display, C: Eq + Display> TreeNode<'_, T, C> {
365 fn latex(&self) -> String {
366 match &self.node {
367 MgNode::Node { features } => {
368 let features = features
369 .iter()
370 .map(std::string::ToString::to_string)
371 .join(" ");
372
373 #[cfg(feature = "semantics")]
374 if let Some(meaning) = &self.semantics {
375 match meaning {
376 SemanticNode::Rich(..) => {
377 return format!(
378 "\\semder{{{features}}}{{\\texttt{{{}}}}}",
379 clean_up_expr(meaning.to_string().as_str())
380 );
381 }
382 SemanticNode::Simple(_) => {
383 return format!("\\semder{{{features}}}{{\\textsc{{{meaning}}}}}");
384 }
385 }
386 }
387 format!("\\der{{{features}}}")
388 }
389 MgNode::Start => "\\textsc{Start}".to_string(),
390 MgNode::Leaf {
391 lemma, features, ..
392 } => {
393 let features = features
394 .iter()
395 .map(std::string::ToString::to_string)
396 .join(" ");
397 let lemma = lemma.to_string("$\\epsilon$", "-");
398 #[cfg(feature = "semantics")]
399 if let Some(meaning) = &self.semantics {
400 match meaning {
401 SemanticNode::Rich(..) => {
402 return format!(
403 "\\lex{{{features}}}{{{lemma}}}{{\\texttt{{{}}}}}",
404 clean_up_expr(meaning.to_string().as_str())
405 );
406 }
407 SemanticNode::Simple(_) => {
408 return format!(
409 "\\lex{{{features}}}{{{lemma}}}{{\\textsc{{{meaning}}}}}"
410 );
411 }
412 }
413 }
414 format!("\\plainlex{{{features}}}{{{lemma}}}")
415 }
416 MgNode::Trace { trace } => format!("$t_{}$", trace.0),
417 }
418 }
419}
420
421impl<'src, T, C: Eq + Display> TreeNode<'src, T, C> {
422 fn new(node: MgNode<T, C>, storage: Storage<C>, rule: RuleIndex) -> TreeNode<'static, T, C> {
423 TreeNode {
424 node,
425 rule,
426 storage,
427
428 #[cfg(feature = "semantics")]
429 semantics: None,
430
431 #[cfg(not(feature = "semantics"))]
432 semantics: PhantomData,
433 }
434 }
435
436 #[cfg(feature = "semantics")]
437 fn new_semantics(
438 node: MgNode<T, C>,
439 storage: Storage<C>,
440 semantics: Option<SemanticNode<'src>>,
441 rule: RuleIndex,
442 ) -> Self {
443 TreeNode {
444 node,
445 rule,
446 storage,
447 semantics,
448 }
449 }
450}
451
452impl<C: Eq + Display> Serialize for Feature<C> {
453 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
454 where
455 S: serde::Serializer,
456 {
457 serializer.serialize_str(self.to_string().as_str())
458 }
459}
460
461#[cfg(feature = "semantics")]
462fn clean_up_expr(s: &str) -> String {
463 let re = Regex::new(r"lambda (?<t>[eat,< >]+) ").unwrap();
464 let s = s
465 .replace('&', "\\&")
466 .replace('_', "\\_")
467 .replace('#', "\\#");
468 re.replace_all(s.as_str(), "{$\\lambda_{$t}$}")
469 .to_string()
470 .replace('<', "\\left\\langle ")
471 .replace('>', "\\right\\rangle ")
472}
473
474impl<C: Eq + Display + Clone> Serialize for Storage<C> {
475 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
476 where
477 S: serde::Serializer,
478 {
479 let mut seq = serializer.serialize_seq(Some(self.len()))?;
480 for mover in self.values() {
481 seq.serialize_element(
482 &mover
483 .iter()
484 .map(|c| Feature::Licensee(c.clone()))
485 .collect::<Vec<_>>(),
486 )?;
487 }
488 seq.end()
489 }
490}
491
492impl<T, C: Eq + Clone> Serialize for TreeNode<'_, T, C>
493where
494 C: Display,
495 T: Serialize,
496{
497 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
498 where
499 S: serde::Serializer,
500 {
501 match &self.node {
502 MgNode::Node { features, .. } => {
503 #[cfg(not(feature = "semantics"))]
504 let n = 3;
505
506 #[cfg(feature = "semantics")]
507 let n = if self.semantics.is_some() { 4 } else { 3 };
508
509 let mut seq = serializer.serialize_struct_variant("MgNode", 0, "Node", n)?;
510
511 seq.serialize_field("features", features)?;
512 seq.serialize_field("movement", &self.storage)?;
513
514 #[cfg(feature = "semantics")]
515 if let Some(semantics) = &self.semantics {
516 seq.serialize_field("semantics", &semantics)?;
517 }
518
519 seq.end()
520 }
521 MgNode::Leaf { lemma, features } => {
522 #[cfg(not(feature = "semantics"))]
523 let n = 4;
524
525 #[cfg(feature = "semantics")]
526 let n = if self.semantics.is_some() { 5 } else { 4 };
527
528 let mut seq = serializer.serialize_struct_variant("MgNode", 1, "Leaf", n)?;
529
530 seq.serialize_field("features", features)?;
531 seq.serialize_field("lemma", lemma)?;
532
533 #[cfg(feature = "semantics")]
534 if let Some(semantics) = &self.semantics {
535 seq.serialize_field("semantics", semantics)?;
536 }
537
538 seq.end()
539 }
540 MgNode::Trace { trace } => {
541 #[cfg(not(feature = "semantics"))]
542 let n = 2;
543
544 #[cfg(feature = "semantics")]
545 let n = if self.semantics.is_some() { 3 } else { 2 };
546
547 let mut seq = serializer.serialize_struct_variant("MgNode", 2, "Trace", n)?;
548
549 seq.serialize_field("trace", trace)?;
550
551 #[cfg(feature = "semantics")]
552 if let Some(semantics) = &self.semantics {
553 seq.serialize_field("semantics", semantics)?;
554 }
555
556 seq.end()
557 }
558 MgNode::Start => "Start".serialize(serializer),
559 }
560 }
561}
562
563impl<T, C: Eq + Clone> Serialize for Tree<'_, T, C>
564where
565 C: Display,
566 T: Serialize,
567{
568 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
569 where
570 S: serde::Serializer,
571 {
572 if self.children.is_empty() {
574 self.node.serialize(serializer)
575 } else {
576 let mut seq = serializer.serialize_seq(Some(self.children.len() + 1))?;
577 seq.serialize_element(&self.node)?;
578 for tree in &self.children {
579 seq.serialize_element(tree)?;
580 }
581 seq.end()
582 }
583 }
584}
585
586impl Serialize for TraceId {
587 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
588 where
589 S: serde::Serializer,
590 {
591 self.0.serialize(serializer)
592 }
593}
594
595impl<'src, T, C: Eq + Display> Tree<'src, T, C> {
596 pub fn children(&self) -> impl Iterator<Item = &Tree<'src, T, C>> {
598 self.children.iter()
599 }
600
601 pub fn node(&self) -> &TreeNode<'src, T, C> {
603 &self.node
604 }
605
606 pub fn gorn_addresses(&self) -> HashSet<GornIndex> {
608 let mut h = HashSet::default();
609
610 let mut stack = vec![(self, GornIndex::default())];
611
612 while let Some((tree, gorn)) = stack.pop() {
613 h.insert(gorn);
614 stack.extend(tree.children.iter().enumerate().map(|(x, child)| {
615 (
616 child,
617 gorn.clone_push(match x {
618 0 => Direction::Left,
619 1 => Direction::Right,
620 _ => panic!("Trees should always be binary!"),
621 }),
622 )
623 }));
624 }
625
626 h
627 }
628}
629
630#[cfg(test)]
631mod test {
632 use crate::grammars::STABLER2011;
633 use crate::{Lexicon, ParsingConfig, PhonContent};
634 use petgraph::dot::Dot;
635
636 #[test]
637 fn petgraph() -> anyhow::Result<()> {
638 let lex = Lexicon::from_string(STABLER2011)?;
639 let (_, _, r) = lex
640 .parse(
641 &PhonContent::from(["which", "queen", "the", "king", "prefers"]),
642 "C",
643 &ParsingConfig::default(),
644 )?
645 .next()
646 .unwrap();
647
648 let d = lex.derivation(r);
649 let (g, _root) = d.tree().petgraph();
650
651 let s = format!("{}", Dot::new(&g));
652 println!("{s}");
653 assert_eq!(
654 s,
655 "digraph {\n 0 [ label = \"C\" ]\n 1 [ label = \"D -W\" ]\n 2 [ label = \"+W C\" ]\n 3 [ label = \"which::N= D -W\" ]\n 4 [ label = \"queen::N\" ]\n 5 [ label = \"ε::V= +W C\" ]\n 6 [ label = \"V\" ]\n 7 [ label = \"D\" ]\n 8 [ label = \"=D V\" ]\n 9 [ label = \"the::N= D\" ]\n 10 [ label = \"king::N\" ]\n 11 [ label = \"prefers::D= =D V\" ]\n 12 [ label = \"t0\" ]\n 0 -> 1 [ label = \"Merge(Left)\" ]\n 0 -> 2 [ label = \"Merge(Right)\" ]\n 1 -> 3 [ label = \"Merge(Left)\" ]\n 1 -> 4 [ label = \"Merge(Right)\" ]\n 2 -> 5 [ label = \"Merge(Left)\" ]\n 2 -> 6 [ label = \"Merge(Right)\" ]\n 6 -> 7 [ label = \"Merge(Left)\" ]\n 6 -> 8 [ label = \"Merge(Right)\" ]\n 7 -> 9 [ label = \"Merge(Left)\" ]\n 7 -> 10 [ label = \"Merge(Right)\" ]\n 8 -> 11 [ label = \"Merge(Left)\" ]\n 8 -> 12 [ label = \"Merge(Right)\" ]\n 12 -> 1 [ label = \"Move\" ]\n}\n"
656 );
657 Ok(())
658 }
659}