Skip to content

Commit 1dfb21d

Browse files
committed
wip
Signed-off-by: Joe Isaacs <[email protected]>
1 parent 63aa58c commit 1dfb21d

File tree

10 files changed

+257
-164
lines changed

10 files changed

+257
-164
lines changed

vortex-array/src/expr/exprs/get_item/transform.rs

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ mod tests {
8282

8383
let dtype = DType::Primitive(PType::I32, NonNullable);
8484

85-
let result = simplify_typed(get_z, &dtype, &ExprSession::default()).unwrap();
85+
let result = simplify_typed(get_z, &dtype, ExprSession::default().rewrite_rules()).unwrap();
8686

8787
assert_eq!(&result, &lit(4));
8888
}
@@ -103,7 +103,8 @@ mod tests {
103103

104104
let dtype = DType::Primitive(PType::I32, NonNullable);
105105

106-
let result = simplify_typed(get_final, &dtype, &ExprSession::default()).unwrap();
106+
let result =
107+
simplify_typed(get_final, &dtype, ExprSession::default().rewrite_rules()).unwrap();
107108

108109
assert_eq!(&result, &lit(42));
109110
}
@@ -119,7 +120,8 @@ mod tests {
119120

120121
let dtype = DType::Primitive(PType::I32, NonNullable);
121122

122-
let result = simplify_typed(get_result, &dtype, &ExprSession::default()).unwrap();
123+
let result =
124+
simplify_typed(get_result, &dtype, ExprSession::default().rewrite_rules()).unwrap();
123125

124126
let expected = checked_add(lit(1), lit(10));
125127
assert_eq!(&result, &expected);

vortex-array/src/expr/session/rewrite.rs

Lines changed: 82 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ use std::marker::PhantomData;
66
use std::sync::Arc;
77

88
use vortex_error::VortexResult;
9-
use vortex_utils::aliases::hash_map::HashMap;
9+
use vortex_utils::aliases::dash_map::DashMap;
1010

1111
use crate::expr::transform::rules::{
1212
AnyParent, ParentMatcher, ParentReduceRule, ReduceRule, RuleContext, TypedRuleContext,
@@ -104,15 +104,13 @@ where
104104
}
105105
}
106106

107-
type RuleRegistry<Rule> = HashMap<ExprId, Vec<Arc<Rule>>>;
108-
type ParentRuleRegistry<Rule> = HashMap<(ExprId, ExprId), Vec<Arc<Rule>>>;
107+
type RuleRegistry<Rule> = DashMap<ExprId, Vec<Arc<Rule>>>;
108+
type ParentRuleRegistry<Rule> = DashMap<ExprId, Vec<Arc<Rule>>>;
109109

110-
/// Registry of expression rewrite rules.
111-
///
112-
/// Stores rewrite rules indexed by the expression ID they apply to.
113-
/// Typed and untyped rules are stored separately for better organization.
110+
/// Inner struct that holds all the rule registries.
111+
/// Wrapped in a single Arc by RewriteRuleRegistry for efficient cloning.
114112
#[derive(Default)]
115-
pub struct RewriteRuleRegistry {
113+
struct RewriteRuleRegistryInner {
116114
/// Typed reduce rules (require TypedRewriteContext), indexed by expression ID
117115
typed_reduce_rules: RuleRegistry<dyn DynTypedReduceRule>,
118116
/// Untyped reduce rules (require only RewriteContext), indexed by expression ID
@@ -127,19 +125,34 @@ pub struct RewriteRuleRegistry {
127125
any_parent_rules: RuleRegistry<dyn DynParentReduceRule>,
128126
}
129127

128+
/// Registry of expression rewrite rules.
129+
///
130+
/// Stores rewrite rules indexed by the expression ID they apply to.
131+
/// Typed and untyped rules are stored separately for better organization.
132+
#[derive(Clone)]
133+
pub struct RewriteRuleRegistry {
134+
inner: Arc<RewriteRuleRegistryInner>,
135+
}
136+
137+
impl Default for RewriteRuleRegistry {
138+
fn default() -> Self {
139+
Self {
140+
inner: Arc::new(RewriteRuleRegistryInner::default()),
141+
}
142+
}
143+
}
144+
130145
// TODO(joe): follow up with rule debug info.
131146
impl Debug for RewriteRuleRegistry {
132147
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
133148
f.debug_struct("RewriteRuleRegistry")
134-
.field("typed_reduce_rules_count", &self.typed_reduce_rules.len())
135-
.field("reduce_rules_count", &self.reduce_rules.len())
136-
.field("typed_parent_rules", &self.typed_parent_rules.len())
137-
.field("parent_rules_count", &self.parent_rules.len())
138149
.field(
139-
"typed_any_parent_rules_count",
140-
&self.typed_any_parent_rules.len(),
150+
"typed_reduce_rules_count",
151+
&self.inner.typed_reduce_rules.len(),
141152
)
142-
.field("any_parent_rules_count", &self.any_parent_rules.len())
153+
.field("reduce_rules_count", &self.inner.reduce_rules.len())
154+
.field("typed_parent_rules", &self.inner.typed_parent_rules.len())
155+
.field("parent_rules_count", &self.inner.parent_rules.len())
143156
.finish()
144157
}
145158
}
@@ -161,7 +174,8 @@ impl RewriteRuleRegistry {
161174
rule,
162175
_phantom: PhantomData,
163176
};
164-
self.typed_reduce_rules
177+
self.inner
178+
.typed_reduce_rules
165179
.entry(vtable.id())
166180
.or_default()
167181
.push(Arc::new(adapter));
@@ -179,7 +193,8 @@ impl RewriteRuleRegistry {
179193
rule,
180194
_phantom: PhantomData,
181195
};
182-
self.reduce_rules
196+
self.inner
197+
.reduce_rules
183198
.entry(vtable.id())
184199
.or_default()
185200
.push(Arc::new(adapter));
@@ -201,8 +216,9 @@ impl RewriteRuleRegistry {
201216
rule,
202217
_phantom: PhantomData,
203218
};
204-
self.parent_rules
205-
.entry((child_vtable.id(), parent_vtable.id()))
219+
self.inner
220+
.parent_rules
221+
.entry(child_vtable.id())
206222
.or_default()
207223
.push(Arc::new(adapter));
208224
}
@@ -218,7 +234,8 @@ impl RewriteRuleRegistry {
218234
rule,
219235
_phantom: PhantomData,
220236
};
221-
self.any_parent_rules
237+
self.inner
238+
.parent_rules
222239
.entry(child_vtable.id())
223240
.or_default()
224241
.push(Arc::new(adapter));
@@ -240,8 +257,9 @@ impl RewriteRuleRegistry {
240257
rule,
241258
_phantom: PhantomData,
242259
};
243-
self.typed_parent_rules
244-
.entry((child_vtable.id(), parent_vtable.id()))
260+
self.inner
261+
.typed_parent_rules
262+
.entry(child_vtable.id())
245263
.or_default()
246264
.push(Arc::new(adapter));
247265
}
@@ -260,74 +278,58 @@ impl RewriteRuleRegistry {
260278
rule,
261279
_phantom: PhantomData,
262280
};
263-
self.typed_any_parent_rules
281+
self.inner
282+
.typed_parent_rules
264283
.entry(child_vtable.id())
265284
.or_default()
266285
.push(Arc::new(adapter));
267286
}
268287

269-
/// Get all typed reduce rules for a given expression ID.
270-
pub(crate) fn typed_reduce_rules_for(
271-
&self,
272-
id: &ExprId,
273-
) -> impl Iterator<Item = &Arc<dyn DynTypedReduceRule>> {
274-
self.typed_reduce_rules
275-
.get(id)
276-
.into_iter()
277-
.flat_map(|v| v.iter())
288+
/// Execute a callback with all typed reduce rules for a given expression ID.
289+
pub(crate) fn with_typed_reduce_rules<F, R>(&self, id: &ExprId, f: F) -> R
290+
where
291+
F: FnOnce(&[Arc<dyn DynTypedReduceRule>]) -> R,
292+
{
293+
if let Some(entry) = self.inner.typed_reduce_rules.get(id) {
294+
f(entry.value())
295+
} else {
296+
f(&[])
297+
}
278298
}
279299

280-
/// Get all untyped reduce rules for a given expression ID.
281-
pub(crate) fn reduce_rules_for(
282-
&self,
283-
id: &ExprId,
284-
) -> impl Iterator<Item = &Arc<dyn DynReduceRule>> {
285-
self.reduce_rules.get(id).into_iter().flat_map(|v| v.iter())
300+
/// Execute a callback with all untyped reduce rules for a given expression ID.
301+
pub(crate) fn with_reduce_rules<F, R>(&self, id: &ExprId, f: F) -> R
302+
where
303+
F: FnOnce(&[Arc<dyn DynReduceRule>]) -> R,
304+
{
305+
if let Some(entry) = self.inner.reduce_rules.get(id) {
306+
f(entry.value())
307+
} else {
308+
f(&[])
309+
}
286310
}
287311

288-
/// Get all untyped parent reduce rules for a given child and parent expression ID pair.
289-
///
290-
/// Returns both specific parent rules and wildcard "any parent" rules.
291-
pub(crate) fn parent_rules_for(
292-
&self,
293-
child_id: &ExprId,
294-
parent_id: &ExprId,
295-
) -> impl Iterator<Item = &Arc<dyn DynParentReduceRule>> {
296-
let specific = self
297-
.parent_rules
298-
.get(&(child_id.clone(), parent_id.clone()))
299-
.into_iter()
300-
.flat_map(|v| v.iter());
301-
302-
let wildcard = self
303-
.any_parent_rules
304-
.get(child_id)
305-
.into_iter()
306-
.flat_map(|v| v.iter());
307-
308-
specific.chain(wildcard)
312+
/// Execute a callback with all untyped parent reduce rules for a given expression ID.
313+
pub(crate) fn with_parent_rules<F, R>(&self, id: &ExprId, f: F) -> R
314+
where
315+
F: FnOnce(&[Arc<dyn DynParentReduceRule>]) -> R,
316+
{
317+
if let Some(entry) = self.inner.parent_rules.get(id) {
318+
f(entry.value())
319+
} else {
320+
f(&[])
321+
}
309322
}
310323

311-
/// Get all the typed parent reduce rules for a given child and parent expression ID pair.
312-
///
313-
/// Returns both specific parent rules and wildcard "any parent" rules.
314-
pub(crate) fn typed_parent_rules_for(
315-
&self,
316-
child_id: &ExprId,
317-
parent_id: &ExprId,
318-
) -> impl Iterator<Item = &Arc<dyn DynTypedParentReduceRule>> {
319-
let specific = self
320-
.typed_parent_rules
321-
.get(&(child_id.clone(), parent_id.clone()))
322-
.into_iter()
323-
.flat_map(|v| v.iter());
324-
325-
let wildcard = self
326-
.typed_any_parent_rules
327-
.get(child_id)
328-
.into_iter()
329-
.flat_map(|v| v.iter());
330-
331-
specific.chain(wildcard)
324+
/// Execute a callback with all typed parent reduce rules for a given expression ID.
325+
pub(crate) fn with_typed_parent_rules<F, R>(&self, id: &ExprId, f: F) -> R
326+
where
327+
F: FnOnce(&[Arc<dyn DynTypedParentReduceRule>]) -> R,
328+
{
329+
if let Some(entry) = self.inner.typed_parent_rules.get(id) {
330+
f(entry.value())
331+
} else {
332+
f(&[])
333+
}
332334
}
333335
}

vortex-array/src/expr/transform/optimizer.rs

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,33 +5,35 @@ use vortex_dtype::DType;
55
use vortex_error::VortexResult;
66

77
use crate::expr::Expression;
8-
use crate::expr::session::ExprSession;
8+
use crate::expr::session::{ExprSession, RewriteRuleRegistry};
99
use crate::expr::transform::{simplify, simplify_typed};
1010

1111
/// A unified optimizer for expressions that can work with or without type information.
12-
pub struct ExprOptimizer<'a> {
13-
session: &'a ExprSession,
12+
pub struct ExprOptimizer {
13+
rule_registry: RewriteRuleRegistry,
1414
}
1515

16-
impl<'a> ExprOptimizer<'a> {
16+
impl ExprOptimizer {
1717
/// Create a new untyped optimizer.
1818
///
1919
/// This optimizer will use untyped simplification rules only.
20-
pub fn new(session: &'a ExprSession) -> Self {
21-
Self { session }
20+
pub fn new(session: &ExprSession) -> Self {
21+
Self {
22+
rule_registry: session.rewrite_rules().clone(),
23+
}
2224
}
2325

2426
/// Optimize the given expression.
2527
///
2628
/// If this optimizer was created with a dtype, this will perform typed optimization.
2729
/// Otherwise, it will perform untyped optimization.
2830
pub fn optimize(&self, expr: Expression) -> VortexResult<Expression> {
29-
simplify(expr, self.session)
31+
simplify(expr, &self.rule_registry)
3032
}
3133

3234
/// Apply optimize rules to the expression, with a known dtype. This will also apply rules
3335
/// in `optimize`.
3436
pub fn optimize_typed(&self, expr: Expression, dtype: &DType) -> VortexResult<Expression> {
35-
simplify_typed(expr, dtype, self.session)
37+
simplify_typed(expr, dtype, &self.rule_registry)
3638
}
3739
}

vortex-array/src/expr/transform/partition.rs

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,12 @@ mod tests {
290290

291291
let split_a = partitioned.find_partition(&"a".into()).unwrap();
292292
assert_eq!(
293-
&simplify_typed(split_a.clone(), &dtype, &ExprSession::default()).unwrap(),
293+
&simplify_typed(
294+
split_a.clone(),
295+
&dtype,
296+
ExprSession::default().rewrite_rules()
297+
)
298+
.unwrap(),
294299
&pack(
295300
[
296301
("a_0", get_item("x", get_item("a", root()))),
@@ -340,7 +345,7 @@ mod tests {
340345
get_item("y", get_item("a", root())),
341346
select(["a", "b"], root()),
342347
);
343-
let expr = simplify_typed(expr, &dtype, &ExprSession::default()).unwrap();
348+
let expr = simplify_typed(expr, &dtype, ExprSession::default().rewrite_rules()).unwrap();
344349
let partitioned =
345350
partition(expr, &dtype, annotate_scope_access(fields), &optimizer).unwrap();
346351

vortex-array/src/expr/transform/rules.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,12 @@ pub struct RuleContext;
137137
impl private::Sealed for RuleContext {}
138138
impl RewriteContext for RuleContext {}
139139

140+
impl From<&TypedRuleContext> for RuleContext {
141+
fn from(_value: &TypedRuleContext) -> Self {
142+
RuleContext
143+
}
144+
}
145+
140146
/// Type-erased wrappers that allows dynamic dispatch.
141147
pub(crate) trait DynReduceRule: Send + Sync {
142148
fn reduce(&self, expr: &Expression, ctx: &RuleContext) -> VortexResult<Option<Expression>>;

0 commit comments

Comments
 (0)