@@ -5,9 +5,9 @@ use vortex_buffer::BitBufferMut;
55use vortex_dtype:: IntegerPType ;
66use vortex_dtype:: Nullability ;
77use vortex_dtype:: match_each_integer_ptype;
8+ use vortex_dtype:: match_smallest_offset_type;
89use vortex_error:: VortexExpect ;
910use vortex_error:: VortexResult ;
10- use vortex_error:: vortex_panic;
1111use vortex_mask:: Mask ;
1212
1313use crate :: Array ;
@@ -34,27 +34,33 @@ use crate::vtable::ValidityHelper;
3434/// that lists are stored contiguously and in-order (`offset[i+1] >= offset[i]`). Taking
3535/// non-contiguous indices would violate this requirement.
3636impl TakeKernel for ListVTable {
37+ #[ expect( clippy:: cognitive_complexity) ]
3738 fn take ( & self , array : & ListArray , indices : & dyn Array ) -> VortexResult < ArrayRef > {
3839 let indices = indices. to_primitive ( ) ;
3940 let offsets = array. offsets ( ) . to_primitive ( ) ;
41+ // This is an over-approximation of the total number of elements in the resulting array.
42+ let total_approx = array. elements ( ) . len ( ) * indices. len ( ) ;
4043
4144 match_each_integer_ptype ! ( offsets. dtype( ) . as_ptype( ) , |O | {
45+ let offsets_slice = offsets. as_slice:: <O >( ) ;
4246 match_each_integer_ptype!( indices. ptype( ) , |I | {
43- _take:: <I , O >(
44- array,
45- offsets. as_slice:: <O >( ) ,
46- & indices,
47- array. validity_mask( ) ,
48- indices. validity_mask( ) ,
49- )
47+ match_smallest_offset_type!( total_approx, |OutputOffsetType | {
48+ _take:: <I , O , OutputOffsetType >(
49+ array,
50+ offsets_slice,
51+ & indices,
52+ array. validity_mask( ) ,
53+ indices. validity_mask( ) ,
54+ )
55+ } )
5056 } )
5157 } )
5258 }
5359}
5460
5561register_kernel ! ( TakeKernelAdapter ( ListVTable ) . lift( ) ) ;
5662
57- fn _take < I : IntegerPType , O : IntegerPType > (
63+ fn _take < I : IntegerPType , O : IntegerPType , OutputOffsetType : IntegerPType > (
5864 array : & ListArray ,
5965 offsets : & [ O ] ,
6066 indices_array : & PrimitiveArray ,
@@ -64,7 +70,7 @@ fn _take<I: IntegerPType, O: IntegerPType>(
6470 let indices: & [ I ] = indices_array. as_slice :: < I > ( ) ;
6571
6672 if !indices_validity_mask. all_true ( ) || !data_validity. all_true ( ) {
67- return _take_nullable :: < I , O > (
73+ return _take_nullable :: < I , O , OutputOffsetType > (
6874 array,
6975 offsets,
7076 indices,
@@ -73,18 +79,18 @@ fn _take<I: IntegerPType, O: IntegerPType>(
7379 ) ;
7480 }
7581
76- let mut new_offsets =
77- PrimitiveBuilder :: < u64 > :: with_capacity ( Nullability :: NonNullable , indices. len ( ) ) ;
82+ let mut new_offsets = PrimitiveBuilder :: < OutputOffsetType > :: with_capacity (
83+ Nullability :: NonNullable ,
84+ indices. len ( ) ,
85+ ) ;
7886 let mut elements_to_take =
7987 PrimitiveBuilder :: with_capacity ( Nullability :: NonNullable , 2 * indices. len ( ) ) ;
8088
81- let mut current_offset = 0u64 ;
89+ let mut current_offset = OutputOffsetType :: zero ( ) ;
8290 new_offsets. append_zero ( ) ;
8391
8492 for & data_idx in indices {
85- let data_idx = data_idx
86- . to_usize ( )
87- . unwrap_or_else ( || vortex_panic ! ( "Failed to convert index to usize: {}" , data_idx) ) ;
93+ let data_idx: usize = data_idx. as_ ( ) ;
8894
8995 let start = offsets[ data_idx] ;
9096 let stop = offsets[ data_idx + 1 ] ;
@@ -94,15 +100,14 @@ fn _take<I: IntegerPType, O: IntegerPType>(
94100 // We could convert start and end to usize, but that would impose a potentially
95101 // harder constraint - now we don't care if they fit into usize as long as their
96102 // difference does.
97- let additional = ( stop - start) . to_usize ( ) . unwrap_or_else ( || {
98- vortex_panic ! ( "Failed to convert range length to usize: {}" , stop - start)
99- } ) ;
103+ let additional: usize = ( stop - start) . as_ ( ) ;
100104
101105 elements_to_take. reserve_exact ( additional) ;
102106 for i in 0 ..additional {
103107 elements_to_take. append_value ( start + O :: from_usize ( i) . vortex_expect ( "i < additional" ) ) ;
104108 }
105- current_offset += ( stop - start) . as_ ( ) as u64 ;
109+ current_offset +=
110+ OutputOffsetType :: from_usize ( ( stop - start) . as_ ( ) ) . vortex_expect ( "offset conversion" ) ;
106111 new_offsets. append_value ( current_offset) ;
107112 }
108113
@@ -122,15 +127,17 @@ fn _take<I: IntegerPType, O: IntegerPType>(
122127 . to_array ( ) )
123128}
124129
125- fn _take_nullable < I : IntegerPType , O : IntegerPType > (
130+ fn _take_nullable < I : IntegerPType , O : IntegerPType , OutputOffsetType : IntegerPType > (
126131 array : & ListArray ,
127132 offsets : & [ O ] ,
128133 indices : & [ I ] ,
129134 data_validity : Mask ,
130135 indices_validity : Mask ,
131136) -> VortexResult < ArrayRef > {
132- let mut new_offsets =
133- PrimitiveBuilder :: < u64 > :: with_capacity ( Nullability :: NonNullable , indices. len ( ) ) ;
137+ let mut new_offsets = PrimitiveBuilder :: < OutputOffsetType > :: with_capacity (
138+ Nullability :: NonNullable ,
139+ indices. len ( ) ,
140+ ) ;
134141
135142 // This will be the indices we push down to the child array to call `take` with.
136143 //
@@ -142,7 +149,7 @@ fn _take_nullable<I: IntegerPType, O: IntegerPType>(
142149 let mut elements_to_take =
143150 PrimitiveBuilder :: < O > :: with_capacity ( Nullability :: NonNullable , 2 * indices. len ( ) ) ;
144151
145- let mut current_offset = 0u64 ;
152+ let mut current_offset = OutputOffsetType :: zero ( ) ;
146153 new_offsets. append_zero ( ) ;
147154
148155 // Set all bits to invalid and selectively set which values are valid.
@@ -155,9 +162,7 @@ fn _take_nullable<I: IntegerPType, O: IntegerPType>(
155162 continue ;
156163 }
157164
158- let data_idx = data_idx
159- . to_usize ( )
160- . unwrap_or_else ( || vortex_panic ! ( "Failed to convert index to usize: {}" , data_idx) ) ;
165+ let data_idx: usize = data_idx. as_ ( ) ;
161166
162167 if !data_validity. value ( data_idx) {
163168 new_offsets. append_value ( current_offset) ;
@@ -169,15 +174,14 @@ fn _take_nullable<I: IntegerPType, O: IntegerPType>(
169174 let stop = offsets[ data_idx + 1 ] ;
170175
171176 // See the note it the `take` on the reasoning
172- let additional = ( stop - start) . to_usize ( ) . unwrap_or_else ( || {
173- vortex_panic ! ( "Failed to convert range length to usize: {}" , stop - start)
174- } ) ;
177+ let additional: usize = ( stop - start) . as_ ( ) ;
175178
176179 elements_to_take. reserve_exact ( additional) ;
177180 for i in 0 ..additional {
178181 elements_to_take. append_value ( start + O :: from_usize ( i) . vortex_expect ( "i < additional" ) ) ;
179182 }
180- current_offset += ( stop - start) . as_ ( ) as u64 ;
183+ current_offset +=
184+ OutputOffsetType :: from_usize ( ( stop - start) . as_ ( ) ) . vortex_expect ( "offset conversion" ) ;
181185 new_offsets. append_value ( current_offset) ;
182186 new_validity. set ( idx) ;
183187 }
@@ -201,6 +205,9 @@ mod test {
201205 use vortex_dtype:: DType ;
202206 use vortex_dtype:: Nullability ;
203207 use vortex_dtype:: PType :: I32 ;
208+ use vortex_dtype:: PType :: U8 ;
209+ use vortex_dtype:: PType :: U16 ;
210+ use vortex_dtype:: PType :: U32 ;
204211 use vortex_scalar:: Scalar ;
205212
206213 use crate :: Array ;
@@ -455,4 +462,34 @@ mod test {
455462 assert ! ( result_view. is_invalid( 1 ) ) ;
456463 assert ! ( result_view. is_valid( 2 ) ) ;
457464 }
465+
466+ #[ rstest]
467+ #[ case( 10 , U8 ) ]
468+ #[ case( 300 , U16 ) ]
469+ #[ case( 70000 , U32 ) ]
470+ fn test_output_offset_type_selection (
471+ #[ case] element_count : u32 ,
472+ #[ case] expected_ptype : vortex_dtype:: PType ,
473+ ) {
474+ let elements: Vec < _ > = ( 0 ..element_count) . collect ( ) ;
475+ let elements_array = PrimitiveArray :: from_iter ( elements) . to_array ( ) ;
476+
477+ let mut offsets = Vec :: with_capacity ( ( element_count + 1 ) as usize ) ;
478+ for idx in 0 ..=element_count {
479+ offsets. push ( idx as u64 ) ;
480+ }
481+ let offsets_array = PrimitiveArray :: from_iter ( offsets) . to_array ( ) ;
482+
483+ let list = ListArray :: try_new ( elements_array, offsets_array, Validity :: NonNullable )
484+ . unwrap ( )
485+ . to_array ( ) ;
486+
487+ let indices: Vec < u32 > = ( 0 ..element_count) . collect ( ) ;
488+ let result = take ( & list, & PrimitiveArray :: from_iter ( indices) . to_array ( ) ) . unwrap ( ) ;
489+
490+ assert_eq ! (
491+ result. to_listview( ) . offsets( ) . dtype( ) . as_ptype( ) ,
492+ expected_ptype
493+ ) ;
494+ }
458495}
0 commit comments