|
| 1 | +// SPDX-License-Identifier: Apache-2.0 |
| 2 | +// SPDX-FileCopyrightText: Copyright the Vortex contributors |
| 3 | + |
| 4 | +use std::hash::Hash; |
| 5 | + |
| 6 | +use vortex_error::{VortexExpect, VortexResult}; |
| 7 | +use vortex_utils::aliases::hash_map::HashMap; |
| 8 | +use vortex_utils::aliases::hash_set::HashSet; |
| 9 | + |
| 10 | +use crate::expr::Expression; |
| 11 | +use crate::expr::traversal::{NodeExt, NodeVisitor, TraversalOrder}; |
| 12 | + |
| 13 | +pub trait Annotation: Clone + Hash + Eq {} |
| 14 | + |
| 15 | +impl<A> Annotation for A where A: Clone + Hash + Eq {} |
| 16 | + |
| 17 | +pub trait AnnotationFn: Fn(&Expression) -> Vec<Self::Annotation> { |
| 18 | + type Annotation: Annotation; |
| 19 | +} |
| 20 | + |
| 21 | +impl<A, F> AnnotationFn for F |
| 22 | +where |
| 23 | + A: Annotation, |
| 24 | + F: Fn(&Expression) -> Vec<A>, |
| 25 | +{ |
| 26 | + type Annotation = A; |
| 27 | +} |
| 28 | + |
| 29 | +pub type Annotations<'a, A> = HashMap<&'a Expression, HashSet<A>>; |
| 30 | + |
| 31 | +/// Walk the expression tree and annotate each expression with zero or more annotations. |
| 32 | +/// |
| 33 | +/// Returns a map of each expression to all annotations that any of its descendent (child) |
| 34 | +/// expressions are annotated with. |
| 35 | +/// |
| 36 | +/// This uses a specialized traversal strategy with early termination: |
| 37 | +/// - If a node is directly annotated (non-empty), it uses only those annotations and |
| 38 | +/// **skips traversing its children entirely** |
| 39 | +/// - If a node is not directly annotated (empty), it traverses children and bubbles up |
| 40 | +/// their annotations |
| 41 | +/// |
| 42 | +/// This "skip" behavior makes this function different from [`label_tree`], which always |
| 43 | +/// visits all nodes. Use this when you want to find the "shallowest" matches in a tree. |
| 44 | +/// |
| 45 | +/// Note: This cannot use [`label_tree`] because the early termination (skip) requires |
| 46 | +/// conditional traversal based on the node's direct annotations. |
| 47 | +pub fn descendent_annotation_union_set<A: AnnotationFn>( |
| 48 | + expr: &Expression, |
| 49 | + annotate: A, |
| 50 | +) -> Annotations<'_, A::Annotation> { |
| 51 | + let mut visitor = AnnotationVisitor { |
| 52 | + annotations: Default::default(), |
| 53 | + annotate, |
| 54 | + }; |
| 55 | + expr.accept(&mut visitor).vortex_expect("Infallible"); |
| 56 | + visitor.annotations |
| 57 | +} |
| 58 | + |
| 59 | +struct AnnotationVisitor<'a, A: AnnotationFn> { |
| 60 | + annotations: Annotations<'a, A::Annotation>, |
| 61 | + annotate: A, |
| 62 | +} |
| 63 | + |
| 64 | +impl<'a, A: AnnotationFn> NodeVisitor<'a> for AnnotationVisitor<'a, A> { |
| 65 | + type NodeTy = Expression; |
| 66 | + |
| 67 | + fn visit_down(&mut self, node: &'a Self::NodeTy) -> VortexResult<TraversalOrder> { |
| 68 | + let annotations = (self.annotate)(node); |
| 69 | + if annotations.is_empty() { |
| 70 | + // If the annotate fn returns empty, we do not annotate this node directly. |
| 71 | + // Continue traversing to check children. |
| 72 | + Ok(TraversalOrder::Continue) |
| 73 | + } else { |
| 74 | + // Node is directly annotated - store these annotations and skip children |
| 75 | + self.annotations |
| 76 | + .entry(node) |
| 77 | + .or_default() |
| 78 | + .extend(annotations); |
| 79 | + Ok(TraversalOrder::Skip) |
| 80 | + } |
| 81 | + } |
| 82 | + |
| 83 | + fn visit_up(&mut self, node: &'a Expression) -> VortexResult<TraversalOrder> { |
| 84 | + // Bubble up child annotations to this node |
| 85 | + let child_annotations = node |
| 86 | + .children() |
| 87 | + .iter() |
| 88 | + .filter_map(|c| self.annotations.get(c).cloned()) |
| 89 | + .collect::<Vec<_>>(); |
| 90 | + |
| 91 | + let annotations = self.annotations.entry(node).or_default(); |
| 92 | + child_annotations |
| 93 | + .into_iter() |
| 94 | + .for_each(|ps| annotations.extend(ps.iter().cloned())); |
| 95 | + |
| 96 | + Ok(TraversalOrder::Continue) |
| 97 | + } |
| 98 | +} |
| 99 | + |
| 100 | +// Keep the old name for backwards compatibility |
| 101 | +pub use descendent_annotation_union_set as descendent_annotations; |
0 commit comments