Skip to content

Commit a4f00e3

Browse files
authored
Store partition field names once (#2279)
Instead of constructing them every time we evaluate an expression over a struct layout. In other news, `StructArray::try_new` is annoyingly expensive
1 parent 8dfcc58 commit a4f00e3

File tree

2 files changed

+39
-55
lines changed

2 files changed

+39
-55
lines changed

vortex-expr/src/transform/partition.rs

Lines changed: 35 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
1-
use itertools::Itertools;
21
use vortex_array::aliases::hash_map::HashMap;
3-
use vortex_dtype::{DType, FieldName, StructDType};
2+
use vortex_dtype::{DType, FieldName, FieldNames, StructDType};
43
use vortex_error::{vortex_bail, VortexExpect, VortexResult};
54

65
use crate::transform::immediate_access::{immediate_scope_accesses, FieldAccesses};
@@ -36,26 +35,21 @@ pub struct PartitionedExpr {
3635
/// The root expression used to re-assemble the results.
3736
pub root: ExprRef,
3837
/// The partitions of the expression.
39-
pub partitions: Box<[Partition]>,
38+
pub partitions: Box<[ExprRef]>,
39+
/// The field names for the partitions
40+
pub partition_names: FieldNames,
4041
}
4142

4243
impl PartitionedExpr {
4344
/// Return the partition for a given field, if it exists.
44-
pub fn find_partition(&self, field: &FieldName) -> Option<&Partition> {
45-
self.partitions.iter().find(|p| &p.name == field)
45+
pub fn find_partition(&self, field: &FieldName) -> Option<&ExprRef> {
46+
self.partition_names
47+
.iter()
48+
.position(|name| name == field)
49+
.map(|idx| &self.partitions[idx])
4650
}
4751
}
4852

49-
/// A single partition of an expression.
50-
#[derive(Debug)]
51-
pub struct Partition {
52-
/// The name of the partition, to be used when re-assembling the results.
53-
// TODO(ngates): we wouldn't need this if we had a MergeExpr.
54-
pub name: FieldName,
55-
/// The expression that defines the partition.
56-
pub expr: ExprRef,
57-
}
58-
5953
#[derive(Debug)]
6054
struct StructFieldExpressionSplitter<'a> {
6155
sub_expressions: HashMap<FieldName, Vec<ExprRef>>,
@@ -91,30 +85,27 @@ impl<'a> StructFieldExpressionSplitter<'a> {
9185
let mut remove_accesses: Vec<FieldName> = Vec::new();
9286

9387
// Create partitions which can be passed to layout fields
94-
let partitions: Vec<Partition> = splitter
95-
.sub_expressions
96-
.into_iter()
97-
.map(|(name, exprs)| {
98-
let field_dtype = scope_dtype.field(&name)?;
99-
// If there is a single expr then we don't need to `pack` this, and we must update
100-
// the root expr removing this access.
101-
let expr = if exprs.len() == 1 {
102-
remove_accesses.push(Self::field_idx_name(&name, 0));
103-
exprs.first().vortex_expect("exprs is non-empty").clone()
104-
} else {
105-
pack(
106-
exprs
107-
.into_iter()
108-
.enumerate()
109-
.map(|(idx, expr)| (Self::field_idx_name(&name, idx), expr)),
110-
)
111-
};
112-
VortexResult::Ok(Partition {
113-
name,
114-
expr: simplify_typed(expr, &field_dtype)?,
115-
})
116-
})
117-
.try_collect()?;
88+
let mut partitions = Vec::with_capacity(splitter.sub_expressions.len());
89+
let mut partition_names = Vec::with_capacity(splitter.sub_expressions.len());
90+
for (name, exprs) in splitter.sub_expressions.into_iter() {
91+
let field_dtype = scope_dtype.field(&name)?;
92+
// If there is a single expr then we don't need to `pack` this, and we must update
93+
// the root expr removing this access.
94+
let expr = if exprs.len() == 1 {
95+
remove_accesses.push(Self::field_idx_name(&name, 0));
96+
exprs.first().vortex_expect("exprs is non-empty").clone()
97+
} else {
98+
pack(
99+
exprs
100+
.into_iter()
101+
.enumerate()
102+
.map(|(idx, expr)| (Self::field_idx_name(&name, idx), expr)),
103+
)
104+
};
105+
106+
partitions.push(simplify_typed(expr, &field_dtype)?);
107+
partition_names.push(name);
108+
}
118109

119110
let expression_access_counts = field_accesses.get(&expr).map(|ac| ac.len());
120111
// Ensure that there are not more accesses than partitions, we missed something
@@ -130,6 +121,7 @@ impl<'a> StructFieldExpressionSplitter<'a> {
130121
Ok(PartitionedExpr {
131122
root: simplify_typed(split.result, dtype)?,
132123
partitions: partitions.into_boxed_slice(),
124+
partition_names: partition_names.into(),
133125
})
134126
}
135127
}
@@ -308,11 +300,8 @@ mod tests {
308300
assert!(split_a.is_some());
309301
let split_a = split_a.unwrap();
310302

311-
assert_eq!(&partitioned.root, &get_item(split_a.name.clone(), ident()));
312-
assert_eq!(
313-
&simplify(split_a.expr.clone()).unwrap(),
314-
&get_item("b", ident())
315-
);
303+
assert_eq!(&partitioned.root, &get_item("a", ident()));
304+
assert_eq!(&simplify(split_a.clone()).unwrap(), &get_item("b", ident()));
316305
}
317306

318307
#[test]
@@ -328,7 +317,7 @@ mod tests {
328317

329318
let split_a = partitioned.find_partition(&"a".into()).unwrap();
330319
assert_eq!(
331-
&simplify(split_a.expr.clone()).unwrap(),
320+
&simplify(split_a.clone()).unwrap(),
332321
&pack([
333322
(
334323
StructFieldExpressionSplitter::field_idx_name(&"a".into(), 0),
@@ -341,7 +330,7 @@ mod tests {
341330
])
342331
);
343332
let split_c = partitioned.find_partition(&"c".into()).unwrap();
344-
assert_eq!(&simplify(split_c.expr.clone()).unwrap(), &ident())
333+
assert_eq!(&simplify(split_c.clone()).unwrap(), &ident())
345334
}
346335

347336
#[test]

vortex-layout/src/layouts/struct_/eval_expr.rs

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,17 +17,17 @@ impl ExprEvaluator for StructReader {
1717
// Partition the expression into expressions that can be evaluated over individual fields
1818
let partitioned = self.partition_expr(expr.clone())?;
1919
let field_readers: Vec<_> = partitioned
20-
.partitions
20+
.partition_names
2121
.iter()
22-
.map(|partition| self.child(&partition.name.clone()))
22+
.map(|name| self.child(name))
2323
.try_collect()?;
2424

2525
let arrays = try_join_all(
2626
field_readers
2727
.iter()
2828
.zip_eq(partitioned.partitions.iter())
2929
.map(|(reader, partition)| {
30-
reader.evaluate_expr(row_mask.clone(), partition.expr.clone())
30+
reader.evaluate_expr(row_mask.clone(), partition.clone())
3131
}),
3232
)
3333
.await?;
@@ -36,12 +36,7 @@ impl ExprEvaluator for StructReader {
3636
debug_assert!(arrays.iter().all(|a| a.len() == row_count));
3737

3838
let root_scope = StructArray::try_new(
39-
partitioned
40-
.partitions
41-
.iter()
42-
.map(|p| p.name.clone())
43-
.collect::<Vec<_>>()
44-
.into(),
39+
partitioned.partition_names.clone(),
4540
arrays,
4641
row_count,
4742
Validity::NonNullable,

0 commit comments

Comments
 (0)