Skip to content

Commit 56a05f4

Browse files
committed
reduce
Signed-off-by: Joe Isaacs <[email protected]>
1 parent 5264e5d commit 56a05f4

File tree

6 files changed

+96
-69
lines changed

6 files changed

+96
-69
lines changed

vortex-array/src/expr/session/mod.rs

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -79,13 +79,24 @@ impl ExprSession {
7979
}
8080

8181
/// Register a parent reduce rule in the session.
82-
pub fn register_parent_rule<V: VTable>(
83-
&mut self,
84-
vtable: &'static V,
85-
rule: impl ParentReduceRule<V> + 'static,
86-
) {
82+
pub fn register_parent_rule<V, R>(&mut self, vtable: &'static V, rule: R)
83+
where
84+
V: VTable,
85+
R: 'static,
86+
for<'a> R: ParentReduceRule<V, &'a dyn RewriteContext>,
87+
{
8788
self.rewrite_rules.register_parent_rule(vtable, rule);
8889
}
90+
91+
/// Register a typed parent reduce rule in the session.
92+
pub fn register_typed_parent_rule<V, R>(&mut self, vtable: &'static V, rule: R)
93+
where
94+
V: VTable,
95+
R: 'static,
96+
for<'a> R: ParentReduceRule<V, &'a dyn crate::expr::transform::TypedRewriteContext>,
97+
{
98+
self.rewrite_rules.register_typed_parent_rule(vtable, rule);
99+
}
89100
}
90101

91102
impl Default for ExprSession {

vortex-array/src/expr/session/rewrite.rs

Lines changed: 27 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@ use std::sync::Arc;
77
use vortex_error::VortexResult;
88
use vortex_utils::aliases::hash_map::HashMap;
99

10+
use crate::expr::transform::TypedRewriteContext;
1011
use crate::expr::transform::rules::{ParentReduceRule, ReduceRule, RewriteContext};
11-
use crate::expr::transform::{Context, TypedRewriteContext};
1212
use crate::expr::{ExprId, Expression, VTable};
1313

1414
/// Type-erased wrapper for ReduceRule that allows dynamic dispatch.
@@ -120,16 +120,20 @@ pub(crate) trait DynParentReduceRule: Send + Sync {
120120
}
121121

122122
/// Concrete wrapper that implements DynParentReduceRule for a specific VTable type.
123-
struct ParentReduceRuleAdapter<V: VTable, R: ParentReduceRule<V, C>>
123+
struct ParentReduceRuleAdapter<V, R>
124124
where
125125
V: VTable,
126-
for<'a> R: ReduceRule<V, &'a dyn RewriteContext>,
126+
for<'a> R: ParentReduceRule<V, &'a dyn RewriteContext>,
127127
{
128128
rule: R,
129129
_phantom: PhantomData<V>,
130130
}
131131

132-
impl<V: VTable, R: ParentReduceRule<V>> ParentReduceRuleAdapter<V, R> {
132+
impl<V, R> ParentReduceRuleAdapter<V, R>
133+
where
134+
V: VTable,
135+
for<'a> R: ParentReduceRule<V, &'a dyn RewriteContext>,
136+
{
133137
fn new(rule: R) -> Self {
134138
Self {
135139
rule,
@@ -138,7 +142,11 @@ impl<V: VTable, R: ParentReduceRule<V>> ParentReduceRuleAdapter<V, R> {
138142
}
139143
}
140144

141-
impl<V: VTable, R: ParentReduceRule<V>> DynParentReduceRule for ParentReduceRuleAdapter<V, R> {
145+
impl<V, R> DynParentReduceRule for ParentReduceRuleAdapter<V, R>
146+
where
147+
V: VTable,
148+
for<'a> R: ParentReduceRule<V, &'a dyn RewriteContext>,
149+
{
142150
fn reduce_parent_dyn(
143151
&self,
144152
expr: &Expression,
@@ -200,7 +208,7 @@ where
200208
let Some(view) = expr.as_opt::<V>() else {
201209
return Ok(None);
202210
};
203-
self.rule.reduce(&view, parent, child_idx, ctx)
211+
self.rule.reduce_parent(&view, parent, child_idx, ctx)
204212
}
205213
}
206214

@@ -267,11 +275,12 @@ impl RewriteRuleRegistry {
267275
.push(Arc::new(adapter));
268276
}
269277

270-
pub fn register_parent_rule<V: VTable, R: ParentReduceRule<V> + 'static>(
271-
&mut self,
272-
vtable: &'static V,
273-
rule: R,
274-
) {
278+
pub fn register_parent_rule<V, R>(&mut self, vtable: &'static V, rule: R)
279+
where
280+
V: VTable,
281+
R: 'static,
282+
for<'a> R: ParentReduceRule<V, &'a dyn RewriteContext>,
283+
{
275284
let id = vtable.id();
276285
let adapter = ParentReduceRuleAdapter::new(rule);
277286
self.parent_rules
@@ -281,14 +290,15 @@ impl RewriteRuleRegistry {
281290
}
282291

283292
/// Register a parent reduce rule.
284-
pub fn register_typed_parent_rule<V: VTable, R: ParentReduceRule<V> + 'static>(
285-
&mut self,
286-
vtable: &'static V,
287-
rule: R,
288-
) {
293+
pub fn register_typed_parent_rule<V, R>(&mut self, vtable: &'static V, rule: R)
294+
where
295+
V: VTable,
296+
R: 'static,
297+
for<'a> R: ParentReduceRule<V, &'a dyn TypedRewriteContext>,
298+
{
289299
let id = vtable.id();
290300
let adapter = TypedParentReduceRuleAdapter::new(rule);
291-
self.parent_rules
301+
self.typed_parent_rules
292302
.entry(id)
293303
.or_default()
294304
.push(Arc::new(adapter));

vortex-array/src/expr/transform/rules.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ pub trait ParentReduceRule<V: VTable, C: Context>: Send + Sync {
5959
expr: &ExpressionView<V>,
6060
parent: &Expression,
6161
child_idx: usize,
62-
ctx: &dyn RewriteContext,
62+
ctx: C,
6363
) -> VortexResult<Option<Expression>>;
6464
}
6565

@@ -71,6 +71,9 @@ impl<T: Context + ?Sized> Context for &T {}
7171
/// Base context for rewrite rules.
7272
pub trait RewriteContext: Context {}
7373

74+
// Blanket implementation: all references to RewriteContext implementors also implement RewriteContext
75+
impl<T: RewriteContext + ?Sized> RewriteContext for &T {}
76+
7477
/// Context available to rewrite rules during expression optimization.
7578
/// Extends `RewriteContext` and provides access to dtype information.
7679
///

vortex-array/src/expr/transform/simplify.rs

Lines changed: 17 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ fn apply_parent_rules(
3131
expr.transform_up(|node| {
3232
for (idx, child) in node.children().iter().enumerate() {
3333
for rule in session.rewrite_rules().parent_rules_for(&child.id()) {
34-
if let Some(new_expr) = rule.reduce_parent_dyn(&child, &node, idx, ctx)? {
34+
if let Some(new_expr) = rule.reduce_parent_dyn(child, &node, idx, ctx)? {
3535
return Ok(Transformed::yes(new_expr));
3636
}
3737
}
@@ -46,7 +46,7 @@ pub(crate) fn apply_child_rules_impl(
4646
ctx: &dyn RewriteContext,
4747
session: &ExprSession,
4848
) -> VortexResult<Expression> {
49-
expr.transform_down(|node| {
49+
expr.transform_up(|node| {
5050
for rule in session.rewrite_rules().reduce_rules_for(&node.id()) {
5151
if let Some(new_expr) = rule.reduce_dyn(&node, ctx)? {
5252
return Ok(Transformed::yes(new_expr));
@@ -59,27 +59,27 @@ pub(crate) fn apply_child_rules_impl(
5959

6060
#[cfg(test)]
6161
mod tests {
62+
use vortex_scalar::Scalar;
63+
6264
use super::*;
6365
use crate::expr::exprs::binary::{Binary, checked_add};
6466
use crate::expr::exprs::literal::{Literal, lit};
6567
use crate::expr::exprs::operators::Operator;
6668
use crate::expr::session::ExprSession;
6769
use crate::expr::transform::rules::ParentReduceRule;
68-
use crate::expr::{Expression, ExpressionView};
70+
use crate::expr::{Expression, ExpressionView, col};
6971

7072
/// Test rule: simplifies addition with zero: 0 + x -> x when literal zero is a child of an Add
7173
struct AddZeroRule;
7274

73-
impl ParentReduceRule<Literal> for AddZeroRule {
75+
impl<C: RewriteContext> ParentReduceRule<Literal, C> for AddZeroRule {
7476
fn reduce_parent(
7577
&self,
7678
expr: &ExpressionView<Literal>,
7779
parent: &Expression,
7880
child_idx: usize,
79-
_ctx: &dyn RewriteContext,
81+
_ctx: C,
8082
) -> VortexResult<Option<Expression>> {
81-
use vortex_scalar::Scalar;
82-
8383
// Only apply if the parent is an Add operation
8484
let Some(bin) = parent.as_opt::<Binary>() else {
8585
return Ok(None);
@@ -108,16 +108,13 @@ mod tests {
108108
session.register_parent_rule(&Literal, AddZeroRule);
109109

110110
// Test: 0 + x should simplify to x
111-
let x = lit(5);
111+
let x = col("x");
112112
let zero = lit(0);
113113
let expr = checked_add(zero.clone(), x.clone());
114-
println!("expr {}", expr.display_tree());
115-
println!("expr dbg {:?}", expr);
116114

117-
// let result = simplify(expr, &session).unwrap();
118-
//
119-
// // Should simplify to x (lit(5))
120-
// assert_eq!(&result, &lit(5));
115+
let result = simplify(expr, &session).unwrap();
116+
117+
assert_eq!(&result, &x);
121118
}
122119

123120
#[test]
@@ -126,16 +123,14 @@ mod tests {
126123
session.register_parent_rule(&Literal, AddZeroRule);
127124

128125
// Test: 0 + (0 + x) should simplify to 0 + x, then to x
129-
let x = lit(7);
126+
let x = col("x");
130127
let zero = lit(0);
131128
let zero_plus_x = checked_add(zero.clone(), x.clone());
132129
let expr = checked_add(zero.clone(), zero_plus_x);
133130

134131
let result = simplify(expr, &session).unwrap();
135132

136-
// After first pass: 0 + (x) becomes x + (x) at the inner level
137-
// After second pass: x
138-
assert_eq!(&result, &lit(7));
133+
assert_eq!(&result, &x);
139134
}
140135

141136
#[test]
@@ -144,13 +139,13 @@ mod tests {
144139
session.register_parent_rule(&Literal, AddZeroRule);
145140

146141
// Test: x + 0 should simplify to x
147-
let x = lit(3);
142+
let x = col("x");
148143
let zero = lit(0);
149144
let expr = checked_add(x.clone(), zero.clone());
150145

151146
let result = simplify(expr, &session).unwrap();
152147

153-
assert_eq!(&result, &lit(3));
148+
assert_eq!(&result, &x);
154149
}
155150

156151
#[test]
@@ -159,28 +154,13 @@ mod tests {
159154
session.register_parent_rule(&Literal, AddZeroRule);
160155

161156
// Test: (0 + x) + 0 should simplify to x
162-
let x = lit(9);
157+
let x = col("x");
163158
let zero = lit(0);
164159
let zero_plus_x = checked_add(zero.clone(), x.clone());
165160
let expr = checked_add(zero_plus_x, zero.clone());
166161

167162
let result = simplify(expr, &session).unwrap();
168163

169-
assert_eq!(&result, &lit(9));
170-
}
171-
172-
#[test]
173-
fn test_add_zero_parent_rule_no_match() {
174-
let mut session = ExprSession::default();
175-
session.register_parent_rule(&Literal, AddZeroRule);
176-
177-
// Test: x + y (no zeros) should not simplify
178-
let x = lit(3);
179-
let y = lit(4);
180-
let expr = checked_add(x.clone(), y.clone());
181-
182-
let result = simplify(expr.clone(), &session).unwrap();
183-
184-
assert_eq!(&result, &expr);
164+
assert_eq!(&result, &x);
185165
}
186166
}

vortex-array/src/expr/transform/simplify_typed.rs

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,11 @@ fn apply_child_rules_impl_typed(
3232
ctx: &dyn TypedRewriteContext,
3333
session: &ExprSession,
3434
) -> VortexResult<Expression> {
35-
expr.transform_down(|node| {
35+
fn rewrite(
36+
node: Expression,
37+
ctx: &dyn TypedRewriteContext,
38+
session: &ExprSession,
39+
) -> VortexResult<Transformed<Expression>> {
3640
for rule in session.rewrite_rules().typed_reduce_rules_for(&node.id()) {
3741
if let Some(new_expr) = rule.reduce_dyn_typed(&node, ctx)? {
3842
return Ok(Transformed::yes(new_expr));
@@ -44,7 +48,11 @@ fn apply_child_rules_impl_typed(
4448
}
4549
}
4650
Ok(Transformed::no(node))
47-
})
51+
}
52+
expr.transform(
53+
|node| rewrite(node, ctx, session),
54+
|node| rewrite(node, ctx, session),
55+
)
4856
.map(|t| t.into_inner())
4957
}
5058

@@ -56,12 +64,12 @@ fn apply_parent_rules_impl_typed(
5664
expr.transform_up(|node| {
5765
for (idx, child) in node.children().iter().enumerate() {
5866
for rule in session.rewrite_rules().typed_parent_rules_for(&child.id()) {
59-
if let Some(new_expr) = rule.reduce_parent_dyn_typed(&child, &node, idx, ctx)? {
67+
if let Some(new_expr) = rule.reduce_parent_dyn_typed(child, &node, idx, ctx)? {
6068
return Ok(Transformed::yes(new_expr));
6169
}
6270
}
6371
for rule in session.rewrite_rules().parent_rules_for(&child.id()) {
64-
if let Some(new_expr) = rule.reduce_parent_dyn(&child, &node, idx, ctx)? {
72+
if let Some(new_expr) = rule.reduce_parent_dyn(child, &node, idx, ctx)? {
6573
return Ok(Transformed::yes(new_expr));
6674
}
6775
}

vortex-array/src/expr/traversal/mod.rs

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,7 @@ pub trait NodeExt: Node {
199199
self,
200200
f: F,
201201
) -> VortexResult<Transformed<Self>> {
202-
let mut rewriter = FnRewriter {
202+
let mut rewriter = FnRewriter::<F, F, _> {
203203
f_down: Some(f),
204204
f_up: None,
205205
_data: PhantomData,
@@ -208,12 +208,26 @@ pub trait NodeExt: Node {
208208
self.rewrite(&mut rewriter)
209209
}
210210

211+
fn transform<F, G>(self, down: F, up: G) -> VortexResult<Transformed<Self>>
212+
where
213+
F: FnMut(Self) -> VortexResult<Transformed<Self>>,
214+
G: FnMut(Self) -> VortexResult<Transformed<Self>>,
215+
{
216+
let mut rewriter = FnRewriter {
217+
f_down: Some(down),
218+
f_up: Some(up),
219+
_data: PhantomData,
220+
};
221+
222+
self.rewrite(&mut rewriter)
223+
}
224+
211225
/// A post-order transform
212226
fn transform_up<F: FnMut(Self) -> VortexResult<Transformed<Self>>>(
213227
self,
214228
f: F,
215229
) -> VortexResult<Transformed<Self>> {
216-
let mut rewriter = FnRewriter {
230+
let mut rewriter = FnRewriter::<F, F, _> {
217231
f_down: None,
218232
f_up: Some(f),
219233
_data: PhantomData,
@@ -272,16 +286,17 @@ pub trait NodeExt: Node {
272286

273287
impl<T: Node> NodeExt for T {}
274288

275-
struct FnRewriter<F, T> {
289+
struct FnRewriter<F, G, T> {
276290
f_down: Option<F>,
277-
f_up: Option<F>,
291+
f_up: Option<G>,
278292
_data: PhantomData<T>,
279293
}
280294

281-
impl<F, T> NodeRewriter for FnRewriter<F, T>
295+
impl<F, G, T> NodeRewriter for FnRewriter<F, G, T>
282296
where
283297
T: Node,
284298
F: FnMut(T) -> VortexResult<Transformed<T>>,
299+
G: FnMut(T) -> VortexResult<Transformed<T>>,
285300
{
286301
type NodeTy = T;
287302

0 commit comments

Comments
 (0)