@@ -14,22 +14,19 @@ use crate::expr::transform::{
1414} ;
1515use crate :: expr:: { ExprId , Expression , VTable } ;
1616
17- /// Universal adapter for both ReduceRule and ParentReduceRule with any context type.
18- struct RuleAdapter < V : VTable , R > {
17+ /// Adapter for ReduceRule
18+ struct ReduceRuleAdapter < V : VTable , R > {
1919 rule : R ,
2020 _phantom : PhantomData < V > ,
2121}
2222
23- impl < V : VTable , R > RuleAdapter < V , R > {
24- fn new ( rule : R ) -> Self {
25- Self {
26- rule,
27- _phantom : PhantomData ,
28- }
29- }
23+ /// Adapter for ParentReduceRule
24+ struct ReduceParentRuleAdapter < Child : VTable , Parent : VTable , R > {
25+ rule : R ,
26+ _phantom : PhantomData < ( Child , Parent ) > ,
3027}
3128
32- impl < V , R > DynReduceRule for RuleAdapter < V , R >
29+ impl < V , R > DynReduceRule for ReduceRuleAdapter < V , R >
3330where
3431 V : VTable ,
3532 R : ReduceRule < V , RuleContext > ,
4239 }
4340}
4441
45- impl < V , R > DynTypedReduceRule for RuleAdapter < V , R >
42+ impl < V , R > DynTypedReduceRule for ReduceRuleAdapter < V , R >
4643where
4744 V : VTable ,
4845 R : ReduceRule < V , TypedRuleContext > ,
@@ -59,10 +56,11 @@ where
5956 }
6057}
6158
62- impl < V , R > DynParentReduceRule for RuleAdapter < V , R >
59+ impl < Child , Parent , R > DynParentReduceRule for ReduceParentRuleAdapter < Child , Parent , R >
6360where
64- V : VTable ,
65- R : ParentReduceRule < V , RuleContext > ,
61+ Child : VTable ,
62+ Parent : VTable ,
63+ R : ParentReduceRule < Child , Parent , RuleContext > ,
6664{
6765 fn reduce_parent (
6866 & self ,
@@ -71,17 +69,21 @@ where
7169 child_idx : usize ,
7270 ctx : & RuleContext ,
7371 ) -> VortexResult < Option < Expression > > {
74- let Some ( view) = expr. as_opt :: < V > ( ) else {
72+ let Some ( view) = expr. as_opt :: < Child > ( ) else {
73+ return Ok ( None ) ;
74+ } ;
75+ let Some ( parent_view) = parent. as_opt :: < Parent > ( ) else {
7576 return Ok ( None ) ;
7677 } ;
77- self . rule . reduce_parent ( & view, parent , child_idx, ctx)
78+ self . rule . reduce_parent ( & view, & parent_view , child_idx, ctx)
7879 }
7980}
8081
81- impl < V , R > DynTypedParentReduceRule for RuleAdapter < V , R >
82+ impl < Child , Parent , R > DynTypedParentReduceRule for ReduceParentRuleAdapter < Child , Parent , R >
8283where
83- V : VTable ,
84- R : ParentReduceRule < V , TypedRuleContext > ,
84+ Child : VTable ,
85+ Parent : VTable ,
86+ R : ParentReduceRule < Child , Parent , TypedRuleContext > ,
8587{
8688 fn reduce_parent (
8789 & self ,
@@ -90,14 +92,18 @@ where
9092 child_idx : usize ,
9193 ctx : & TypedRuleContext ,
9294 ) -> VortexResult < Option < Expression > > {
93- let Some ( view) = expr. as_opt :: < V > ( ) else {
95+ let Some ( view) = expr. as_opt :: < Child > ( ) else {
96+ return Ok ( None ) ;
97+ } ;
98+ let Some ( parent_view) = parent. as_opt :: < Parent > ( ) else {
9499 return Ok ( None ) ;
95100 } ;
96- self . rule . reduce_parent ( & view, parent , child_idx, ctx)
101+ self . rule . reduce_parent ( & view, & parent_view , child_idx, ctx)
97102 }
98103}
99104
100105type RuleRegistry < Rule > = HashMap < ExprId , Vec < Arc < Rule > > > ;
106+ type ParentRuleRegistry < Rule > = HashMap < ( ExprId , ExprId ) , Vec < Arc < Rule > > > ;
101107
102108/// Registry of expression rewrite rules.
103109///
@@ -109,10 +115,10 @@ pub struct RewriteRuleRegistry {
109115 typed_reduce_rules : RuleRegistry < dyn DynTypedReduceRule > ,
110116 /// Untyped reduce rules (require only RewriteContext), indexed by expression ID
111117 reduce_rules : RuleRegistry < dyn DynReduceRule > ,
112- /// Parent reduce rules, indexed by expression ID
113- typed_parent_rules : RuleRegistry < dyn DynTypedParentReduceRule > ,
114- /// Parent reduce rules, indexed by expression ID
115- parent_rules : RuleRegistry < dyn DynParentReduceRule > ,
118+ /// Parent reduce rules, indexed by (child_id, parent_id)
119+ typed_parent_rules : ParentRuleRegistry < dyn DynTypedParentReduceRule > ,
120+ /// Parent reduce rules, indexed by (child_id, parent_id)
121+ parent_rules : ParentRuleRegistry < dyn DynParentReduceRule > ,
116122}
117123
118124// TODO(joe): follow up with rule debug info.
@@ -140,7 +146,10 @@ impl RewriteRuleRegistry {
140146 R : ' static ,
141147 R : ReduceRule < V , TypedRuleContext > ,
142148 {
143- let adapter = RuleAdapter :: new ( rule) ;
149+ let adapter = ReduceRuleAdapter {
150+ rule,
151+ _phantom : PhantomData ,
152+ } ;
144153 self . typed_reduce_rules
145154 . entry ( vtable. id ( ) )
146155 . or_default ( )
@@ -155,36 +164,62 @@ impl RewriteRuleRegistry {
155164 R : ' static ,
156165 R : ReduceRule < V , RuleContext > ,
157166 {
158- let adapter = RuleAdapter :: new ( rule) ;
167+ let adapter = ReduceRuleAdapter {
168+ rule,
169+ _phantom : PhantomData ,
170+ } ;
159171 self . reduce_rules
160172 . entry ( vtable. id ( ) )
161173 . or_default ( )
162174 . push ( Arc :: new ( adapter) ) ;
163175 }
164176
165- pub fn register_parent_rule < V , R > ( & mut self , vtable : & ' static V , rule : R )
166- where
167- V : VTable ,
177+ pub fn register_parent_rule < Child , Parent , R > (
178+ & mut self ,
179+ child_vtable : & ' static Child ,
180+ parent_vtable : & ' static Parent ,
181+ rule : R ,
182+ ) where
183+ Child : VTable ,
184+ Parent : VTable ,
168185 R : ' static ,
169- R : ParentReduceRule < V , RuleContext > ,
186+ R : ParentReduceRule < Child , Parent , RuleContext > ,
170187 {
171- let adapter = RuleAdapter :: new ( rule) ;
188+ let adapter = ReduceParentRuleAdapter {
189+ rule,
190+ _phantom : PhantomData ,
191+ } ;
172192 self . parent_rules
173- . entry ( vtable . id ( ) )
193+ . entry ( ( child_vtable . id ( ) , parent_vtable . id ( ) ) )
174194 . or_default ( )
175195 . push ( Arc :: new ( adapter) ) ;
176196 }
177197
178- /// Register a parent reduce rule.
179- pub fn register_typed_parent_rule < V , R > ( & mut self , vtable : & ' static V , rule : R )
180- where
181- V : VTable ,
198+ /// Register a typed parent reduce rule.
199+ ///
200+ /// # Type Parameters
201+ /// * `Child` - The child expression VTable type
202+ /// * `Parent` - The parent expression VTable type
203+ /// * `R` - The rule implementation
204+ ///
205+ /// The rule will only be invoked when both the child has type Child and the parent has type Parent.
206+ pub fn register_typed_parent_rule < Child , Parent , R > (
207+ & mut self ,
208+ child_vtable : & ' static Child ,
209+ parent_vtable : & ' static Parent ,
210+ rule : R ,
211+ ) where
212+ Child : VTable ,
213+ Parent : VTable ,
182214 R : ' static ,
183- R : ParentReduceRule < V , TypedRuleContext > ,
215+ R : ParentReduceRule < Child , Parent , TypedRuleContext > ,
184216 {
185- let adapter = RuleAdapter :: new ( rule) ;
217+ let adapter = ReduceParentRuleAdapter {
218+ rule,
219+ _phantom : PhantomData ,
220+ } ;
186221 self . typed_parent_rules
187- . entry ( vtable . id ( ) )
222+ . entry ( ( child_vtable . id ( ) , parent_vtable . id ( ) ) )
188223 . or_default ( )
189224 . push ( Arc :: new ( adapter) ) ;
190225 }
@@ -205,21 +240,26 @@ impl RewriteRuleRegistry {
205240 . unwrap_or_default ( )
206241 }
207242
208- /// Get all untyped parent reduce rules for a given expression ID.
209- pub ( crate ) fn parent_rules_for ( & self , id : & ExprId ) -> & [ Arc < dyn DynParentReduceRule > ] {
243+ /// Get all untyped parent reduce rules for a given child and parent expression ID pair.
244+ pub ( crate ) fn parent_rules_for (
245+ & self ,
246+ child_id : & ExprId ,
247+ parent_id : & ExprId ,
248+ ) -> & [ Arc < dyn DynParentReduceRule > ] {
210249 self . parent_rules
211- . get ( id )
250+ . get ( & ( child_id . clone ( ) , parent_id . clone ( ) ) )
212251 . map ( |v| v. as_slice ( ) )
213252 . unwrap_or_default ( )
214253 }
215254
216- /// Get all the typed parent reduce rules for a given expression ID.
255+ /// Get all the typed parent reduce rules for a given child and parent expression ID pair .
217256 pub ( crate ) fn typed_parent_rules_for (
218257 & self ,
219- id : & ExprId ,
258+ child_id : & ExprId ,
259+ parent_id : & ExprId ,
220260 ) -> & [ Arc < dyn DynTypedParentReduceRule > ] {
221261 self . typed_parent_rules
222- . get ( id )
262+ . get ( & ( child_id . clone ( ) , parent_id . clone ( ) ) )
223263 . map ( |v| v. as_slice ( ) )
224264 . unwrap_or_default ( )
225265 }
0 commit comments