@@ -7,8 +7,8 @@ use std::sync::Arc;
77use vortex_error:: VortexResult ;
88use vortex_utils:: aliases:: hash_map:: HashMap ;
99
10- use crate :: expr:: transform:: TypedRewriteContext ;
1110use crate :: expr:: transform:: rules:: { ParentReduceRule , ReduceRule , RewriteContext } ;
11+ use crate :: expr:: transform:: { Context , TypedRewriteContext } ;
1212use crate :: expr:: { ExprId , Expression , VTable } ;
1313
1414/// Type-erased wrapper for ReduceRule that allows dynamic dispatch.
@@ -120,7 +120,11 @@ pub(crate) trait DynParentReduceRule: Send + Sync {
120120}
121121
122122/// Concrete wrapper that implements DynParentReduceRule for a specific VTable type.
123- struct ParentReduceRuleAdapter < V : VTable , R : ParentReduceRule < V > > {
123+ struct ParentReduceRuleAdapter < V : VTable , R : ParentReduceRule < V , C > >
124+ where
125+ V : VTable ,
126+ for < ' a > R : ReduceRule < V , & ' a dyn RewriteContext > ,
127+ {
124128 rule : R ,
125129 _phantom : PhantomData < V > ,
126130}
@@ -150,7 +154,7 @@ impl<V: VTable, R: ParentReduceRule<V>> DynParentReduceRule for ParentReduceRule
150154}
151155
152156pub ( crate ) trait DynTypedParentReduceRule : Send + Sync {
153- fn reduce_parent_dyn (
157+ fn reduce_parent_dyn_typed (
154158 & self ,
155159 expr : & Expression ,
156160 parent : & Expression ,
@@ -159,12 +163,20 @@ pub(crate) trait DynTypedParentReduceRule: Send + Sync {
159163 ) -> VortexResult < Option < Expression > > ;
160164}
161165
162- struct TypedParentReduceRuleAdapter < V : VTable , R : ParentReduceRule < V > > {
166+ struct TypedParentReduceRuleAdapter < V : VTable , R >
167+ where
168+ V : VTable ,
169+ for < ' a > R : ParentReduceRule < V , & ' a dyn TypedRewriteContext > ,
170+ {
163171 rule : R ,
164172 _phantom : PhantomData < V > ,
165173}
166174
167- impl < V : VTable , R : ParentReduceRule < V > > TypedParentReduceRuleAdapter < V , R > {
175+ impl < V , R > TypedParentReduceRuleAdapter < V , R >
176+ where
177+ V : VTable ,
178+ for < ' a > R : ParentReduceRule < V , & ' a dyn TypedRewriteContext > ,
179+ {
168180 fn new ( rule : R ) -> Self {
169181 Self {
170182 rule,
@@ -173,10 +185,12 @@ impl<V: VTable, R: ParentReduceRule<V>> TypedParentReduceRuleAdapter<V, R> {
173185 }
174186}
175187
176- impl < V : VTable , R : ParentReduceRule < V > > DynTypedParentReduceRule
177- for TypedParentReduceRuleAdapter < V , R >
188+ impl < V , R > DynTypedParentReduceRule for TypedParentReduceRuleAdapter < V , R >
189+ where
190+ V : VTable ,
191+ for < ' a > R : ParentReduceRule < V , & ' a dyn TypedRewriteContext > ,
178192{
179- fn reduce_parent_dyn (
193+ fn reduce_parent_dyn_typed (
180194 & self ,
181195 expr : & Expression ,
182196 parent : & Expression ,
@@ -186,7 +200,7 @@ impl<V: VTable, R: ParentReduceRule<V>> DynTypedParentReduceRule
186200 let Some ( view) = expr. as_opt :: < V > ( ) else {
187201 return Ok ( None ) ;
188202 } ;
189- self . rule . reduce_parent ( & view, parent, child_idx, ctx)
203+ self . rule . reduce ( & view, parent, child_idx, ctx)
190204 }
191205}
192206
@@ -203,7 +217,7 @@ pub struct RewriteRuleRegistry {
203217 /// Parent reduce rules, indexed by expression ID
204218 parent_rules : HashMap < ExprId , Vec < Arc < dyn DynParentReduceRule > > > ,
205219 /// Parent reduce rules, indexed by expression ID
206- typed_parent_rules : HashMap < ExprId , Vec < Arc < dyn DynParentReduceRule > > > ,
220+ typed_parent_rules : HashMap < ExprId , Vec < Arc < dyn DynTypedParentReduceRule > > > ,
207221}
208222
209223impl std:: fmt:: Debug for RewriteRuleRegistry {
@@ -253,7 +267,6 @@ impl RewriteRuleRegistry {
253267 . push ( Arc :: new ( adapter) ) ;
254268 }
255269
256- /// Register a parent reduce rule.
257270 pub fn register_parent_rule < V : VTable , R : ParentReduceRule < V > + ' static > (
258271 & mut self ,
259272 vtable : & ' static V ,
@@ -267,29 +280,52 @@ impl RewriteRuleRegistry {
267280 . push ( Arc :: new ( adapter) ) ;
268281 }
269282
283+ /// Register a parent reduce rule.
284+ pub fn register_typed_parent_rule < V : VTable , R : ParentReduceRule < V > + ' static > (
285+ & mut self ,
286+ vtable : & ' static V ,
287+ rule : R ,
288+ ) {
289+ let id = vtable. id ( ) ;
290+ let adapter = TypedParentReduceRuleAdapter :: new ( rule) ;
291+ self . parent_rules
292+ . entry ( id)
293+ . or_default ( )
294+ . push ( Arc :: new ( adapter) ) ;
295+ }
296+
270297 /// Get all typed reduce rules for a given expression ID.
271- pub ( crate ) fn typed_reduce_rules_for (
272- & self ,
273- id : & ExprId ,
274- ) -> Option < & [ Arc < dyn DynTypedReduceRule > ] > {
275- self . typed_reduce_rules . get ( id ) . map ( |v| v . as_slice ( ) )
298+ pub ( crate ) fn typed_reduce_rules_for ( & self , id : & ExprId ) -> & [ Arc < dyn DynTypedReduceRule > ] {
299+ self . typed_reduce_rules
300+ . get ( id )
301+ . map ( |v| v . as_slice ( ) )
302+ . unwrap_or_default ( )
276303 }
277304
278305 /// Get all untyped reduce rules for a given expression ID.
279- pub ( crate ) fn reduce_rules_for ( & self , id : & ExprId ) -> Option < & [ Arc < dyn DynReduceRule > ] > {
280- self . reduce_rules . get ( id) . map ( |v| v. as_slice ( ) )
306+ pub ( crate ) fn reduce_rules_for ( & self , id : & ExprId ) -> & [ Arc < dyn DynReduceRule > ] {
307+ self . reduce_rules
308+ . get ( id)
309+ . map ( |v| v. as_slice ( ) )
310+ . unwrap_or_default ( )
281311 }
282312
283313 /// Get all parent reduce rules for a given expression ID.
284- pub ( crate ) fn parent_rules_for ( & self , id : & ExprId ) -> Option < & [ Arc < dyn DynParentReduceRule > ] > {
285- self . parent_rules . get ( id) . map ( |v| v. as_slice ( ) )
314+ pub ( crate ) fn parent_rules_for ( & self , id : & ExprId ) -> & [ Arc < dyn DynParentReduceRule > ] {
315+ self . parent_rules
316+ . get ( id)
317+ . map ( |v| v. as_slice ( ) )
318+ . unwrap_or_default ( )
286319 }
287320
288321 /// Get all the typed parent reduce rules for a given expression ID.
289322 pub ( crate ) fn typed_parent_rules_for (
290323 & self ,
291324 id : & ExprId ,
292- ) -> Option < & [ Arc < dyn DynParentReduceRule > ] > {
293- self . typed_parent_rules . get ( id) . map ( |v| v. as_slice ( ) )
325+ ) -> & [ Arc < dyn DynTypedParentReduceRule > ] {
326+ self . typed_parent_rules
327+ . get ( id)
328+ . map ( |v| v. as_slice ( ) )
329+ . unwrap_or_default ( )
294330 }
295331}
0 commit comments