Skip to content

Commit 7f8ec35

Browse files
committed
wip
Signed-off-by: Joe Isaacs <[email protected]>
1 parent dd547b3 commit 7f8ec35

File tree

6 files changed

+717
-48
lines changed

6 files changed

+717
-48
lines changed

vortex-array/src/expr/session.rs

Lines changed: 97 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
// SPDX-License-Identifier: Apache-2.0
22
// SPDX-FileCopyrightText: Copyright the Vortex contributors
33

4+
use std::sync::Arc;
5+
46
use vortex_session::registry::Registry;
57
use vortex_session::{Ref, SessionExt};
8+
use vortex_utils::aliases::hash_map::HashMap;
69

7-
use crate::expr::ExprVTable;
810
use crate::expr::exprs::between::Between;
911
use crate::expr::exprs::binary::Binary;
1012
use crate::expr::exprs::cast::Cast;
@@ -18,21 +20,93 @@ use crate::expr::exprs::not::Not;
1820
use crate::expr::exprs::pack::Pack;
1921
use crate::expr::exprs::root::Root;
2022
use crate::expr::exprs::select::Select;
23+
use crate::expr::transform::remove_select::RemoveSelectRule;
24+
use crate::expr::transform::simplify::PackGetItemRule;
25+
use crate::expr::transform::traits::{ChildReduceRule, ParentReduceRule, ReduceRule};
26+
use crate::expr::{ExprId, ExprVTable};
2127

2228
/// Registry of expression vtables.
2329
pub type ExprRegistry = Registry<ExprVTable>;
2430

25-
/// Session state for expression vtables.
31+
/// Registry of expression rewrite rules.
32+
///
33+
/// Stores rewrite rules indexed by the expression ID they apply to.
34+
#[derive(Default)]
35+
pub struct RewriteRuleRegistry {
36+
/// Generic reduce rules (no context needed), indexed by expression ID
37+
reduce_rules: HashMap<ExprId, Vec<Arc<dyn ReduceRule>>>,
38+
/// Child reduce rules, indexed by expression ID
39+
child_rules: HashMap<ExprId, Vec<Arc<dyn ChildReduceRule>>>,
40+
/// Parent reduce rules, indexed by expression ID
41+
parent_rules: HashMap<ExprId, Vec<Arc<dyn ParentReduceRule>>>,
42+
}
43+
44+
impl std::fmt::Debug for RewriteRuleRegistry {
45+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
46+
f.debug_struct("RewriteRuleRegistry")
47+
.field("reduce_rules_count", &self.reduce_rules.len())
48+
.field("child_rules_count", &self.child_rules.len())
49+
.field("parent_rules_count", &self.parent_rules.len())
50+
.finish()
51+
}
52+
}
53+
54+
impl RewriteRuleRegistry {
55+
pub fn new() -> Self {
56+
Self::default()
57+
}
58+
59+
/// Register a generic reduce rule.
60+
pub fn register_reduce_rule(&mut self, rule: Arc<dyn ReduceRule>) {
61+
let id = rule.id();
62+
self.reduce_rules.entry(id).or_default().push(rule);
63+
}
64+
65+
/// Register a child reduce rule.
66+
pub fn register_child_rule(&mut self, rule: Arc<dyn ChildReduceRule>) {
67+
let id = rule.id();
68+
self.child_rules.entry(id).or_default().push(rule);
69+
}
70+
71+
/// Register a parent reduce rule.
72+
pub fn register_parent_rule(&mut self, rule: Arc<dyn ParentReduceRule>) {
73+
let id = rule.id();
74+
self.parent_rules.entry(id).or_default().push(rule);
75+
}
76+
77+
/// Get all generic reduce rules for a given expression ID.
78+
pub fn reduce_rules_for(&self, id: &ExprId) -> Option<&[Arc<dyn ReduceRule>]> {
79+
self.reduce_rules.get(id).map(|v| v.as_slice())
80+
}
81+
82+
/// Get all child reduce rules for a given expression ID.
83+
pub fn child_rules_for(&self, id: &ExprId) -> Option<&[Arc<dyn ChildReduceRule>]> {
84+
self.child_rules.get(id).map(|v| v.as_slice())
85+
}
86+
87+
/// Get all parent reduce rules for a given expression ID.
88+
pub fn parent_rules_for(&self, id: &ExprId) -> Option<&[Arc<dyn ParentReduceRule>]> {
89+
self.parent_rules.get(id).map(|v| v.as_slice())
90+
}
91+
}
92+
93+
/// Session state for expression vtables and rewrite rules.
2694
#[derive(Debug)]
2795
pub struct ExprSession {
2896
registry: ExprRegistry,
97+
rewrite_rules: RewriteRuleRegistry,
2998
}
3099

31100
impl ExprSession {
32101
pub fn registry(&self) -> &ExprRegistry {
33102
&self.registry
34103
}
35104

105+
/// Get the rewrite rule registry.
106+
pub fn rewrite_rules(&self) -> &RewriteRuleRegistry {
107+
&self.rewrite_rules
108+
}
109+
36110
/// Register an expression vtable in the session, replacing any existing vtable with the same ID.
37111
pub fn register(&self, expr: ExprVTable) {
38112
self.registry.register(expr)
@@ -42,6 +116,21 @@ impl ExprSession {
42116
pub fn register_many(&self, exprs: impl IntoIterator<Item = ExprVTable>) {
43117
self.registry.register_many(exprs);
44118
}
119+
120+
/// Register a generic reduce rule in the session.
121+
pub fn register_reduce_rule(&mut self, rule: impl ReduceRule + 'static) {
122+
self.rewrite_rules.register_reduce_rule(Arc::new(rule));
123+
}
124+
125+
/// Register a child reduce rule in the session.
126+
pub fn register_child_rule(&mut self, rule: impl ChildReduceRule + 'static) {
127+
self.rewrite_rules.register_child_rule(Arc::new(rule));
128+
}
129+
130+
/// Register a parent reduce rule in the session.
131+
pub fn register_parent_rule(&mut self, rule: impl ParentReduceRule + 'static) {
132+
self.rewrite_rules.register_parent_rule(Arc::new(rule));
133+
}
45134
}
46135

47136
impl Default for ExprSession {
@@ -65,8 +154,14 @@ impl Default for ExprSession {
65154
ExprVTable::from_static(&Select),
66155
]);
67156

157+
// Register built-in rewrite rules
158+
let mut rewrite_rules = RewriteRuleRegistry::new();
159+
rewrite_rules.register_reduce_rule(Arc::new(RemoveSelectRule));
160+
rewrite_rules.register_child_rule(Arc::new(PackGetItemRule));
161+
68162
Self {
69163
registry: expressions,
164+
rewrite_rules,
70165
}
71166
}
72167
}

vortex-array/src/expr/transform/mod.rs

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,21 +7,14 @@ pub mod immediate_access;
77
pub(crate) mod match_between;
88
mod partition;
99
mod remove_merge;
10-
mod remove_select;
10+
pub mod remove_select;
1111
mod replace;
12-
mod simplify;
12+
pub mod simplify;
1313
mod simplify_typed;
14+
pub mod traits;
1415

1516
pub use partition::*;
1617
pub use replace::*;
1718
pub use simplify::*;
1819
pub use simplify_typed::*;
19-
20-
use crate::expr::traversal::Transformed;
21-
use crate::expr::{ExprId, Expression};
22-
23-
trait ExpressionParentTransformer {
24-
fn id(&self) -> ExprId;
25-
26-
fn reduce_parent(&self, expr: &Expression, parent: &Expression) -> Transformed<Expression>;
27-
}
20+
pub use traits::*;

vortex-array/src/expr/transform/remove_select.rs

Lines changed: 133 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,12 @@
44
use vortex_dtype::DType;
55
use vortex_error::{VortexResult, vortex_err};
66

7-
use crate::expr::Expression;
87
use crate::expr::exprs::get_item::get_item;
98
use crate::expr::exprs::pack::pack;
109
use crate::expr::exprs::select::Select;
10+
use crate::expr::transform::traits::{ReduceRule, RewriteContext};
1111
use crate::expr::traversal::{NodeExt, Transformed};
12+
use crate::expr::{ExprId, Expression};
1213

1314
/// Replaces [crate::SelectExpr] with combination of [crate::GetItem] and [crate::Pack] expressions.
1415
pub(crate) fn remove_select(e: Expression, ctx: &DType) -> VortexResult<Expression> {
@@ -20,38 +21,87 @@ fn remove_select_transformer(
2021
node: Expression,
2122
ctx: &DType,
2223
) -> VortexResult<Transformed<Expression>> {
23-
match node.as_opt::<Select>() {
24-
None => Ok(Transformed::no(node)),
25-
Some(select) => {
26-
let child = select.child();
27-
let child_dtype = child.return_dtype(ctx)?;
28-
let child_nullability = child_dtype.nullability();
29-
30-
let child_dtype = child_dtype.as_struct_fields_opt().ok_or_else(|| {
31-
vortex_err!(
32-
"Select child must return a struct dtype, however it was a {}",
33-
child_dtype
34-
)
35-
})?;
36-
37-
let expr = pack(
38-
select
39-
.data()
40-
.as_include_names(child_dtype.names())
41-
.map_err(|e| {
42-
e.with_context(format!(
43-
"Select fields {:?} must be a subset of child fields {:?}",
44-
select.data(),
45-
child_dtype.names()
46-
))
47-
})?
48-
.iter()
49-
.map(|name| (name.clone(), get_item(name.clone(), child.clone()))),
50-
child_nullability,
51-
);
52-
53-
Ok(Transformed::yes(expr))
54-
}
24+
if let Some(select) = node.as_opt::<Select>() {
25+
let child = select.child();
26+
let child_dtype = child.return_dtype(ctx)?;
27+
let child_nullability = child_dtype.nullability();
28+
29+
let child_dtype = child_dtype.as_struct_fields_opt().ok_or_else(|| {
30+
vortex_err!(
31+
"Select child must return a struct dtype, however it was a {}",
32+
child_dtype
33+
)
34+
})?;
35+
36+
let expr = pack(
37+
select
38+
.data()
39+
.as_include_names(child_dtype.names())
40+
.map_err(|e| {
41+
e.with_context(format!(
42+
"Select fields {:?} must be a subset of child fields {:?}",
43+
select.data(),
44+
child_dtype.names()
45+
))
46+
})?
47+
.iter()
48+
.map(|name| (name.clone(), get_item(name.clone(), child.clone()))),
49+
child_nullability,
50+
);
51+
52+
Ok(Transformed::yes(expr))
53+
} else {
54+
Ok(Transformed::no(node))
55+
}
56+
}
57+
58+
/// Rule that removes Select expressions by converting them to Pack + GetItem.
59+
///
60+
/// Transforms: `select(["a", "b"], expr)` → `pack(a: get_item("a", expr), b: get_item("b", expr))`
61+
pub struct RemoveSelectRule;
62+
63+
impl ReduceRule for RemoveSelectRule {
64+
fn id(&self) -> ExprId {
65+
ExprId::new_ref("vortex.select")
66+
}
67+
68+
fn reduce(
69+
&self,
70+
expr: &Expression,
71+
ctx: &dyn RewriteContext,
72+
) -> VortexResult<Option<Expression>> {
73+
let Some(select) = expr.as_opt::<Select>() else {
74+
return Ok(None);
75+
};
76+
77+
let child = select.child();
78+
let child_dtype = child.return_dtype(ctx.dtype())?;
79+
let child_nullability = child_dtype.nullability();
80+
81+
let child_dtype = child_dtype.as_struct_fields_opt().ok_or_else(|| {
82+
vortex_err!(
83+
"Select child must return a struct dtype, however it was a {}",
84+
child_dtype
85+
)
86+
})?;
87+
88+
let expr = pack(
89+
select
90+
.data()
91+
.as_include_names(child_dtype.names())
92+
.map_err(|e| {
93+
e.with_context(format!(
94+
"Select fields {:?} must be a subset of child fields {:?}",
95+
select.data(),
96+
child_dtype.names()
97+
))
98+
})?
99+
.iter()
100+
.map(|name| (name.clone(), get_item(name.clone(), child.clone()))),
101+
child_nullability,
102+
);
103+
104+
Ok(Some(expr))
55105
}
56106
}
57107

@@ -61,10 +111,13 @@ mod tests {
61111
use vortex_dtype::PType::I32;
62112
use vortex_dtype::{DType, StructFields};
63113

64-
use super::remove_select;
114+
use super::{RemoveSelectRule, remove_select};
65115
use crate::expr::exprs::pack::Pack;
66116
use crate::expr::exprs::root::root;
67117
use crate::expr::exprs::select::select;
118+
use crate::expr::session::ExprSession;
119+
use crate::expr::transform::simplify_typed::apply_child_rules;
120+
use crate::expr::transform::traits::{ReduceRule, SimpleRewriteContext};
68121

69122
#[test]
70123
fn test_remove_select() {
@@ -78,4 +131,50 @@ mod tests {
78131
assert!(e.is::<Pack>());
79132
assert!(e.return_dtype(&dtype).unwrap().is_nullable());
80133
}
134+
135+
#[test]
136+
fn test_remove_select_rule_direct() {
137+
let dtype = DType::Struct(
138+
StructFields::new(["a", "b"].into(), vec![I32.into(), I32.into()]),
139+
Nullable,
140+
);
141+
let e = select(["a", "b"], root());
142+
143+
let rule = RemoveSelectRule;
144+
let ctx = SimpleRewriteContext { dtype: &dtype };
145+
let result = rule.reduce(&e, &ctx).unwrap();
146+
147+
assert!(result.is_some());
148+
let transformed = result.unwrap();
149+
assert!(transformed.is::<Pack>());
150+
assert!(transformed.return_dtype(&dtype).unwrap().is_nullable());
151+
}
152+
153+
#[test]
154+
fn test_remove_select_via_session() {
155+
let dtype = DType::Struct(
156+
StructFields::new(
157+
["a", "b", "c"].into(),
158+
vec![I32.into(), I32.into(), I32.into()],
159+
),
160+
Nullable,
161+
);
162+
163+
// Create expression: select(["a", "c"], root())
164+
let e = select(["a", "c"], root());
165+
166+
// Use session which has RemoveSelectRule registered
167+
let session = ExprSession::default();
168+
let result = apply_child_rules(e, &dtype, &session).unwrap();
169+
170+
// Should be transformed to Pack
171+
assert!(result.is::<Pack>());
172+
173+
// Verify the dtype has only selected fields
174+
let result_dtype = result.return_dtype(&dtype).unwrap();
175+
let fields = result_dtype.as_struct_fields_opt().unwrap();
176+
assert_eq!(fields.names().len(), 2);
177+
assert_eq!(fields.names()[0].as_ref(), "a");
178+
assert_eq!(fields.names()[1].as_ref(), "c");
179+
}
81180
}

0 commit comments

Comments
 (0)