@@ -3,24 +3,43 @@ mod filter;
33mod mask;
44
55use itertools:: Itertools ;
6+ use vortex_dtype:: Nullability :: NonNullable ;
67use vortex_error:: VortexResult ;
8+ use vortex_scalar:: Scalar ;
79
810use crate :: arrays:: StructVTable ;
911use crate :: arrays:: struct_:: StructArray ;
1012use crate :: compute:: {
1113 IsConstantKernel , IsConstantKernelAdapter , IsConstantOpts , MinMaxKernel , MinMaxKernelAdapter ,
12- MinMaxResult , TakeKernel , TakeKernelAdapter , is_constant_opts, take,
14+ MinMaxResult , TakeKernel , TakeKernelAdapter , fill_null , is_constant_opts, take,
1315} ;
16+ use crate :: validity:: Validity ;
1417use crate :: vtable:: ValidityHelper ;
1518use crate :: { Array , ArrayRef , IntoArray , register_kernel} ;
1619
1720impl TakeKernel for StructVTable {
1821 fn take ( & self , array : & StructArray , indices : & dyn Array ) -> VortexResult < ArrayRef > {
22+ // If the struct array is empty then the indices must be all null, otherwise it will access
23+ // an out of bounds element
24+ if array. is_empty ( ) {
25+ return StructArray :: try_new_with_dtype (
26+ array. fields ( ) . to_vec ( ) ,
27+ array. struct_fields ( ) . clone ( ) ,
28+ indices. len ( ) ,
29+ Validity :: AllInvalid ,
30+ )
31+ . map ( StructArray :: into_array) ;
32+ }
33+ // The validity is applied to the struct validity,
34+ let inner_indices = & fill_null (
35+ indices,
36+ & Scalar :: default_value ( indices. dtype ( ) . with_nullability ( NonNullable ) ) ,
37+ ) ?;
1938 StructArray :: try_new_with_dtype (
2039 array
2140 . fields ( )
2241 . iter ( )
23- . map ( |field| take ( field, indices ) )
42+ . map ( |field| take ( field, inner_indices ) )
2443 . try_collect ( ) ?,
2544 array. struct_fields ( ) . clone ( ) ,
2645 indices. len ( ) ,
@@ -71,13 +90,15 @@ register_kernel!(IsConstantKernelAdapter(StructVTable).lift());
7190mod tests {
7291 use std:: sync:: Arc ;
7392
93+ use Nullability :: { NonNullable , Nullable } ;
7494 use vortex_buffer:: buffer;
7595 use vortex_dtype:: { DType , FieldNames , Nullability , PType , StructFields } ;
7696 use vortex_mask:: Mask ;
97+ use vortex_scalar:: Scalar ;
7798
7899 use crate :: arrays:: { BoolArray , BooleanBuffer , PrimitiveArray , StructArray , VarBinArray } ;
79100 use crate :: compute:: conformance:: mask:: test_mask;
80- use crate :: compute:: { cast, filter} ;
101+ use crate :: compute:: { cast, filter, take } ;
81102 use crate :: validity:: Validity ;
82103 use crate :: { Array , IntoArray as _} ;
83104
@@ -92,6 +113,52 @@ mod tests {
92113 assert_eq ! ( filtered. len( ) , 5 ) ;
93114 }
94115
116+ #[ test]
117+ fn take_empty_struct ( ) {
118+ let struct_arr =
119+ StructArray :: try_new ( vec ! [ ] . into ( ) , vec ! [ ] , 10 , Validity :: NonNullable ) . unwrap ( ) ;
120+ let indices = PrimitiveArray :: from_option_iter ( [ Some ( 1 ) , None ] ) ;
121+ let taken = take ( struct_arr. as_ref ( ) , indices. as_ref ( ) ) . unwrap ( ) ;
122+ assert_eq ! ( taken. len( ) , 2 ) ;
123+
124+ assert_eq ! (
125+ taken. scalar_at( 0 ) . unwrap( ) ,
126+ Scalar :: struct_(
127+ DType :: Struct ( Arc :: new( StructFields :: new( [ ] . into( ) , vec![ ] ) ) , Nullable ) ,
128+ vec![ ]
129+ )
130+ ) ;
131+ assert_eq ! (
132+ taken. scalar_at( 1 ) . unwrap( ) ,
133+ Scalar :: null( DType :: Struct (
134+ Arc :: new( StructFields :: new( [ ] . into( ) , vec![ ] ) ) ,
135+ Nullable
136+ ) )
137+ ) ;
138+ }
139+
140+ #[ test]
141+ fn take_field_struct ( ) {
142+ let struct_arr =
143+ StructArray :: from_fields ( & [ ( "a" , PrimitiveArray :: from_iter ( 0 ..10 ) . to_array ( ) ) ] )
144+ . unwrap ( ) ;
145+ let indices = PrimitiveArray :: from_option_iter ( [ Some ( 1 ) , None ] ) ;
146+ let taken = take ( struct_arr. as_ref ( ) , indices. as_ref ( ) ) . unwrap ( ) ;
147+ assert_eq ! ( taken. len( ) , 2 ) ;
148+
149+ assert_eq ! (
150+ taken. scalar_at( 0 ) . unwrap( ) ,
151+ Scalar :: struct_(
152+ struct_arr. dtype( ) . union_nullability( Nullable ) ,
153+ vec![ Scalar :: primitive( 1 , NonNullable ) ] ,
154+ )
155+ ) ;
156+ assert_eq ! (
157+ taken. scalar_at( 1 ) . unwrap( ) ,
158+ Scalar :: null( struct_arr. dtype( ) . union_nullability( Nullable ) , )
159+ ) ;
160+ }
161+
95162 #[ test]
96163 fn filter_empty_struct_with_empty_filter ( ) {
97164 let struct_arr =
@@ -114,7 +181,7 @@ mod tests {
114181 let xs = buffer ! [ 0i64 , 1 , 2 , 3 , 4 ] . into_array ( ) ;
115182 let ys = VarBinArray :: from_iter (
116183 [ Some ( "a" ) , Some ( "b" ) , None , Some ( "d" ) , None ] ,
117- DType :: Utf8 ( Nullability :: Nullable ) ,
184+ DType :: Utf8 ( Nullable ) ,
118185 )
119186 . into_array ( ) ;
120187 let zs =
@@ -148,17 +215,13 @@ mod tests {
148215 let array = StructArray :: try_new ( vec ! [ ] . into ( ) , vec ! [ ] , 5 , Validity :: NonNullable )
149216 . unwrap ( )
150217 . into_array ( ) ;
151- let non_nullable_dtype = DType :: Struct (
152- Arc :: from ( StructFields :: new ( [ ] . into ( ) , vec ! [ ] ) ) ,
153- Nullability :: NonNullable ,
154- ) ;
218+ let non_nullable_dtype =
219+ DType :: Struct ( Arc :: from ( StructFields :: new ( [ ] . into ( ) , vec ! [ ] ) ) , NonNullable ) ;
155220 let casted = cast ( & array, & non_nullable_dtype) . unwrap ( ) ;
156221 assert_eq ! ( casted. dtype( ) , & non_nullable_dtype) ;
157222
158- let nullable_dtype = DType :: Struct (
159- Arc :: from ( StructFields :: new ( [ ] . into ( ) , vec ! [ ] ) ) ,
160- Nullability :: Nullable ,
161- ) ;
223+ let nullable_dtype =
224+ DType :: Struct ( Arc :: from ( StructFields :: new ( [ ] . into ( ) , vec ! [ ] ) ) , Nullable ) ;
162225 let casted = cast ( & array, & nullable_dtype) . unwrap ( ) ;
163226 assert_eq ! ( casted. dtype( ) , & nullable_dtype) ;
164227 }
@@ -177,7 +240,7 @@ mod tests {
177240 )
178241 . unwrap ( ) ;
179242
180- let tu8 = DType :: Primitive ( PType :: U8 , Nullability :: NonNullable ) ;
243+ let tu8 = DType :: Primitive ( PType :: U8 , NonNullable ) ;
181244
182245 let result = cast (
183246 array. as_ref ( ) ,
@@ -186,7 +249,7 @@ mod tests {
186249 FieldNames :: from ( [ "ys" . into ( ) , "xs" . into ( ) , "zs" . into ( ) ] ) ,
187250 vec ! [ tu8. clone( ) , tu8. clone( ) , tu8] ,
188251 ) ) ,
189- Nullability :: NonNullable ,
252+ NonNullable ,
190253 ) ,
191254 ) ;
192255 assert ! (
@@ -201,10 +264,7 @@ mod tests {
201264 #[ test]
202265 fn test_cast_complex_struct ( ) {
203266 let xs = PrimitiveArray :: from_option_iter ( [ Some ( 0i64 ) , Some ( 1 ) , Some ( 2 ) , Some ( 3 ) , Some ( 4 ) ] ) ;
204- let ys = VarBinArray :: from_vec (
205- vec ! [ "a" , "b" , "c" , "d" , "e" ] ,
206- DType :: Utf8 ( Nullability :: Nullable ) ,
207- ) ;
267+ let ys = VarBinArray :: from_vec ( vec ! [ "a" , "b" , "c" , "d" , "e" ] , DType :: Utf8 ( Nullable ) ) ;
208268 let zs = BoolArray :: new (
209269 BooleanBuffer :: from_iter ( [ true , true , false , false , true ] ) ,
210270 Validity :: AllValid ,
@@ -241,17 +301,17 @@ mod tests {
241301 Arc :: from( StructFields :: new(
242302 [ "left" . into( ) , "right" . into( ) ] . into( ) ,
243303 vec![
244- DType :: Primitive ( PType :: I64 , Nullability :: NonNullable ) ,
245- DType :: Primitive ( PType :: I64 , Nullability :: Nullable ) ,
304+ DType :: Primitive ( PType :: I64 , NonNullable ) ,
305+ DType :: Primitive ( PType :: I64 , Nullable ) ,
246306 ] ,
247307 ) ) ,
248- Nullability :: Nullable ,
308+ Nullable ,
249309 ) ,
250- DType :: Utf8 ( Nullability :: Nullable ) ,
251- DType :: Bool ( Nullability :: Nullable ) ,
310+ DType :: Utf8 ( Nullable ) ,
311+ DType :: Bool ( Nullable ) ,
252312 ] ,
253313 ) ) ,
254- Nullability :: Nullable ,
314+ Nullable ,
255315 ) ;
256316 let casted = cast ( & fully_nullable_array, & non_null_xs_right) . unwrap ( ) ;
257317 assert_eq ! ( casted. dtype( ) , & non_null_xs_right) ;
@@ -264,17 +324,17 @@ mod tests {
264324 Arc :: from( StructFields :: new(
265325 [ "left" . into( ) , "right" . into( ) ] . into( ) ,
266326 vec![
267- DType :: Primitive ( PType :: I64 , Nullability :: Nullable ) ,
268- DType :: Primitive ( PType :: I64 , Nullability :: Nullable ) ,
327+ DType :: Primitive ( PType :: I64 , Nullable ) ,
328+ DType :: Primitive ( PType :: I64 , Nullable ) ,
269329 ] ,
270330 ) ) ,
271- Nullability :: NonNullable ,
331+ NonNullable ,
272332 ) ,
273- DType :: Utf8 ( Nullability :: Nullable ) ,
274- DType :: Bool ( Nullability :: Nullable ) ,
333+ DType :: Utf8 ( Nullable ) ,
334+ DType :: Bool ( Nullable ) ,
275335 ] ,
276336 ) ) ,
277- Nullability :: Nullable ,
337+ Nullable ,
278338 ) ;
279339 let casted = cast ( & fully_nullable_array, & non_null_xs) . unwrap ( ) ;
280340 assert_eq ! ( casted. dtype( ) , & non_null_xs) ;
0 commit comments