11// SPDX-License-Identifier: Apache-2.0
22// SPDX-FileCopyrightText: Copyright the Vortex contributors
33
4+ use num_traits:: ToPrimitive ;
45use vortex_buffer:: BitBufferMut ;
56use vortex_dtype:: IntegerPType ;
67use vortex_dtype:: Nullability ;
@@ -40,21 +41,44 @@ impl TakeKernel for ListVTable {
4041
4142 match_each_integer_ptype ! ( offsets. dtype( ) . as_ptype( ) , |O | {
4243 match_each_integer_ptype!( indices. ptype( ) , |I | {
43- _take:: <I , O >(
44- array,
45- offsets. as_slice:: <O >( ) ,
46- & indices,
47- array. validity_mask( ) ,
48- indices. validity_mask( ) ,
49- )
44+ let offsets_slice = offsets. as_slice:: <O >( ) ;
45+ let indices_slice: & [ I ] = indices. as_slice:: <I >( ) ;
46+
47+ // Calculate total count to determine appropriate accumulation type
48+ let total_count = indices_slice
49+ . iter( )
50+ . map( |idx| {
51+ let idx = idx. to_usize( ) . unwrap_or_else( || {
52+ vortex_panic!( "Failed to convert index to usize: {}" , idx)
53+ } ) ;
54+ ( offsets_slice[ idx + 1 ] - offsets_slice[ idx] )
55+ . to_usize( )
56+ . unwrap_or_else( || {
57+ vortex_panic!(
58+ "Failed to convert offset difference to usize: {}" ,
59+ offsets_slice[ idx + 1 ] - offsets_slice[ idx]
60+ )
61+ } )
62+ } )
63+ . sum:: <usize >( ) ;
64+
65+ match_smallest_offset_type!( total_count, |AccumType | {
66+ _take:: <I , O , AccumType >(
67+ array,
68+ offsets_slice,
69+ & indices,
70+ array. validity_mask( ) ,
71+ indices. validity_mask( ) ,
72+ )
73+ } )
5074 } )
5175 } )
5276 }
5377}
5478
5579register_kernel ! ( TakeKernelAdapter ( ListVTable ) . lift( ) ) ;
5680
57- fn _take < I : IntegerPType , O : IntegerPType > (
81+ fn _take < I : IntegerPType , O : IntegerPType , AccumType : IntegerPType > (
5882 array : & ListArray ,
5983 offsets : & [ O ] ,
6084 indices_array : & PrimitiveArray ,
@@ -64,7 +88,7 @@ fn _take<I: IntegerPType, O: IntegerPType>(
6488 let indices: & [ I ] = indices_array. as_slice :: < I > ( ) ;
6589
6690 if !indices_validity_mask. all_true ( ) || !data_validity. all_true ( ) {
67- return _take_nullable :: < I , O > (
91+ return _take_nullable :: < I , O , AccumType > (
6892 array,
6993 offsets,
7094 indices,
@@ -74,24 +98,13 @@ fn _take<I: IntegerPType, O: IntegerPType>(
7498 }
7599
76100 let mut new_offsets =
77- PrimitiveBuilder :: < u64 > :: with_capacity ( Nullability :: NonNullable , indices. len ( ) ) ;
101+ PrimitiveBuilder :: < AccumType > :: with_capacity ( Nullability :: NonNullable , indices. len ( ) ) ;
78102 let mut elements_to_take =
79103 PrimitiveBuilder :: with_capacity ( Nullability :: NonNullable , 2 * indices. len ( ) ) ;
80104
81- let mut current_offset = 0u64 ;
105+ let mut current_offset = AccumType :: zero ( ) ;
82106 new_offsets. append_zero ( ) ;
83107
84- // Total element count.
85- let total_count = indices
86- . iter ( )
87- . map ( |idx| {
88- let idx = idx
89- . to_usize ( )
90- . unwrap_or_else ( || vortex_panic ! ( "Failed to convert index to usize: {}" , idx) ) ;
91- ( offsets[ idx + 1 ] - offsets[ idx] ) . as_ ( ) as usize
92- } )
93- . sum :: < usize > ( ) ;
94-
95108 for & data_idx in indices {
96109 let data_idx = data_idx
97110 . to_usize ( )
@@ -113,7 +126,13 @@ fn _take<I: IntegerPType, O: IntegerPType>(
113126 for i in 0 ..additional {
114127 elements_to_take. append_value ( start + O :: from_usize ( i) . vortex_expect ( "i < additional" ) ) ;
115128 }
116- current_offset += ( stop - start) . as_ ( ) as u64 ;
129+ current_offset += AccumType :: from_usize ( ( stop - start) . to_usize ( ) . unwrap_or_else ( || {
130+ vortex_panic ! (
131+ "Failed to convert offset difference to usize: {}" ,
132+ stop - start
133+ )
134+ } ) )
135+ . vortex_expect ( "offset conversion" ) ;
117136 new_offsets. append_value ( current_offset) ;
118137 }
119138
@@ -133,15 +152,15 @@ fn _take<I: IntegerPType, O: IntegerPType>(
133152 . to_array ( ) )
134153}
135154
136- fn _take_nullable < I : IntegerPType , O : IntegerPType > (
155+ fn _take_nullable < I : IntegerPType , O : IntegerPType , AccumType : IntegerPType > (
137156 array : & ListArray ,
138157 offsets : & [ O ] ,
139158 indices : & [ I ] ,
140159 data_validity : Mask ,
141160 indices_validity : Mask ,
142161) -> VortexResult < ArrayRef > {
143162 let mut new_offsets =
144- PrimitiveBuilder :: < u64 > :: with_capacity ( Nullability :: NonNullable , indices. len ( ) ) ;
163+ PrimitiveBuilder :: < AccumType > :: with_capacity ( Nullability :: NonNullable , indices. len ( ) ) ;
145164
146165 // This will be the indices we push down to the child array to call `take` with.
147166 //
@@ -153,7 +172,7 @@ fn _take_nullable<I: IntegerPType, O: IntegerPType>(
153172 let mut elements_to_take =
154173 PrimitiveBuilder :: < O > :: with_capacity ( Nullability :: NonNullable , 2 * indices. len ( ) ) ;
155174
156- let mut current_offset = 0u64 ;
175+ let mut current_offset = AccumType :: zero ( ) ;
157176 new_offsets. append_zero ( ) ;
158177
159178 // Set all bits to invalid and selectively set which values are valid.
@@ -188,7 +207,13 @@ fn _take_nullable<I: IntegerPType, O: IntegerPType>(
188207 for i in 0 ..additional {
189208 elements_to_take. append_value ( start + O :: from_usize ( i) . vortex_expect ( "i < additional" ) ) ;
190209 }
191- current_offset += ( stop - start) . as_ ( ) as u64 ;
210+ current_offset += AccumType :: from_usize ( ( stop - start) . to_usize ( ) . unwrap_or_else ( || {
211+ vortex_panic ! (
212+ "Failed to convert offset difference to usize: {}" ,
213+ stop - start
214+ )
215+ } ) )
216+ . vortex_expect ( "offset conversion" ) ;
192217 new_offsets. append_value ( current_offset) ;
193218 new_validity. set ( idx) ;
194219 }
0 commit comments