@@ -5,9 +5,9 @@ use vortex_buffer::BitBufferMut;
55use vortex_dtype:: IntegerPType ;
66use vortex_dtype:: Nullability ;
77use vortex_dtype:: match_each_integer_ptype;
8+ use vortex_dtype:: match_smallest_offset_type;
89use vortex_error:: VortexExpect ;
910use vortex_error:: VortexResult ;
10- use vortex_error:: vortex_panic;
1111use vortex_mask:: Mask ;
1212
1313use crate :: Array ;
@@ -34,27 +34,33 @@ use crate::vtable::ValidityHelper;
3434/// that lists are stored contiguously and in-order (`offset[i+1] >= offset[i]`). Taking
3535/// non-contiguous indices would violate this requirement.
3636impl TakeKernel for ListVTable {
37+ #[ expect( clippy:: cognitive_complexity) ]
3738 fn take ( & self , array : & ListArray , indices : & dyn Array ) -> VortexResult < ArrayRef > {
3839 let indices = indices. to_primitive ( ) ;
3940 let offsets = array. offsets ( ) . to_primitive ( ) ;
41+ // This is an over-approximation of the total number of elements in the resulting array.
42+ let total_approx = array. elements ( ) . len ( ) . saturating_mul ( indices. len ( ) ) ;
4043
4144 match_each_integer_ptype ! ( offsets. dtype( ) . as_ptype( ) , |O | {
45+ let offsets_slice = offsets. as_slice:: <O >( ) ;
4246 match_each_integer_ptype!( indices. ptype( ) , |I | {
43- _take:: <I , O >(
44- array,
45- offsets. as_slice:: <O >( ) ,
46- & indices,
47- array. validity_mask( ) ,
48- indices. validity_mask( ) ,
49- )
47+ match_smallest_offset_type!( total_approx, |OutputOffsetType | {
48+ _take:: <I , O , OutputOffsetType >(
49+ array,
50+ offsets_slice,
51+ & indices,
52+ array. validity_mask( ) ,
53+ indices. validity_mask( ) ,
54+ )
55+ } )
5056 } )
5157 } )
5258 }
5359}
5460
5561register_kernel ! ( TakeKernelAdapter ( ListVTable ) . lift( ) ) ;
5662
57- fn _take < I : IntegerPType , O : IntegerPType > (
63+ fn _take < I : IntegerPType , O : IntegerPType , OutputOffsetType : IntegerPType > (
5864 array : & ListArray ,
5965 offsets : & [ O ] ,
6066 indices_array : & PrimitiveArray ,
@@ -64,7 +70,7 @@ fn _take<I: IntegerPType, O: IntegerPType>(
6470 let indices: & [ I ] = indices_array. as_slice :: < I > ( ) ;
6571
6672 if !indices_validity_mask. all_true ( ) || !data_validity. all_true ( ) {
67- return _take_nullable :: < I , O > (
73+ return _take_nullable :: < I , O , OutputOffsetType > (
6874 array,
6975 offsets,
7076 indices,
@@ -73,17 +79,18 @@ fn _take<I: IntegerPType, O: IntegerPType>(
7379 ) ;
7480 }
7581
76- let mut new_offsets = PrimitiveBuilder :: with_capacity ( Nullability :: NonNullable , indices. len ( ) ) ;
82+ let mut new_offsets = PrimitiveBuilder :: < OutputOffsetType > :: with_capacity (
83+ Nullability :: NonNullable ,
84+ indices. len ( ) ,
85+ ) ;
7786 let mut elements_to_take =
7887 PrimitiveBuilder :: with_capacity ( Nullability :: NonNullable , 2 * indices. len ( ) ) ;
7988
80- let mut current_offset = O :: zero ( ) ;
89+ let mut current_offset = OutputOffsetType :: zero ( ) ;
8190 new_offsets. append_zero ( ) ;
8291
8392 for & data_idx in indices {
84- let data_idx = data_idx
85- . to_usize ( )
86- . unwrap_or_else ( || vortex_panic ! ( "Failed to convert index to usize: {}" , data_idx) ) ;
93+ let data_idx: usize = data_idx. as_ ( ) ;
8794
8895 let start = offsets[ data_idx] ;
8996 let stop = offsets[ data_idx + 1 ] ;
@@ -93,15 +100,15 @@ fn _take<I: IntegerPType, O: IntegerPType>(
93100 // We could convert start and end to usize, but that would impose a potentially
94101 // harder constraint - now we don't care if they fit into usize as long as their
95102 // difference does.
96- let additional = ( stop - start) . to_usize ( ) . unwrap_or_else ( || {
97- vortex_panic ! ( "Failed to convert range length to usize: {}" , stop - start)
98- } ) ;
103+ let additional: usize = ( stop - start) . as_ ( ) ;
99104
105+ // TODO(0ax1): optimize this
100106 elements_to_take. reserve_exact ( additional) ;
101107 for i in 0 ..additional {
102108 elements_to_take. append_value ( start + O :: from_usize ( i) . vortex_expect ( "i < additional" ) ) ;
103109 }
104- current_offset += stop - start;
110+ current_offset +=
111+ OutputOffsetType :: from_usize ( ( stop - start) . as_ ( ) ) . vortex_expect ( "offset conversion" ) ;
105112 new_offsets. append_value ( current_offset) ;
106113 }
107114
@@ -121,14 +128,17 @@ fn _take<I: IntegerPType, O: IntegerPType>(
121128 . to_array ( ) )
122129}
123130
124- fn _take_nullable < I : IntegerPType , O : IntegerPType > (
131+ fn _take_nullable < I : IntegerPType , O : IntegerPType , OutputOffsetType : IntegerPType > (
125132 array : & ListArray ,
126133 offsets : & [ O ] ,
127134 indices : & [ I ] ,
128135 data_validity : Mask ,
129136 indices_validity : Mask ,
130137) -> VortexResult < ArrayRef > {
131- let mut new_offsets = PrimitiveBuilder :: with_capacity ( Nullability :: NonNullable , indices. len ( ) ) ;
138+ let mut new_offsets = PrimitiveBuilder :: < OutputOffsetType > :: with_capacity (
139+ Nullability :: NonNullable ,
140+ indices. len ( ) ,
141+ ) ;
132142
133143 // This will be the indices we push down to the child array to call `take` with.
134144 //
@@ -140,7 +150,7 @@ fn _take_nullable<I: IntegerPType, O: IntegerPType>(
140150 let mut elements_to_take =
141151 PrimitiveBuilder :: < O > :: with_capacity ( Nullability :: NonNullable , 2 * indices. len ( ) ) ;
142152
143- let mut current_offset = O :: zero ( ) ;
153+ let mut current_offset = OutputOffsetType :: zero ( ) ;
144154 new_offsets. append_zero ( ) ;
145155
146156 // Set all bits to invalid and selectively set which values are valid.
@@ -153,9 +163,7 @@ fn _take_nullable<I: IntegerPType, O: IntegerPType>(
153163 continue ;
154164 }
155165
156- let data_idx = data_idx
157- . to_usize ( )
158- . unwrap_or_else ( || vortex_panic ! ( "Failed to convert index to usize: {}" , data_idx) ) ;
166+ let data_idx: usize = data_idx. as_ ( ) ;
159167
160168 if !data_validity. value ( data_idx) {
161169 new_offsets. append_value ( current_offset) ;
@@ -167,15 +175,14 @@ fn _take_nullable<I: IntegerPType, O: IntegerPType>(
167175 let stop = offsets[ data_idx + 1 ] ;
168176
169177 // See the note it the `take` on the reasoning
170- let additional = ( stop - start) . to_usize ( ) . unwrap_or_else ( || {
171- vortex_panic ! ( "Failed to convert range length to usize: {}" , stop - start)
172- } ) ;
178+ let additional: usize = ( stop - start) . as_ ( ) ;
173179
174180 elements_to_take. reserve_exact ( additional) ;
175181 for i in 0 ..additional {
176182 elements_to_take. append_value ( start + O :: from_usize ( i) . vortex_expect ( "i < additional" ) ) ;
177183 }
178- current_offset += stop - start;
184+ current_offset +=
185+ OutputOffsetType :: from_usize ( ( stop - start) . as_ ( ) ) . vortex_expect ( "offset conversion" ) ;
179186 new_offsets. append_value ( current_offset) ;
180187 new_validity. set ( idx) ;
181188 }
@@ -411,4 +418,46 @@ mod test {
411418 fn test_take_list_conformance ( #[ case] list : ListArray ) {
412419 test_take_conformance ( list. as_ref ( ) ) ;
413420 }
421+
422+ #[ test]
423+ fn test_u64_offset_accumulation_non_nullable ( ) {
424+ let elements = buffer ! [ 0i32 ; 200 ] . into_array ( ) ;
425+ let offsets = buffer ! [ 0u8 , 200 ] . into_array ( ) ;
426+ let list = ListArray :: try_new ( elements, offsets, Validity :: NonNullable )
427+ . unwrap ( )
428+ . to_array ( ) ;
429+
430+ // Take the same large list twice - would overflow u8 but works with u64.
431+ let idx = buffer ! [ 0u8 , 0 ] . into_array ( ) ;
432+ let result = take ( & list, & idx) . unwrap ( ) ;
433+
434+ assert_eq ! ( result. len( ) , 2 ) ;
435+
436+ let result_view = result. to_listview ( ) ;
437+ assert_eq ! ( result_view. len( ) , 2 ) ;
438+ assert ! ( result_view. is_valid( 0 ) ) ;
439+ assert ! ( result_view. is_valid( 1 ) ) ;
440+ }
441+
442+ #[ test]
443+ fn test_u64_offset_accumulation_nullable ( ) {
444+ let elements = buffer ! [ 0i32 ; 150 ] . into_array ( ) ;
445+ let offsets = buffer ! [ 0u8 , 150 , 150 ] . into_array ( ) ;
446+ let validity = BoolArray :: from_iter ( vec ! [ true , false ] ) . to_array ( ) ;
447+ let list = ListArray :: try_new ( elements, offsets, Validity :: Array ( validity) )
448+ . unwrap ( )
449+ . to_array ( ) ;
450+
451+ // Take the same large list twice - would overflow u8 but works with u64.
452+ let idx = PrimitiveArray :: from_option_iter ( vec ! [ Some ( 0u8 ) , None , Some ( 0u8 ) ] ) . to_array ( ) ;
453+ let result = take ( & list, & idx) . unwrap ( ) ;
454+
455+ assert_eq ! ( result. len( ) , 3 ) ;
456+
457+ let result_view = result. to_listview ( ) ;
458+ assert_eq ! ( result_view. len( ) , 3 ) ;
459+ assert ! ( result_view. is_valid( 0 ) ) ;
460+ assert ! ( result_view. is_invalid( 1 ) ) ;
461+ assert ! ( result_view. is_valid( 2 ) ) ;
462+ }
414463}
0 commit comments