Skip to content

Commit 5264e5d

Browse files
committed
t
Signed-off-by: Joe Isaacs <[email protected]>
1 parent 677c1fe commit 5264e5d

File tree

4 files changed

+105
-215
lines changed

4 files changed

+105
-215
lines changed

vortex-array/src/expr/session/rewrite.rs

Lines changed: 58 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@ use std::sync::Arc;
77
use vortex_error::VortexResult;
88
use vortex_utils::aliases::hash_map::HashMap;
99

10-
use crate::expr::transform::TypedRewriteContext;
1110
use crate::expr::transform::rules::{ParentReduceRule, ReduceRule, RewriteContext};
11+
use crate::expr::transform::{Context, TypedRewriteContext};
1212
use crate::expr::{ExprId, Expression, VTable};
1313

1414
/// Type-erased wrapper for ReduceRule that allows dynamic dispatch.
@@ -120,7 +120,11 @@ pub(crate) trait DynParentReduceRule: Send + Sync {
120120
}
121121

122122
/// Concrete wrapper that implements DynParentReduceRule for a specific VTable type.
123-
struct ParentReduceRuleAdapter<V: VTable, R: ParentReduceRule<V>> {
123+
struct ParentReduceRuleAdapter<V: VTable, R: ParentReduceRule<V, C>>
124+
where
125+
V: VTable,
126+
for<'a> R: ReduceRule<V, &'a dyn RewriteContext>,
127+
{
124128
rule: R,
125129
_phantom: PhantomData<V>,
126130
}
@@ -150,7 +154,7 @@ impl<V: VTable, R: ParentReduceRule<V>> DynParentReduceRule for ParentReduceRule
150154
}
151155

152156
pub(crate) trait DynTypedParentReduceRule: Send + Sync {
153-
fn reduce_parent_dyn(
157+
fn reduce_parent_dyn_typed(
154158
&self,
155159
expr: &Expression,
156160
parent: &Expression,
@@ -159,12 +163,20 @@ pub(crate) trait DynTypedParentReduceRule: Send + Sync {
159163
) -> VortexResult<Option<Expression>>;
160164
}
161165

162-
struct TypedParentReduceRuleAdapter<V: VTable, R: ParentReduceRule<V>> {
166+
struct TypedParentReduceRuleAdapter<V: VTable, R>
167+
where
168+
V: VTable,
169+
for<'a> R: ParentReduceRule<V, &'a dyn TypedRewriteContext>,
170+
{
163171
rule: R,
164172
_phantom: PhantomData<V>,
165173
}
166174

167-
impl<V: VTable, R: ParentReduceRule<V>> TypedParentReduceRuleAdapter<V, R> {
175+
impl<V, R> TypedParentReduceRuleAdapter<V, R>
176+
where
177+
V: VTable,
178+
for<'a> R: ParentReduceRule<V, &'a dyn TypedRewriteContext>,
179+
{
168180
fn new(rule: R) -> Self {
169181
Self {
170182
rule,
@@ -173,10 +185,12 @@ impl<V: VTable, R: ParentReduceRule<V>> TypedParentReduceRuleAdapter<V, R> {
173185
}
174186
}
175187

176-
impl<V: VTable, R: ParentReduceRule<V>> DynTypedParentReduceRule
177-
for TypedParentReduceRuleAdapter<V, R>
188+
impl<V, R> DynTypedParentReduceRule for TypedParentReduceRuleAdapter<V, R>
189+
where
190+
V: VTable,
191+
for<'a> R: ParentReduceRule<V, &'a dyn TypedRewriteContext>,
178192
{
179-
fn reduce_parent_dyn(
193+
fn reduce_parent_dyn_typed(
180194
&self,
181195
expr: &Expression,
182196
parent: &Expression,
@@ -186,7 +200,7 @@ impl<V: VTable, R: ParentReduceRule<V>> DynTypedParentReduceRule
186200
let Some(view) = expr.as_opt::<V>() else {
187201
return Ok(None);
188202
};
189-
self.rule.reduce_parent(&view, parent, child_idx, ctx)
203+
self.rule.reduce(&view, parent, child_idx, ctx)
190204
}
191205
}
192206

@@ -203,7 +217,7 @@ pub struct RewriteRuleRegistry {
203217
/// Parent reduce rules, indexed by expression ID
204218
parent_rules: HashMap<ExprId, Vec<Arc<dyn DynParentReduceRule>>>,
205219
/// Parent reduce rules, indexed by expression ID
206-
typed_parent_rules: HashMap<ExprId, Vec<Arc<dyn DynParentReduceRule>>>,
220+
typed_parent_rules: HashMap<ExprId, Vec<Arc<dyn DynTypedParentReduceRule>>>,
207221
}
208222

209223
impl std::fmt::Debug for RewriteRuleRegistry {
@@ -253,7 +267,6 @@ impl RewriteRuleRegistry {
253267
.push(Arc::new(adapter));
254268
}
255269

256-
/// Register a parent reduce rule.
257270
pub fn register_parent_rule<V: VTable, R: ParentReduceRule<V> + 'static>(
258271
&mut self,
259272
vtable: &'static V,
@@ -267,29 +280,52 @@ impl RewriteRuleRegistry {
267280
.push(Arc::new(adapter));
268281
}
269282

283+
/// Register a parent reduce rule.
284+
pub fn register_typed_parent_rule<V: VTable, R: ParentReduceRule<V> + 'static>(
285+
&mut self,
286+
vtable: &'static V,
287+
rule: R,
288+
) {
289+
let id = vtable.id();
290+
let adapter = TypedParentReduceRuleAdapter::new(rule);
291+
self.parent_rules
292+
.entry(id)
293+
.or_default()
294+
.push(Arc::new(adapter));
295+
}
296+
270297
/// Get all typed reduce rules for a given expression ID.
271-
pub(crate) fn typed_reduce_rules_for(
272-
&self,
273-
id: &ExprId,
274-
) -> Option<&[Arc<dyn DynTypedReduceRule>]> {
275-
self.typed_reduce_rules.get(id).map(|v| v.as_slice())
298+
pub(crate) fn typed_reduce_rules_for(&self, id: &ExprId) -> &[Arc<dyn DynTypedReduceRule>] {
299+
self.typed_reduce_rules
300+
.get(id)
301+
.map(|v| v.as_slice())
302+
.unwrap_or_default()
276303
}
277304

278305
/// Get all untyped reduce rules for a given expression ID.
279-
pub(crate) fn reduce_rules_for(&self, id: &ExprId) -> Option<&[Arc<dyn DynReduceRule>]> {
280-
self.reduce_rules.get(id).map(|v| v.as_slice())
306+
pub(crate) fn reduce_rules_for(&self, id: &ExprId) -> &[Arc<dyn DynReduceRule>] {
307+
self.reduce_rules
308+
.get(id)
309+
.map(|v| v.as_slice())
310+
.unwrap_or_default()
281311
}
282312

283313
/// Get all parent reduce rules for a given expression ID.
284-
pub(crate) fn parent_rules_for(&self, id: &ExprId) -> Option<&[Arc<dyn DynParentReduceRule>]> {
285-
self.parent_rules.get(id).map(|v| v.as_slice())
314+
pub(crate) fn parent_rules_for(&self, id: &ExprId) -> &[Arc<dyn DynParentReduceRule>] {
315+
self.parent_rules
316+
.get(id)
317+
.map(|v| v.as_slice())
318+
.unwrap_or_default()
286319
}
287320

288321
/// Get all the typed parent reduce rules for a given expression ID.
289322
pub(crate) fn typed_parent_rules_for(
290323
&self,
291324
id: &ExprId,
292-
) -> Option<&[Arc<dyn DynParentReduceRule>]> {
293-
self.typed_parent_rules.get(id).map(|v| v.as_slice())
325+
) -> &[Arc<dyn DynTypedParentReduceRule>] {
326+
self.typed_parent_rules
327+
.get(id)
328+
.map(|v| v.as_slice())
329+
.unwrap_or_default()
294330
}
295331
}

vortex-array/src/expr/transform/rules.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ pub trait ReduceRule<V: VTable, C: Context>: Send + Sync {
4242
/// # Type Parameters
4343
/// * `V` - The VTable type this rule applies to. The rule will only be invoked for expressions
4444
/// with this vtable type, providing compile-time type safety.
45-
pub trait ParentReduceRule<V: VTable>: Send + Sync {
45+
pub trait ParentReduceRule<V: VTable, C: Context>: Send + Sync {
4646
/// Try to rewrite an expression based on its parent.
4747
///
4848
/// # Arguments

vortex-array/src/expr/transform/simplify.rs

Lines changed: 18 additions & 113 deletions
Original file line numberDiff line numberDiff line change
@@ -16,140 +16,45 @@ use crate::expr::traversal::{NodeExt, Transformed};
1616
pub fn simplify(e: Expression, session: &ExprSession) -> VortexResult<Expression> {
1717
let ctx = EmptyRewriteContext;
1818

19-
// First bottom-up (child rules)
19+
let e = apply_parent_rules(e, &ctx, session)?;
2020
let e = apply_child_rules_impl(e, &ctx, session)?;
21-
22-
let e = apply_parent_rules_impl(e, &ctx, session)?;
23-
2421
let e = find_between(e);
2522

2623
Ok(e)
2724
}
2825

29-
/// Internal implementation: Apply parent rules in a top-down manner.
30-
pub(crate) fn apply_parent_rules_impl(
31-
expr: Expression,
32-
ctx: &dyn RewriteContext,
33-
session: &ExprSession,
34-
) -> VortexResult<Expression> {
35-
apply_parent_rules_recursive(expr, None, ctx, session)
36-
}
37-
38-
/// Recursive helper for applying parent rules.
39-
///
40-
/// This applies parent rules bottom-up:
41-
/// 1. First recursively process all children
42-
/// 2. Rebuild expression with new children
43-
/// 3. Apply parent rules to each child with the rebuilt parent
44-
/// 4. If any child changes, recursively apply again
45-
fn apply_parent_rules_recursive(
26+
fn apply_parent_rules(
4627
expr: Expression,
47-
_parent: Option<&Expression>,
48-
ctx: &dyn RewriteContext,
49-
session: &ExprSession,
50-
) -> VortexResult<Expression> {
51-
// First, recursively process all children bottom-up
52-
let mut new_children = Vec::with_capacity(expr.children().len());
53-
let mut children_changed = false;
54-
55-
for child in expr.children().iter() {
56-
// Recursively process this child first
57-
let new_child = apply_parent_rules_recursive(child.clone(), Some(&expr), ctx, session)?;
58-
59-
new_children.push(new_child);
60-
}
61-
62-
// Rebuild the expression with new children if any changed
63-
let mut expr = if children_changed {
64-
expr.with_children(new_children)?
65-
} else {
66-
expr
67-
};
68-
69-
// Now apply parent rules to each child using the rebuilt parent
70-
loop {
71-
let mut any_child_changed = false;
72-
let mut updated_children = Vec::with_capacity(expr.children().len());
73-
74-
for (child_idx, child) in expr.children().iter().enumerate() {
75-
// Try to apply parent rules to this child given that expr is its parent
76-
let new_child =
77-
apply_parent_rules_to_child(child.clone(), &expr, child_idx, ctx, session)?;
78-
79-
if child != &new_child {
80-
any_child_changed = true;
81-
}
82-
83-
updated_children.push(new_child);
84-
}
85-
86-
if any_child_changed {
87-
expr = expr.with_children(updated_children)?;
88-
} else {
89-
break;
90-
}
91-
}
92-
93-
Ok(expr)
94-
}
95-
96-
/// Apply parent rules to a child expression given its parent and child index.
97-
fn apply_parent_rules_to_child(
98-
child: Expression,
99-
parent: &Expression,
100-
child_idx: usize,
10128
ctx: &dyn RewriteContext,
10229
session: &ExprSession,
10330
) -> VortexResult<Expression> {
104-
let child_id = child.id();
105-
if let Some(rules) = session.rewrite_rules().parent_rules_for(&child_id) {
106-
let mut current = child;
107-
for rule in rules {
108-
if let Some(new_expr) = rule.reduce_parent_dyn(&current, parent, child_idx, ctx)? {
109-
current = new_expr;
31+
expr.transform_up(|node| {
32+
for (idx, child) in node.children().iter().enumerate() {
33+
for rule in session.rewrite_rules().parent_rules_for(&child.id()) {
34+
if let Some(new_expr) = rule.reduce_parent_dyn(&child, &node, idx, ctx)? {
35+
return Ok(Transformed::yes(new_expr));
36+
}
11037
}
11138
}
112-
Ok(current)
113-
} else {
114-
Ok(child)
115-
}
39+
Ok(Transformed::no(node))
40+
})
41+
.map(|t| t.into_inner())
11642
}
11743

118-
/// Internal implementation: Apply child rules in a bottom-up manner with RewriteContext.
11944
pub(crate) fn apply_child_rules_impl(
12045
expr: Expression,
12146
ctx: &dyn RewriteContext,
12247
session: &ExprSession,
12348
) -> VortexResult<Expression> {
124-
expr.transform_up(|node| apply_reduce_rules_node(node, ctx, session))
125-
.map(|t| t.into_inner())
126-
}
127-
128-
/// Apply child rules to a single node with RewriteContext.
129-
fn apply_reduce_rules_node(
130-
expr: Expression,
131-
ctx: &dyn RewriteContext,
132-
session: &ExprSession,
133-
) -> VortexResult<Transformed<Expression>> {
134-
let expr_id = expr.id();
135-
let mut current = expr;
136-
let mut changed = false;
137-
138-
// Apply untyped generic reduce rules
139-
if let Some(rules) = session.rewrite_rules().reduce_rules_for(&expr_id) {
140-
for rule in rules {
141-
if let Some(new_expr) = rule.reduce_dyn(&current, ctx)? {
142-
current = new_expr;
143-
changed = true;
49+
expr.transform_down(|node| {
50+
for rule in session.rewrite_rules().reduce_rules_for(&node.id()) {
51+
if let Some(new_expr) = rule.reduce_dyn(&node, ctx)? {
52+
return Ok(Transformed::yes(new_expr));
14453
}
14554
}
146-
}
147-
148-
if changed {
149-
Ok(Transformed::yes(current))
150-
} else {
151-
Ok(Transformed::no(current))
152-
}
55+
Ok(Transformed::no(node))
56+
})
57+
.map(|t| t.into_inner())
15358
}
15459

15560
#[cfg(test)]

0 commit comments

Comments
 (0)