Skip to content

Commit 9c8c0a7

Browse files
committed
wip
Signed-off-by: Joe Isaacs <[email protected]>
1 parent da256c6 commit 9c8c0a7

File tree

12 files changed

+163
-321
lines changed

12 files changed

+163
-321
lines changed

vortex-array/src/expr/exprs/get_item.rs renamed to vortex-array/src/expr/exprs/get_item/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
// SPDX-License-Identifier: Apache-2.0
22
// SPDX-FileCopyrightText: Copyright the Vortex contributors
33

4+
pub mod transform;
5+
46
use std::fmt::Formatter;
57
use std::ops::Not;
68

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
// SPDX-FileCopyrightText: Copyright the Vortex contributors
3+
4+
use vortex_error::VortexResult;
5+
6+
use crate::expr::exprs::get_item::GetItem;
7+
use crate::expr::exprs::pack::Pack;
8+
use crate::expr::transform::traits::{ChildReduceRule, RewriteContext};
9+
use crate::expr::{Expression, ExpressionView};
10+
11+
/// Rewrite rule: `pack(l_1: e_1, ..., l_i: e_i, ..., l_n: e_n).get_item(l_i) = e_i`
12+
///
13+
/// Simplifies accessing a field from a pack expression by directly returning the field's
14+
/// expression instead of materializing the pack.
15+
///
16+
/// # Example
17+
/// ```
18+
/// # use vortex_array::expr::exprs::{get_item::get_item, literal::lit, pack::pack};
19+
/// # use vortex_dtype::Nullability::NonNullable;
20+
/// let e = get_item("b", pack([("a", lit(1)), ("b", lit(2))], NonNullable));
21+
/// // After applying PackGetItemRule, this becomes: lit(2)
22+
/// ```
23+
pub struct PackGetItemRule;
24+
25+
impl ChildReduceRule<GetItem> for PackGetItemRule {
26+
fn reduce_child(
27+
&self,
28+
get_item: &ExpressionView<GetItem>,
29+
child: &Expression,
30+
child_idx: usize,
31+
_ctx: &dyn RewriteContext,
32+
) -> VortexResult<Option<Expression>> {
33+
// Only consider the first child (child_idx == 0) of GetItem expressions
34+
if child_idx != 0 {
35+
return Ok(None);
36+
}
37+
38+
// Check if child is Pack
39+
if let Some(pack) = child.as_opt::<Pack>() {
40+
// Extract the field from the pack
41+
let field_expr = pack.field(get_item.data())?;
42+
return Ok(Some(field_expr));
43+
}
44+
45+
Ok(None)
46+
}
47+
}
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
// SPDX-License-Identifier: Apache-2.0
22
// SPDX-FileCopyrightText: Copyright the Vortex contributors
33

4+
pub mod transform;
5+
46
use std::fmt::Formatter;
57
use std::hash::Hash;
68
use std::sync::Arc;
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
// SPDX-FileCopyrightText: Copyright the Vortex contributors
3+
4+
use itertools::Itertools as _;
5+
use vortex_error::{VortexExpect, VortexResult, vortex_bail, vortex_err};
6+
use vortex_utils::aliases::hash_set::HashSet;
7+
8+
use crate::expr::exprs::get_item::get_item;
9+
use crate::expr::exprs::merge::{DuplicateHandling, Merge};
10+
use crate::expr::exprs::pack::pack;
11+
use crate::expr::transform::traits::{ReduceRule, RewriteContext};
12+
use crate::expr::{Expression, ExpressionView};
13+
14+
/// Rule that removes Merge expressions by converting them to Pack + GetItem.
15+
///
16+
/// Transforms: `merge([struct1, struct2])` → `pack(field1: get_item("field1", struct1), field2: get_item("field2", struct2), ...)`
17+
pub struct RemoveMergeRule;
18+
19+
impl ReduceRule<Merge> for RemoveMergeRule {
20+
fn reduce(
21+
&self,
22+
merge: &ExpressionView<Merge>,
23+
ctx: &dyn RewriteContext,
24+
) -> VortexResult<Option<Expression>> {
25+
let merge_dtype = merge.return_dtype(ctx.dtype())?;
26+
let mut names = Vec::with_capacity(merge.children().len() * 2);
27+
let mut children = Vec::with_capacity(merge.children().len() * 2);
28+
let mut duplicate_names = HashSet::<_>::new();
29+
30+
for child in merge.children().iter() {
31+
let child_dtype = child.return_dtype(ctx.dtype())?;
32+
if !child_dtype.is_struct() {
33+
return Err(vortex_err!(
34+
"Merge child must return a non-nullable struct dtype, got {}",
35+
child_dtype
36+
));
37+
}
38+
39+
let child_dtype = child_dtype
40+
.as_struct_fields_opt()
41+
.vortex_expect("expected struct");
42+
43+
for name in child_dtype.names().iter() {
44+
if let Some(idx) = names.iter().position(|n| n == name) {
45+
duplicate_names.insert(name.clone());
46+
children[idx] = child.clone();
47+
} else {
48+
names.push(name.clone());
49+
children.push(child.clone());
50+
}
51+
}
52+
53+
if merge.data() == &DuplicateHandling::Error && !duplicate_names.is_empty() {
54+
vortex_bail!(
55+
"merge: duplicate fields in children: {}",
56+
duplicate_names.into_iter().format(", ")
57+
)
58+
}
59+
}
60+
61+
let expr = pack(
62+
names
63+
.into_iter()
64+
.zip(children)
65+
.map(|(name, child)| (name.clone(), get_item(name, child))),
66+
merge_dtype.nullability(),
67+
);
68+
69+
Ok(Some(expr))
70+
}
71+
}
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
// SPDX-License-Identifier: Apache-2.0
22
// SPDX-FileCopyrightText: Copyright the Vortex contributors
33

4+
pub mod transform;
5+
46
use std::fmt::{Display, Formatter};
57

68
use itertools::Itertools;
Lines changed: 17 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -1,60 +1,14 @@
11
// SPDX-License-Identifier: Apache-2.0
22
// SPDX-FileCopyrightText: Copyright the Vortex contributors
33

4-
use vortex_dtype::DType;
54
use vortex_error::{VortexResult, vortex_err};
65

76
use crate::expr::exprs::get_item::get_item;
87
use crate::expr::exprs::pack::pack;
98
use crate::expr::exprs::select::Select;
109
use crate::expr::transform::traits::{ReduceRule, RewriteContext};
11-
use crate::expr::traversal::{NodeExt, Transformed};
1210
use crate::expr::{Expression, ExpressionView};
1311

14-
/// Replaces [crate::SelectExpr] with combination of [crate::GetItem] and [crate::Pack] expressions.
15-
pub(crate) fn remove_select(e: Expression, ctx: &DType) -> VortexResult<Expression> {
16-
e.transform_up(|node| remove_select_transformer(node, ctx))
17-
.map(|e| e.into_inner())
18-
}
19-
20-
fn remove_select_transformer(
21-
node: Expression,
22-
ctx: &DType,
23-
) -> VortexResult<Transformed<Expression>> {
24-
if let Some(select) = node.as_opt::<Select>() {
25-
let child = select.child();
26-
let child_dtype = child.return_dtype(ctx)?;
27-
let child_nullability = child_dtype.nullability();
28-
29-
let child_dtype = child_dtype.as_struct_fields_opt().ok_or_else(|| {
30-
vortex_err!(
31-
"Select child must return a struct dtype, however it was a {}",
32-
child_dtype
33-
)
34-
})?;
35-
36-
let expr = pack(
37-
select
38-
.data()
39-
.as_include_names(child_dtype.names())
40-
.map_err(|e| {
41-
e.with_context(format!(
42-
"Select fields {:?} must be a subset of child fields {:?}",
43-
select.data(),
44-
child_dtype.names()
45-
))
46-
})?
47-
.iter()
48-
.map(|name| (name.clone(), get_item(name.clone(), child.clone()))),
49-
child_nullability,
50-
);
51-
52-
Ok(Transformed::yes(expr))
53-
} else {
54-
Ok(Transformed::no(node))
55-
}
56-
}
57-
5812
/// Rule that removes Select expressions by converting them to Pack + GetItem.
5913
///
6014
/// Transforms: `select(["a", "b"], expr)` → `pack(a: get_item("a", expr), b: get_item("b", expr))`
@@ -103,29 +57,14 @@ mod tests {
10357
use vortex_dtype::PType::I32;
10458
use vortex_dtype::{DType, StructFields};
10559

106-
use super::{RemoveSelectRule, remove_select};
60+
use super::RemoveSelectRule;
10761
use crate::expr::exprs::pack::Pack;
10862
use crate::expr::exprs::root::root;
10963
use crate::expr::exprs::select::{Select, select};
110-
use crate::expr::session::ExprSession;
111-
use crate::expr::transform::simplify_typed::apply_child_rules;
11264
use crate::expr::transform::traits::{ReduceRule, SimpleRewriteContext};
11365

11466
#[test]
115-
fn test_remove_select() {
116-
let dtype = DType::Struct(
117-
StructFields::new(["a", "b"].into(), vec![I32.into(), I32.into()]),
118-
Nullable,
119-
);
120-
let e = select(["a", "b"], root());
121-
let e = remove_select(e, &dtype).unwrap();
122-
123-
assert!(e.is::<Pack>());
124-
assert!(e.return_dtype(&dtype).unwrap().is_nullable());
125-
}
126-
127-
#[test]
128-
fn test_remove_select_rule_direct() {
67+
fn test_remove_select_rule() {
12968
let dtype = DType::Struct(
13069
StructFields::new(["a", "b"].into(), vec![I32.into(), I32.into()]),
13170
Nullable,
@@ -144,30 +83,31 @@ mod tests {
14483
}
14584

14685
#[test]
147-
fn test_remove_select_via_session() {
86+
fn test_remove_select_rule_exclude_fields() {
87+
use crate::expr::exprs::select::select_exclude;
88+
14889
let dtype = DType::Struct(
14990
StructFields::new(
15091
["a", "b", "c"].into(),
15192
vec![I32.into(), I32.into(), I32.into()],
15293
),
15394
Nullable,
15495
);
96+
let e = select_exclude(["c"], root());
15597

156-
// Create expression: select(["a", "c"], root())
157-
let e = select(["a", "c"], root());
158-
159-
// Use session which has RemoveSelectRule registered
160-
let session = ExprSession::default();
161-
let result = apply_child_rules(e, &dtype, &session).unwrap();
98+
let rule = RemoveSelectRule;
99+
let ctx = SimpleRewriteContext { dtype: &dtype };
100+
let select_view = e.as_::<Select>();
101+
let result = rule.reduce(&select_view, &ctx).unwrap();
162102

163-
// Should be transformed to Pack
164-
assert!(result.is::<Pack>());
103+
assert!(result.is_some());
104+
let transformed = result.unwrap();
105+
assert!(transformed.is::<Pack>());
165106

166-
// Verify the dtype has only selected fields
167-
let result_dtype = result.return_dtype(&dtype).unwrap();
107+
// Should exclude "c" and include "a" and "b"
108+
let result_dtype = transformed.return_dtype(&dtype).unwrap();
109+
assert!(result_dtype.is_nullable());
168110
let fields = result_dtype.as_struct_fields_opt().unwrap();
169-
assert_eq!(fields.names().len(), 2);
170-
assert_eq!(fields.names()[0].as_ref(), "a");
171-
assert_eq!(fields.names()[1].as_ref(), "c");
111+
assert_eq!(fields.names().as_ref(), &["a", "b"]);
172112
}
173113
}

vortex-array/src/expr/session.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,17 +12,18 @@ use crate::expr::exprs::between::Between;
1212
use crate::expr::exprs::binary::Binary;
1313
use crate::expr::exprs::cast::Cast;
1414
use crate::expr::exprs::get_item::GetItem;
15+
use crate::expr::exprs::get_item::transform::PackGetItemRule;
1516
use crate::expr::exprs::is_null::IsNull;
1617
use crate::expr::exprs::like::Like;
1718
use crate::expr::exprs::list_contains::ListContains;
1819
use crate::expr::exprs::literal::Literal;
1920
use crate::expr::exprs::merge::Merge;
21+
use crate::expr::exprs::merge::transform::RemoveMergeRule;
2022
use crate::expr::exprs::not::Not;
2123
use crate::expr::exprs::pack::Pack;
2224
use crate::expr::exprs::root::Root;
2325
use crate::expr::exprs::select::Select;
24-
use crate::expr::transform::remove_select::RemoveSelectRule;
25-
use crate::expr::transform::simplify::PackGetItemRule;
26+
use crate::expr::exprs::select::transform::RemoveSelectRule;
2627
use crate::expr::transform::traits::{
2728
ChildReduceRule, ParentReduceRule, ReduceRule, RewriteContext,
2829
};
@@ -313,6 +314,7 @@ impl Default for ExprSession {
313314
// Register built-in rewrite rules
314315
let mut rewrite_rules = RewriteRuleRegistry::new();
315316
rewrite_rules.register_reduce_rule(&Select, RemoveSelectRule);
317+
rewrite_rules.register_reduce_rule(&Merge, RemoveMergeRule);
316318
rewrite_rules.register_child_rule(&GetItem, PackGetItemRule);
317319

318320
Self {

vortex-array/src/expr/transform/mod.rs

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,6 @@ pub mod annotations;
66
pub mod immediate_access;
77
pub(crate) mod match_between;
88
mod partition;
9-
mod remove_merge;
10-
pub mod remove_select;
119
mod replace;
1210
pub mod simplify;
1311
mod simplify_typed;

vortex-array/src/expr/transform/partition.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,6 @@ mod tests {
209209
use crate::expr::exprs::select::select;
210210
use crate::expr::transform::immediate_access::annotate_scope_access;
211211
use crate::expr::transform::replace::replace_root_fields;
212-
use crate::expr::transform::simplify::simplify;
213212
use crate::expr::transform::simplify_typed::simplify_typed;
214213

215214
#[fixture]
@@ -274,7 +273,7 @@ mod tests {
274273

275274
let split_a = partitioned.find_partition(&"a".into()).unwrap();
276275
assert_eq!(
277-
&simplify(split_a.clone()).unwrap(),
276+
&simplify_typed(split_a.clone(), &dtype).unwrap(),
278277
&pack(
279278
[
280279
("a_0", get_item("x", get_item("a", root()))),

0 commit comments

Comments
 (0)