11// SPDX-License-Identifier: Apache-2.0
22// SPDX-FileCopyrightText: Copyright the Vortex contributors
33
4+ use num_traits:: AsPrimitive ;
45use vortex_buffer:: BitBufferMut ;
56use vortex_dtype:: IntegerPType ;
67use vortex_dtype:: Nullability ;
78use vortex_dtype:: match_each_integer_ptype;
9+ use vortex_dtype:: match_smallest_offset_type;
810use vortex_error:: VortexExpect ;
911use vortex_error:: VortexResult ;
10- use vortex_error:: vortex_panic;
1112use vortex_mask:: Mask ;
1213
1314use crate :: Array ;
@@ -40,21 +41,34 @@ impl TakeKernel for ListVTable {
4041
4142 match_each_integer_ptype ! ( offsets. dtype( ) . as_ptype( ) , |O | {
4243 match_each_integer_ptype!( indices. ptype( ) , |I | {
43- _take:: <I , O >(
44- array,
45- offsets. as_slice:: <O >( ) ,
46- & indices,
47- array. validity_mask( ) ,
48- indices. validity_mask( ) ,
49- )
44+ let offsets_slice = offsets. as_slice:: <O >( ) ;
45+ let indices_slice: & [ I ] = indices. as_slice:: <I >( ) ;
46+ let total_element_count = indices_slice
47+ . iter( )
48+ . map( |idx| {
49+ let idx: usize = idx. as_( ) ;
50+ let diff: usize = ( offsets_slice[ idx + 1 ] - offsets_slice[ idx] ) . as_( ) ;
51+ diff
52+ } )
53+ . sum:: <usize >( ) ;
54+
55+ match_smallest_offset_type!( total_element_count, |OutputOffsetType | {
56+ _take:: <I , O , OutputOffsetType >(
57+ array,
58+ offsets_slice,
59+ & indices,
60+ array. validity_mask( ) ,
61+ indices. validity_mask( ) ,
62+ )
63+ } )
5064 } )
5165 } )
5266 }
5367}
5468
5569register_kernel ! ( TakeKernelAdapter ( ListVTable ) . lift( ) ) ;
5670
57- fn _take < I : IntegerPType , O : IntegerPType > (
71+ fn _take < I : IntegerPType , O : IntegerPType , OutputOffsetType : IntegerPType > (
5872 array : & ListArray ,
5973 offsets : & [ O ] ,
6074 indices_array : & PrimitiveArray ,
@@ -64,7 +78,7 @@ fn _take<I: IntegerPType, O: IntegerPType>(
6478 let indices: & [ I ] = indices_array. as_slice :: < I > ( ) ;
6579
6680 if !indices_validity_mask. all_true ( ) || !data_validity. all_true ( ) {
67- return _take_nullable :: < I , O > (
81+ return _take_nullable :: < I , O , OutputOffsetType > (
6882 array,
6983 offsets,
7084 indices,
@@ -73,18 +87,18 @@ fn _take<I: IntegerPType, O: IntegerPType>(
7387 ) ;
7488 }
7589
76- let mut new_offsets =
77- PrimitiveBuilder :: < u64 > :: with_capacity ( Nullability :: NonNullable , indices. len ( ) ) ;
90+ let mut new_offsets = PrimitiveBuilder :: < OutputOffsetType > :: with_capacity (
91+ Nullability :: NonNullable ,
92+ indices. len ( ) ,
93+ ) ;
7894 let mut elements_to_take =
7995 PrimitiveBuilder :: with_capacity ( Nullability :: NonNullable , 2 * indices. len ( ) ) ;
8096
81- let mut current_offset = 0u64 ;
97+ let mut current_offset = OutputOffsetType :: zero ( ) ;
8298 new_offsets. append_zero ( ) ;
8399
84100 for & data_idx in indices {
85- let data_idx = data_idx
86- . to_usize ( )
87- . unwrap_or_else ( || vortex_panic ! ( "Failed to convert index to usize: {}" , data_idx) ) ;
101+ let data_idx: usize = data_idx. as_ ( ) ;
88102
89103 let start = offsets[ data_idx] ;
90104 let stop = offsets[ data_idx + 1 ] ;
@@ -94,15 +108,14 @@ fn _take<I: IntegerPType, O: IntegerPType>(
94108 // We could convert start and end to usize, but that would impose a potentially
95109 // harder constraint - now we don't care if they fit into usize as long as their
96110 // difference does.
97- let additional = ( stop - start) . to_usize ( ) . unwrap_or_else ( || {
98- vortex_panic ! ( "Failed to convert range length to usize: {}" , stop - start)
99- } ) ;
111+ let additional: usize = ( stop - start) . as_ ( ) ;
100112
101113 elements_to_take. reserve_exact ( additional) ;
102114 for i in 0 ..additional {
103115 elements_to_take. append_value ( start + O :: from_usize ( i) . vortex_expect ( "i < additional" ) ) ;
104116 }
105- current_offset += ( stop - start) . as_ ( ) as u64 ;
117+ current_offset +=
118+ OutputOffsetType :: from_usize ( ( stop - start) . as_ ( ) ) . vortex_expect ( "offset conversion" ) ;
106119 new_offsets. append_value ( current_offset) ;
107120 }
108121
@@ -122,15 +135,17 @@ fn _take<I: IntegerPType, O: IntegerPType>(
122135 . to_array ( ) )
123136}
124137
125- fn _take_nullable < I : IntegerPType , O : IntegerPType > (
138+ fn _take_nullable < I : IntegerPType , O : IntegerPType , OutputOffsetType : IntegerPType > (
126139 array : & ListArray ,
127140 offsets : & [ O ] ,
128141 indices : & [ I ] ,
129142 data_validity : Mask ,
130143 indices_validity : Mask ,
131144) -> VortexResult < ArrayRef > {
132- let mut new_offsets =
133- PrimitiveBuilder :: < u64 > :: with_capacity ( Nullability :: NonNullable , indices. len ( ) ) ;
145+ let mut new_offsets = PrimitiveBuilder :: < OutputOffsetType > :: with_capacity (
146+ Nullability :: NonNullable ,
147+ indices. len ( ) ,
148+ ) ;
134149
135150 // This will be the indices we push down to the child array to call `take` with.
136151 //
@@ -142,7 +157,7 @@ fn _take_nullable<I: IntegerPType, O: IntegerPType>(
142157 let mut elements_to_take =
143158 PrimitiveBuilder :: < O > :: with_capacity ( Nullability :: NonNullable , 2 * indices. len ( ) ) ;
144159
145- let mut current_offset = 0u64 ;
160+ let mut current_offset = OutputOffsetType :: zero ( ) ;
146161 new_offsets. append_zero ( ) ;
147162
148163 // Set all bits to invalid and selectively set which values are valid.
@@ -155,9 +170,7 @@ fn _take_nullable<I: IntegerPType, O: IntegerPType>(
155170 continue ;
156171 }
157172
158- let data_idx = data_idx
159- . to_usize ( )
160- . unwrap_or_else ( || vortex_panic ! ( "Failed to convert index to usize: {}" , data_idx) ) ;
173+ let data_idx: usize = data_idx. as_ ( ) ;
161174
162175 if !data_validity. value ( data_idx) {
163176 new_offsets. append_value ( current_offset) ;
@@ -169,15 +182,14 @@ fn _take_nullable<I: IntegerPType, O: IntegerPType>(
169182 let stop = offsets[ data_idx + 1 ] ;
170183
171184 // See the note it the `take` on the reasoning
172- let additional = ( stop - start) . to_usize ( ) . unwrap_or_else ( || {
173- vortex_panic ! ( "Failed to convert range length to usize: {}" , stop - start)
174- } ) ;
185+ let additional: usize = ( stop - start) . as_ ( ) ;
175186
176187 elements_to_take. reserve_exact ( additional) ;
177188 for i in 0 ..additional {
178189 elements_to_take. append_value ( start + O :: from_usize ( i) . vortex_expect ( "i < additional" ) ) ;
179190 }
180- current_offset += ( stop - start) . as_ ( ) as u64 ;
191+ current_offset +=
192+ OutputOffsetType :: from_usize ( ( stop - start) . as_ ( ) ) . vortex_expect ( "offset conversion" ) ;
181193 new_offsets. append_value ( current_offset) ;
182194 new_validity. set ( idx) ;
183195 }
@@ -200,7 +212,7 @@ mod test {
200212 use vortex_buffer:: buffer;
201213 use vortex_dtype:: DType ;
202214 use vortex_dtype:: Nullability ;
203- use vortex_dtype:: PType :: I32 ;
215+ use vortex_dtype:: PType :: { I32 , U8 , U16 , U32 } ;
204216 use vortex_scalar:: Scalar ;
205217
206218 use crate :: Array ;
@@ -455,4 +467,34 @@ mod test {
455467 assert ! ( result_view. is_invalid( 1 ) ) ;
456468 assert ! ( result_view. is_valid( 2 ) ) ;
457469 }
470+
471+ #[ rstest]
472+ #[ case( 10 , U8 ) ]
473+ #[ case( 300 , U16 ) ]
474+ #[ case( 70000 , U32 ) ]
475+ fn test_output_offset_type_selection (
476+ #[ case] element_count : usize ,
477+ #[ case] expected_ptype : vortex_dtype:: PType ,
478+ ) {
479+ let elements: Vec < i32 > = ( 0 ..element_count as i32 ) . collect ( ) ;
480+ let elements_array = PrimitiveArray :: from_iter ( elements) . to_array ( ) ;
481+
482+ let mut offsets = Vec :: with_capacity ( element_count + 1 ) ;
483+ for idx in 0 ..=element_count {
484+ offsets. push ( idx as u64 ) ;
485+ }
486+ let offsets_array = PrimitiveArray :: from_iter ( offsets) . to_array ( ) ;
487+
488+ let list = ListArray :: try_new ( elements_array, offsets_array, Validity :: NonNullable )
489+ . unwrap ( )
490+ . to_array ( ) ;
491+
492+ let indices: Vec < u32 > = ( 0 ..element_count as u32 ) . collect ( ) ;
493+ let result = take ( & list, & PrimitiveArray :: from_iter ( indices) . to_array ( ) ) . unwrap ( ) ;
494+
495+ assert_eq ! (
496+ result. to_listview( ) . offsets( ) . dtype( ) . as_ptype( ) ,
497+ expected_ptype
498+ ) ;
499+ }
458500}
0 commit comments