@@ -43,18 +43,17 @@ impl TakeKernel for ListVTable {
4343 match_each_integer_ptype!( indices. ptype( ) , |I | {
4444 let offsets_slice = offsets. as_slice:: <O >( ) ;
4545 let indices_slice: & [ I ] = indices. as_slice:: <I >( ) ;
46- let approx_total_count = indices_slice
46+ let total_element_count = indices_slice
4747 . iter( )
4848 . map( |idx| {
4949 let idx: usize = idx. as_( ) ;
50- let length : usize = ( offsets_slice[ idx + 1 ] - offsets_slice[ idx] ) . as_( ) ;
51- length
50+ let diff : usize = ( offsets_slice[ idx + 1 ] - offsets_slice[ idx] ) . as_( ) ;
51+ diff
5252 } )
53- . max( )
54- . unwrap_or( 0 ) ;
53+ . sum:: <usize >( ) ;
5554
56- match_smallest_offset_type!( approx_total_count , |AccumType | {
57- _take:: <I , O , AccumType >(
55+ match_smallest_offset_type!( total_element_count , |OutputOffsetType | {
56+ _take:: <I , O , OutputOffsetType >(
5857 array,
5958 offsets_slice,
6059 & indices,
@@ -69,7 +68,7 @@ impl TakeKernel for ListVTable {
6968
7069register_kernel ! ( TakeKernelAdapter ( ListVTable ) . lift( ) ) ;
7170
72- fn _take < I : IntegerPType , O : IntegerPType , AccumType : IntegerPType > (
71+ fn _take < I : IntegerPType , O : IntegerPType , OutputOffsetType : IntegerPType > (
7372 array : & ListArray ,
7473 offsets : & [ O ] ,
7574 indices_array : & PrimitiveArray ,
@@ -79,7 +78,7 @@ fn _take<I: IntegerPType, O: IntegerPType, AccumType: IntegerPType>(
7978 let indices: & [ I ] = indices_array. as_slice :: < I > ( ) ;
8079
8180 if !indices_validity_mask. all_true ( ) || !data_validity. all_true ( ) {
82- return _take_nullable :: < I , O , AccumType > (
81+ return _take_nullable :: < I , O , OutputOffsetType > (
8382 array,
8483 offsets,
8584 indices,
@@ -88,12 +87,14 @@ fn _take<I: IntegerPType, O: IntegerPType, AccumType: IntegerPType>(
8887 ) ;
8988 }
9089
91- let mut new_offsets =
92- PrimitiveBuilder :: < AccumType > :: with_capacity ( Nullability :: NonNullable , indices. len ( ) ) ;
90+ let mut new_offsets = PrimitiveBuilder :: < OutputOffsetType > :: with_capacity (
91+ Nullability :: NonNullable ,
92+ indices. len ( ) ,
93+ ) ;
9394 let mut elements_to_take =
9495 PrimitiveBuilder :: with_capacity ( Nullability :: NonNullable , 2 * indices. len ( ) ) ;
9596
96- let mut current_offset = AccumType :: zero ( ) ;
97+ let mut current_offset = OutputOffsetType :: zero ( ) ;
9798 new_offsets. append_zero ( ) ;
9899
99100 for & data_idx in indices {
@@ -114,7 +115,7 @@ fn _take<I: IntegerPType, O: IntegerPType, AccumType: IntegerPType>(
114115 elements_to_take. append_value ( start + O :: from_usize ( i) . vortex_expect ( "i < additional" ) ) ;
115116 }
116117 current_offset +=
117- AccumType :: from_usize ( ( stop - start) . as_ ( ) ) . vortex_expect ( "offset conversion" ) ;
118+ OutputOffsetType :: from_usize ( ( stop - start) . as_ ( ) ) . vortex_expect ( "offset conversion" ) ;
118119 new_offsets. append_value ( current_offset) ;
119120 }
120121
@@ -134,15 +135,17 @@ fn _take<I: IntegerPType, O: IntegerPType, AccumType: IntegerPType>(
134135 . to_array ( ) )
135136}
136137
137- fn _take_nullable < I : IntegerPType , O : IntegerPType , AccumType : IntegerPType > (
138+ fn _take_nullable < I : IntegerPType , O : IntegerPType , OutputOffsetType : IntegerPType > (
138139 array : & ListArray ,
139140 offsets : & [ O ] ,
140141 indices : & [ I ] ,
141142 data_validity : Mask ,
142143 indices_validity : Mask ,
143144) -> VortexResult < ArrayRef > {
144- let mut new_offsets =
145- PrimitiveBuilder :: < AccumType > :: with_capacity ( Nullability :: NonNullable , indices. len ( ) ) ;
145+ let mut new_offsets = PrimitiveBuilder :: < OutputOffsetType > :: with_capacity (
146+ Nullability :: NonNullable ,
147+ indices. len ( ) ,
148+ ) ;
146149
147150 // This will be the indices we push down to the child array to call `take` with.
148151 //
@@ -154,7 +157,7 @@ fn _take_nullable<I: IntegerPType, O: IntegerPType, AccumType: IntegerPType>(
154157 let mut elements_to_take =
155158 PrimitiveBuilder :: < O > :: with_capacity ( Nullability :: NonNullable , 2 * indices. len ( ) ) ;
156159
157- let mut current_offset = AccumType :: zero ( ) ;
160+ let mut current_offset = OutputOffsetType :: zero ( ) ;
158161 new_offsets. append_zero ( ) ;
159162
160163 // Set all bits to invalid and selectively set which values are valid.
@@ -186,7 +189,7 @@ fn _take_nullable<I: IntegerPType, O: IntegerPType, AccumType: IntegerPType>(
186189 elements_to_take. append_value ( start + O :: from_usize ( i) . vortex_expect ( "i < additional" ) ) ;
187190 }
188191 current_offset +=
189- AccumType :: from_usize ( ( stop - start) . as_ ( ) ) . vortex_expect ( "offset conversion" ) ;
192+ OutputOffsetType :: from_usize ( ( stop - start) . as_ ( ) ) . vortex_expect ( "offset conversion" ) ;
190193 new_offsets. append_value ( current_offset) ;
191194 new_validity. set ( idx) ;
192195 }
@@ -209,7 +212,7 @@ mod test {
209212 use vortex_buffer:: buffer;
210213 use vortex_dtype:: DType ;
211214 use vortex_dtype:: Nullability ;
212- use vortex_dtype:: PType :: I32 ;
215+ use vortex_dtype:: PType :: { I32 , U8 , U16 , U32 } ;
213216 use vortex_scalar:: Scalar ;
214217
215218 use crate :: Array ;
@@ -464,4 +467,34 @@ mod test {
464467 assert ! ( result_view. is_invalid( 1 ) ) ;
465468 assert ! ( result_view. is_valid( 2 ) ) ;
466469 }
470+
471+ #[ rstest]
472+ #[ case( 10 , U8 ) ]
473+ #[ case( 300 , U16 ) ]
474+ #[ case( 70000 , U32 ) ]
475+ fn test_output_offset_type_selection (
476+ #[ case] element_count : usize ,
477+ #[ case] expected_ptype : vortex_dtype:: PType ,
478+ ) {
479+ let elements: Vec < i32 > = ( 0 ..element_count as i32 ) . collect ( ) ;
480+ let elements_array = PrimitiveArray :: from_iter ( elements) . to_array ( ) ;
481+
482+ let mut offsets = Vec :: with_capacity ( element_count + 1 ) ;
483+ for idx in 0 ..element_count {
484+ offsets. push ( idx as u64 ) ;
485+ }
486+ let offsets_array = PrimitiveArray :: from_iter ( offsets) . to_array ( ) ;
487+
488+ let list = ListArray :: try_new ( elements_array, offsets_array, Validity :: NonNullable )
489+ . unwrap ( )
490+ . to_array ( ) ;
491+
492+ let indices: Vec < u32 > = ( 0 ..element_count as u32 ) . collect ( ) ;
493+ let result = take ( & list, & PrimitiveArray :: from_iter ( indices) . to_array ( ) ) . unwrap ( ) ;
494+
495+ assert_eq ! (
496+ result. to_listview( ) . offsets( ) . dtype( ) . as_ptype( ) ,
497+ expected_ptype
498+ ) ;
499+ }
467500}
0 commit comments