Skip to content

Commit a778204

Browse files
authored
Add expression transformation to replace all occurrences of a node (#3795)
Signed-off-by: Nicholas Gates <[email protected]>
1 parent ee8121d commit a778204

File tree

3 files changed

+74
-1
lines changed

3 files changed

+74
-1
lines changed

vortex-expr/src/exprs/pack.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,7 @@ impl Display for PackExpr {
199199
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
200200
write!(
201201
f,
202-
"pack({{{}}}){}",
202+
"pack({}){}",
203203
self.names
204204
.iter()
205205
.zip(&self.values)

vortex-expr/src/transform/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ pub(crate) mod match_between;
99
pub mod partition;
1010
mod remove_merge;
1111
mod remove_select;
12+
pub mod replace;
1213
pub mod simplify;
1314
pub mod simplify_typed;
1415
pub mod var_partition;
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
// SPDX-FileCopyrightText: Copyright the Vortex contributors
3+
4+
use vortex_error::{VortexExpect, VortexResult};
5+
6+
use crate::traversal::{MutNodeVisitor, Node, TransformResult, TraversalOrder};
7+
use crate::{ExprRef, VortexExpr};
8+
9+
/// Replaces all occurrences of `needle` in the expression `expr` with `replacement`.
10+
pub fn replace(expr: ExprRef, needle: &dyn VortexExpr, replacement: ExprRef) -> ExprRef {
11+
let mut transform = ReplaceVisitor {
12+
needle,
13+
replacement,
14+
};
15+
expr.transform(&mut transform)
16+
.vortex_expect("ReplaceVisitor should not fail")
17+
.into_inner()
18+
}
19+
20+
/// A visitor that replaces occurrences of a specific expression (`needle`) with a replacement
21+
/// expression (`replacement`).
22+
struct ReplaceVisitor<'a> {
23+
needle: &'a dyn VortexExpr,
24+
replacement: ExprRef,
25+
}
26+
27+
impl MutNodeVisitor for ReplaceVisitor<'_> {
28+
type NodeTy = ExprRef;
29+
30+
fn visit_down(&mut self, node: &Self::NodeTy) -> VortexResult<TraversalOrder> {
31+
if self.needle.eq(node.as_ref()) {
32+
// Short-circuit traversal if the needle is found
33+
Ok(TraversalOrder::Skip)
34+
} else {
35+
Ok(TraversalOrder::Continue)
36+
}
37+
}
38+
39+
fn visit_up(&mut self, node: Self::NodeTy) -> VortexResult<TransformResult<Self::NodeTy>> {
40+
if self.needle.eq(node.as_ref()) {
41+
Ok(TransformResult::yes(self.replacement.clone()))
42+
} else {
43+
Ok(TransformResult::no(node))
44+
}
45+
}
46+
}
47+
48+
#[cfg(test)]
49+
mod test {
50+
use vortex_dtype::Nullability::NonNullable;
51+
52+
use super::replace;
53+
use crate::{get_item, lit, pack};
54+
55+
#[test]
56+
fn test_replace_full_tree() {
57+
let e = get_item("b", pack([("a", lit(1)), ("b", lit(2))], NonNullable));
58+
let needle = get_item("b", pack([("a", lit(1)), ("b", lit(2))], NonNullable));
59+
let replacement = lit(42);
60+
let replaced_expr = replace(e, needle.as_ref(), replacement.clone());
61+
assert_eq!(&replaced_expr, &replacement);
62+
}
63+
64+
#[test]
65+
fn test_replace_leaf() {
66+
let e = pack([("a", lit(1)), ("b", lit(2))], NonNullable);
67+
let needle = lit(2);
68+
let replacement = lit(42);
69+
let replaced_expr = replace(e, needle.as_ref(), replacement.clone());
70+
assert_eq!(replaced_expr.to_string(), "pack(a: 1i32, b: 42i32)");
71+
}
72+
}

0 commit comments

Comments
 (0)