11use vortex_buffer:: BufferMut ;
2- use vortex_dtype:: PType ;
2+ use vortex_dtype:: { DType , PType } ;
33use vortex_error:: VortexResult ;
44
5- use crate :: arrays:: ChunkedVTable ;
65use crate :: arrays:: chunked:: ChunkedArray ;
6+ use crate :: arrays:: { ChunkedVTable , PrimitiveArray } ;
77use crate :: compute:: { TakeKernel , TakeKernelAdapter , cast, take} ;
8+ use crate :: validity:: Validity ;
89use crate :: { Array , ArrayRef , IntoArray , ToCanonical , register_kernel} ;
910
1011impl TakeKernel for ChunkedVTable {
1112 fn take ( & self , array : & ChunkedArray , indices : & dyn Array ) -> VortexResult < ArrayRef > {
12- let indices = cast ( indices, PType :: U64 . into ( ) ) ?. to_primitive ( ) ?;
13+ let indices = cast (
14+ indices,
15+ & DType :: Primitive ( PType :: U64 , indices. dtype ( ) . nullability ( ) ) ,
16+ ) ?
17+ . to_primitive ( ) ?;
18+
19+ // TODO(joe): Should we split this implementation based on indices nullability?
20+ let nullability = indices. dtype ( ) . nullability ( ) ;
21+ let indices_mask = indices. validity_mask ( ) ?;
22+ let indices = indices. as_slice :: < u64 > ( ) ;
1323
14- // While the chunk idx remains the same, accumulate a list of chunk indices.
1524 let mut chunks = Vec :: new ( ) ;
1625 let mut indices_in_chunk = BufferMut :: < u64 > :: empty ( ) ;
17- let mut prev_chunk_idx = array
18- . find_chunk_idx ( indices . as_slice :: < u64 > ( ) [ 0 ] . try_into ( ) ? )
19- . 0 ;
20- for idx in indices. as_slice :: < u64 > ( ) {
26+ let mut start = 0 ;
27+ let mut stop = 0 ;
28+ let mut prev_chunk_idx = array . find_chunk_idx ( indices [ 0 ] . try_into ( ) ? ) . 0 ;
29+ for idx in indices {
2130 let idx = usize:: try_from ( * idx) ?;
2231 let ( chunk_idx, idx_in_chunk) = array. find_chunk_idx ( idx) ;
2332
2433 if chunk_idx != prev_chunk_idx {
2534 // Start a new chunk
26- let indices_in_chunk_array = indices_in_chunk. clone ( ) . into_array ( ) ;
27- chunks. push ( take ( array. chunk ( prev_chunk_idx) ?, & indices_in_chunk_array) ?) ;
35+ let indices_in_chunk_array = PrimitiveArray :: new (
36+ indices_in_chunk. clone ( ) . freeze ( ) ,
37+ Validity :: from_mask ( indices_mask. slice ( start, stop - start) , nullability) ,
38+ ) ;
39+ chunks. push ( take (
40+ array. chunk ( prev_chunk_idx) ?,
41+ indices_in_chunk_array. as_ref ( ) ,
42+ ) ?) ;
2843 indices_in_chunk. clear ( ) ;
44+ start = stop;
2945 }
3046
3147 indices_in_chunk. push ( idx_in_chunk as u64 ) ;
48+ stop += 1 ;
3249 prev_chunk_idx = chunk_idx;
3350 }
3451
3552 if !indices_in_chunk. is_empty ( ) {
36- let indices_in_chunk_array = indices_in_chunk. into_array ( ) ;
37- chunks. push ( take ( array. chunk ( prev_chunk_idx) ?, & indices_in_chunk_array) ?) ;
53+ let indices_in_chunk_array = PrimitiveArray :: new (
54+ indices_in_chunk. freeze ( ) ,
55+ Validity :: from_mask ( indices_mask. slice ( start, stop - start) , nullability) ,
56+ ) ;
57+ chunks. push ( take (
58+ array. chunk ( prev_chunk_idx) ?,
59+ indices_in_chunk_array. as_ref ( ) ,
60+ ) ?) ;
3861 }
3962
40- Ok ( ChunkedArray :: new_unchecked ( chunks, array. dtype ( ) . clone ( ) ) . into_array ( ) )
63+ Ok ( ChunkedArray :: new_unchecked (
64+ chunks,
65+ array. dtype ( ) . clone ( ) . union_nullability ( nullability) ,
66+ )
67+ . into_array ( ) )
4168 }
4269}
4370
@@ -50,8 +77,10 @@ mod test {
5077 use crate :: IntoArray ;
5178 use crate :: array:: Array ;
5279 use crate :: arrays:: chunked:: ChunkedArray ;
80+ use crate :: arrays:: { BoolArray , PrimitiveArray , StructArray } ;
5381 use crate :: canonical:: ToCanonical ;
5482 use crate :: compute:: take;
83+ use crate :: validity:: Validity ;
5584
5685 #[ test]
5786 fn test_take ( ) {
@@ -68,4 +97,30 @@ mod test {
6897 . unwrap ( ) ;
6998 assert_eq ! ( result. as_slice:: <i32 >( ) , & [ 1 , 1 , 1 , 2 ] ) ;
7099 }
100+
101+ #[ test]
102+ fn test_take_nullability ( ) {
103+ let struct_array =
104+ StructArray :: try_new ( [ ] . into ( ) , vec ! [ ] , 100 , Validity :: NonNullable ) . unwrap ( ) ;
105+
106+ let arr = ChunkedArray :: from_iter ( vec ! [ struct_array. to_array( ) , struct_array. to_array( ) ] ) ;
107+
108+ let result = take (
109+ arr. as_ref ( ) ,
110+ PrimitiveArray :: from_option_iter ( vec ! [ Some ( 0 ) , None , Some ( 101 ) ] ) . as_ref ( ) ,
111+ )
112+ . unwrap ( ) ;
113+
114+ let expect = StructArray :: try_new (
115+ [ ] . into ( ) ,
116+ vec ! [ ] ,
117+ 3 ,
118+ Validity :: Array ( BoolArray :: from_iter ( vec ! [ true , false , true ] ) . to_array ( ) ) ,
119+ )
120+ . unwrap ( ) ;
121+ assert_eq ! ( result. dtype( ) , expect. dtype( ) ) ;
122+ assert_eq ! ( result. scalar_at( 0 ) . unwrap( ) , expect. scalar_at( 0 ) . unwrap( ) ) ;
123+ assert_eq ! ( result. scalar_at( 1 ) . unwrap( ) , expect. scalar_at( 1 ) . unwrap( ) ) ;
124+ assert_eq ! ( result. scalar_at( 2 ) . unwrap( ) , expect. scalar_at( 2 ) . unwrap( ) ) ;
125+ }
71126}
0 commit comments