Skip to content

Commit a5c357a

Browse files
authored
fix: teach StructFieldExpressionSplitter the scope of its root expr (#3743)
Simplifications that need the scope (e.g. merge) need the _right_ scope. The scope for the root of a split expression is a structure with one field per partition. Signed-off-by: Daniel King <[email protected]>
1 parent 8d38555 commit a5c357a

File tree

1 file changed

+54
-6
lines changed

1 file changed

+54
-6
lines changed

vortex-expr/src/transform/partition.rs

Lines changed: 54 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,10 @@ impl<'a> StructFieldExpressionSplitter<'a> {
109109

110110
let mut splitter = StructFieldExpressionSplitter::new(&field_accesses, scope_dtype);
111111

112-
let split = expr.clone().transform_with_context(&mut splitter, ())?;
112+
let split = expr
113+
.clone()
114+
.transform_with_context(&mut splitter, ())?
115+
.result();
113116

114117
let mut remove_accesses: Vec<FieldName> = Vec::new();
115118

@@ -153,13 +156,19 @@ impl<'a> StructFieldExpressionSplitter<'a> {
153156
debug_assert_eq!(expression_access_counts.unwrap_or(0), partitions.len());
154157

155158
let split = split
156-
.result()
157-
.transform(&mut ReplaceAccessesWithChild(remove_accesses))?;
159+
.transform(&mut ReplaceAccessesWithChild(remove_accesses))?
160+
.into_inner();
158161

159-
let ctx = ScopeDType::new(dtype.clone());
162+
let ctx = ScopeDType::new(DType::Struct(
163+
StructFields::new(
164+
FieldNames::from(partition_names.clone()),
165+
partition_dtypes.clone(),
166+
),
167+
Nullability::NonNullable,
168+
));
160169

161170
Ok(PartitionedExpr {
162-
root: simplify_typed(split.into_inner(), &ctx)?,
171+
root: simplify_typed(split, &ctx)?,
163172
partitions: partitions.into_boxed_slice(),
164173
partition_names: partition_names.into(),
165174
partition_dtypes: partition_dtypes.into_boxed_slice(),
@@ -293,11 +302,12 @@ mod tests {
293302
use vortex_dtype::Nullability::NonNullable;
294303
use vortex_dtype::PType::I32;
295304
use vortex_dtype::{DType, StructFields};
305+
use vortex_utils::aliases::hash_set::HashSet;
296306

297307
use super::*;
298308
use crate::transform::simplify::simplify;
299309
use crate::transform::simplify_typed::simplify_typed;
300-
use crate::{Pack, and, get_item, lit, pack, root, select};
310+
use crate::{Pack, and, col, get_item, lit, merge, pack, root, select};
301311

302312
fn dtype() -> DType {
303313
DType::Struct(
@@ -448,4 +458,42 @@ mod tests {
448458
)
449459
)
450460
}
461+
462+
#[test]
463+
fn test_expr_merge() {
464+
let dtype = dtype();
465+
466+
let expr = merge(
467+
[col("a"), pack([("b", col("b"))], NonNullable)],
468+
NonNullable,
469+
);
470+
471+
let partitioned = StructFieldExpressionSplitter::split(expr, &dtype).unwrap();
472+
let expected = pack(
473+
[
474+
("a", get_item("a", col("a"))),
475+
("b", get_item("b", col("b"))),
476+
],
477+
NonNullable,
478+
);
479+
assert_eq!(
480+
&partitioned.root, &expected,
481+
"{} {}",
482+
partitioned.root, expected
483+
);
484+
let expected = [root(), pack([("b", root())], NonNullable)]
485+
.into_iter()
486+
.collect::<HashSet<_>>();
487+
assert_eq!(
488+
&partitioned
489+
.partitions
490+
.clone()
491+
.into_iter()
492+
.collect::<HashSet<_>>(),
493+
&expected,
494+
"{} {}",
495+
partitioned.partitions.iter().join(";"),
496+
expected.iter().join(";")
497+
);
498+
}
451499
}

0 commit comments

Comments
 (0)