@@ -17,37 +17,118 @@ impl TakeKernel for VarBinVTable {
1717 let offsets = array. offsets ( ) . to_primitive ( ) ;
1818 let data = array. bytes ( ) ;
1919 let indices = indices. to_primitive ( ) ;
20- match_each_integer_ptype ! ( offsets. ptype( ) , |O | {
21- match_each_integer_ptype!( indices. ptype( ) , |I | {
22- Ok ( take(
20+ let array = match_each_integer_ptype ! ( indices. ptype( ) , |I | {
21+ // On take, offsets get widened to either 32- or 64-bit based on the original type,
22+ // to avoid overflow issues.
23+ match offsets. ptype( ) {
24+ PType :: U8 => take:: <I , u8 , u32 >(
2325 array
2426 . dtype( )
2527 . clone( )
2628 . union_nullability( indices. dtype( ) . nullability( ) ) ,
27- offsets. as_slice:: <O >( ) ,
29+ offsets. as_slice:: <u8 >( ) ,
2830 data. as_slice( ) ,
2931 indices. as_slice:: <I >( ) ,
3032 array. validity_mask( ) ,
3133 indices. validity_mask( ) ,
32- ) ?
33- . into_array( ) )
34- } )
35- } )
34+ ) ,
35+ PType :: U16 => take:: <I , u16 , u32 >(
36+ array
37+ . dtype( )
38+ . clone( )
39+ . union_nullability( indices. dtype( ) . nullability( ) ) ,
40+ offsets. as_slice:: <u16 >( ) ,
41+ data. as_slice( ) ,
42+ indices. as_slice:: <I >( ) ,
43+ array. validity_mask( ) ,
44+ indices. validity_mask( ) ,
45+ ) ,
46+ PType :: U32 => take:: <I , u32 , u32 >(
47+ array
48+ . dtype( )
49+ . clone( )
50+ . union_nullability( indices. dtype( ) . nullability( ) ) ,
51+ offsets. as_slice:: <u32 >( ) ,
52+ data. as_slice( ) ,
53+ indices. as_slice:: <I >( ) ,
54+ array. validity_mask( ) ,
55+ indices. validity_mask( ) ,
56+ ) ,
57+ PType :: U64 => take:: <I , u64 , u64 >(
58+ array
59+ . dtype( )
60+ . clone( )
61+ . union_nullability( indices. dtype( ) . nullability( ) ) ,
62+ offsets. as_slice:: <u64 >( ) ,
63+ data. as_slice( ) ,
64+ indices. as_slice:: <I >( ) ,
65+ array. validity_mask( ) ,
66+ indices. validity_mask( ) ,
67+ ) ,
68+ PType :: I8 => take:: <I , i8 , i32 >(
69+ array
70+ . dtype( )
71+ . clone( )
72+ . union_nullability( indices. dtype( ) . nullability( ) ) ,
73+ offsets. as_slice:: <i8 >( ) ,
74+ data. as_slice( ) ,
75+ indices. as_slice:: <I >( ) ,
76+ array. validity_mask( ) ,
77+ indices. validity_mask( ) ,
78+ ) ,
79+ PType :: I16 => take:: <I , i16 , i32 >(
80+ array
81+ . dtype( )
82+ . clone( )
83+ . union_nullability( indices. dtype( ) . nullability( ) ) ,
84+ offsets. as_slice:: <i16 >( ) ,
85+ data. as_slice( ) ,
86+ indices. as_slice:: <I >( ) ,
87+ array. validity_mask( ) ,
88+ indices. validity_mask( ) ,
89+ ) ,
90+ PType :: I32 => take:: <I , i32 , i32 >(
91+ array
92+ . dtype( )
93+ . clone( )
94+ . union_nullability( indices. dtype( ) . nullability( ) ) ,
95+ offsets. as_slice:: <i32 >( ) ,
96+ data. as_slice( ) ,
97+ indices. as_slice:: <I >( ) ,
98+ array. validity_mask( ) ,
99+ indices. validity_mask( ) ,
100+ ) ,
101+ PType :: I64 => take:: <I , i64 , i64 >(
102+ array
103+ . dtype( )
104+ . clone( )
105+ . union_nullability( indices. dtype( ) . nullability( ) ) ,
106+ offsets. as_slice:: <i64 >( ) ,
107+ data. as_slice( ) ,
108+ indices. as_slice:: <I >( ) ,
109+ array. validity_mask( ) ,
110+ indices. validity_mask( ) ,
111+ ) ,
112+ _ => unreachable!( "invalid PType for offsets" ) ,
113+ }
114+ } ) ;
115+
116+ Ok ( array?. into_array ( ) )
36117 }
37118}
38119
39120register_kernel ! ( TakeKernelAdapter ( VarBinVTable ) . lift( ) ) ;
40121
41- fn take < I : IntegerPType , O : IntegerPType > (
122+ fn take < Index : IntegerPType , Offset : IntegerPType , NewOffset : IntegerPType > (
42123 dtype : DType ,
43- offsets : & [ O ] ,
124+ offsets : & [ Offset ] ,
44125 data : & [ u8 ] ,
45- indices : & [ I ] ,
126+ indices : & [ Index ] ,
46127 validity_mask : Mask ,
47128 indices_validity_mask : Mask ,
48129) -> VortexResult < VarBinArray > {
49130 if !validity_mask. all_true ( ) || !indices_validity_mask. all_true ( ) {
50- return Ok ( take_nullable (
131+ return Ok ( take_nullable :: < Index , Offset , NewOffset > (
51132 dtype,
52133 offsets,
53134 data,
@@ -57,25 +138,22 @@ fn take<I: IntegerPType, O: IntegerPType>(
57138 ) ) ;
58139 }
59140
60- let mut new_offsets = BufferMut :: with_capacity ( indices. len ( ) + 1 ) ;
61- new_offsets. push ( O :: zero ( ) ) ;
62- let mut current_offset = O :: zero ( ) ;
141+ let mut new_offsets = BufferMut :: < NewOffset > :: with_capacity ( indices. len ( ) + 1 ) ;
142+ new_offsets. push ( NewOffset :: zero ( ) ) ;
143+ let mut current_offset = NewOffset :: zero ( ) ;
63144
64145 for & idx in indices {
65146 let idx = idx
66147 . to_usize ( )
67148 . unwrap_or_else ( || vortex_panic ! ( "Failed to convert index to usize: {}" , idx) ) ;
68149 let start = offsets[ idx] ;
69150 let stop = offsets[ idx + 1 ] ;
70- current_offset += stop - start;
151+
152+ current_offset += NewOffset :: from ( stop - start) . vortex_expect ( "offset type overflow" ) ;
71153 new_offsets. push ( current_offset) ;
72154 }
73155
74- let mut new_data = ByteBufferMut :: with_capacity (
75- current_offset
76- . to_usize ( )
77- . vortex_expect ( "Failed to cast max offset to usize" ) ,
78- ) ;
156+ let mut new_data = ByteBufferMut :: with_capacity ( current_offset. as_ ( ) ) ;
79157
80158 for idx in indices {
81159 let idx = idx
@@ -104,17 +182,17 @@ fn take<I: IntegerPType, O: IntegerPType>(
104182 }
105183}
106184
107- fn take_nullable < I : IntegerPType , O : IntegerPType > (
185+ fn take_nullable < Index : IntegerPType , Offset : IntegerPType , NewOffset : IntegerPType > (
108186 dtype : DType ,
109- offsets : & [ O ] ,
187+ offsets : & [ Offset ] ,
110188 data : & [ u8 ] ,
111- indices : & [ I ] ,
189+ indices : & [ Index ] ,
112190 data_validity : Mask ,
113191 indices_validity : Mask ,
114192) -> VarBinArray {
115- let mut new_offsets = BufferMut :: with_capacity ( indices. len ( ) + 1 ) ;
116- new_offsets. push ( O :: zero ( ) ) ;
117- let mut current_offset = O :: zero ( ) ;
193+ let mut new_offsets = BufferMut :: < NewOffset > :: with_capacity ( indices. len ( ) + 1 ) ;
194+ new_offsets. push ( NewOffset :: zero ( ) ) ;
195+ let mut current_offset = NewOffset :: zero ( ) ;
118196
119197 let mut validity_buffer = BitBufferMut :: with_capacity ( indices. len ( ) ) ;
120198
@@ -135,7 +213,7 @@ fn take_nullable<I: IntegerPType, O: IntegerPType>(
135213 validity_buffer. append ( true ) ;
136214 let start = offsets[ data_idx_usize] ;
137215 let stop = offsets[ data_idx_usize + 1 ] ;
138- current_offset += stop - start;
216+ current_offset += NewOffset :: from ( stop - start) . vortex_expect ( "offset type overflow" ) ;
139217 new_offsets. push ( current_offset) ;
140218 valid_indices. push ( data_idx_usize) ;
141219 } else {
@@ -144,11 +222,7 @@ fn take_nullable<I: IntegerPType, O: IntegerPType>(
144222 }
145223 }
146224
147- let mut new_data = ByteBufferMut :: with_capacity (
148- current_offset
149- . to_usize ( )
150- . vortex_expect ( "Failed to cast max offset to usize" ) ,
151- ) ;
225+ let mut new_data = ByteBufferMut :: with_capacity ( current_offset. as_ ( ) ) ;
152226
153227 // Second pass: copy data for valid indices only
154228 for data_idx in valid_indices {
@@ -178,12 +252,14 @@ fn take_nullable<I: IntegerPType, O: IntegerPType>(
178252#[ cfg( test) ]
179253mod tests {
180254 use rstest:: rstest;
255+ use vortex_buffer:: { ByteBuffer , buffer} ;
181256 use vortex_dtype:: { DType , Nullability } ;
182257
183- use crate :: Array ;
184- use crate :: arrays:: { PrimitiveArray , VarBinArray } ;
258+ use crate :: arrays:: { PrimitiveArray , VarBinArray , VarBinVTable } ;
185259 use crate :: compute:: conformance:: take:: test_take_conformance;
186260 use crate :: compute:: take;
261+ use crate :: validity:: Validity ;
262+ use crate :: { Array , IntoArray } ;
187263
188264 #[ test]
189265 fn test_null_take ( ) {
@@ -221,4 +297,27 @@ mod tests {
221297 fn test_take_varbin_conformance ( #[ case] array : VarBinArray ) {
222298 test_take_conformance ( array. as_ref ( ) ) ;
223299 }
300+
301+ #[ test]
302+ fn test_take_overflow ( ) {
303+ let scream = std:: iter:: once ( "a" ) . cycle ( ) . take ( 128 ) . collect :: < String > ( ) ;
304+ let bytes = ByteBuffer :: copy_from ( scream. as_bytes ( ) ) ;
305+ let offsets = buffer ! [ 0u8 , 128u8 ] . into_array ( ) ;
306+
307+ let array = VarBinArray :: new (
308+ offsets,
309+ bytes,
310+ DType :: Utf8 ( Nullability :: NonNullable ) ,
311+ Validity :: NonNullable ,
312+ ) ;
313+
314+ let indices = buffer ! [ 0u32 , 0u32 , 0u32 ] . into_array ( ) ;
315+ let taken = take ( array. as_ref ( ) , indices. as_ref ( ) ) . unwrap ( ) ;
316+
317+ let taken_str = taken. as_ :: < VarBinVTable > ( ) ;
318+ assert_eq ! ( taken_str. len( ) , 3 ) ;
319+ assert_eq ! ( taken_str. bytes_at( 0 ) . as_bytes( ) , scream. as_bytes( ) ) ;
320+ assert_eq ! ( taken_str. bytes_at( 1 ) . as_bytes( ) , scream. as_bytes( ) ) ;
321+ assert_eq ! ( taken_str. bytes_at( 2 ) . as_bytes( ) , scream. as_bytes( ) ) ;
322+ }
224323}
0 commit comments