11// SPDX-License-Identifier: Apache-2.0
22// SPDX-FileCopyrightText: Copyright the Vortex contributors
33
4- use vortex_buffer:: BitBufferMut ;
54use vortex_dtype:: IntegerPType ;
65use vortex_dtype:: Nullability ;
76use vortex_dtype:: match_each_integer_ptype;
87use vortex_dtype:: match_smallest_offset_type;
98use vortex_error:: VortexExpect ;
109use vortex_error:: VortexResult ;
11- use vortex_mask:: Mask ;
1210
1311use crate :: Array ;
1412use crate :: ArrayRef ;
@@ -22,7 +20,6 @@ use crate::compute::TakeKernel;
2220use crate :: compute:: TakeKernelAdapter ;
2321use crate :: compute:: take;
2422use crate :: register_kernel;
25- use crate :: validity:: Validity ;
2623use crate :: vtable:: ValidityHelper ;
2724
2825// TODO(connor)[ListView]: Re-revert to the version where we simply convert to a `ListView` and call
@@ -37,21 +34,13 @@ impl TakeKernel for ListVTable {
3734 #[ expect( clippy:: cognitive_complexity) ]
3835 fn take ( & self , array : & ListArray , indices : & dyn Array ) -> VortexResult < ArrayRef > {
3936 let indices = indices. to_primitive ( ) ;
40- let offsets = array. offsets ( ) . to_primitive ( ) ;
4137 // This is an over-approximation of the total number of elements in the resulting array.
4238 let total_approx = array. elements ( ) . len ( ) . saturating_mul ( indices. len ( ) ) ;
4339
44- match_each_integer_ptype ! ( offsets. dtype( ) . as_ptype( ) , |O | {
45- let offsets_slice = offsets. as_slice:: <O >( ) ;
40+ match_each_integer_ptype ! ( array. offsets( ) . dtype( ) . as_ptype( ) , |O | {
4641 match_each_integer_ptype!( indices. ptype( ) , |I | {
4742 match_smallest_offset_type!( total_approx, |OutputOffsetType | {
48- _take:: <I , O , OutputOffsetType >(
49- array,
50- offsets_slice,
51- & indices,
52- array. validity_mask( ) ,
53- indices. validity_mask( ) ,
54- )
43+ _take:: <I , O , OutputOffsetType >( array, & indices)
5544 } )
5645 } )
5746 } )
@@ -62,23 +51,19 @@ register_kernel!(TakeKernelAdapter(ListVTable).lift());
6251
6352fn _take < I : IntegerPType , O : IntegerPType , OutputOffsetType : IntegerPType > (
6453 array : & ListArray ,
65- offsets : & [ O ] ,
6654 indices_array : & PrimitiveArray ,
67- data_validity : Mask ,
68- indices_validity_mask : Mask ,
6955) -> VortexResult < ArrayRef > {
70- let indices: & [ I ] = indices_array. as_slice :: < I > ( ) ;
71-
72- if !indices_validity_mask. all_true ( ) || !data_validity. all_true ( ) {
73- return _take_nullable :: < I , O , OutputOffsetType > (
74- array,
75- offsets,
76- indices,
77- data_validity,
78- indices_validity_mask,
79- ) ;
56+ let data_validity = array. validity_mask ( ) ;
57+ let indices_validity = indices_array. validity_mask ( ) ;
58+
59+ if !indices_validity. all_true ( ) || !data_validity. all_true ( ) {
60+ return _take_nullable :: < I , O , OutputOffsetType > ( array, indices_array) ;
8061 }
8162
63+ let offsets_array = array. offsets ( ) . to_primitive ( ) ;
64+ let offsets: & [ O ] = offsets_array. as_slice ( ) ;
65+ let indices: & [ I ] = indices_array. as_slice ( ) ;
66+
8267 let mut new_offsets = PrimitiveBuilder :: < OutputOffsetType > :: with_capacity (
8368 Nullability :: NonNullable ,
8469 indices. len ( ) ,
@@ -120,21 +105,21 @@ fn _take<I: IntegerPType, O: IntegerPType, OutputOffsetType: IntegerPType>(
120105 Ok ( ListArray :: try_new (
121106 new_elements,
122107 new_offsets,
123- indices_array
124- . validity ( )
125- . clone ( )
126- . and ( array. validity ( ) . clone ( ) ) ,
108+ array. validity ( ) . clone ( ) . take ( indices_array. as_ref ( ) ) ?,
127109 ) ?
128110 . to_array ( ) )
129111}
130112
131113fn _take_nullable < I : IntegerPType , O : IntegerPType , OutputOffsetType : IntegerPType > (
132114 array : & ListArray ,
133- offsets : & [ O ] ,
134- indices : & [ I ] ,
135- data_validity : Mask ,
136- indices_validity : Mask ,
115+ indices_array : & PrimitiveArray ,
137116) -> VortexResult < ArrayRef > {
117+ let offsets_array = array. offsets ( ) . to_primitive ( ) ;
118+ let offsets: & [ O ] = offsets_array. as_slice ( ) ;
119+ let indices: & [ I ] = indices_array. as_slice ( ) ;
120+ let data_validity = array. validity_mask ( ) ;
121+ let indices_validity = indices_array. validity_mask ( ) ;
122+
138123 let mut new_offsets = PrimitiveBuilder :: < OutputOffsetType > :: with_capacity (
139124 Nullability :: NonNullable ,
140125 indices. len ( ) ,
@@ -153,28 +138,23 @@ fn _take_nullable<I: IntegerPType, O: IntegerPType, OutputOffsetType: IntegerPTy
153138 let mut current_offset = OutputOffsetType :: zero ( ) ;
154139 new_offsets. append_zero ( ) ;
155140
156- // Set all bits to invalid and selectively set which values are valid.
157- let mut new_validity = BitBufferMut :: new_unset ( indices. len ( ) ) ;
158-
159141 for ( idx, data_idx) in indices. iter ( ) . enumerate ( ) {
160142 if !indices_validity. value ( idx) {
161143 new_offsets. append_value ( current_offset) ;
162- // Bit buffer already has this set to invalid.
163144 continue ;
164145 }
165146
166147 let data_idx: usize = data_idx. as_ ( ) ;
167148
168149 if !data_validity. value ( data_idx) {
169150 new_offsets. append_value ( current_offset) ;
170- // Bit buffer already has this set to invalid.
171151 continue ;
172152 }
173153
174154 let start = offsets[ data_idx] ;
175155 let stop = offsets[ data_idx + 1 ] ;
176156
177- // See the note it the `take ` on the reasoning
157+ // See the note in `_take ` on the reasoning.
178158 let additional: usize = ( stop - start) . as_ ( ) ;
179159
180160 elements_to_take. reserve_exact ( additional) ;
@@ -184,17 +164,18 @@ fn _take_nullable<I: IntegerPType, O: IntegerPType, OutputOffsetType: IntegerPTy
184164 current_offset +=
185165 OutputOffsetType :: from_usize ( ( stop - start) . as_ ( ) ) . vortex_expect ( "offset conversion" ) ;
186166 new_offsets. append_value ( current_offset) ;
187- new_validity. set ( idx) ;
188167 }
189168
190169 let elements_to_take = elements_to_take. finish ( ) ;
191170 let new_offsets = new_offsets. finish ( ) ;
192171 let new_elements = take ( array. elements ( ) , elements_to_take. as_ref ( ) ) ?;
193172
194- let new_validity = Validity :: from ( new_validity. freeze ( ) ) ;
195- // data are indexes are nullable, so the final result is also nullable.
196-
197- Ok ( ListArray :: try_new ( new_elements, new_offsets, new_validity) ?. to_array ( ) )
173+ Ok ( ListArray :: try_new (
174+ new_elements,
175+ new_offsets,
176+ array. validity ( ) . clone ( ) . take ( indices_array. as_ref ( ) ) ?,
177+ ) ?
178+ . to_array ( ) )
198179}
199180
200181#[ cfg( test) ]
@@ -460,4 +441,27 @@ mod test {
460441 assert ! ( result_view. is_invalid( 1 ) ) ;
461442 assert ! ( result_view. is_valid( 2 ) ) ;
462443 }
444+
445+ /// Regression test for validity length mismatch bug.
446+ ///
447+ /// When source array has `Validity::Array(...)` and indices are non-nullable,
448+ /// the result validity must have length equal to indices.len(), not source.len().
449+ #[ test]
450+ fn test_take_validity_length_mismatch_regression ( ) {
451+ // Source array with explicit validity array (length 2).
452+ let list = ListArray :: try_new (
453+ buffer ! [ 1i32 , 2 , 3 , 4 ] . into_array ( ) ,
454+ buffer ! [ 0 , 2 , 4 ] . into_array ( ) ,
455+ Validity :: Array ( BoolArray :: from_iter ( vec ! [ true , true ] ) . to_array ( ) ) ,
456+ )
457+ . unwrap ( )
458+ . to_array ( ) ;
459+
460+ // Take more indices than source length (4 vs 2) with non-nullable indices.
461+ let idx = buffer ! [ 0u32 , 1 , 0 , 1 ] . into_array ( ) ;
462+
463+ // This should not panic - result should have length 4.
464+ let result = take ( & list, & idx) . unwrap ( ) ;
465+ assert_eq ! ( result. len( ) , 4 ) ;
466+ }
463467}
0 commit comments