Skip to content

Commit da256c6

Browse files
committed
wip
Signed-off-by: Joe Isaacs <[email protected]>
1 parent 7f8ec35 commit da256c6

File tree

5 files changed

+237
-99
lines changed

5 files changed

+237
-99
lines changed

vortex-array/src/expr/session.rs

Lines changed: 181 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
use std::sync::Arc;
55

6+
use vortex_error::VortexResult;
67
use vortex_session::registry::Registry;
78
use vortex_session::{Ref, SessionExt};
89
use vortex_utils::aliases::hash_map::HashMap;
@@ -22,23 +23,142 @@ use crate::expr::exprs::root::Root;
2223
use crate::expr::exprs::select::Select;
2324
use crate::expr::transform::remove_select::RemoveSelectRule;
2425
use crate::expr::transform::simplify::PackGetItemRule;
25-
use crate::expr::transform::traits::{ChildReduceRule, ParentReduceRule, ReduceRule};
26-
use crate::expr::{ExprId, ExprVTable};
26+
use crate::expr::transform::traits::{
27+
ChildReduceRule, ParentReduceRule, ReduceRule, RewriteContext,
28+
};
29+
use crate::expr::{ExprId, ExprVTable, Expression, VTable};
2730

2831
/// Registry of expression vtables.
2932
pub type ExprRegistry = Registry<ExprVTable>;
3033

34+
/// Type-erased wrapper for ReduceRule that allows dynamic dispatch.
35+
pub(crate) trait DynReduceRule: Send + Sync {
36+
fn reduce_dyn(
37+
&self,
38+
expr: &Expression,
39+
ctx: &dyn RewriteContext,
40+
) -> VortexResult<Option<Expression>>;
41+
}
42+
43+
/// Concrete wrapper that implements DynReduceRule for a specific VTable type.
44+
struct ReduceRuleAdapter<V: VTable, R: ReduceRule<V>> {
45+
rule: R,
46+
_phantom: std::marker::PhantomData<V>,
47+
}
48+
49+
impl<V: VTable, R: ReduceRule<V>> ReduceRuleAdapter<V, R> {
50+
fn new(rule: R) -> Self {
51+
Self {
52+
rule,
53+
_phantom: std::marker::PhantomData,
54+
}
55+
}
56+
}
57+
58+
impl<V: VTable, R: ReduceRule<V>> DynReduceRule for ReduceRuleAdapter<V, R> {
59+
fn reduce_dyn(
60+
&self,
61+
expr: &Expression,
62+
ctx: &dyn RewriteContext,
63+
) -> VortexResult<Option<Expression>> {
64+
let Some(view) = expr.as_opt::<V>() else {
65+
return Ok(None);
66+
};
67+
self.rule.reduce(&view, ctx)
68+
}
69+
}
70+
71+
/// Type-erased wrapper for ChildReduceRule that allows dynamic dispatch.
72+
pub(crate) trait DynChildReduceRule: Send + Sync {
73+
fn reduce_child_dyn(
74+
&self,
75+
expr: &Expression,
76+
child: &Expression,
77+
child_idx: usize,
78+
ctx: &dyn RewriteContext,
79+
) -> VortexResult<Option<Expression>>;
80+
}
81+
82+
/// Concrete wrapper that implements DynChildReduceRule for a specific VTable type.
83+
struct ChildReduceRuleAdapter<V: VTable, R: ChildReduceRule<V>> {
84+
rule: R,
85+
_phantom: std::marker::PhantomData<V>,
86+
}
87+
88+
impl<V: VTable, R: ChildReduceRule<V>> ChildReduceRuleAdapter<V, R> {
89+
fn new(rule: R) -> Self {
90+
Self {
91+
rule,
92+
_phantom: std::marker::PhantomData,
93+
}
94+
}
95+
}
96+
97+
impl<V: VTable, R: ChildReduceRule<V>> DynChildReduceRule for ChildReduceRuleAdapter<V, R> {
98+
fn reduce_child_dyn(
99+
&self,
100+
expr: &Expression,
101+
child: &Expression,
102+
child_idx: usize,
103+
ctx: &dyn RewriteContext,
104+
) -> VortexResult<Option<Expression>> {
105+
let Some(view) = expr.as_opt::<V>() else {
106+
return Ok(None);
107+
};
108+
self.rule.reduce_child(&view, child, child_idx, ctx)
109+
}
110+
}
111+
112+
/// Type-erased wrapper for ParentReduceRule that allows dynamic dispatch.
113+
pub(crate) trait DynParentReduceRule: Send + Sync {
114+
fn reduce_parent_dyn(
115+
&self,
116+
expr: &Expression,
117+
parent: &Expression,
118+
ctx: &dyn RewriteContext,
119+
) -> VortexResult<Option<Expression>>;
120+
}
121+
122+
/// Concrete wrapper that implements DynParentReduceRule for a specific VTable type.
123+
struct ParentReduceRuleAdapter<V: VTable, R: ParentReduceRule<V>> {
124+
rule: R,
125+
_phantom: std::marker::PhantomData<V>,
126+
}
127+
128+
impl<V: VTable, R: ParentReduceRule<V>> ParentReduceRuleAdapter<V, R> {
129+
fn new(rule: R) -> Self {
130+
Self {
131+
rule,
132+
_phantom: std::marker::PhantomData,
133+
}
134+
}
135+
}
136+
137+
impl<V: VTable, R: ParentReduceRule<V>> DynParentReduceRule for ParentReduceRuleAdapter<V, R> {
138+
fn reduce_parent_dyn(
139+
&self,
140+
expr: &Expression,
141+
parent: &Expression,
142+
ctx: &dyn RewriteContext,
143+
) -> VortexResult<Option<Expression>> {
144+
let Some(view) = expr.as_opt::<V>() else {
145+
return Ok(None);
146+
};
147+
self.rule.reduce_parent(&view, parent, ctx)
148+
}
149+
}
150+
31151
/// Registry of expression rewrite rules.
32152
///
33153
/// Stores rewrite rules indexed by the expression ID they apply to.
34154
#[derive(Default)]
35155
pub struct RewriteRuleRegistry {
36156
/// Generic reduce rules (no context needed), indexed by expression ID
37-
reduce_rules: HashMap<ExprId, Vec<Arc<dyn ReduceRule>>>,
157+
reduce_rules: HashMap<ExprId, Vec<Arc<dyn DynReduceRule>>>,
38158
/// Child reduce rules, indexed by expression ID
39-
child_rules: HashMap<ExprId, Vec<Arc<dyn ChildReduceRule>>>,
159+
child_rules: HashMap<ExprId, Vec<Arc<dyn DynChildReduceRule>>>,
40160
/// Parent reduce rules, indexed by expression ID
41-
parent_rules: HashMap<ExprId, Vec<Arc<dyn ParentReduceRule>>>,
161+
parent_rules: HashMap<ExprId, Vec<Arc<dyn DynParentReduceRule>>>,
42162
}
43163

44164
impl std::fmt::Debug for RewriteRuleRegistry {
@@ -57,35 +177,59 @@ impl RewriteRuleRegistry {
57177
}
58178

59179
/// Register a generic reduce rule.
60-
pub fn register_reduce_rule(&mut self, rule: Arc<dyn ReduceRule>) {
61-
let id = rule.id();
62-
self.reduce_rules.entry(id).or_default().push(rule);
180+
pub fn register_reduce_rule<V: VTable, R: ReduceRule<V> + 'static>(
181+
&mut self,
182+
vtable: &'static V,
183+
rule: R,
184+
) {
185+
let id = vtable.id();
186+
let adapter = ReduceRuleAdapter::new(rule);
187+
self.reduce_rules
188+
.entry(id)
189+
.or_default()
190+
.push(Arc::new(adapter));
63191
}
64192

65193
/// Register a child reduce rule.
66-
pub fn register_child_rule(&mut self, rule: Arc<dyn ChildReduceRule>) {
67-
let id = rule.id();
68-
self.child_rules.entry(id).or_default().push(rule);
194+
pub fn register_child_rule<V: VTable, R: ChildReduceRule<V> + 'static>(
195+
&mut self,
196+
vtable: &'static V,
197+
rule: R,
198+
) {
199+
let id = vtable.id();
200+
let adapter = ChildReduceRuleAdapter::new(rule);
201+
self.child_rules
202+
.entry(id)
203+
.or_default()
204+
.push(Arc::new(adapter));
69205
}
70206

71207
/// Register a parent reduce rule.
72-
pub fn register_parent_rule(&mut self, rule: Arc<dyn ParentReduceRule>) {
73-
let id = rule.id();
74-
self.parent_rules.entry(id).or_default().push(rule);
208+
pub fn register_parent_rule<V: VTable, R: ParentReduceRule<V> + 'static>(
209+
&mut self,
210+
vtable: &'static V,
211+
rule: R,
212+
) {
213+
let id = vtable.id();
214+
let adapter = ParentReduceRuleAdapter::new(rule);
215+
self.parent_rules
216+
.entry(id)
217+
.or_default()
218+
.push(Arc::new(adapter));
75219
}
76220

77221
/// Get all generic reduce rules for a given expression ID.
78-
pub fn reduce_rules_for(&self, id: &ExprId) -> Option<&[Arc<dyn ReduceRule>]> {
222+
pub(crate) fn reduce_rules_for(&self, id: &ExprId) -> Option<&[Arc<dyn DynReduceRule>]> {
79223
self.reduce_rules.get(id).map(|v| v.as_slice())
80224
}
81225

82226
/// Get all child reduce rules for a given expression ID.
83-
pub fn child_rules_for(&self, id: &ExprId) -> Option<&[Arc<dyn ChildReduceRule>]> {
227+
pub(crate) fn child_rules_for(&self, id: &ExprId) -> Option<&[Arc<dyn DynChildReduceRule>]> {
84228
self.child_rules.get(id).map(|v| v.as_slice())
85229
}
86230

87231
/// Get all parent reduce rules for a given expression ID.
88-
pub fn parent_rules_for(&self, id: &ExprId) -> Option<&[Arc<dyn ParentReduceRule>]> {
232+
pub(crate) fn parent_rules_for(&self, id: &ExprId) -> Option<&[Arc<dyn DynParentReduceRule>]> {
89233
self.parent_rules.get(id).map(|v| v.as_slice())
90234
}
91235
}
@@ -118,18 +262,30 @@ impl ExprSession {
118262
}
119263

120264
/// Register a generic reduce rule in the session.
121-
pub fn register_reduce_rule(&mut self, rule: impl ReduceRule + 'static) {
122-
self.rewrite_rules.register_reduce_rule(Arc::new(rule));
265+
pub fn register_reduce_rule<V: VTable>(
266+
&mut self,
267+
vtable: &'static V,
268+
rule: impl ReduceRule<V> + 'static,
269+
) {
270+
self.rewrite_rules.register_reduce_rule(vtable, rule);
123271
}
124272

125273
/// Register a child reduce rule in the session.
126-
pub fn register_child_rule(&mut self, rule: impl ChildReduceRule + 'static) {
127-
self.rewrite_rules.register_child_rule(Arc::new(rule));
274+
pub fn register_child_rule<V: VTable>(
275+
&mut self,
276+
vtable: &'static V,
277+
rule: impl ChildReduceRule<V> + 'static,
278+
) {
279+
self.rewrite_rules.register_child_rule(vtable, rule);
128280
}
129281

130282
/// Register a parent reduce rule in the session.
131-
pub fn register_parent_rule(&mut self, rule: impl ParentReduceRule + 'static) {
132-
self.rewrite_rules.register_parent_rule(Arc::new(rule));
283+
pub fn register_parent_rule<V: VTable>(
284+
&mut self,
285+
vtable: &'static V,
286+
rule: impl ParentReduceRule<V> + 'static,
287+
) {
288+
self.rewrite_rules.register_parent_rule(vtable, rule);
133289
}
134290
}
135291

@@ -156,8 +312,8 @@ impl Default for ExprSession {
156312

157313
// Register built-in rewrite rules
158314
let mut rewrite_rules = RewriteRuleRegistry::new();
159-
rewrite_rules.register_reduce_rule(Arc::new(RemoveSelectRule));
160-
rewrite_rules.register_child_rule(Arc::new(PackGetItemRule));
315+
rewrite_rules.register_reduce_rule(&Select, RemoveSelectRule);
316+
rewrite_rules.register_child_rule(&GetItem, PackGetItemRule);
161317

162318
Self {
163319
registry: expressions,

vortex-array/src/expr/transform/remove_select.rs

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ use crate::expr::exprs::pack::pack;
99
use crate::expr::exprs::select::Select;
1010
use crate::expr::transform::traits::{ReduceRule, RewriteContext};
1111
use crate::expr::traversal::{NodeExt, Transformed};
12-
use crate::expr::{ExprId, Expression};
12+
use crate::expr::{Expression, ExpressionView};
1313

1414
/// Replaces [crate::SelectExpr] with combination of [crate::GetItem] and [crate::Pack] expressions.
1515
pub(crate) fn remove_select(e: Expression, ctx: &DType) -> VortexResult<Expression> {
@@ -60,20 +60,12 @@ fn remove_select_transformer(
6060
/// Transforms: `select(["a", "b"], expr)` → `pack(a: get_item("a", expr), b: get_item("b", expr))`
6161
pub struct RemoveSelectRule;
6262

63-
impl ReduceRule for RemoveSelectRule {
64-
fn id(&self) -> ExprId {
65-
ExprId::new_ref("vortex.select")
66-
}
67-
63+
impl ReduceRule<Select> for RemoveSelectRule {
6864
fn reduce(
6965
&self,
70-
expr: &Expression,
66+
select: &ExpressionView<Select>,
7167
ctx: &dyn RewriteContext,
7268
) -> VortexResult<Option<Expression>> {
73-
let Some(select) = expr.as_opt::<Select>() else {
74-
return Ok(None);
75-
};
76-
7769
let child = select.child();
7870
let child_dtype = child.return_dtype(ctx.dtype())?;
7971
let child_nullability = child_dtype.nullability();
@@ -114,7 +106,7 @@ mod tests {
114106
use super::{RemoveSelectRule, remove_select};
115107
use crate::expr::exprs::pack::Pack;
116108
use crate::expr::exprs::root::root;
117-
use crate::expr::exprs::select::select;
109+
use crate::expr::exprs::select::{Select, select};
118110
use crate::expr::session::ExprSession;
119111
use crate::expr::transform::simplify_typed::apply_child_rules;
120112
use crate::expr::transform::traits::{ReduceRule, SimpleRewriteContext};
@@ -142,7 +134,8 @@ mod tests {
142134

143135
let rule = RemoveSelectRule;
144136
let ctx = SimpleRewriteContext { dtype: &dtype };
145-
let result = rule.reduce(&e, &ctx).unwrap();
137+
let select_view = e.as_::<Select>();
138+
let result = rule.reduce(&select_view, &ctx).unwrap();
146139

147140
assert!(result.is_some());
148141
let transformed = result.unwrap();

0 commit comments

Comments
 (0)