Skip to content

Commit de2811c

Browse files
committed
fix: avoid oob take for primitive
Signed-off-by: Andrew Duffy <[email protected]>
1 parent fe4c81b commit de2811c

File tree

2 files changed

+113
-27
lines changed

2 files changed

+113
-27
lines changed

vortex-array/src/arrays/primitive/compute/take/mod.rs

Lines changed: 2 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ mod avx2;
66

77
#[cfg(vortex_nightly)]
88
mod portable;
9+
mod scalar;
910

1011
use std::sync::LazyLock;
1112

@@ -48,7 +49,7 @@ static PRIMITIVE_TAKE_KERNEL: LazyLock<&'static dyn TakeImpl> = LazyLock::new(||
4849
}
4950
} else {
5051
// stable all other platforms: scalar kernel
51-
&TakeKernelScalar
52+
&scalar::TakeKernelScalar
5253
}
5354
}
5455
});
@@ -62,25 +63,6 @@ trait TakeImpl: Send + Sync {
6263
) -> VortexResult<ArrayRef>;
6364
}
6465

65-
#[allow(unused)]
66-
struct TakeKernelScalar;
67-
68-
impl TakeImpl for TakeKernelScalar {
69-
fn take(
70-
&self,
71-
array: &PrimitiveArray,
72-
indices: &PrimitiveArray,
73-
validity: Validity,
74-
) -> VortexResult<ArrayRef> {
75-
match_each_native_ptype!(array.ptype(), |T| {
76-
match_each_integer_ptype!(indices.ptype(), |I| {
77-
let values = take_primitive_scalar(array.as_slice::<T>(), indices.as_slice::<I>());
78-
Ok(PrimitiveArray::new(values, validity).into_array())
79-
})
80-
})
81-
}
82-
}
83-
8466
impl TakeKernel for PrimitiveVTable {
8567
fn take(&self, array: &PrimitiveArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
8668
let DType::Primitive(ptype, null) = indices.dtype() else {
@@ -102,13 +84,6 @@ impl TakeKernel for PrimitiveVTable {
10284

10385
register_kernel!(TakeKernelAdapter(PrimitiveVTable).lift());
10486

105-
// Compiler may see this as unused based on enabled features
106-
#[allow(unused)]
107-
#[inline(always)]
108-
fn take_primitive_scalar<T: NativePType, I: IntegerPType>(array: &[T], indices: &[I]) -> Buffer<T> {
109-
indices.iter().map(|idx| array[idx.as_()]).collect()
110-
}
111-
11287
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
11388
#[cfg(test)]
11489
mod test {
Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
// SPDX-FileCopyrightText: Copyright the Vortex contributors
3+
4+
use vortex_buffer::Buffer;
5+
use vortex_dtype::{IntegerPType, NativePType, match_each_integer_ptype, match_each_native_ptype};
6+
use vortex_error::VortexResult;
7+
8+
use crate::arrays::PrimitiveArray;
9+
use crate::arrays::primitive::compute::take::TakeImpl;
10+
use crate::validity::Validity;
11+
use crate::vtable::ValidityHelper;
12+
use crate::{ArrayRef, IntoArray};
13+
14+
#[allow(unused)]
15+
pub(super) struct TakeKernelScalar;
16+
17+
impl TakeImpl for TakeKernelScalar {
18+
#[allow(clippy::cognitive_complexity)]
19+
fn take(
20+
&self,
21+
array: &PrimitiveArray,
22+
indices: &PrimitiveArray,
23+
validity: Validity,
24+
) -> VortexResult<ArrayRef> {
25+
match_each_native_ptype!(array.ptype(), |T| {
26+
match_each_integer_ptype!(indices.ptype(), |I| {
27+
let indices_slice = indices.as_slice::<I>();
28+
let indices_validity = indices.validity();
29+
let values = if indices_validity.all_valid(indices_slice.len()) {
30+
// Fast path: indices have no nulls, safe to index directly
31+
take_primitive_scalar(array.as_slice::<T>(), indices_slice)
32+
} else {
33+
// Slow path: indices may have nulls with garbage values
34+
take_primitive_scalar_with_nulls(
35+
array.as_slice::<T>(),
36+
indices_slice,
37+
indices_validity,
38+
)
39+
};
40+
Ok(PrimitiveArray::new(values, validity).into_array())
41+
})
42+
})
43+
}
44+
}
45+
46+
// Compiler may see this as unused based on enabled features
47+
#[allow(unused)]
48+
#[inline(always)]
49+
fn take_primitive_scalar<T: NativePType, I: IntegerPType>(array: &[T], indices: &[I]) -> Buffer<T> {
50+
indices.iter().map(|idx| array[idx.as_()]).collect()
51+
}
52+
53+
/// Slow path for take when indices may contain nulls with garbage values.
54+
/// Uses 0 as a safe index for null positions (the value will be masked out by validity).
55+
#[allow(unused)]
56+
#[inline(always)]
57+
fn take_primitive_scalar_with_nulls<T: NativePType, I: IntegerPType>(
58+
array: &[T],
59+
indices: &[I],
60+
validity: &Validity,
61+
) -> Buffer<T> {
62+
indices
63+
.iter()
64+
.enumerate()
65+
.map(|(i, idx)| {
66+
if validity.is_valid(i) {
67+
array[idx.as_()]
68+
} else {
69+
T::zero()
70+
}
71+
})
72+
.collect()
73+
}
74+
75+
#[cfg(test)]
76+
mod tests {
77+
use vortex_buffer::buffer;
78+
79+
use crate::arrays::PrimitiveArray;
80+
use crate::arrays::primitive::compute::take::TakeImpl;
81+
use crate::arrays::primitive::compute::take::scalar::TakeKernelScalar;
82+
use crate::validity::Validity;
83+
use crate::{IntoArray, ToCanonical};
84+
85+
#[test]
86+
fn test_scalar_basic() {
87+
let values = buffer![1, 2, 3, 4, 5].into_array().to_primitive();
88+
let indices = buffer![0, 1, 1, 2, 2, 3, 4].into_array().to_primitive();
89+
90+
let result = TakeKernelScalar
91+
.take(&values, &indices, Validity::NonNullable)
92+
.unwrap()
93+
.to_primitive();
94+
assert_eq!(result.as_slice::<i32>(), &[1, 2, 2, 3, 3, 4, 5]);
95+
}
96+
97+
#[test]
98+
fn test_scalar_with_nulls() {
99+
let values = buffer![1, 2, 3, 4, 5].into_array().to_primitive();
100+
let validity = Validity::from_iter([true, false, true, true, true]);
101+
let indices = PrimitiveArray::new(buffer![0, 100, 2, 3, 4], validity.clone());
102+
103+
let result = TakeKernelScalar
104+
.take(&values, &indices, validity.clone())
105+
.unwrap()
106+
.to_primitive();
107+
108+
assert_eq!(result.as_slice::<i32>(), &[1, 0, 3, 4, 5]);
109+
assert_eq!(result.validity, validity);
110+
}
111+
}

0 commit comments

Comments
 (0)