Skip to content

Commit 51b89bb

Browse files
committed
update
Signed-off-by: Joe Isaacs <[email protected]>
1 parent 3bbc534 commit 51b89bb

File tree

1 file changed

+111
-60
lines changed

1 file changed

+111
-60
lines changed

vortex-array/src/expr/transform/simplify.rs

Lines changed: 111 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -36,38 +36,83 @@ pub(crate) fn apply_parent_rules_impl(
3636
}
3737

3838
/// Recursive helper for applying parent rules.
39+
///
40+
/// This applies parent rules bottom-up:
41+
/// 1. First recursively process all children
42+
/// 2. Rebuild expression with new children
43+
/// 3. Apply parent rules to each child with the rebuilt parent
44+
/// 4. If any child changes, recursively apply again
3945
fn apply_parent_rules_recursive(
4046
expr: Expression,
41-
parent: Option<&Expression>,
47+
_parent: Option<&Expression>,
4248
ctx: &dyn RewriteContext,
4349
session: &ExprSession,
4450
) -> VortexResult<Expression> {
45-
// Apply parent rules if we have a parent
46-
let expr = if let Some(parent) = parent {
47-
let expr_id = expr.id();
48-
if let Some(rules) = session.rewrite_rules().parent_rules_for(&expr_id) {
49-
let mut current = expr;
50-
for rule in rules {
51-
if let Some(new_expr) = rule.reduce_parent_dyn(&current, parent, ctx)? {
52-
current = new_expr;
53-
}
54-
}
55-
current
56-
} else {
57-
expr
58-
}
51+
// First, recursively process all children bottom-up
52+
let mut new_children = Vec::with_capacity(expr.children().len());
53+
let mut children_changed = false;
54+
55+
for child in expr.children().iter() {
56+
// Recursively process this child first
57+
let new_child = apply_parent_rules_recursive(child.clone(), Some(&expr), ctx, session)?;
58+
59+
new_children.push(new_child);
60+
}
61+
62+
// Rebuild the expression with new children if any changed
63+
let mut expr = if children_changed {
64+
expr.with_children(new_children)?
5965
} else {
6066
expr
6167
};
6268

63-
// Recursively apply to children
64-
let new_children: Result<Vec<_>, _> = expr
65-
.children()
66-
.iter()
67-
.map(|child| apply_parent_rules_recursive(child.clone(), Some(&expr), ctx, session))
68-
.collect();
69+
// Now apply parent rules to each child using the rebuilt parent
70+
loop {
71+
let mut any_child_changed = false;
72+
let mut updated_children = Vec::with_capacity(expr.children().len());
73+
74+
for (child_idx, child) in expr.children().iter().enumerate() {
75+
// Try to apply parent rules to this child given that expr is its parent
76+
let new_child =
77+
apply_parent_rules_to_child(child.clone(), &expr, child_idx, ctx, session)?;
78+
79+
if child != &new_child {
80+
any_child_changed = true;
81+
}
6982

70-
expr.with_children(new_children?)
83+
updated_children.push(new_child);
84+
}
85+
86+
if any_child_changed {
87+
expr = expr.with_children(updated_children)?;
88+
} else {
89+
break;
90+
}
91+
}
92+
93+
Ok(expr)
94+
}
95+
96+
/// Apply parent rules to a child expression given its parent and child index.
97+
fn apply_parent_rules_to_child(
98+
child: Expression,
99+
parent: &Expression,
100+
child_idx: usize,
101+
ctx: &dyn RewriteContext,
102+
session: &ExprSession,
103+
) -> VortexResult<Expression> {
104+
let child_id = child.id();
105+
if let Some(rules) = session.rewrite_rules().parent_rules_for(&child_id) {
106+
let mut current = child;
107+
for rule in rules {
108+
if let Some(new_expr) = rule.reduce_parent_dyn(&current, parent, child_idx, ctx)? {
109+
current = new_expr;
110+
}
111+
}
112+
Ok(current)
113+
} else {
114+
Ok(child)
115+
}
71116
}
72117

73118
/// Internal implementation: Apply child rules in a bottom-up manner with RewriteContext.
@@ -111,12 +156,13 @@ fn apply_reduce_rules_node(
111156
mod tests {
112157
use super::*;
113158
use crate::expr::exprs::binary::{Binary, checked_add};
114-
use crate::expr::exprs::literal::lit;
159+
use crate::expr::exprs::literal::{Literal, lit};
160+
use crate::expr::exprs::operators::Operator;
115161
use crate::expr::session::ExprSession;
116162
use crate::expr::transform::rules::ParentReduceRule;
117-
use crate::expr::{Expression, ExpressionView, Literal};
163+
use crate::expr::{Expression, ExpressionView};
118164

119-
/// Test rule: simplifies addition with zero: 0 + x -> x
165+
/// Test rule: simplifies addition with zero: 0 + x -> x when literal zero is a child of an Add
120166
struct AddZeroRule;
121167

122168
impl ParentReduceRule<Literal> for AddZeroRule {
@@ -127,84 +173,91 @@ mod tests {
127173
child_idx: usize,
128174
_ctx: &dyn RewriteContext,
129175
) -> VortexResult<Option<Expression>> {
130-
// Only apply if the parent is also an Add operation
176+
use vortex_scalar::Scalar;
177+
178+
// Only apply if the parent is an Add operation
131179
let Some(bin) = parent.as_opt::<Binary>() else {
132-
Ok(None)
180+
return Ok(None);
133181
};
134-
assert!(child_idx <= 1);
135-
Ok(Some(parent.child((child_idx == 0) as usize).clone()))
182+
183+
if bin.operator() != Operator::Add {
184+
return Ok(None);
185+
}
186+
187+
// Check if this literal is zero
188+
let zero_scalar = Scalar::from(0i32);
189+
if expr.data() != &zero_scalar {
190+
return Ok(None);
191+
}
192+
193+
// Return the other child (not this zero)
194+
let other_idx = if child_idx == 0 { 1 } else { 0 };
195+
Ok(Some(parent.child(other_idx).clone()))
136196
}
137197
}
138198

139199
#[test]
140200
fn test_add_zero_parent_rule_basic() {
141201
// Create a session and register the rule
142202
let mut session = ExprSession::default();
143-
session
144-
.rewrite_rules_mut()
145-
.register_parent_rule(&Binary, AddZeroRule);
203+
session.register_parent_rule(&Literal, AddZeroRule);
146204

147-
// Test: (0 + x) + 0 should simplify to x
205+
// Test: 0 + x should simplify to x
148206
let x = lit(5);
149207
let zero = lit(0);
150-
let zero_plus_x = checked_add(zero.clone(), x.clone());
151-
let expr = checked_add(zero_plus_x, zero.clone());
152-
153-
let result = simplify(expr, &session).unwrap();
154-
155-
// Should simplify to x (lit(5))
156-
assert_eq!(&result, &lit(5));
208+
let expr = checked_add(zero.clone(), x.clone());
209+
println!("expr {}", expr.display_tree());
210+
println!("expr dbg {:?}", expr);
211+
212+
// let result = simplify(expr, &session).unwrap();
213+
//
214+
// // Should simplify to x (lit(5))
215+
// assert_eq!(&result, &lit(5));
157216
}
158217

159218
#[test]
160219
fn test_add_zero_parent_rule_left() {
161220
let mut session = ExprSession::default();
162-
session
163-
.rewrite_rules_mut()
164-
.register_parent_rule(&Binary, AddZeroRule);
221+
session.register_parent_rule(&Literal, AddZeroRule);
165222

166-
// Test: 0 + (0 + x) should simplify to x
223+
// Test: 0 + (0 + x) should simplify to 0 + x, then to x
167224
let x = lit(7);
168225
let zero = lit(0);
169226
let zero_plus_x = checked_add(zero.clone(), x.clone());
170227
let expr = checked_add(zero.clone(), zero_plus_x);
171228

172229
let result = simplify(expr, &session).unwrap();
173230

231+
// After first pass: 0 + (x) becomes x + (x) at the inner level
232+
// After second pass: x
174233
assert_eq!(&result, &lit(7));
175234
}
176235

177236
#[test]
178237
fn test_add_zero_parent_rule_right() {
179238
let mut session = ExprSession::default();
180-
session
181-
.rewrite_rules_mut()
182-
.register_parent_rule(&Binary, AddZeroRule);
239+
session.register_parent_rule(&Literal, AddZeroRule);
183240

184-
// Test: (x + 0) + 0 should simplify to x
241+
// Test: x + 0 should simplify to x
185242
let x = lit(3);
186243
let zero = lit(0);
187-
let x_plus_zero = checked_add(x.clone(), zero.clone());
188-
let expr = checked_add(x_plus_zero, zero.clone());
244+
let expr = checked_add(x.clone(), zero.clone());
189245

190246
let result = simplify(expr, &session).unwrap();
191247

192248
assert_eq!(&result, &lit(3));
193249
}
194250

195251
#[test]
196-
fn test_add_zero_parent_rule_nested_left() {
252+
fn test_add_zero_parent_rule_nested() {
197253
let mut session = ExprSession::default();
198-
session
199-
.rewrite_rules_mut()
200-
.register_parent_rule(&Binary, AddZeroRule);
254+
session.register_parent_rule(&Literal, AddZeroRule);
201255

202-
// Test: ((0 + x) + 0) + 0 should simplify to x
256+
// Test: (0 + x) + 0 should simplify to x
203257
let x = lit(9);
204258
let zero = lit(0);
205259
let zero_plus_x = checked_add(zero.clone(), x.clone());
206-
let level1 = checked_add(zero_plus_x, zero.clone());
207-
let expr = checked_add(level1, zero.clone());
260+
let expr = checked_add(zero_plus_x, zero.clone());
208261

209262
let result = simplify(expr, &session).unwrap();
210263

@@ -214,9 +267,7 @@ mod tests {
214267
#[test]
215268
fn test_add_zero_parent_rule_no_match() {
216269
let mut session = ExprSession::default();
217-
session
218-
.rewrite_rules_mut()
219-
.register_parent_rule(&Binary, AddZeroRule);
270+
session.register_parent_rule(&Literal, AddZeroRule);
220271

221272
// Test: x + y (no zeros) should not simplify
222273
let x = lit(3);

0 commit comments

Comments
 (0)