Skip to content

Commit b096726

Browse files
committed
Check in existing loop
Signed-off-by: blaginin <[email protected]>
1 parent 8c4cc7b commit b096726

File tree

1 file changed

+21
-26
lines changed
  • encodings/sequence/src/compute

1 file changed

+21
-26
lines changed

encodings/sequence/src/compute/take.rs

Lines changed: 21 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ use vortex_dtype::{
1111
DType, IntegerPType, NativePType, Nullability, match_each_integer_ptype,
1212
match_each_native_ptype,
1313
};
14-
use vortex_error::{VortexExpect, VortexResult, vortex_bail};
14+
use vortex_error::{VortexExpect, VortexResult, vortex_panic};
1515
use vortex_mask::{AllOr, Mask};
1616
use vortex_scalar::Scalar;
1717

@@ -23,40 +23,38 @@ impl TakeKernel for SequenceVTable {
2323
let indices = indices.to_primitive();
2424
let result_nullability = array.dtype().nullability() | indices.dtype().nullability();
2525

26-
Ok(match_each_integer_ptype!(indices.ptype(), |T| {
26+
match_each_integer_ptype!(indices.ptype(), |T| {
2727
let indices = indices.as_slice::<T>();
28-
check_bounds(indices, array.len())?;
29-
3028
match_each_native_ptype!(array.ptype(), |S| {
3129
let mul = array.multiplier().cast::<S>();
3230
let base = array.base().cast::<S>();
33-
take(mul, base, indices, mask, result_nullability)
31+
Ok(take(
32+
mul,
33+
base,
34+
indices,
35+
mask,
36+
result_nullability,
37+
array.len(),
38+
))
3439
})
35-
}))
40+
})
3641
}
3742
}
3843

39-
fn check_bounds<T: IntegerPType>(indices: &[T], len: usize) -> VortexResult<()> {
40-
for &i in indices {
41-
let i = i.as_();
42-
if i >= len {
43-
vortex_bail!(OutOfBounds: i, 0, len);
44-
}
45-
}
46-
47-
Ok(())
48-
}
49-
5044
fn take<T: IntegerPType, S: NativePType>(
5145
mul: S,
5246
base: S,
5347
indices: &[T],
5448
indices_mask: Mask,
5549
result_nullability: Nullability,
50+
len: usize,
5651
) -> ArrayRef {
5752
match indices_mask.bit_buffer() {
5853
AllOr::All => PrimitiveArray::new(
5954
Buffer::from_trusted_len_iter(indices.iter().map(|i| {
55+
if i.as_() >= len {
56+
vortex_panic!(OutOfBounds: i.as_(), 0, len);
57+
}
6058
let i = <S as NumCast>::from::<T>(*i).vortex_expect("all indices fit");
6159
base + i * mul
6260
})),
@@ -72,6 +70,10 @@ fn take<T: IntegerPType, S: NativePType>(
7270
let buffer =
7371
Buffer::from_trusted_len_iter(indices.iter().enumerate().map(|(mask_index, i)| {
7472
if b.value(mask_index) {
73+
if i.as_() >= len {
74+
vortex_panic!(OutOfBounds: i.as_(), 0, len);
75+
}
76+
7577
let i =
7678
<S as NumCast>::from::<T>(*i).vortex_expect("all valid indices fit");
7779
base + i * mul
@@ -149,17 +151,10 @@ mod test {
149151
}
150152

151153
#[test]
154+
#[should_panic(expected = "index 20 out of bounds")]
152155
fn test_bounds_check() {
153156
let array = SequenceArray::typed_new(0i32, 1i32, Nullability::NonNullable, 10).unwrap();
154157
let indices = vortex_array::arrays::PrimitiveArray::from_iter([0i32, 20]);
155-
let result = take(array.as_ref(), indices.as_ref());
156-
assert!(result.is_err());
157-
assert!(
158-
result
159-
.err()
160-
.unwrap()
161-
.to_string()
162-
.contains("out of bounds from")
163-
);
158+
let _array = take(array.as_ref(), indices.as_ref()).unwrap();
164159
}
165160
}

0 commit comments

Comments
 (0)