1- use itertools:: Itertools ;
21use vortex_array:: aliases:: hash_map:: HashMap ;
3- use vortex_dtype:: { DType , FieldName , StructDType } ;
2+ use vortex_dtype:: { DType , FieldName , FieldNames , StructDType } ;
43use vortex_error:: { vortex_bail, VortexExpect , VortexResult } ;
54
65use crate :: transform:: immediate_access:: { immediate_scope_accesses, FieldAccesses } ;
@@ -36,26 +35,21 @@ pub struct PartitionedExpr {
3635 /// The root expression used to re-assemble the results.
3736 pub root : ExprRef ,
3837 /// The partitions of the expression.
39- pub partitions : Box < [ Partition ] > ,
38+ pub partitions : Box < [ ExprRef ] > ,
39+ /// The field names for the partitions
40+ pub partition_names : FieldNames ,
4041}
4142
4243impl PartitionedExpr {
4344 /// Return the partition for a given field, if it exists.
44- pub fn find_partition ( & self , field : & FieldName ) -> Option < & Partition > {
45- self . partitions . iter ( ) . find ( |p| & p. name == field)
45+ pub fn find_partition ( & self , field : & FieldName ) -> Option < & ExprRef > {
46+ self . partition_names
47+ . iter ( )
48+ . position ( |name| name == field)
49+ . map ( |idx| & self . partitions [ idx] )
4650 }
4751}
4852
49- /// A single partition of an expression.
50- #[ derive( Debug ) ]
51- pub struct Partition {
52- /// The name of the partition, to be used when re-assembling the results.
53- // TODO(ngates): we wouldn't need this if we had a MergeExpr.
54- pub name : FieldName ,
55- /// The expression that defines the partition.
56- pub expr : ExprRef ,
57- }
58-
5953#[ derive( Debug ) ]
6054struct StructFieldExpressionSplitter < ' a > {
6155 sub_expressions : HashMap < FieldName , Vec < ExprRef > > ,
@@ -91,30 +85,27 @@ impl<'a> StructFieldExpressionSplitter<'a> {
9185 let mut remove_accesses: Vec < FieldName > = Vec :: new ( ) ;
9286
9387 // Create partitions which can be passed to layout fields
94- let partitions: Vec < Partition > = splitter
95- . sub_expressions
96- . into_iter ( )
97- . map ( |( name, exprs) | {
98- let field_dtype = scope_dtype. field ( & name) ?;
99- // If there is a single expr then we don't need to `pack` this, and we must update
100- // the root expr removing this access.
101- let expr = if exprs. len ( ) == 1 {
102- remove_accesses. push ( Self :: field_idx_name ( & name, 0 ) ) ;
103- exprs. first ( ) . vortex_expect ( "exprs is non-empty" ) . clone ( )
104- } else {
105- pack (
106- exprs
107- . into_iter ( )
108- . enumerate ( )
109- . map ( |( idx, expr) | ( Self :: field_idx_name ( & name, idx) , expr) ) ,
110- )
111- } ;
112- VortexResult :: Ok ( Partition {
113- name,
114- expr : simplify_typed ( expr, & field_dtype) ?,
115- } )
116- } )
117- . try_collect ( ) ?;
88+ let mut partitions = Vec :: with_capacity ( splitter. sub_expressions . len ( ) ) ;
89+ let mut partition_names = Vec :: with_capacity ( splitter. sub_expressions . len ( ) ) ;
90+ for ( name, exprs) in splitter. sub_expressions . into_iter ( ) {
91+ let field_dtype = scope_dtype. field ( & name) ?;
92+ // If there is a single expr then we don't need to `pack` this, and we must update
93+ // the root expr removing this access.
94+ let expr = if exprs. len ( ) == 1 {
95+ remove_accesses. push ( Self :: field_idx_name ( & name, 0 ) ) ;
96+ exprs. first ( ) . vortex_expect ( "exprs is non-empty" ) . clone ( )
97+ } else {
98+ pack (
99+ exprs
100+ . into_iter ( )
101+ . enumerate ( )
102+ . map ( |( idx, expr) | ( Self :: field_idx_name ( & name, idx) , expr) ) ,
103+ )
104+ } ;
105+
106+ partitions. push ( simplify_typed ( expr, & field_dtype) ?) ;
107+ partition_names. push ( name) ;
108+ }
118109
119110 let expression_access_counts = field_accesses. get ( & expr) . map ( |ac| ac. len ( ) ) ;
120111 // Ensure that there are not more accesses than partitions, we missed something
@@ -130,6 +121,7 @@ impl<'a> StructFieldExpressionSplitter<'a> {
130121 Ok ( PartitionedExpr {
131122 root : simplify_typed ( split. result , dtype) ?,
132123 partitions : partitions. into_boxed_slice ( ) ,
124+ partition_names : partition_names. into ( ) ,
133125 } )
134126 }
135127}
@@ -308,11 +300,8 @@ mod tests {
308300 assert ! ( split_a. is_some( ) ) ;
309301 let split_a = split_a. unwrap ( ) ;
310302
311- assert_eq ! ( & partitioned. root, & get_item( split_a. name. clone( ) , ident( ) ) ) ;
312- assert_eq ! (
313- & simplify( split_a. expr. clone( ) ) . unwrap( ) ,
314- & get_item( "b" , ident( ) )
315- ) ;
303+ assert_eq ! ( & partitioned. root, & get_item( "a" , ident( ) ) ) ;
304+ assert_eq ! ( & simplify( split_a. clone( ) ) . unwrap( ) , & get_item( "b" , ident( ) ) ) ;
316305 }
317306
318307 #[ test]
@@ -328,7 +317,7 @@ mod tests {
328317
329318 let split_a = partitioned. find_partition ( & "a" . into ( ) ) . unwrap ( ) ;
330319 assert_eq ! (
331- & simplify( split_a. expr . clone( ) ) . unwrap( ) ,
320+ & simplify( split_a. clone( ) ) . unwrap( ) ,
332321 & pack( [
333322 (
334323 StructFieldExpressionSplitter :: field_idx_name( & "a" . into( ) , 0 ) ,
@@ -341,7 +330,7 @@ mod tests {
341330 ] )
342331 ) ;
343332 let split_c = partitioned. find_partition ( & "c" . into ( ) ) . unwrap ( ) ;
344- assert_eq ! ( & simplify( split_c. expr . clone( ) ) . unwrap( ) , & ident( ) )
333+ assert_eq ! ( & simplify( split_c. clone( ) ) . unwrap( ) , & ident( ) )
345334 }
346335
347336 #[ test]
0 commit comments