Skip to content

Commit e628469

Browse files
committed
cleanup
Signed-off-by: Joe Isaacs <[email protected]>
1 parent 1dfb21d commit e628469

File tree

6 files changed

+89
-52
lines changed

6 files changed

+89
-52
lines changed

vortex-array/src/expr/exprs/get_item/transform.rs

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,8 @@ mod tests {
4040
use crate::expr::exprs::literal::lit;
4141
use crate::expr::exprs::pack::pack;
4242
use crate::expr::session::ExprSession;
43+
use crate::expr::transform::ExprOptimizer;
4344
use crate::expr::transform::rules::{ReduceRule, RuleContext};
44-
use crate::expr::transform::simplify_typed;
4545

4646
#[test]
4747
fn test_pack_get_item_rule() {
@@ -82,7 +82,9 @@ mod tests {
8282

8383
let dtype = DType::Primitive(PType::I32, NonNullable);
8484

85-
let result = simplify_typed(get_z, &dtype, ExprSession::default().rewrite_rules()).unwrap();
85+
let session = ExprSession::default();
86+
let optimizer = ExprOptimizer::new(&session);
87+
let result = optimizer.optimize_typed(get_z, &dtype).unwrap();
8688

8789
assert_eq!(&result, &lit(4));
8890
}
@@ -103,8 +105,9 @@ mod tests {
103105

104106
let dtype = DType::Primitive(PType::I32, NonNullable);
105107

106-
let result =
107-
simplify_typed(get_final, &dtype, ExprSession::default().rewrite_rules()).unwrap();
108+
let session = ExprSession::default();
109+
let optimizer = ExprOptimizer::new(&session);
110+
let result = optimizer.optimize_typed(get_final, &dtype).unwrap();
108111

109112
assert_eq!(&result, &lit(42));
110113
}
@@ -120,8 +123,9 @@ mod tests {
120123

121124
let dtype = DType::Primitive(PType::I32, NonNullable);
122125

123-
let result =
124-
simplify_typed(get_result, &dtype, ExprSession::default().rewrite_rules()).unwrap();
126+
let session = ExprSession::default();
127+
let optimizer = ExprOptimizer::new(&session);
128+
let result = optimizer.optimize_typed(get_result, &dtype).unwrap();
125129

126130
let expected = checked_add(lit(1), lit(10));
127131
assert_eq!(&result, &expected);

vortex-array/src/expr/session/rewrite.rs

Lines changed: 69 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ where
105105
}
106106

107107
type RuleRegistry<Rule> = DashMap<ExprId, Vec<Arc<Rule>>>;
108-
type ParentRuleRegistry<Rule> = DashMap<ExprId, Vec<Arc<Rule>>>;
108+
type ParentRuleRegistry<Rule> = DashMap<(ExprId, ExprId), Vec<Arc<Rule>>>;
109109

110110
/// Inner struct that holds all the rule registries.
111111
/// Wrapped in a single Arc by RewriteRuleRegistry for efficient cloning.
@@ -218,7 +218,7 @@ impl RewriteRuleRegistry {
218218
};
219219
self.inner
220220
.parent_rules
221-
.entry(child_vtable.id())
221+
.entry((child_vtable.id(), parent_vtable.id()))
222222
.or_default()
223223
.push(Arc::new(adapter));
224224
}
@@ -235,7 +235,7 @@ impl RewriteRuleRegistry {
235235
_phantom: PhantomData,
236236
};
237237
self.inner
238-
.parent_rules
238+
.any_parent_rules
239239
.entry(child_vtable.id())
240240
.or_default()
241241
.push(Arc::new(adapter));
@@ -259,7 +259,7 @@ impl RewriteRuleRegistry {
259259
};
260260
self.inner
261261
.typed_parent_rules
262-
.entry(child_vtable.id())
262+
.entry((child_vtable.id(), parent_vtable.id()))
263263
.or_default()
264264
.push(Arc::new(adapter));
265265
}
@@ -279,7 +279,7 @@ impl RewriteRuleRegistry {
279279
_phantom: PhantomData,
280280
};
281281
self.inner
282-
.typed_parent_rules
282+
.typed_any_parent_rules
283283
.entry(child_vtable.id())
284284
.or_default()
285285
.push(Arc::new(adapter));
@@ -288,48 +288,84 @@ impl RewriteRuleRegistry {
288288
/// Execute a callback with all typed reduce rules for a given expression ID.
289289
pub(crate) fn with_typed_reduce_rules<F, R>(&self, id: &ExprId, f: F) -> R
290290
where
291-
F: FnOnce(&[Arc<dyn DynTypedReduceRule>]) -> R,
291+
F: FnOnce(&mut dyn Iterator<Item = &dyn DynTypedReduceRule>) -> R,
292292
{
293-
if let Some(entry) = self.inner.typed_reduce_rules.get(id) {
294-
f(entry.value())
295-
} else {
296-
f(&[])
297-
}
293+
f(&mut self
294+
.inner
295+
.typed_reduce_rules
296+
.get(id)
297+
.iter()
298+
.map(|v| v.value())
299+
.flatten()
300+
.map(|arc| arc.as_ref()))
298301
}
299302

300303
/// Execute a callback with all untyped reduce rules for a given expression ID.
301304
pub(crate) fn with_reduce_rules<F, R>(&self, id: &ExprId, f: F) -> R
302305
where
303-
F: FnOnce(&[Arc<dyn DynReduceRule>]) -> R,
306+
F: FnOnce(&mut dyn Iterator<Item = &dyn DynReduceRule>) -> R,
304307
{
305-
if let Some(entry) = self.inner.reduce_rules.get(id) {
306-
f(entry.value())
307-
} else {
308-
f(&[])
309-
}
308+
f(&mut self
309+
.inner
310+
.reduce_rules
311+
.get(id)
312+
.iter()
313+
.map(|v| v.value())
314+
.flatten()
315+
.map(|arc| arc.as_ref()))
310316
}
311317

312-
/// Execute a callback with all untyped parent reduce rules for a given expression ID.
313-
pub(crate) fn with_parent_rules<F, R>(&self, id: &ExprId, f: F) -> R
318+
/// Execute a callback with all untyped parent reduce rules for a given child and parent expression ID.
319+
///
320+
/// Returns rules from both specific parent rules (if parent_id provided) and "any parent" wildcard rules.
321+
pub(crate) fn with_parent_rules<F, R>(
322+
&self,
323+
child_id: &ExprId,
324+
parent_id: Option<&ExprId>,
325+
f: F,
326+
) -> R
314327
where
315-
F: FnOnce(&[Arc<dyn DynParentReduceRule>]) -> R,
328+
F: FnOnce(&mut dyn Iterator<Item = &dyn DynParentReduceRule>) -> R,
316329
{
317-
if let Some(entry) = self.inner.parent_rules.get(id) {
318-
f(entry.value())
319-
} else {
320-
f(&[])
321-
}
330+
let specific_entry = parent_id.and_then(|pid| {
331+
self.inner
332+
.parent_rules
333+
.get(&(child_id.clone(), pid.clone()))
334+
});
335+
let wildcard_entry = self.inner.any_parent_rules.get(child_id);
336+
337+
f(&mut specific_entry
338+
.iter()
339+
.map(|v| v.value())
340+
.flatten()
341+
.chain(wildcard_entry.iter().map(|v| v.value()).flatten())
342+
.map(|arc| arc.as_ref()))
322343
}
323344

324-
/// Execute a callback with all typed parent reduce rules for a given expression ID.
325-
pub(crate) fn with_typed_parent_rules<F, R>(&self, id: &ExprId, f: F) -> R
345+
/// Execute a callback with all typed parent reduce rules for a given child and parent expression ID.
346+
///
347+
/// Returns rules from both specific parent rules (if parent_id provided) and "any parent" wildcard rules.
348+
pub(crate) fn with_typed_parent_rules<F, R>(
349+
&self,
350+
child_id: &ExprId,
351+
parent_id: Option<&ExprId>,
352+
f: F,
353+
) -> R
326354
where
327-
F: FnOnce(&[Arc<dyn DynTypedParentReduceRule>]) -> R,
355+
F: FnOnce(&mut dyn Iterator<Item = &dyn DynTypedParentReduceRule>) -> R,
328356
{
329-
if let Some(entry) = self.inner.typed_parent_rules.get(id) {
330-
f(entry.value())
331-
} else {
332-
f(&[])
333-
}
357+
let specific_entry = parent_id.and_then(|pid| {
358+
self.inner
359+
.typed_parent_rules
360+
.get(&(child_id.clone(), pid.clone()))
361+
});
362+
let wildcard_entry = self.inner.typed_any_parent_rules.get(child_id);
363+
364+
f(&mut specific_entry
365+
.iter()
366+
.map(|v| v.value())
367+
.flatten()
368+
.chain(wildcard_entry.iter().map(|v| v.value()).flatten())
369+
.map(|arc| arc.as_ref()))
334370
}
335371
}

vortex-array/src/expr/transform/mod.rs

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,5 +16,3 @@ pub use optimizer::*;
1616
pub use partition::*;
1717
pub use replace::*;
1818
pub use rules::*;
19-
pub(crate) use simplify::*;
20-
pub(crate) use simplify_typed::*;

vortex-array/src/expr/transform/optimizer.rs

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,27 +6,23 @@ use vortex_error::VortexResult;
66

77
use crate::expr::Expression;
88
use crate::expr::session::{ExprSession, RewriteRuleRegistry};
9-
use crate::expr::transform::{simplify, simplify_typed};
9+
use crate::expr::transform::simplify::simplify;
10+
use crate::expr::transform::simplify_typed::simplify_typed;
1011

1112
/// A unified optimizer for expressions that can work with or without type information.
1213
pub struct ExprOptimizer {
1314
rule_registry: RewriteRuleRegistry,
1415
}
1516

1617
impl ExprOptimizer {
17-
/// Create a new untyped optimizer.
18-
///
19-
/// This optimizer will use untyped simplification rules only.
18+
/// Creates a new optimizer with the rules in `ExprSession`.
2019
pub fn new(session: &ExprSession) -> Self {
2120
Self {
2221
rule_registry: session.rewrite_rules().clone(),
2322
}
2423
}
2524

26-
/// Optimize the given expression.
27-
///
28-
/// If this optimizer was created with a dtype, this will perform typed optimization.
29-
/// Otherwise, it will perform untyped optimization.
25+
/// Optimize the given expression without a dtype.
3026
pub fn optimize(&self, expr: Expression) -> VortexResult<Expression> {
3127
simplify(expr, &self.rule_registry)
3228
}

vortex-array/src/expr/transform/simplify.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ use crate::expr::traversal::{NodeExt, Transformed};
1313
///
1414
/// This applies only untyped rewrite rules registered in the default session.
1515
/// If the scope dtype is known, see `simplify_typed` for a simplifier which uses dtype.
16-
pub(crate) fn simplify(
16+
pub(super) fn simplify(
1717
e: Expression,
1818
rule_registry: &RewriteRuleRegistry,
1919
) -> VortexResult<Expression> {
@@ -35,6 +35,7 @@ fn apply_parent_rules(
3535
for (idx, child) in node.children().iter().enumerate() {
3636
let result = rule_registry.with_parent_rules(
3737
&child.id(),
38+
Some(&node.id()),
3839
|rules| -> VortexResult<Option<Expression>> {
3940
for rule in rules {
4041
if let Some(new_expr) = rule.reduce_parent(child, &node, idx, ctx)? {
@@ -53,7 +54,7 @@ fn apply_parent_rules(
5354
.map(|t| t.into_inner())
5455
}
5556

56-
pub(crate) fn apply_child_rules_impl(
57+
fn apply_child_rules_impl(
5758
expr: Expression,
5859
ctx: &RuleContext,
5960
rule_registry: &RewriteRuleRegistry,

vortex-array/src/expr/transform/simplify_typed.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ use crate::expr::traversal::{NodeExt, Transformed};
1515
///
1616
/// NOTE: After typed simplification, returned expressions is "bound" to the scope DType.
1717
/// Applying the returned expression to a different DType may produce wrong results.
18-
pub(crate) fn simplify_typed(
18+
pub(super) fn simplify_typed(
1919
expr: Expression,
2020
dtype: &DType,
2121
rule_registry: &RewriteRuleRegistry,
@@ -88,6 +88,7 @@ fn apply_parent_rules_impl_typed(
8888
for (idx, child) in node.children().iter().enumerate() {
8989
let result = rule_registry.with_typed_parent_rules(
9090
&child.id(),
91+
Some(&node.id()),
9192
|rules| -> VortexResult<Option<Expression>> {
9293
for rule in rules {
9394
if let Some(new_expr) = rule.reduce_parent(child, &node, idx, ctx)? {
@@ -105,6 +106,7 @@ fn apply_parent_rules_impl_typed(
105106
let untyped_ctx: RuleContext = ctx.into();
106107
let result = rule_registry.with_parent_rules(
107108
&child.id(),
109+
Some(&node.id()),
108110
|rules| -> VortexResult<Option<Expression>> {
109111
for rule in rules {
110112
if let Some(new_expr) =

0 commit comments

Comments
 (0)