@@ -109,7 +109,10 @@ impl<'a> StructFieldExpressionSplitter<'a> {
109109
110110 let mut splitter = StructFieldExpressionSplitter :: new ( & field_accesses, scope_dtype) ;
111111
112- let split = expr. clone ( ) . transform_with_context ( & mut splitter, ( ) ) ?;
112+ let split = expr
113+ . clone ( )
114+ . transform_with_context ( & mut splitter, ( ) ) ?
115+ . result ( ) ;
113116
114117 let mut remove_accesses: Vec < FieldName > = Vec :: new ( ) ;
115118
@@ -153,13 +156,19 @@ impl<'a> StructFieldExpressionSplitter<'a> {
153156 debug_assert_eq ! ( expression_access_counts. unwrap_or( 0 ) , partitions. len( ) ) ;
154157
155158 let split = split
156- . result ( )
157- . transform ( & mut ReplaceAccessesWithChild ( remove_accesses ) ) ? ;
159+ . transform ( & mut ReplaceAccessesWithChild ( remove_accesses ) ) ?
160+ . into_inner ( ) ;
158161
159- let ctx = ScopeDType :: new ( dtype. clone ( ) ) ;
162+ let ctx = ScopeDType :: new ( DType :: Struct (
163+ StructFields :: new (
164+ FieldNames :: from ( partition_names. clone ( ) ) ,
165+ partition_dtypes. clone ( ) ,
166+ ) ,
167+ Nullability :: NonNullable ,
168+ ) ) ;
160169
161170 Ok ( PartitionedExpr {
162- root : simplify_typed ( split. into_inner ( ) , & ctx) ?,
171+ root : simplify_typed ( split, & ctx) ?,
163172 partitions : partitions. into_boxed_slice ( ) ,
164173 partition_names : partition_names. into ( ) ,
165174 partition_dtypes : partition_dtypes. into_boxed_slice ( ) ,
@@ -293,11 +302,12 @@ mod tests {
293302 use vortex_dtype:: Nullability :: NonNullable ;
294303 use vortex_dtype:: PType :: I32 ;
295304 use vortex_dtype:: { DType , StructFields } ;
305+ use vortex_utils:: aliases:: hash_set:: HashSet ;
296306
297307 use super :: * ;
298308 use crate :: transform:: simplify:: simplify;
299309 use crate :: transform:: simplify_typed:: simplify_typed;
300- use crate :: { Pack , and, get_item, lit, pack, root, select} ;
310+ use crate :: { Pack , and, col , get_item, lit, merge , pack, root, select} ;
301311
302312 fn dtype ( ) -> DType {
303313 DType :: Struct (
@@ -448,4 +458,42 @@ mod tests {
448458 )
449459 )
450460 }
461+
462+ #[ test]
463+ fn test_expr_merge ( ) {
464+ let dtype = dtype ( ) ;
465+
466+ let expr = merge (
467+ [ col ( "a" ) , pack ( [ ( "b" , col ( "b" ) ) ] , NonNullable ) ] ,
468+ NonNullable ,
469+ ) ;
470+
471+ let partitioned = StructFieldExpressionSplitter :: split ( expr, & dtype) . unwrap ( ) ;
472+ let expected = pack (
473+ [
474+ ( "a" , get_item ( "a" , col ( "a" ) ) ) ,
475+ ( "b" , get_item ( "b" , col ( "b" ) ) ) ,
476+ ] ,
477+ NonNullable ,
478+ ) ;
479+ assert_eq ! (
480+ & partitioned. root, & expected,
481+ "{} {}" ,
482+ partitioned. root, expected
483+ ) ;
484+ let expected = [ root ( ) , pack ( [ ( "b" , root ( ) ) ] , NonNullable ) ]
485+ . into_iter ( )
486+ . collect :: < HashSet < _ > > ( ) ;
487+ assert_eq ! (
488+ & partitioned
489+ . partitions
490+ . clone( )
491+ . into_iter( )
492+ . collect:: <HashSet <_>>( ) ,
493+ & expected,
494+ "{} {}" ,
495+ partitioned. partitions. iter( ) . join( ";" ) ,
496+ expected. iter( ) . join( ";" )
497+ ) ;
498+ }
451499}
0 commit comments