@@ -6,7 +6,7 @@ use std::marker::PhantomData;
66use std:: sync:: Arc ;
77
88use vortex_error:: VortexResult ;
9- use vortex_utils:: aliases:: hash_map :: HashMap ;
9+ use vortex_utils:: aliases:: dash_map :: DashMap ;
1010
1111use crate :: expr:: transform:: rules:: {
1212 AnyParent , ParentMatcher , ParentReduceRule , ReduceRule , RuleContext , TypedRuleContext ,
@@ -104,15 +104,13 @@ where
104104 }
105105}
106106
107- type RuleRegistry < Rule > = HashMap < ExprId , Vec < Arc < Rule > > > ;
108- type ParentRuleRegistry < Rule > = HashMap < ( ExprId , ExprId ) , Vec < Arc < Rule > > > ;
107+ type RuleRegistry < Rule > = DashMap < ExprId , Vec < Arc < Rule > > > ;
108+ type ParentRuleRegistry < Rule > = DashMap < ExprId , Vec < Arc < Rule > > > ;
109109
110- /// Registry of expression rewrite rules.
111- ///
112- /// Stores rewrite rules indexed by the expression ID they apply to.
113- /// Typed and untyped rules are stored separately for better organization.
110+ /// Inner struct that holds all the rule registries.
111+ /// Wrapped in a single Arc by RewriteRuleRegistry for efficient cloning.
114112#[ derive( Default ) ]
115- pub struct RewriteRuleRegistry {
113+ struct RewriteRuleRegistryInner {
116114 /// Typed reduce rules (require TypedRewriteContext), indexed by expression ID
117115 typed_reduce_rules : RuleRegistry < dyn DynTypedReduceRule > ,
118116 /// Untyped reduce rules (require only RewriteContext), indexed by expression ID
@@ -127,19 +125,34 @@ pub struct RewriteRuleRegistry {
127125 any_parent_rules : RuleRegistry < dyn DynParentReduceRule > ,
128126}
129127
128+ /// Registry of expression rewrite rules.
129+ ///
130+ /// Stores rewrite rules indexed by the expression ID they apply to.
131+ /// Typed and untyped rules are stored separately for better organization.
132+ #[ derive( Clone ) ]
133+ pub struct RewriteRuleRegistry {
134+ inner : Arc < RewriteRuleRegistryInner > ,
135+ }
136+
137+ impl Default for RewriteRuleRegistry {
138+ fn default ( ) -> Self {
139+ Self {
140+ inner : Arc :: new ( RewriteRuleRegistryInner :: default ( ) ) ,
141+ }
142+ }
143+ }
144+
130145// TODO(joe): follow up with rule debug info.
131146impl Debug for RewriteRuleRegistry {
132147 fn fmt ( & self , f : & mut std:: fmt:: Formatter < ' _ > ) -> std:: fmt:: Result {
133148 f. debug_struct ( "RewriteRuleRegistry" )
134- . field ( "typed_reduce_rules_count" , & self . typed_reduce_rules . len ( ) )
135- . field ( "reduce_rules_count" , & self . reduce_rules . len ( ) )
136- . field ( "typed_parent_rules" , & self . typed_parent_rules . len ( ) )
137- . field ( "parent_rules_count" , & self . parent_rules . len ( ) )
138149 . field (
139- "typed_any_parent_rules_count " ,
140- & self . typed_any_parent_rules . len ( ) ,
150+ "typed_reduce_rules_count " ,
151+ & self . inner . typed_reduce_rules . len ( ) ,
141152 )
142- . field ( "any_parent_rules_count" , & self . any_parent_rules . len ( ) )
153+ . field ( "reduce_rules_count" , & self . inner . reduce_rules . len ( ) )
154+ . field ( "typed_parent_rules" , & self . inner . typed_parent_rules . len ( ) )
155+ . field ( "parent_rules_count" , & self . inner . parent_rules . len ( ) )
143156 . finish ( )
144157 }
145158}
@@ -161,7 +174,8 @@ impl RewriteRuleRegistry {
161174 rule,
162175 _phantom : PhantomData ,
163176 } ;
164- self . typed_reduce_rules
177+ self . inner
178+ . typed_reduce_rules
165179 . entry ( vtable. id ( ) )
166180 . or_default ( )
167181 . push ( Arc :: new ( adapter) ) ;
@@ -179,7 +193,8 @@ impl RewriteRuleRegistry {
179193 rule,
180194 _phantom : PhantomData ,
181195 } ;
182- self . reduce_rules
196+ self . inner
197+ . reduce_rules
183198 . entry ( vtable. id ( ) )
184199 . or_default ( )
185200 . push ( Arc :: new ( adapter) ) ;
@@ -201,8 +216,9 @@ impl RewriteRuleRegistry {
201216 rule,
202217 _phantom : PhantomData ,
203218 } ;
204- self . parent_rules
205- . entry ( ( child_vtable. id ( ) , parent_vtable. id ( ) ) )
219+ self . inner
220+ . parent_rules
221+ . entry ( child_vtable. id ( ) )
206222 . or_default ( )
207223 . push ( Arc :: new ( adapter) ) ;
208224 }
@@ -218,7 +234,8 @@ impl RewriteRuleRegistry {
218234 rule,
219235 _phantom : PhantomData ,
220236 } ;
221- self . any_parent_rules
237+ self . inner
238+ . parent_rules
222239 . entry ( child_vtable. id ( ) )
223240 . or_default ( )
224241 . push ( Arc :: new ( adapter) ) ;
@@ -240,8 +257,9 @@ impl RewriteRuleRegistry {
240257 rule,
241258 _phantom : PhantomData ,
242259 } ;
243- self . typed_parent_rules
244- . entry ( ( child_vtable. id ( ) , parent_vtable. id ( ) ) )
260+ self . inner
261+ . typed_parent_rules
262+ . entry ( child_vtable. id ( ) )
245263 . or_default ( )
246264 . push ( Arc :: new ( adapter) ) ;
247265 }
@@ -260,74 +278,58 @@ impl RewriteRuleRegistry {
260278 rule,
261279 _phantom : PhantomData ,
262280 } ;
263- self . typed_any_parent_rules
281+ self . inner
282+ . typed_parent_rules
264283 . entry ( child_vtable. id ( ) )
265284 . or_default ( )
266285 . push ( Arc :: new ( adapter) ) ;
267286 }
268287
269- /// Get all typed reduce rules for a given expression ID.
270- pub ( crate ) fn typed_reduce_rules_for (
271- & self ,
272- id : & ExprId ,
273- ) -> impl Iterator < Item = & Arc < dyn DynTypedReduceRule > > {
274- self . typed_reduce_rules
275- . get ( id)
276- . into_iter ( )
277- . flat_map ( |v| v. iter ( ) )
288+ /// Execute a callback with all typed reduce rules for a given expression ID.
289+ pub ( crate ) fn with_typed_reduce_rules < F , R > ( & self , id : & ExprId , f : F ) -> R
290+ where
291+ F : FnOnce ( & [ Arc < dyn DynTypedReduceRule > ] ) -> R ,
292+ {
293+ if let Some ( entry) = self . inner . typed_reduce_rules . get ( id) {
294+ f ( entry. value ( ) )
295+ } else {
296+ f ( & [ ] )
297+ }
278298 }
279299
280- /// Get all untyped reduce rules for a given expression ID.
281- pub ( crate ) fn reduce_rules_for (
282- & self ,
283- id : & ExprId ,
284- ) -> impl Iterator < Item = & Arc < dyn DynReduceRule > > {
285- self . reduce_rules . get ( id) . into_iter ( ) . flat_map ( |v| v. iter ( ) )
300+ /// Execute a callback with all untyped reduce rules for a given expression ID.
301+ pub ( crate ) fn with_reduce_rules < F , R > ( & self , id : & ExprId , f : F ) -> R
302+ where
303+ F : FnOnce ( & [ Arc < dyn DynReduceRule > ] ) -> R ,
304+ {
305+ if let Some ( entry) = self . inner . reduce_rules . get ( id) {
306+ f ( entry. value ( ) )
307+ } else {
308+ f ( & [ ] )
309+ }
286310 }
287311
288- /// Get all untyped parent reduce rules for a given child and parent expression ID pair.
289- ///
290- /// Returns both specific parent rules and wildcard "any parent" rules.
291- pub ( crate ) fn parent_rules_for (
292- & self ,
293- child_id : & ExprId ,
294- parent_id : & ExprId ,
295- ) -> impl Iterator < Item = & Arc < dyn DynParentReduceRule > > {
296- let specific = self
297- . parent_rules
298- . get ( & ( child_id. clone ( ) , parent_id. clone ( ) ) )
299- . into_iter ( )
300- . flat_map ( |v| v. iter ( ) ) ;
301-
302- let wildcard = self
303- . any_parent_rules
304- . get ( child_id)
305- . into_iter ( )
306- . flat_map ( |v| v. iter ( ) ) ;
307-
308- specific. chain ( wildcard)
312+ /// Execute a callback with all untyped parent reduce rules for a given expression ID.
313+ pub ( crate ) fn with_parent_rules < F , R > ( & self , id : & ExprId , f : F ) -> R
314+ where
315+ F : FnOnce ( & [ Arc < dyn DynParentReduceRule > ] ) -> R ,
316+ {
317+ if let Some ( entry) = self . inner . parent_rules . get ( id) {
318+ f ( entry. value ( ) )
319+ } else {
320+ f ( & [ ] )
321+ }
309322 }
310323
311- /// Get all the typed parent reduce rules for a given child and parent expression ID pair.
312- ///
313- /// Returns both specific parent rules and wildcard "any parent" rules.
314- pub ( crate ) fn typed_parent_rules_for (
315- & self ,
316- child_id : & ExprId ,
317- parent_id : & ExprId ,
318- ) -> impl Iterator < Item = & Arc < dyn DynTypedParentReduceRule > > {
319- let specific = self
320- . typed_parent_rules
321- . get ( & ( child_id. clone ( ) , parent_id. clone ( ) ) )
322- . into_iter ( )
323- . flat_map ( |v| v. iter ( ) ) ;
324-
325- let wildcard = self
326- . typed_any_parent_rules
327- . get ( child_id)
328- . into_iter ( )
329- . flat_map ( |v| v. iter ( ) ) ;
330-
331- specific. chain ( wildcard)
324+ /// Execute a callback with all typed parent reduce rules for a given expression ID.
325+ pub ( crate ) fn with_typed_parent_rules < F , R > ( & self , id : & ExprId , f : F ) -> R
326+ where
327+ F : FnOnce ( & [ Arc < dyn DynTypedParentReduceRule > ] ) -> R ,
328+ {
329+ if let Some ( entry) = self . inner . typed_parent_rules . get ( id) {
330+ f ( entry. value ( ) )
331+ } else {
332+ f ( & [ ] )
333+ }
332334 }
333335}
0 commit comments