Skip to content

Commit a7c6283

Browse files
committed
yye
Signed-off-by: Joe Isaacs <[email protected]>
1 parent 4c18c2b commit a7c6283

File tree

3 files changed

+29
-89
lines changed

3 files changed

+29
-89
lines changed

vortex-array/src/expr/session/rewrite.rs

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -283,23 +283,22 @@ impl RewriteRuleRegistry {
283283
}
284284

285285
/// Get all untyped parent reduce rules for a given child and parent expression ID pair.
286-
/// Returns both specific parent rules AND wildcard "any parent" rules.
286+
///
287+
/// Returns both specific parent rules and wildcard "any parent" rules.
287288
pub(crate) fn parent_rules_for(
288289
&self,
289290
child_id: &ExprId,
290291
parent_id: &ExprId,
291292
) -> Vec<Arc<dyn DynParentReduceRule>> {
292293
let mut rules = Vec::new();
293294

294-
// Add specific parent rules first
295295
if let Some(specific) = self
296296
.parent_rules
297297
.get(&(child_id.clone(), parent_id.clone()))
298298
{
299299
rules.extend_from_slice(specific);
300300
}
301301

302-
// Add wildcard "any parent" rules
303302
if let Some(wildcard) = self.any_parent_rules.get(child_id) {
304303
rules.extend_from_slice(wildcard);
305304
}
@@ -308,23 +307,22 @@ impl RewriteRuleRegistry {
308307
}
309308

310309
/// Get all the typed parent reduce rules for a given child and parent expression ID pair.
311-
/// Returns both specific parent rules AND wildcard "any parent" rules.
310+
///
311+
/// Returns both specific parent rules and wildcard "any parent" rules.
312312
pub(crate) fn typed_parent_rules_for(
313313
&self,
314314
child_id: &ExprId,
315315
parent_id: &ExprId,
316316
) -> Vec<Arc<dyn DynTypedParentReduceRule>> {
317317
let mut rules = Vec::new();
318318

319-
// Add specific parent rules first
320319
if let Some(specific) = self
321320
.typed_parent_rules
322321
.get(&(child_id.clone(), parent_id.clone()))
323322
{
324323
rules.extend_from_slice(specific);
325324
}
326325

327-
// Add wildcard "any parent" rules
328326
if let Some(wildcard) = self.typed_any_parent_rules.get(child_id) {
329327
rules.extend_from_slice(wildcard);
330328
}

vortex-array/src/expr/transform/rules.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ impl ParentMatcher for AnyParent {
2828
type View<'a> = &'a Expression;
2929

3030
fn try_match(parent: &Expression) -> Option<Self::View<'_>> {
31-
Some(parent) // Always matches!
31+
Some(parent)
3232
}
3333
}
3434

@@ -83,7 +83,7 @@ pub trait ParentReduceRule<Child: VTable, Parent: ParentMatcher, C: RewriteConte
8383
/// # Arguments
8484
/// * `expr` - The expression to potentially rewrite (already downcast to type Child)
8585
/// * `parent` - The parent view (type depends on Parent matcher - typed for specific VTables,
86-
/// untyped `&Expression` for `AnyParent`)
86+
/// untyped `&Expression` for `AnyParent`)
8787
/// * `child_idx` - The index of the child expression within the parent.
8888
/// * `ctx` - Context for the rewrite (dtype, etc.)
8989
///

vortex-array/src/expr/transform/simplify.rs

Lines changed: 23 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -108,40 +108,42 @@ mod tests {
108108
}
109109
}
110110

111-
/// Test rule: removes any literal "1" regardless of parent type (wildcard rule)
112-
struct RemoveOneLiteralRule;
111+
/// Test rule: remove identity 0 + x -> x without matching parent directly (equiv to above).
112+
struct AddZeroRuleAnyParent;
113113

114-
impl ParentReduceRule<Literal, AnyParent, RuleContext> for RemoveOneLiteralRule {
114+
impl ParentReduceRule<Literal, AnyParent, RuleContext> for AddZeroRuleAnyParent {
115115
fn reduce_parent(
116116
&self,
117117
expr: &ExpressionView<Literal>,
118-
parent: &Expression, // ← Untyped! AnyParent gives us &Expression
118+
parent: &Expression,
119119
child_idx: usize,
120120
_ctx: &RuleContext,
121121
) -> VortexResult<Option<Expression>> {
122-
// Check if this literal is 1
123-
let one_scalar = Scalar::from(1i32);
124-
if expr.data() != &one_scalar {
122+
// Only apply if the parent is an Add operation
123+
let Some(parent) = parent.as_opt::<Binary>() else {
124+
return Ok(None);
125+
};
126+
if parent.operator() != Operator::Add {
125127
return Ok(None);
126128
}
127129

128-
// Return the OTHER child from the parent (works for any binary parent)
129-
if parent.children().len() == 2 {
130-
let other_idx = if child_idx == 0 { 1 } else { 0 };
131-
return Ok(Some(parent.child(other_idx).clone()));
130+
// Check if this literal is zero
131+
let zero_scalar = Scalar::from(0i32);
132+
if expr.data() != &zero_scalar {
133+
return Ok(None);
132134
}
133135

134-
Ok(None)
136+
// Return the other child (not this zero)
137+
let other_idx = if child_idx == 0 { 1 } else { 0 };
138+
Ok(Some(parent.child(other_idx).clone()))
135139
}
136140
}
137141

138142
#[test]
139-
fn test_add_zero_parent_rule_basic() {
140-
// Create a session and register the rule (specific parent: Binary)
143+
fn test_add_zero_with_specific_parent_rule() {
141144
let mut session = ExprSession::default();
142145
session.register_parent_rule(&Literal, &Binary, AddZeroRule);
143146

144-
// Test: 0 + x should simplify to x
145147
let x = col("x");
146148
let zero = lit(0);
147149
let expr = checked_add(zero, x.clone());
@@ -152,91 +154,31 @@ mod tests {
152154
}
153155

154156
#[test]
155-
fn test_add_zero_parent_rule_left() {
157+
fn test_add_zero_with_any_parent_rule() {
156158
let mut session = ExprSession::default();
157-
session.register_parent_rule(&Literal, &Binary, AddZeroRule);
159+
session.register_any_parent_rule(&Literal, AddZeroRuleAnyParent);
158160

159-
// Test: 0 + (0 + x) should simplify to 0 + x, then to x
160161
let x = col("x");
161162
let zero = lit(0);
162-
let zero_plus_x = checked_add(lit(0), x.clone());
163-
let expr = checked_add(zero, zero_plus_x);
164-
165-
let result = simplify(expr, &session).unwrap();
166-
167-
assert_eq!(&result, &x);
168-
}
169-
170-
#[test]
171-
fn test_add_zero_parent_rule_right() {
172-
let mut session = ExprSession::default();
173-
session.register_parent_rule(&Literal, &Binary, AddZeroRule);
174-
175-
// Test: x + 0 should simplify to x
176-
let x = col("x");
177-
let zero = lit(0);
178-
let expr = checked_add(x.clone(), zero);
163+
let expr = checked_add(zero, x.clone());
179164

180165
let result = simplify(expr, &session).unwrap();
181166

182167
assert_eq!(&result, &x);
183168
}
184169

185170
#[test]
186-
fn test_add_zero_parent_rule_nested() {
171+
fn test_add_zero_with_both_rules() {
187172
let mut session = ExprSession::default();
188173
session.register_parent_rule(&Literal, &Binary, AddZeroRule);
174+
session.register_any_parent_rule(&Literal, AddZeroRuleAnyParent);
189175

190-
// Test: (0 + x) + 0 should simplify to x
191176
let x = col("x");
192177
let zero = lit(0);
193-
let zero_plus_x = checked_add(lit(0), x.clone());
194-
let expr = checked_add(zero_plus_x, zero);
195-
196-
let result = simplify(expr, &session).unwrap();
197-
198-
assert_eq!(&result, &x);
199-
}
200-
201-
#[test]
202-
fn test_any_parent_wildcard_rule() {
203-
// Test AnyParent - rule works with ANY parent type
204-
let mut session = ExprSession::default();
205-
session.register_any_parent_rule(&Literal, RemoveOneLiteralRule);
206-
207-
// Test: x + 1 should simplify to x (works with Add)
208-
let x = col("x");
209-
let one = lit(1);
210-
let expr = checked_add(x.clone(), one);
178+
let expr = checked_add(zero, x.clone());
211179

212180
let result = simplify(expr, &session).unwrap();
213181

214182
assert_eq!(&result, &x);
215183
}
216-
217-
#[test]
218-
fn test_specific_and_wildcard_rules_together() {
219-
// Test both specific and wildcard rules registered at the same time
220-
let mut session = ExprSession::default();
221-
222-
// Specific rule: removes 0 from Add operations only
223-
session.register_parent_rule(&Literal, &Binary, AddZeroRule);
224-
225-
// Wildcard rule: removes 1 from ANY operation
226-
session.register_any_parent_rule(&Literal, RemoveOneLiteralRule);
227-
228-
// Test 1: 0 + x -> x (specific rule applies)
229-
let x = col("x");
230-
let zero = lit(0);
231-
let expr = checked_add(zero, x.clone());
232-
let result = simplify(expr, &session).unwrap();
233-
assert_eq!(&result, &x);
234-
235-
// Test 2: 1 + y -> y (wildcard rule applies)
236-
let y = col("y");
237-
let one = lit(1);
238-
let expr = checked_add(one, y.clone());
239-
let result = simplify(expr, &session).unwrap();
240-
assert_eq!(&result, &y);
241-
}
242184
}

0 commit comments

Comments
 (0)