@@ -11,7 +11,7 @@ use vortex_dtype::{
1111 DType , IntegerPType , NativePType , Nullability , match_each_integer_ptype,
1212 match_each_native_ptype,
1313} ;
14- use vortex_error:: { VortexExpect , VortexResult , vortex_bail } ;
14+ use vortex_error:: { VortexExpect , VortexResult , vortex_panic } ;
1515use vortex_mask:: { AllOr , Mask } ;
1616use vortex_scalar:: Scalar ;
1717
@@ -23,40 +23,38 @@ impl TakeKernel for SequenceVTable {
2323 let indices = indices. to_primitive ( ) ;
2424 let result_nullability = array. dtype ( ) . nullability ( ) | indices. dtype ( ) . nullability ( ) ;
2525
26- Ok ( match_each_integer_ptype ! ( indices. ptype( ) , |T | {
26+ match_each_integer_ptype ! ( indices. ptype( ) , |T | {
2727 let indices = indices. as_slice:: <T >( ) ;
28- check_bounds( indices, array. len( ) ) ?;
29-
3028 match_each_native_ptype!( array. ptype( ) , |S | {
3129 let mul = array. multiplier( ) . cast:: <S >( ) ;
3230 let base = array. base( ) . cast:: <S >( ) ;
33- take( mul, base, indices, mask, result_nullability)
31+ Ok ( take(
32+ mul,
33+ base,
34+ indices,
35+ mask,
36+ result_nullability,
37+ array. len( ) ,
38+ ) )
3439 } )
35- } ) )
40+ } )
3641 }
3742}
3843
39- fn check_bounds < T : IntegerPType > ( indices : & [ T ] , len : usize ) -> VortexResult < ( ) > {
40- for & i in indices {
41- let i = i. as_ ( ) ;
42- if i >= len {
43- vortex_bail ! ( OutOfBounds : i, 0 , len) ;
44- }
45- }
46-
47- Ok ( ( ) )
48- }
49-
5044fn take < T : IntegerPType , S : NativePType > (
5145 mul : S ,
5246 base : S ,
5347 indices : & [ T ] ,
5448 indices_mask : Mask ,
5549 result_nullability : Nullability ,
50+ len : usize ,
5651) -> ArrayRef {
5752 match indices_mask. bit_buffer ( ) {
5853 AllOr :: All => PrimitiveArray :: new (
5954 Buffer :: from_trusted_len_iter ( indices. iter ( ) . map ( |i| {
55+ if i. as_ ( ) >= len {
56+ vortex_panic ! ( OutOfBounds : i. as_( ) , 0 , len) ;
57+ }
6058 let i = <S as NumCast >:: from :: < T > ( * i) . vortex_expect ( "all indices fit" ) ;
6159 base + i * mul
6260 } ) ) ,
@@ -72,6 +70,10 @@ fn take<T: IntegerPType, S: NativePType>(
7270 let buffer =
7371 Buffer :: from_trusted_len_iter ( indices. iter ( ) . enumerate ( ) . map ( |( mask_index, i) | {
7472 if b. value ( mask_index) {
73+ if i. as_ ( ) >= len {
74+ vortex_panic ! ( OutOfBounds : i. as_( ) , 0 , len) ;
75+ }
76+
7577 let i =
7678 <S as NumCast >:: from :: < T > ( * i) . vortex_expect ( "all valid indices fit" ) ;
7779 base + i * mul
@@ -149,17 +151,10 @@ mod test {
149151 }
150152
151153 #[ test]
154+ #[ should_panic( expected = "index 20 out of bounds" ) ]
152155 fn test_bounds_check ( ) {
153156 let array = SequenceArray :: typed_new ( 0i32 , 1i32 , Nullability :: NonNullable , 10 ) . unwrap ( ) ;
154157 let indices = vortex_array:: arrays:: PrimitiveArray :: from_iter ( [ 0i32 , 20 ] ) ;
155- let result = take ( array. as_ref ( ) , indices. as_ref ( ) ) ;
156- assert ! ( result. is_err( ) ) ;
157- assert ! (
158- result
159- . err( )
160- . unwrap( )
161- . to_string( )
162- . contains( "out of bounds from" )
163- ) ;
158+ let _array = take ( array. as_ref ( ) , indices. as_ref ( ) ) . unwrap ( ) ;
164159 }
165160}
0 commit comments