Skip to content

Commit 60492c4

Browse files
feat[vortex-expr]: add expr rewrite rules (#5348)
Breaking: Using a vortex session in more places - simplify_typed/simplify - Layout::new_reader --------- Signed-off-by: Joe Isaacs <[email protected]> Co-authored-by: Nicholas Gates <[email protected]>
1 parent 4c990b4 commit 60492c4

File tree

45 files changed

+1349
-363
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

45 files changed

+1349
-363
lines changed

vortex-array/src/expr/exprs/get_item.rs renamed to vortex-array/src/expr/exprs/get_item/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
// SPDX-License-Identifier: Apache-2.0
22
// SPDX-FileCopyrightText: Copyright the Vortex contributors
33

4+
pub mod transform;
5+
46
use std::fmt::Formatter;
57
use std::ops::Not;
68

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
// SPDX-FileCopyrightText: Copyright the Vortex contributors
3+
4+
use vortex_error::VortexResult;
5+
6+
use crate::expr::exprs::get_item::GetItem;
7+
use crate::expr::exprs::pack::Pack;
8+
use crate::expr::transform::rules::{ReduceRule, RuleContext};
9+
use crate::expr::{Expression, ExpressionView};
10+
11+
/// Rewrite rule: `pack(l_1: e_1, ..., l_i: e_i, ..., l_n: e_n).get_item(l_i) = e_i`
12+
///
13+
/// Simplifies accessing a field from a pack expression by directly returning the field's
14+
/// expression instead of materializing the pack.
15+
pub struct PackGetItemRule;
16+
17+
impl ReduceRule<GetItem, RuleContext> for PackGetItemRule {
18+
fn reduce(
19+
&self,
20+
get_item: &ExpressionView<GetItem>,
21+
_ctx: &RuleContext,
22+
) -> VortexResult<Option<Expression>> {
23+
if let Some(pack) = get_item.child(0).as_opt::<Pack>() {
24+
let field_expr = pack.field(get_item.data())?;
25+
return Ok(Some(field_expr));
26+
}
27+
28+
Ok(None)
29+
}
30+
}
31+
32+
#[cfg(test)]
33+
mod tests {
34+
use vortex_dtype::Nullability::NonNullable;
35+
use vortex_dtype::{DType, PType};
36+
37+
use super::PackGetItemRule;
38+
use crate::expr::exprs::binary::checked_add;
39+
use crate::expr::exprs::get_item::{GetItem, get_item};
40+
use crate::expr::exprs::literal::lit;
41+
use crate::expr::exprs::pack::pack;
42+
use crate::expr::session::ExprSession;
43+
use crate::expr::transform::rules::{ReduceRule, RuleContext};
44+
use crate::expr::transform::simplify_typed;
45+
46+
#[test]
47+
fn test_pack_get_item_rule() {
48+
let rule = PackGetItemRule;
49+
50+
// Create: pack(a: lit(1), b: lit(2)).get_item("b")
51+
let pack_expr = pack([("a", lit(1)), ("b", lit(2))], NonNullable);
52+
let get_item_expr = get_item("b", pack_expr);
53+
54+
let get_item_view = get_item_expr.as_::<GetItem>();
55+
let result = rule.reduce(&get_item_view, &RuleContext).unwrap();
56+
57+
assert!(result.is_some());
58+
assert_eq!(&result.unwrap(), &lit(2));
59+
}
60+
61+
#[test]
62+
fn test_pack_get_item_rule_no_match() {
63+
let rule = PackGetItemRule;
64+
65+
// Create: get_item("x", lit(42)) - not a pack child
66+
let lit_expr = lit(42);
67+
let get_item_expr = get_item("x", lit_expr);
68+
69+
let get_item_view = get_item_expr.as_::<GetItem>();
70+
let result = rule.reduce(&get_item_view, &RuleContext).unwrap();
71+
72+
assert!(result.is_none());
73+
}
74+
75+
#[test]
76+
fn test_multi_level_pack_get_item_simplify() {
77+
let inner_pack = pack([("a", lit(1)), ("b", lit(2))], NonNullable);
78+
let get_a = get_item("a", inner_pack);
79+
80+
let outer_pack = pack([("x", get_a), ("y", lit(3)), ("z", lit(4))], NonNullable);
81+
let get_z = get_item("z", outer_pack);
82+
83+
let dtype = DType::Primitive(PType::I32, NonNullable);
84+
85+
let result = simplify_typed(get_z, &dtype, &ExprSession::default()).unwrap();
86+
87+
assert_eq!(&result, &lit(4));
88+
}
89+
90+
#[test]
91+
fn test_deeply_nested_pack_get_item() {
92+
let innermost = pack([("a", lit(42))], NonNullable);
93+
let get_a = get_item("a", innermost);
94+
95+
let level2 = pack([("b", get_a)], NonNullable);
96+
let get_b = get_item("b", level2);
97+
98+
let level3 = pack([("c", get_b)], NonNullable);
99+
let get_c = get_item("c", level3);
100+
101+
let outermost = pack([("final", get_c)], NonNullable);
102+
let get_final = get_item("final", outermost);
103+
104+
let dtype = DType::Primitive(PType::I32, NonNullable);
105+
106+
let result = simplify_typed(get_final, &dtype, &ExprSession::default()).unwrap();
107+
108+
assert_eq!(&result, &lit(42));
109+
}
110+
111+
#[test]
112+
fn test_partial_pack_get_item_simplify() {
113+
let inner_pack = pack([("x", lit(1)), ("y", lit(2))], NonNullable);
114+
let get_x = get_item("x", inner_pack);
115+
let add_expr = checked_add(get_x, lit(10));
116+
117+
let outer_pack = pack([("result", add_expr)], NonNullable);
118+
let get_result = get_item("result", outer_pack);
119+
120+
let dtype = DType::Primitive(PType::I32, NonNullable);
121+
122+
let result = simplify_typed(get_result, &dtype, &ExprSession::default()).unwrap();
123+
124+
let expected = checked_add(lit(1), lit(10));
125+
assert_eq!(&result, &expected);
126+
}
127+
}
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
// SPDX-License-Identifier: Apache-2.0
22
// SPDX-FileCopyrightText: Copyright the Vortex contributors
33

4+
pub mod transform;
5+
46
use std::fmt::Formatter;
57
use std::hash::Hash;
68
use std::sync::Arc;
Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
// SPDX-FileCopyrightText: Copyright the Vortex contributors
3+
4+
use itertools::Itertools as _;
5+
use vortex_error::{VortexExpect, VortexResult, vortex_bail};
6+
use vortex_utils::aliases::hash_set::HashSet;
7+
8+
use crate::expr::exprs::get_item::get_item;
9+
use crate::expr::exprs::merge::{DuplicateHandling, Merge};
10+
use crate::expr::exprs::pack::pack;
11+
use crate::expr::transform::rules::{ReduceRule, TypedRuleContext};
12+
use crate::expr::{Expression, ExpressionView};
13+
14+
/// Rule that removes Merge expressions by converting them to Pack + GetItem.
15+
///
16+
/// Transforms: `merge([struct1, struct2])` → `pack(field1: get_item("field1", struct1), field2: get_item("field2", struct2), ...)`
17+
pub struct RemoveMergeRule;
18+
19+
impl ReduceRule<Merge, TypedRuleContext> for RemoveMergeRule {
20+
fn reduce(
21+
&self,
22+
merge: &ExpressionView<Merge>,
23+
ctx: &TypedRuleContext,
24+
) -> VortexResult<Option<Expression>> {
25+
let merge_dtype = merge.return_dtype(ctx.dtype())?;
26+
let mut names = Vec::with_capacity(merge.children().len() * 2);
27+
let mut children = Vec::with_capacity(merge.children().len() * 2);
28+
let mut duplicate_names = HashSet::<_>::new();
29+
30+
for child in merge.children().iter() {
31+
let child_dtype = child.return_dtype(ctx.dtype())?;
32+
if !child_dtype.is_struct() {
33+
vortex_bail!(
34+
"Merge child must return a non-nullable struct dtype, got {}",
35+
child_dtype
36+
)
37+
}
38+
39+
let child_dtype = child_dtype
40+
.as_struct_fields_opt()
41+
.vortex_expect("expected struct");
42+
43+
for name in child_dtype.names().iter() {
44+
if let Some(idx) = names.iter().position(|n| n == name) {
45+
duplicate_names.insert(name.clone());
46+
children[idx] = child.clone();
47+
} else {
48+
names.push(name.clone());
49+
children.push(child.clone());
50+
}
51+
}
52+
53+
if merge.data() == &DuplicateHandling::Error && !duplicate_names.is_empty() {
54+
vortex_bail!(
55+
"merge: duplicate fields in children: {}",
56+
duplicate_names.into_iter().format(", ")
57+
)
58+
}
59+
}
60+
61+
let expr = pack(
62+
names
63+
.into_iter()
64+
.zip(children)
65+
.map(|(name, child)| (name.clone(), get_item(name, child))),
66+
merge_dtype.nullability(),
67+
);
68+
69+
Ok(Some(expr))
70+
}
71+
}
72+
73+
#[cfg(test)]
74+
mod tests {
75+
use vortex_dtype::DType;
76+
use vortex_dtype::Nullability::NonNullable;
77+
use vortex_dtype::PType::{I32, I64, U32, U64};
78+
79+
use super::RemoveMergeRule;
80+
use crate::expr::exprs::get_item::get_item;
81+
use crate::expr::exprs::merge::{DuplicateHandling, Merge, merge_opts};
82+
use crate::expr::exprs::pack::Pack;
83+
use crate::expr::exprs::root::root;
84+
use crate::expr::transform::rules::{ReduceRule, TypedRuleContext};
85+
86+
#[test]
87+
fn test_remove_merge() {
88+
let dtype = DType::struct_(
89+
[
90+
("0", DType::struct_([("a", I32), ("b", I64)], NonNullable)),
91+
("1", DType::struct_([("b", U32), ("c", U64)], NonNullable)),
92+
],
93+
NonNullable,
94+
);
95+
96+
let e = merge_opts(
97+
[get_item("0", root()), get_item("1", root())],
98+
DuplicateHandling::RightMost,
99+
);
100+
101+
let ctx = TypedRuleContext::new(dtype.clone());
102+
let rule = RemoveMergeRule;
103+
let merge_view = e.as_::<Merge>();
104+
let result = rule.reduce(&merge_view, &ctx).unwrap();
105+
106+
assert!(result.is_some());
107+
let result = result.unwrap();
108+
assert!(result.is::<Pack>());
109+
assert_eq!(
110+
result.return_dtype(&dtype).unwrap(),
111+
DType::struct_([("a", I32), ("b", U32), ("c", U64)], NonNullable)
112+
);
113+
}
114+
}
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
// SPDX-License-Identifier: Apache-2.0
22
// SPDX-FileCopyrightText: Copyright the Vortex contributors
33

4+
pub mod transform;
5+
46
use std::fmt::{Display, Formatter};
57

68
use itertools::Itertools;
Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
// SPDX-FileCopyrightText: Copyright the Vortex contributors
3+
4+
use vortex_error::{VortexResult, vortex_err};
5+
6+
use crate::expr::exprs::get_item::get_item;
7+
use crate::expr::exprs::pack::pack;
8+
use crate::expr::exprs::select::Select;
9+
use crate::expr::transform::rules::{ReduceRule, TypedRuleContext};
10+
use crate::expr::{Expression, ExpressionView};
11+
12+
/// Rule that removes Select expressions by converting them to Pack + GetItem.
13+
///
14+
/// Transforms: `select(["a", "b"], expr)` → `pack(a: get_item("a", expr), b: get_item("b", expr))`
15+
pub struct RemoveSelectRule;
16+
17+
impl ReduceRule<Select, TypedRuleContext> for RemoveSelectRule {
18+
fn reduce(
19+
&self,
20+
select: &ExpressionView<Select>,
21+
ctx: &TypedRuleContext,
22+
) -> VortexResult<Option<Expression>> {
23+
let child = select.child();
24+
let child_dtype = child.return_dtype(ctx.dtype())?;
25+
let child_nullability = child_dtype.nullability();
26+
27+
let child_dtype = child_dtype.as_struct_fields_opt().ok_or_else(|| {
28+
vortex_err!(
29+
"Select child must return a struct dtype, however it was a {}",
30+
child_dtype
31+
)
32+
})?;
33+
34+
let expr = pack(
35+
select
36+
.data()
37+
.as_include_names(child_dtype.names())
38+
.map_err(|e| {
39+
e.with_context(format!(
40+
"Select fields {:?} must be a subset of child fields {:?}",
41+
select.data(),
42+
child_dtype.names()
43+
))
44+
})?
45+
.iter()
46+
.map(|name| (name.clone(), get_item(name.clone(), child.clone()))),
47+
child_nullability,
48+
);
49+
50+
Ok(Some(expr))
51+
}
52+
}
53+
54+
#[cfg(test)]
55+
mod tests {
56+
use vortex_dtype::Nullability::Nullable;
57+
use vortex_dtype::PType::I32;
58+
use vortex_dtype::{DType, StructFields};
59+
60+
use super::RemoveSelectRule;
61+
use crate::expr::exprs::pack::Pack;
62+
use crate::expr::exprs::root::root;
63+
use crate::expr::exprs::select::{Select, select};
64+
use crate::expr::transform::rules::{ReduceRule, TypedRuleContext};
65+
66+
#[test]
67+
fn test_remove_select_rule() {
68+
let dtype = DType::Struct(
69+
StructFields::new(["a", "b"].into(), vec![I32.into(), I32.into()]),
70+
Nullable,
71+
);
72+
let e = select(["a", "b"], root());
73+
74+
let rule = RemoveSelectRule;
75+
let ctx = TypedRuleContext::new(dtype.clone());
76+
let select_view = e.as_::<Select>();
77+
let result = rule.reduce(&select_view, &ctx).unwrap();
78+
79+
assert!(result.is_some());
80+
let transformed = result.unwrap();
81+
assert!(transformed.is::<Pack>());
82+
assert!(transformed.return_dtype(&dtype).unwrap().is_nullable());
83+
}
84+
85+
#[test]
86+
fn test_remove_select_rule_exclude_fields() {
87+
use crate::expr::exprs::select::select_exclude;
88+
89+
let dtype = DType::Struct(
90+
StructFields::new(
91+
["a", "b", "c"].into(),
92+
vec![I32.into(), I32.into(), I32.into()],
93+
),
94+
Nullable,
95+
);
96+
let e = select_exclude(["c"], root());
97+
98+
let rule = RemoveSelectRule;
99+
let ctx = TypedRuleContext::new(dtype.clone());
100+
let select_view = e.as_::<Select>();
101+
let result = rule.reduce(&select_view, &ctx).unwrap();
102+
103+
assert!(result.is_some());
104+
let transformed = result.unwrap();
105+
assert!(transformed.is::<Pack>());
106+
107+
// Should exclude "c" and include "a" and "b"
108+
let result_dtype = transformed.return_dtype(&dtype).unwrap();
109+
assert!(result_dtype.is_nullable());
110+
let fields = result_dtype.as_struct_fields_opt().unwrap();
111+
assert_eq!(fields.names().as_ref(), &["a", "b"]);
112+
}
113+
}

0 commit comments

Comments
 (0)