Skip to content

Commit 0c84c7d

Browse files
committed
feat[vortex-expr]: reduce_parent has typed parent
Signed-off-by: Joe Isaacs <[email protected]>
1 parent 23967d7 commit 0c84c7d

File tree

5 files changed

+137
-80
lines changed

5 files changed

+137
-80
lines changed

vortex-array/src/expr/session/mod.rs

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -79,23 +79,35 @@ impl ExprSession {
7979
}
8080

8181
/// Register a parent reduce rule in the session.
82-
pub fn register_parent_rule<V, R>(&mut self, vtable: &'static V, rule: R)
83-
where
84-
V: VTable,
82+
pub fn register_parent_rule<Child, Parent, R>(
83+
&mut self,
84+
child_vtable: &'static Child,
85+
parent_vtable: &'static Parent,
86+
rule: R,
87+
) where
88+
Child: VTable,
89+
Parent: VTable,
8590
R: 'static,
86-
R: ParentReduceRule<V, RuleContext>,
91+
R: ParentReduceRule<Child, Parent, RuleContext>,
8792
{
88-
self.rewrite_rules.register_parent_rule(vtable, rule);
93+
self.rewrite_rules
94+
.register_parent_rule(child_vtable, parent_vtable, rule);
8995
}
9096

9197
/// Register a typed parent reduce rule in the session.
92-
pub fn register_typed_parent_rule<V, R>(&mut self, vtable: &'static V, rule: R)
93-
where
94-
V: VTable,
98+
pub fn register_typed_parent_rule<Child, Parent, R>(
99+
&mut self,
100+
child_vtable: &'static Child,
101+
parent_vtable: &'static Parent,
102+
rule: R,
103+
) where
104+
Child: VTable,
105+
Parent: VTable,
95106
R: 'static,
96-
R: ParentReduceRule<V, TypedRuleContext>,
107+
R: ParentReduceRule<Child, Parent, TypedRuleContext>,
97108
{
98-
self.rewrite_rules.register_typed_parent_rule(vtable, rule);
109+
self.rewrite_rules
110+
.register_typed_parent_rule(child_vtable, parent_vtable, rule);
99111
}
100112
}
101113

vortex-array/src/expr/session/rewrite.rs

Lines changed: 86 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -14,22 +14,19 @@ use crate::expr::transform::{
1414
};
1515
use crate::expr::{ExprId, Expression, VTable};
1616

17-
/// Universal adapter for both ReduceRule and ParentReduceRule with any context type.
18-
struct RuleAdapter<V: VTable, R> {
17+
/// Adapter for ReduceRule
18+
struct ReduceRuleAdapter<V: VTable, R> {
1919
rule: R,
2020
_phantom: PhantomData<V>,
2121
}
2222

23-
impl<V: VTable, R> RuleAdapter<V, R> {
24-
fn new(rule: R) -> Self {
25-
Self {
26-
rule,
27-
_phantom: PhantomData,
28-
}
29-
}
23+
/// Adapter for ParentReduceRule
24+
struct ReduceParentRuleAdapter<Child: VTable, Parent: VTable, R> {
25+
rule: R,
26+
_phantom: PhantomData<(Child, Parent)>,
3027
}
3128

32-
impl<V, R> DynReduceRule for RuleAdapter<V, R>
29+
impl<V, R> DynReduceRule for ReduceRuleAdapter<V, R>
3330
where
3431
V: VTable,
3532
R: ReduceRule<V, RuleContext>,
@@ -42,7 +39,7 @@ where
4239
}
4340
}
4441

45-
impl<V, R> DynTypedReduceRule for RuleAdapter<V, R>
42+
impl<V, R> DynTypedReduceRule for ReduceRuleAdapter<V, R>
4643
where
4744
V: VTable,
4845
R: ReduceRule<V, TypedRuleContext>,
@@ -59,10 +56,11 @@ where
5956
}
6057
}
6158

62-
impl<V, R> DynParentReduceRule for RuleAdapter<V, R>
59+
impl<Child, Parent, R> DynParentReduceRule for ReduceParentRuleAdapter<Child, Parent, R>
6360
where
64-
V: VTable,
65-
R: ParentReduceRule<V, RuleContext>,
61+
Child: VTable,
62+
Parent: VTable,
63+
R: ParentReduceRule<Child, Parent, RuleContext>,
6664
{
6765
fn reduce_parent(
6866
&self,
@@ -71,17 +69,21 @@ where
7169
child_idx: usize,
7270
ctx: &RuleContext,
7371
) -> VortexResult<Option<Expression>> {
74-
let Some(view) = expr.as_opt::<V>() else {
72+
let Some(view) = expr.as_opt::<Child>() else {
73+
return Ok(None);
74+
};
75+
let Some(parent_view) = parent.as_opt::<Parent>() else {
7576
return Ok(None);
7677
};
77-
self.rule.reduce_parent(&view, parent, child_idx, ctx)
78+
self.rule.reduce_parent(&view, &parent_view, child_idx, ctx)
7879
}
7980
}
8081

81-
impl<V, R> DynTypedParentReduceRule for RuleAdapter<V, R>
82+
impl<Child, Parent, R> DynTypedParentReduceRule for ReduceParentRuleAdapter<Child, Parent, R>
8283
where
83-
V: VTable,
84-
R: ParentReduceRule<V, TypedRuleContext>,
84+
Child: VTable,
85+
Parent: VTable,
86+
R: ParentReduceRule<Child, Parent, TypedRuleContext>,
8587
{
8688
fn reduce_parent(
8789
&self,
@@ -90,14 +92,18 @@ where
9092
child_idx: usize,
9193
ctx: &TypedRuleContext,
9294
) -> VortexResult<Option<Expression>> {
93-
let Some(view) = expr.as_opt::<V>() else {
95+
let Some(view) = expr.as_opt::<Child>() else {
96+
return Ok(None);
97+
};
98+
let Some(parent_view) = parent.as_opt::<Parent>() else {
9499
return Ok(None);
95100
};
96-
self.rule.reduce_parent(&view, parent, child_idx, ctx)
101+
self.rule.reduce_parent(&view, &parent_view, child_idx, ctx)
97102
}
98103
}
99104

100105
type RuleRegistry<Rule> = HashMap<ExprId, Vec<Arc<Rule>>>;
106+
type ParentRuleRegistry<Rule> = HashMap<(ExprId, ExprId), Vec<Arc<Rule>>>;
101107

102108
/// Registry of expression rewrite rules.
103109
///
@@ -109,10 +115,10 @@ pub struct RewriteRuleRegistry {
109115
typed_reduce_rules: RuleRegistry<dyn DynTypedReduceRule>,
110116
/// Untyped reduce rules (require only RewriteContext), indexed by expression ID
111117
reduce_rules: RuleRegistry<dyn DynReduceRule>,
112-
/// Parent reduce rules, indexed by expression ID
113-
typed_parent_rules: RuleRegistry<dyn DynTypedParentReduceRule>,
114-
/// Parent reduce rules, indexed by expression ID
115-
parent_rules: RuleRegistry<dyn DynParentReduceRule>,
118+
/// Parent reduce rules, indexed by (child_id, parent_id)
119+
typed_parent_rules: ParentRuleRegistry<dyn DynTypedParentReduceRule>,
120+
/// Parent reduce rules, indexed by (child_id, parent_id)
121+
parent_rules: ParentRuleRegistry<dyn DynParentReduceRule>,
116122
}
117123

118124
// TODO(joe): follow up with rule debug info.
@@ -140,7 +146,10 @@ impl RewriteRuleRegistry {
140146
R: 'static,
141147
R: ReduceRule<V, TypedRuleContext>,
142148
{
143-
let adapter = RuleAdapter::new(rule);
149+
let adapter = ReduceRuleAdapter {
150+
rule,
151+
_phantom: PhantomData,
152+
};
144153
self.typed_reduce_rules
145154
.entry(vtable.id())
146155
.or_default()
@@ -155,36 +164,62 @@ impl RewriteRuleRegistry {
155164
R: 'static,
156165
R: ReduceRule<V, RuleContext>,
157166
{
158-
let adapter = RuleAdapter::new(rule);
167+
let adapter = ReduceRuleAdapter {
168+
rule,
169+
_phantom: PhantomData,
170+
};
159171
self.reduce_rules
160172
.entry(vtable.id())
161173
.or_default()
162174
.push(Arc::new(adapter));
163175
}
164176

165-
pub fn register_parent_rule<V, R>(&mut self, vtable: &'static V, rule: R)
166-
where
167-
V: VTable,
177+
pub fn register_parent_rule<Child, Parent, R>(
178+
&mut self,
179+
child_vtable: &'static Child,
180+
parent_vtable: &'static Parent,
181+
rule: R,
182+
) where
183+
Child: VTable,
184+
Parent: VTable,
168185
R: 'static,
169-
R: ParentReduceRule<V, RuleContext>,
186+
R: ParentReduceRule<Child, Parent, RuleContext>,
170187
{
171-
let adapter = RuleAdapter::new(rule);
188+
let adapter = ReduceParentRuleAdapter {
189+
rule,
190+
_phantom: PhantomData,
191+
};
172192
self.parent_rules
173-
.entry(vtable.id())
193+
.entry((child_vtable.id(), parent_vtable.id()))
174194
.or_default()
175195
.push(Arc::new(adapter));
176196
}
177197

178-
/// Register a parent reduce rule.
179-
pub fn register_typed_parent_rule<V, R>(&mut self, vtable: &'static V, rule: R)
180-
where
181-
V: VTable,
198+
/// Register a typed parent reduce rule.
199+
///
200+
/// # Type Parameters
201+
/// * `Child` - The child expression VTable type
202+
/// * `Parent` - The parent expression VTable type
203+
/// * `R` - The rule implementation
204+
///
205+
/// The rule will only be invoked when both the child has type Child and the parent has type Parent.
206+
pub fn register_typed_parent_rule<Child, Parent, R>(
207+
&mut self,
208+
child_vtable: &'static Child,
209+
parent_vtable: &'static Parent,
210+
rule: R,
211+
) where
212+
Child: VTable,
213+
Parent: VTable,
182214
R: 'static,
183-
R: ParentReduceRule<V, TypedRuleContext>,
215+
R: ParentReduceRule<Child, Parent, TypedRuleContext>,
184216
{
185-
let adapter = RuleAdapter::new(rule);
217+
let adapter = ReduceParentRuleAdapter {
218+
rule,
219+
_phantom: PhantomData,
220+
};
186221
self.typed_parent_rules
187-
.entry(vtable.id())
222+
.entry((child_vtable.id(), parent_vtable.id()))
188223
.or_default()
189224
.push(Arc::new(adapter));
190225
}
@@ -205,21 +240,26 @@ impl RewriteRuleRegistry {
205240
.unwrap_or_default()
206241
}
207242

208-
/// Get all untyped parent reduce rules for a given expression ID.
209-
pub(crate) fn parent_rules_for(&self, id: &ExprId) -> &[Arc<dyn DynParentReduceRule>] {
243+
/// Get all untyped parent reduce rules for a given child and parent expression ID pair.
244+
pub(crate) fn parent_rules_for(
245+
&self,
246+
child_id: &ExprId,
247+
parent_id: &ExprId,
248+
) -> &[Arc<dyn DynParentReduceRule>] {
210249
self.parent_rules
211-
.get(id)
250+
.get(&(child_id.clone(), parent_id.clone()))
212251
.map(|v| v.as_slice())
213252
.unwrap_or_default()
214253
}
215254

216-
/// Get all the typed parent reduce rules for a given expression ID.
255+
/// Get all the typed parent reduce rules for a given child and parent expression ID pair.
217256
pub(crate) fn typed_parent_rules_for(
218257
&self,
219-
id: &ExprId,
258+
child_id: &ExprId,
259+
parent_id: &ExprId,
220260
) -> &[Arc<dyn DynTypedParentReduceRule>] {
221261
self.typed_parent_rules
222-
.get(id)
262+
.get(&(child_id.clone(), parent_id.clone()))
223263
.map(|v| v.as_slice())
224264
.unwrap_or_default()
225265
}

vortex-array/src/expr/transform/rules.rs

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -40,14 +40,17 @@ pub trait ReduceRule<V: VTable, C: RewriteContext>: Send + Sync {
4040
/// Note: This rule is only called for non-root expressions (i.e., when there is a parent).
4141
///
4242
/// # Type Parameters
43-
/// * `V` - The VTable type this rule applies to. The rule will only be invoked for expressions
44-
/// with this vtable type, providing compile-time type safety.
45-
pub trait ParentReduceRule<V: VTable, C: RewriteContext>: Send + Sync {
43+
/// * `Child` - The VTable type this rule applies to (the child expression type). The rule will only
44+
/// be invoked for expressions with this vtable type, providing compile-time type safety.
45+
/// * `Parent` - The VTable type of the parent expression. The rule will only be invoked when
46+
/// the parent has this vtable type, providing compile-time type safety.
47+
/// * `C` - The rewrite context type (RuleContext or TypedRuleContext)
48+
pub trait ParentReduceRule<Child: VTable, Parent: VTable, C: RewriteContext>: Send + Sync {
4649
/// Try to rewrite an expression based on its parent.
4750
///
4851
/// # Arguments
49-
/// * `expr` - The expression to potentially rewrite (already downcast to type V)
50-
/// * `parent` - The parent expression (always present - rule not called for root)
52+
/// * `expr` - The expression to potentially rewrite (already downcast to type Child)
53+
/// * `parent` - The parent expression (already downcast to type Parent)
5154
/// * `child_idx` - The index of the child expression within the parent.
5255
/// * `ctx` - Context for the rewrite (dtype, etc.)
5356
///
@@ -56,8 +59,8 @@ pub trait ParentReduceRule<V: VTable, C: RewriteContext>: Send + Sync {
5659
/// * `None` if the rule does not apply
5760
fn reduce_parent(
5861
&self,
59-
expr: &ExpressionView<V>,
60-
parent: &Expression,
62+
expr: &ExpressionView<Child>,
63+
parent: &ExpressionView<Parent>,
6164
child_idx: usize,
6265
ctx: &C,
6366
) -> VortexResult<Option<Expression>>;
@@ -94,9 +97,6 @@ impl TypedRuleContext {
9497
impl private::Sealed for TypedRuleContext {}
9598
impl RewriteContext for TypedRuleContext {}
9699

97-
impl private::Sealed for &TypedRuleContext {}
98-
impl RewriteContext for &TypedRuleContext {}
99-
100100
/// A context for rewrite rules that don't need dtype information.
101101
#[derive(Debug, Clone, Copy, Default)]
102102
pub struct RuleContext;

0 commit comments

Comments
 (0)