@@ -17,37 +17,98 @@ impl TakeKernel for VarBinVTable {
1717 let offsets = array. offsets ( ) . to_primitive ( ) ;
1818 let data = array. bytes ( ) ;
1919 let indices = indices. to_primitive ( ) ;
20- match_each_integer_ptype ! ( offsets. ptype( ) , |O | {
21- match_each_integer_ptype!( indices. ptype( ) , |I | {
22- Ok ( take(
23- array
24- . dtype( )
25- . clone( )
26- . union_nullability( indices. dtype( ) . nullability( ) ) ,
27- offsets. as_slice:: <O >( ) ,
20+ let dtype = array
21+ . dtype ( )
22+ . clone ( )
23+ . union_nullability ( indices. dtype ( ) . nullability ( ) ) ;
24+ let array = match_each_integer_ptype ! ( indices. ptype( ) , |I | {
25+ // On take, offsets get widened to either 32- or 64-bit based on the original type,
26+ // to avoid overflow issues.
27+ match offsets. ptype( ) {
28+ PType :: U8 => take:: <I , u8 , u32 >(
29+ dtype,
30+ offsets. as_slice:: <u8 >( ) ,
2831 data. as_slice( ) ,
2932 indices. as_slice:: <I >( ) ,
3033 array. validity_mask( ) ,
3134 indices. validity_mask( ) ,
32- ) ?
33- . into_array( ) )
34- } )
35- } )
35+ ) ,
36+ PType :: U16 => take:: <I , u16 , u32 >(
37+ dtype,
38+ offsets. as_slice:: <u16 >( ) ,
39+ data. as_slice( ) ,
40+ indices. as_slice:: <I >( ) ,
41+ array. validity_mask( ) ,
42+ indices. validity_mask( ) ,
43+ ) ,
44+ PType :: U32 => take:: <I , u32 , u32 >(
45+ dtype,
46+ offsets. as_slice:: <u32 >( ) ,
47+ data. as_slice( ) ,
48+ indices. as_slice:: <I >( ) ,
49+ array. validity_mask( ) ,
50+ indices. validity_mask( ) ,
51+ ) ,
52+ PType :: U64 => take:: <I , u64 , u64 >(
53+ dtype,
54+ offsets. as_slice:: <u64 >( ) ,
55+ data. as_slice( ) ,
56+ indices. as_slice:: <I >( ) ,
57+ array. validity_mask( ) ,
58+ indices. validity_mask( ) ,
59+ ) ,
60+ PType :: I8 => take:: <I , i8 , i32 >(
61+ dtype,
62+ offsets. as_slice:: <i8 >( ) ,
63+ data. as_slice( ) ,
64+ indices. as_slice:: <I >( ) ,
65+ array. validity_mask( ) ,
66+ indices. validity_mask( ) ,
67+ ) ,
68+ PType :: I16 => take:: <I , i16 , i32 >(
69+ dtype,
70+ offsets. as_slice:: <i16 >( ) ,
71+ data. as_slice( ) ,
72+ indices. as_slice:: <I >( ) ,
73+ array. validity_mask( ) ,
74+ indices. validity_mask( ) ,
75+ ) ,
76+ PType :: I32 => take:: <I , i32 , i32 >(
77+ dtype,
78+ offsets. as_slice:: <i32 >( ) ,
79+ data. as_slice( ) ,
80+ indices. as_slice:: <I >( ) ,
81+ array. validity_mask( ) ,
82+ indices. validity_mask( ) ,
83+ ) ,
84+ PType :: I64 => take:: <I , i64 , i64 >(
85+ dtype,
86+ offsets. as_slice:: <i64 >( ) ,
87+ data. as_slice( ) ,
88+ indices. as_slice:: <I >( ) ,
89+ array. validity_mask( ) ,
90+ indices. validity_mask( ) ,
91+ ) ,
92+ _ => unreachable!( "invalid PType for offsets" ) ,
93+ }
94+ } ) ;
95+
96+ Ok ( array?. into_array ( ) )
3697 }
3798}
3899
39100register_kernel ! ( TakeKernelAdapter ( VarBinVTable ) . lift( ) ) ;
40101
41- fn take < I : IntegerPType , O : IntegerPType > (
102+ fn take < Index : IntegerPType , Offset : IntegerPType , NewOffset : IntegerPType > (
42103 dtype : DType ,
43- offsets : & [ O ] ,
104+ offsets : & [ Offset ] ,
44105 data : & [ u8 ] ,
45- indices : & [ I ] ,
106+ indices : & [ Index ] ,
46107 validity_mask : Mask ,
47108 indices_validity_mask : Mask ,
48109) -> VortexResult < VarBinArray > {
49110 if !validity_mask. all_true ( ) || !indices_validity_mask. all_true ( ) {
50- return Ok ( take_nullable (
111+ return Ok ( take_nullable :: < Index , Offset , NewOffset > (
51112 dtype,
52113 offsets,
53114 data,
@@ -57,25 +118,22 @@ fn take<I: IntegerPType, O: IntegerPType>(
57118 ) ) ;
58119 }
59120
60- let mut new_offsets = BufferMut :: with_capacity ( indices. len ( ) + 1 ) ;
61- new_offsets. push ( O :: zero ( ) ) ;
62- let mut current_offset = O :: zero ( ) ;
121+ let mut new_offsets = BufferMut :: < NewOffset > :: with_capacity ( indices. len ( ) + 1 ) ;
122+ new_offsets. push ( NewOffset :: zero ( ) ) ;
123+ let mut current_offset = NewOffset :: zero ( ) ;
63124
64125 for & idx in indices {
65126 let idx = idx
66127 . to_usize ( )
67128 . unwrap_or_else ( || vortex_panic ! ( "Failed to convert index to usize: {}" , idx) ) ;
68129 let start = offsets[ idx] ;
69130 let stop = offsets[ idx + 1 ] ;
70- current_offset += stop - start;
131+
132+ current_offset += NewOffset :: from ( stop - start) . vortex_expect ( "offset type overflow" ) ;
71133 new_offsets. push ( current_offset) ;
72134 }
73135
74- let mut new_data = ByteBufferMut :: with_capacity (
75- current_offset
76- . to_usize ( )
77- . vortex_expect ( "Failed to cast max offset to usize" ) ,
78- ) ;
136+ let mut new_data = ByteBufferMut :: with_capacity ( current_offset. as_ ( ) ) ;
79137
80138 for idx in indices {
81139 let idx = idx
@@ -104,17 +162,17 @@ fn take<I: IntegerPType, O: IntegerPType>(
104162 }
105163}
106164
107- fn take_nullable < I : IntegerPType , O : IntegerPType > (
165+ fn take_nullable < Index : IntegerPType , Offset : IntegerPType , NewOffset : IntegerPType > (
108166 dtype : DType ,
109- offsets : & [ O ] ,
167+ offsets : & [ Offset ] ,
110168 data : & [ u8 ] ,
111- indices : & [ I ] ,
169+ indices : & [ Index ] ,
112170 data_validity : Mask ,
113171 indices_validity : Mask ,
114172) -> VarBinArray {
115- let mut new_offsets = BufferMut :: with_capacity ( indices. len ( ) + 1 ) ;
116- new_offsets. push ( O :: zero ( ) ) ;
117- let mut current_offset = O :: zero ( ) ;
173+ let mut new_offsets = BufferMut :: < NewOffset > :: with_capacity ( indices. len ( ) + 1 ) ;
174+ new_offsets. push ( NewOffset :: zero ( ) ) ;
175+ let mut current_offset = NewOffset :: zero ( ) ;
118176
119177 let mut validity_buffer = BitBufferMut :: with_capacity ( indices. len ( ) ) ;
120178
@@ -135,7 +193,7 @@ fn take_nullable<I: IntegerPType, O: IntegerPType>(
135193 validity_buffer. append ( true ) ;
136194 let start = offsets[ data_idx_usize] ;
137195 let stop = offsets[ data_idx_usize + 1 ] ;
138- current_offset += stop - start;
196+ current_offset += NewOffset :: from ( stop - start) . vortex_expect ( "offset type overflow" ) ;
139197 new_offsets. push ( current_offset) ;
140198 valid_indices. push ( data_idx_usize) ;
141199 } else {
@@ -144,11 +202,7 @@ fn take_nullable<I: IntegerPType, O: IntegerPType>(
144202 }
145203 }
146204
147- let mut new_data = ByteBufferMut :: with_capacity (
148- current_offset
149- . to_usize ( )
150- . vortex_expect ( "Failed to cast max offset to usize" ) ,
151- ) ;
205+ let mut new_data = ByteBufferMut :: with_capacity ( current_offset. as_ ( ) ) ;
152206
153207 // Second pass: copy data for valid indices only
154208 for data_idx in valid_indices {
@@ -178,12 +232,14 @@ fn take_nullable<I: IntegerPType, O: IntegerPType>(
178232#[ cfg( test) ]
179233mod tests {
180234 use rstest:: rstest;
235+ use vortex_buffer:: { ByteBuffer , buffer} ;
181236 use vortex_dtype:: { DType , Nullability } ;
182237
183- use crate :: Array ;
184- use crate :: arrays:: { PrimitiveArray , VarBinArray } ;
238+ use crate :: arrays:: { PrimitiveArray , VarBinArray , VarBinVTable } ;
185239 use crate :: compute:: conformance:: take:: test_take_conformance;
186240 use crate :: compute:: take;
241+ use crate :: validity:: Validity ;
242+ use crate :: { Array , IntoArray } ;
187243
188244 #[ test]
189245 fn test_null_take ( ) {
@@ -221,4 +277,27 @@ mod tests {
221277 fn test_take_varbin_conformance ( #[ case] array : VarBinArray ) {
222278 test_take_conformance ( array. as_ref ( ) ) ;
223279 }
280+
281+ #[ test]
282+ fn test_take_overflow ( ) {
283+ let scream = std:: iter:: once ( "a" ) . cycle ( ) . take ( 128 ) . collect :: < String > ( ) ;
284+ let bytes = ByteBuffer :: copy_from ( scream. as_bytes ( ) ) ;
285+ let offsets = buffer ! [ 0u8 , 128u8 ] . into_array ( ) ;
286+
287+ let array = VarBinArray :: new (
288+ offsets,
289+ bytes,
290+ DType :: Utf8 ( Nullability :: NonNullable ) ,
291+ Validity :: NonNullable ,
292+ ) ;
293+
294+ let indices = buffer ! [ 0u32 , 0u32 , 0u32 ] . into_array ( ) ;
295+ let taken = take ( array. as_ref ( ) , indices. as_ref ( ) ) . unwrap ( ) ;
296+
297+ let taken_str = taken. as_ :: < VarBinVTable > ( ) ;
298+ assert_eq ! ( taken_str. len( ) , 3 ) ;
299+ assert_eq ! ( taken_str. bytes_at( 0 ) . as_bytes( ) , scream. as_bytes( ) ) ;
300+ assert_eq ! ( taken_str. bytes_at( 1 ) . as_bytes( ) , scream. as_bytes( ) ) ;
301+ assert_eq ! ( taken_str. bytes_at( 2 ) . as_bytes( ) , scream. as_bytes( ) ) ;
302+ }
224303}
0 commit comments