@@ -105,7 +105,7 @@ where
105105}
106106
107107type RuleRegistry < Rule > = DashMap < ExprId , Vec < Arc < Rule > > > ;
108- type ParentRuleRegistry < Rule > = DashMap < ExprId , Vec < Arc < Rule > > > ;
108+ type ParentRuleRegistry < Rule > = DashMap < ( ExprId , ExprId ) , Vec < Arc < Rule > > > ;
109109
110110/// Inner struct that holds all the rule registries.
111111/// Wrapped in a single Arc by RewriteRuleRegistry for efficient cloning.
@@ -218,7 +218,7 @@ impl RewriteRuleRegistry {
218218 } ;
219219 self . inner
220220 . parent_rules
221- . entry ( child_vtable. id ( ) )
221+ . entry ( ( child_vtable. id ( ) , parent_vtable . id ( ) ) )
222222 . or_default ( )
223223 . push ( Arc :: new ( adapter) ) ;
224224 }
@@ -235,7 +235,7 @@ impl RewriteRuleRegistry {
235235 _phantom : PhantomData ,
236236 } ;
237237 self . inner
238- . parent_rules
238+ . any_parent_rules
239239 . entry ( child_vtable. id ( ) )
240240 . or_default ( )
241241 . push ( Arc :: new ( adapter) ) ;
@@ -259,7 +259,7 @@ impl RewriteRuleRegistry {
259259 } ;
260260 self . inner
261261 . typed_parent_rules
262- . entry ( child_vtable. id ( ) )
262+ . entry ( ( child_vtable. id ( ) , parent_vtable . id ( ) ) )
263263 . or_default ( )
264264 . push ( Arc :: new ( adapter) ) ;
265265 }
@@ -279,7 +279,7 @@ impl RewriteRuleRegistry {
279279 _phantom : PhantomData ,
280280 } ;
281281 self . inner
282- . typed_parent_rules
282+ . typed_any_parent_rules
283283 . entry ( child_vtable. id ( ) )
284284 . or_default ( )
285285 . push ( Arc :: new ( adapter) ) ;
@@ -288,48 +288,84 @@ impl RewriteRuleRegistry {
288288 /// Execute a callback with all typed reduce rules for a given expression ID.
289289 pub ( crate ) fn with_typed_reduce_rules < F , R > ( & self , id : & ExprId , f : F ) -> R
290290 where
291- F : FnOnce ( & [ Arc < dyn DynTypedReduceRule > ] ) -> R ,
291+ F : FnOnce ( & mut dyn Iterator < Item = & dyn DynTypedReduceRule > ) -> R ,
292292 {
293- if let Some ( entry) = self . inner . typed_reduce_rules . get ( id) {
294- f ( entry. value ( ) )
295- } else {
296- f ( & [ ] )
297- }
293+ f ( & mut self
294+ . inner
295+ . typed_reduce_rules
296+ . get ( id)
297+ . iter ( )
298+ . map ( |v| v. value ( ) )
299+ . flatten ( )
300+ . map ( |arc| arc. as_ref ( ) ) )
298301 }
299302
300303 /// Execute a callback with all untyped reduce rules for a given expression ID.
301304 pub ( crate ) fn with_reduce_rules < F , R > ( & self , id : & ExprId , f : F ) -> R
302305 where
303- F : FnOnce ( & [ Arc < dyn DynReduceRule > ] ) -> R ,
306+ F : FnOnce ( & mut dyn Iterator < Item = & dyn DynReduceRule > ) -> R ,
304307 {
305- if let Some ( entry) = self . inner . reduce_rules . get ( id) {
306- f ( entry. value ( ) )
307- } else {
308- f ( & [ ] )
309- }
308+ f ( & mut self
309+ . inner
310+ . reduce_rules
311+ . get ( id)
312+ . iter ( )
313+ . map ( |v| v. value ( ) )
314+ . flatten ( )
315+ . map ( |arc| arc. as_ref ( ) ) )
310316 }
311317
312- /// Execute a callback with all untyped parent reduce rules for a given expression ID.
313- pub ( crate ) fn with_parent_rules < F , R > ( & self , id : & ExprId , f : F ) -> R
318+ /// Execute a callback with all untyped parent reduce rules for a given child and parent expression ID.
319+ ///
320+ /// Returns rules from both specific parent rules (if parent_id provided) and "any parent" wildcard rules.
321+ pub ( crate ) fn with_parent_rules < F , R > (
322+ & self ,
323+ child_id : & ExprId ,
324+ parent_id : Option < & ExprId > ,
325+ f : F ,
326+ ) -> R
314327 where
315- F : FnOnce ( & [ Arc < dyn DynParentReduceRule > ] ) -> R ,
328+ F : FnOnce ( & mut dyn Iterator < Item = & dyn DynParentReduceRule > ) -> R ,
316329 {
317- if let Some ( entry) = self . inner . parent_rules . get ( id) {
318- f ( entry. value ( ) )
319- } else {
320- f ( & [ ] )
321- }
330+ let specific_entry = parent_id. and_then ( |pid| {
331+ self . inner
332+ . parent_rules
333+ . get ( & ( child_id. clone ( ) , pid. clone ( ) ) )
334+ } ) ;
335+ let wildcard_entry = self . inner . any_parent_rules . get ( child_id) ;
336+
337+ f ( & mut specific_entry
338+ . iter ( )
339+ . map ( |v| v. value ( ) )
340+ . flatten ( )
341+ . chain ( wildcard_entry. iter ( ) . map ( |v| v. value ( ) ) . flatten ( ) )
342+ . map ( |arc| arc. as_ref ( ) ) )
322343 }
323344
324- /// Execute a callback with all typed parent reduce rules for a given expression ID.
325- pub ( crate ) fn with_typed_parent_rules < F , R > ( & self , id : & ExprId , f : F ) -> R
345+ /// Execute a callback with all typed parent reduce rules for a given child and parent expression ID.
346+ ///
347+ /// Returns rules from both specific parent rules (if parent_id provided) and "any parent" wildcard rules.
348+ pub ( crate ) fn with_typed_parent_rules < F , R > (
349+ & self ,
350+ child_id : & ExprId ,
351+ parent_id : Option < & ExprId > ,
352+ f : F ,
353+ ) -> R
326354 where
327- F : FnOnce ( & [ Arc < dyn DynTypedParentReduceRule > ] ) -> R ,
355+ F : FnOnce ( & mut dyn Iterator < Item = & dyn DynTypedParentReduceRule > ) -> R ,
328356 {
329- if let Some ( entry) = self . inner . typed_parent_rules . get ( id) {
330- f ( entry. value ( ) )
331- } else {
332- f ( & [ ] )
333- }
357+ let specific_entry = parent_id. and_then ( |pid| {
358+ self . inner
359+ . typed_parent_rules
360+ . get ( & ( child_id. clone ( ) , pid. clone ( ) ) )
361+ } ) ;
362+ let wildcard_entry = self . inner . typed_any_parent_rules . get ( child_id) ;
363+
364+ f ( & mut specific_entry
365+ . iter ( )
366+ . map ( |v| v. value ( ) )
367+ . flatten ( )
368+ . chain ( wildcard_entry. iter ( ) . map ( |v| v. value ( ) ) . flatten ( ) )
369+ . map ( |arc| arc. as_ref ( ) ) )
334370 }
335371}
0 commit comments