11// SPDX-License-Identifier: Apache-2.0
22// SPDX-FileCopyrightText: Copyright the Vortex contributors
33
4+ use std:: fmt:: Display ;
45use std:: hash:: Hash ;
56
67use itertools:: Itertools as _;
78use vortex_array:: arrays:: StructArray ;
89use vortex_array:: validity:: Validity ;
9- use vortex_array:: { Array , ArrayRef , DeserializeMetadata , EmptyMetadata , IntoArray , ToCanonical } ;
10+ use vortex_array:: {
11+ Array , ArrayRef , DeserializeMetadata , EmptyMetadata , IntoArray as _, ToCanonical ,
12+ } ;
1013use vortex_dtype:: { DType , FieldNames , Nullability , StructFields } ;
11- use vortex_error:: { VortexExpect as _, VortexResult , vortex_bail} ;
14+ use vortex_error:: { VortexResult , vortex_bail} ;
15+ use vortex_utils:: aliases:: hash_set:: HashSet ;
1216
1317use crate :: display:: { DisplayAs , DisplayFormat } ;
1418use crate :: { AnalysisExpr , ExprEncodingRef , ExprId , ExprRef , IntoExpr , Scope , VTable , vtable} ;
@@ -25,6 +29,32 @@ vtable!(Merge);
2529#[ derive( Debug , Clone , PartialEq , Eq , Hash ) ]
2630pub struct MergeExpr {
2731 values : Vec < ExprRef > ,
32+ duplicate_handling : DuplicateHandling ,
33+ }
34+
35+ impl MergeExpr {
36+ pub fn duplicate_handling ( & self ) -> DuplicateHandling {
37+ self . duplicate_handling
38+ }
39+ }
40+
41+ /// What to do when merged structs share a field name.
42+ #[ derive( Default , Debug , Copy , Clone , PartialEq , Eq , Hash ) ]
43+ pub enum DuplicateHandling {
44+ /// If two structs share a field name, take the value from the right-most struct.
45+ RightMost ,
46+ /// If two structs share a field name, error.
47+ #[ default]
48+ Error ,
49+ }
50+
51+ impl Display for DuplicateHandling {
52+ fn fmt ( & self , f : & mut std:: fmt:: Formatter < ' _ > ) -> std:: fmt:: Result {
53+ match self {
54+ DuplicateHandling :: Error => write ! ( f, "error" ) ,
55+ DuplicateHandling :: RightMost => write ! ( f, "right-most" ) ,
56+ }
57+ }
2858}
2959
3060pub struct MergeExprEncoding ;
@@ -50,8 +80,11 @@ impl VTable for MergeVTable {
5080 expr. values . iter ( ) . collect ( )
5181 }
5282
53- fn with_children ( _expr : & Self :: Expr , children : Vec < ExprRef > ) -> VortexResult < Self :: Expr > {
54- Ok ( MergeExpr { values : children } )
83+ fn with_children ( expr : & Self :: Expr , children : Vec < ExprRef > ) -> VortexResult < Self :: Expr > {
84+ Ok ( MergeExpr {
85+ values : children,
86+ duplicate_handling : expr. duplicate_handling ,
87+ } )
5588 }
5689
5790 fn build (
@@ -65,39 +98,33 @@ impl VTable for MergeVTable {
6598 children
6699 ) ;
67100 }
68- Ok ( MergeExpr { values : children } )
101+ Ok ( MergeExpr {
102+ values : children,
103+ duplicate_handling : DuplicateHandling :: default ( ) ,
104+ } )
69105 }
70106
71107 fn evaluate ( expr : & Self :: Expr , scope : & Scope ) -> VortexResult < ArrayRef > {
72- let len = scope. len ( ) ;
73- let value_arrays = expr
74- . values
75- . iter ( )
76- . map ( |value_expr| value_expr. unchecked_evaluate ( scope) )
77- . process_results ( |it| it. collect :: < Vec < _ > > ( ) ) ?;
78-
79108 // Collect fields in order of appearance. Later fields overwrite earlier fields.
80109 let mut field_names = Vec :: new ( ) ;
81110 let mut arrays = Vec :: new ( ) ;
111+ let mut duplicate_names = HashSet :: < _ > :: new ( ) ;
82112
83- for value_array in value_arrays . iter ( ) {
113+ for expr in expr . values . iter ( ) {
84114 // TODO(marko): When nullable, we need to merge struct validity into field validity.
85- if value_array. dtype ( ) . is_nullable ( ) {
86- todo ! ( "merge nullable structs" ) ;
115+ let array = expr. unchecked_evaluate ( scope) ?;
116+ if array. dtype ( ) . is_nullable ( ) {
117+ vortex_bail ! ( "merge expects non-nullable input" ) ;
87118 }
88- if !value_array . dtype ( ) . is_struct ( ) {
89- vortex_bail ! ( "merge expects non-nullable struct input" ) ;
119+ if !array . dtype ( ) . is_struct ( ) {
120+ vortex_bail ! ( "merge expects struct input" ) ;
90121 }
122+ let array = array. to_struct ( ) ;
91123
92- let struct_array = value_array. to_struct ( ) ;
93-
94- for ( field_name, array) in struct_array
95- . names ( )
96- . iter ( )
97- . zip_eq ( struct_array. fields ( ) . iter ( ) . cloned ( ) )
98- {
124+ for ( field_name, array) in array. names ( ) . iter ( ) . zip_eq ( array. fields ( ) . iter ( ) . cloned ( ) ) {
99125 // Update or insert field.
100126 if let Some ( idx) = field_names. iter ( ) . position ( |name| name == field_name) {
127+ duplicate_names. insert ( field_name. clone ( ) ) ;
101128 arrays[ idx] = array;
102129 } else {
103130 field_names. push ( field_name. clone ( ) ) ;
@@ -106,8 +133,16 @@ impl VTable for MergeVTable {
106133 }
107134 }
108135
136+ if expr. duplicate_handling == DuplicateHandling :: Error && !duplicate_names. is_empty ( ) {
137+ vortex_bail ! (
138+ "merge: duplicate fields in children: {}" ,
139+ duplicate_names. into_iter( ) . format( ", " )
140+ )
141+ }
142+
109143 // TODO(DK): When children are allowed to be nullable, this needs to change.
110144 let validity = Validity :: NonNullable ;
145+ let len = scope. len ( ) ;
111146 Ok (
112147 StructArray :: try_new ( FieldNames :: from ( field_names) , arrays, len, validity) ?
113148 . into_array ( ) ,
@@ -117,27 +152,23 @@ impl VTable for MergeVTable {
117152 fn return_dtype ( expr : & Self :: Expr , scope : & DType ) -> VortexResult < DType > {
118153 let mut field_names = Vec :: new ( ) ;
119154 let mut arrays = Vec :: new ( ) ;
155+ let mut merge_nullability = Nullability :: NonNullable ;
156+ let mut duplicate_names = HashSet :: < _ > :: new ( ) ;
120157
121- let mut nullability = Nullability :: NonNullable ;
122-
123- for value in expr. values . iter ( ) {
124- let dtype = value. return_dtype ( scope) ?;
125- if !dtype. is_struct ( ) {
158+ for expr in expr. values . iter ( ) {
159+ let dtype = expr. return_dtype ( scope) ?;
160+ let Some ( fields) = dtype. as_struct_fields_opt ( ) else {
126161 vortex_bail ! ( "merge expects struct input" ) ;
127- }
162+ } ;
128163 if dtype. is_nullable ( ) {
129164 vortex_bail ! ( "merge expects non-nullable input" ) ;
130165 }
131- nullability |= dtype. nullability ( ) ;
132166
133- let struct_dtype = dtype
134- . as_struct_fields_opt ( )
135- . vortex_expect ( "merge expects struct input" ) ;
167+ merge_nullability |= dtype. nullability ( ) ;
136168
137- for i in 0 ..struct_dtype. nfields ( ) {
138- let field_name = struct_dtype. field_name ( i) . vortex_expect ( "never OOB" ) ;
139- let field_dtype = struct_dtype. field_by_index ( i) . vortex_expect ( "never OOB" ) ;
169+ for ( field_name, field_dtype) in fields. names ( ) . iter ( ) . zip_eq ( fields. fields ( ) ) {
140170 if let Some ( idx) = field_names. iter ( ) . position ( |name| name == field_name) {
171+ duplicate_names. insert ( field_name. clone ( ) ) ;
141172 arrays[ idx] = field_dtype;
142173 } else {
143174 field_names. push ( field_name. clone ( ) ) ;
@@ -146,21 +177,42 @@ impl VTable for MergeVTable {
146177 }
147178 }
148179
180+ if expr. duplicate_handling == DuplicateHandling :: Error && !duplicate_names. is_empty ( ) {
181+ vortex_bail ! (
182+ "merge: duplicate fields in children: {}" ,
183+ duplicate_names. into_iter( ) . format( ", " )
184+ )
185+ }
186+
149187 Ok ( DType :: Struct (
150188 StructFields :: new ( FieldNames :: from ( field_names) , arrays) ,
151- nullability ,
189+ merge_nullability ,
152190 ) )
153191 }
154192}
155193
156194impl MergeExpr {
157195 pub fn new ( values : Vec < ExprRef > ) -> Self {
158- MergeExpr { values }
196+ MergeExpr {
197+ values,
198+ duplicate_handling : DuplicateHandling :: default ( ) ,
199+ }
159200 }
160201
161202 pub fn new_expr ( values : Vec < ExprRef > ) -> ExprRef {
162203 Self :: new ( values) . into_expr ( )
163204 }
205+
206+ pub fn new_opts ( values : Vec < ExprRef > , duplicate_handling : DuplicateHandling ) -> Self {
207+ MergeExpr {
208+ values,
209+ duplicate_handling,
210+ }
211+ }
212+
213+ pub fn new_expr_opts ( values : Vec < ExprRef > , duplicate_handling : DuplicateHandling ) -> ExprRef {
214+ Self :: new_opts ( values, duplicate_handling) . into_expr ( )
215+ }
164216}
165217
166218/// Creates an expression that merges struct expressions into a single struct.
@@ -178,11 +230,24 @@ pub fn merge(elements: impl IntoIterator<Item = impl Into<ExprRef>>) -> ExprRef
178230 MergeExpr :: new ( values) . into_expr ( )
179231}
180232
233+ pub fn merge_opts (
234+ elements : impl IntoIterator < Item = impl Into < ExprRef > > ,
235+ duplicate_handling : DuplicateHandling ,
236+ ) -> ExprRef {
237+ let values = elements. into_iter ( ) . map ( |value| value. into ( ) ) . collect_vec ( ) ;
238+ MergeExpr :: new_opts ( values, duplicate_handling) . into_expr ( )
239+ }
240+
181241impl DisplayAs for MergeExpr {
182242 fn fmt_as ( & self , df : DisplayFormat , f : & mut std:: fmt:: Formatter ) -> std:: fmt:: Result {
183243 match df {
184244 DisplayFormat :: Compact => {
185- write ! ( f, "merge({})" , self . values. iter( ) . format( ", " ) , )
245+ write ! (
246+ f,
247+ "merge[{}]({})" ,
248+ self . duplicate_handling,
249+ self . values. iter( ) . format( ", " ) ,
250+ )
186251 }
187252 DisplayFormat :: Tree => {
188253 write ! ( f, "Merge" )
@@ -200,7 +265,7 @@ mod tests {
200265 use vortex_buffer:: buffer;
201266 use vortex_error:: { VortexResult , vortex_bail} ;
202267
203- use crate :: { MergeExpr , Scope , get_item, merge, root} ;
268+ use crate :: { DuplicateHandling , MergeExpr , Scope , get_item, merge, root} ;
204269
205270 fn primitive_field ( array : & dyn Array , field_path : & [ & str ] ) -> VortexResult < PrimitiveArray > {
206271 let mut field_path = field_path. iter ( ) ;
@@ -217,12 +282,15 @@ mod tests {
217282 }
218283
219284 #[ test]
220- pub fn test_merge ( ) {
221- let expr = MergeExpr :: new ( vec ! [
222- get_item( "0" , root( ) ) ,
223- get_item( "1" , root( ) ) ,
224- get_item( "2" , root( ) ) ,
225- ] ) ;
285+ pub fn test_merge_right_most ( ) {
286+ let expr = MergeExpr :: new_opts (
287+ vec ! [
288+ get_item( "0" , root( ) ) ,
289+ get_item( "1" , root( ) ) ,
290+ get_item( "2" , root( ) ) ,
291+ ] ,
292+ DuplicateHandling :: RightMost ,
293+ ) ;
226294
227295 let test_array = StructArray :: from_fields ( & [
228296 (
@@ -294,6 +362,52 @@ mod tests {
294362 ) ;
295363 }
296364
365+ #[ test]
366+ #[ should_panic( expected = "merge: duplicate fields in children" ) ]
367+ pub fn test_merge_error_on_dupe_return_dtype ( ) {
368+ let expr = MergeExpr :: new_opts (
369+ vec ! [ get_item( "0" , root( ) ) , get_item( "1" , root( ) ) ] ,
370+ DuplicateHandling :: Error ,
371+ ) ;
372+ let test_array = StructArray :: try_from_iter ( [
373+ (
374+ "0" ,
375+ StructArray :: try_from_iter ( [ ( "a" , buffer ! [ 1 ] ) , ( "b" , buffer ! [ 1 ] ) ] ) . unwrap ( ) ,
376+ ) ,
377+ (
378+ "1" ,
379+ StructArray :: try_from_iter ( [ ( "c" , buffer ! [ 1 ] ) , ( "b" , buffer ! [ 1 ] ) ] ) . unwrap ( ) ,
380+ ) ,
381+ ] )
382+ . unwrap ( )
383+ . into_array ( ) ;
384+
385+ expr. return_dtype ( test_array. dtype ( ) ) . unwrap ( ) ;
386+ }
387+
388+ #[ test]
389+ #[ should_panic( expected = "merge: duplicate fields in children" ) ]
390+ pub fn test_merge_error_on_dupe_evaluate ( ) {
391+ let expr = MergeExpr :: new_opts (
392+ vec ! [ get_item( "0" , root( ) ) , get_item( "1" , root( ) ) ] ,
393+ DuplicateHandling :: Error ,
394+ ) ;
395+ let test_array = StructArray :: try_from_iter ( [
396+ (
397+ "0" ,
398+ StructArray :: try_from_iter ( [ ( "a" , buffer ! [ 1 ] ) , ( "b" , buffer ! [ 1 ] ) ] ) . unwrap ( ) ,
399+ ) ,
400+ (
401+ "1" ,
402+ StructArray :: try_from_iter ( [ ( "c" , buffer ! [ 1 ] ) , ( "b" , buffer ! [ 1 ] ) ] ) . unwrap ( ) ,
403+ ) ,
404+ ] )
405+ . unwrap ( )
406+ . into_array ( ) ;
407+
408+ expr. evaluate ( & Scope :: new ( test_array) ) . unwrap ( ) ;
409+ }
410+
297411 #[ test]
298412 pub fn test_empty_merge ( ) {
299413 let expr = MergeExpr :: new ( Vec :: new ( ) ) ;
@@ -310,7 +424,10 @@ mod tests {
310424 pub fn test_nested_merge ( ) {
311425 // Nested structs are not merged!
312426
313- let expr = MergeExpr :: new ( vec ! [ get_item( "0" , root( ) ) , get_item( "1" , root( ) ) ] ) ;
427+ let expr = MergeExpr :: new_opts (
428+ vec ! [ get_item( "0" , root( ) ) , get_item( "1" , root( ) ) ] ,
429+ DuplicateHandling :: RightMost ,
430+ ) ;
314431
315432 let test_array = StructArray :: from_fields ( & [
316433 (
@@ -396,9 +513,9 @@ mod tests {
396513 #[ test]
397514 pub fn test_display ( ) {
398515 let expr = merge ( [ get_item ( "struct1" , root ( ) ) , get_item ( "struct2" , root ( ) ) ] ) ;
399- assert_eq ! ( expr. to_string( ) , "merge($.struct1, $.struct2)" ) ;
516+ assert_eq ! ( expr. to_string( ) , "merge[error] ($.struct1, $.struct2)" ) ;
400517
401518 let expr2 = MergeExpr :: new ( vec ! [ get_item( "a" , root( ) ) ] ) ;
402- assert_eq ! ( expr2. to_string( ) , "merge($.a)" ) ;
519+ assert_eq ! ( expr2. to_string( ) , "merge[error] ($.a)" ) ;
403520 }
404521}
0 commit comments