Skip to content

Commit c0d44ef

Browse files
authored
Fix: portable SIMD take implementation (#5654)
It turns out we are not checking `RUSTFLAGS="--cfg vortex_nightly" cargo +nightly nextest run` on our CI. I'll add that to our actions in a followup PR. Also adds some non-trivial tests. Signed-off-by: Connor Tsui <[email protected]>
1 parent bf654c4 commit c0d44ef

File tree

6 files changed

+168
-31
lines changed

6 files changed

+168
-31
lines changed

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

vortex-array/src/arrays/primitive/compute/take/portable.rs

Lines changed: 2 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,6 @@ use crate::validity::Validity;
3333

3434
pub(super) struct TakeKernelPortableSimd;
3535

36-
// SIMD types larger than the SIMD register size are beneficial for
37-
// performance as this leads to better instruction level parallelism.
38-
const SIMD_WIDTH: usize = 64;
39-
4036
impl TakeImpl for TakeKernelPortableSimd {
4137
fn take(
4238
&self,
@@ -47,7 +43,7 @@ impl TakeImpl for TakeKernelPortableSimd {
4743
if array.ptype() == PType::F16 {
4844
// Special handling for f16 to treat as opaque u16
4945
let decoded = match_each_unsigned_integer_ptype!(unsigned_indices.ptype(), |C| {
50-
portable::take_portable_simd::<u16, C, SIMD_WIDTH>(
46+
portable::take_portable_simd::<u16, C, { portable::SIMD_WIDTH }>(
5147
array.reinterpret_cast(PType::U16).as_slice(),
5248
unsigned_indices.as_slice(),
5349
)
@@ -58,7 +54,7 @@ impl TakeImpl for TakeKernelPortableSimd {
5854
} else {
5955
match_each_unsigned_integer_ptype!(unsigned_indices.ptype(), |C| {
6056
match_each_native_simd_ptype!(array.ptype(), |V| {
61-
let decoded = portable::take_portable_simd::<V, C, SIMD_WIDTH>(
57+
let decoded = portable::take_portable_simd::<V, C, { portable::SIMD_WIDTH }>(
6258
array.as_slice(),
6359
unsigned_indices.as_slice(),
6460
);
@@ -68,17 +64,3 @@ impl TakeImpl for TakeKernelPortableSimd {
6864
}
6965
}
7066
}
71-
72-
#[cfg(test)]
73-
mod tests {
74-
use super::take_portable_simd;
75-
76-
#[test]
77-
fn test_take_out_of_bounds() {
78-
let indices = vec![2_000_000u32; 64];
79-
let values = vec![1i32];
80-
81-
let result = take_portable_simd::<u32, i32, 64>(&indices, &values);
82-
assert_eq!(result.as_slice(), [0i32; 64]);
83-
}
84-
}

vortex-compute/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ vortex-vector = { workspace = true }
2929
arrow-array = { workspace = true, optional = true }
3030
arrow-buffer = { workspace = true, optional = true }
3131
arrow-schema = { workspace = true, optional = true }
32+
half = { workspace = true }
3233
log = { workspace = true }
3334
multiversion = { workspace = true }
3435
num-traits = { workspace = true }

vortex-compute/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
//! A collection of compute functions primarily for operating over Vortex vectors.
55
6+
#![cfg_attr(vortex_nightly, feature(portable_simd))]
67
#![deny(missing_docs)]
78
#![deny(clippy::missing_panics_doc)]
89
#![deny(clippy::missing_safety_doc)]

vortex-compute/src/take/slice/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ impl<T: NativePType, I: UnsignedPType> Take<[I]> for &[T] {
3030
}
3131
}
3232

33+
#[allow(unreachable_code, reason = "`vortex_nightly` path returns early")]
3334
take_scalar(self, indices)
3435
}
3536
}

vortex-compute/src/take/slice/portable.rs

Lines changed: 162 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#![cfg(vortex_nightly)]
77

88
use std::mem::MaybeUninit;
9+
use std::mem::size_of;
910
use std::mem::transmute;
1011
use std::simd;
1112
use std::simd::num::SimdUint;
@@ -17,27 +18,55 @@ use vortex_buffer::BufferMut;
1718
use vortex_dtype::NativePType;
1819
use vortex_dtype::PType;
1920
use vortex_dtype::UnsignedPType;
21+
use vortex_dtype::match_each_native_simd_ptype;
22+
use vortex_dtype::match_each_unsigned_integer_ptype;
23+
24+
/// SIMD types larger than the SIMD register size are beneficial for
25+
/// performance as this leads to better instruction level parallelism.
26+
pub const SIMD_WIDTH: usize = 64;
2027

2128
/// Takes the specified indices into a new [`Buffer`] using portable SIMD.
29+
///
30+
/// This function handles the type matching required to satisfy `SimdElement` bounds.
31+
/// For `f16` values, it reinterprets them as `u16` since `f16` doesn't implement `SimdElement`.
2232
#[inline]
23-
pub fn take_portable<T, I>(buffer: &[T], indices: &[I]) -> Buffer<T>
24-
where
25-
T: NativePType + simd::SimdElement,
26-
I: UnsignedPType + simd::SimdElement,
27-
{
33+
pub fn take_portable<T: NativePType, I: UnsignedPType>(buffer: &[T], indices: &[I]) -> Buffer<T> {
2834
if T::PTYPE == PType::F16 {
35+
assert_eq!(size_of::<half::f16>(), size_of::<T>());
36+
2937
// Since Rust does not actually support 16-bit floats, we first reinterpret the data as
3038
// `u16` integers.
39+
// SAFETY: We know that f16 has the same bit pattern as u16, so this transmute is fine to
40+
// make.
3141
let u16_slice: &[u16] =
3242
unsafe { std::slice::from_raw_parts(buffer.as_ptr() as *const u16, buffer.len()) };
43+
return take_with_indices(u16_slice, indices).cast_into::<T>();
44+
}
3345

34-
let taken_u16 = take_portable_simd::<u16, I, SIMD_WIDTH>(u16_slice, indices);
35-
let taken_f16 = taken_u16.cast_into::<T>();
46+
match_each_native_simd_ptype!(T::PTYPE, |TC| {
47+
assert_eq!(size_of::<TC>(), size_of::<T>());
3648

37-
taken_f16
38-
} else {
39-
take_portable_simd::<T, I, SIMD_WIDTH>(buffer, indices)
40-
}
49+
// SAFETY: This is essentially a no-op that tricks the compiler into adding the
50+
// `simd::SimdElement` bound we need to call `take_with_indices`.
51+
let buffer: &[TC] =
52+
unsafe { std::slice::from_raw_parts(buffer.as_ptr() as *const TC, buffer.len()) };
53+
take_with_indices(buffer, indices).cast_into::<T>()
54+
})
55+
}
56+
57+
/// Helper that matches on index type and calls `take_portable_simd`.
58+
///
59+
/// We separate this code out from above to add the [`simd::SimdElement`] constraint.
60+
#[inline]
61+
fn take_with_indices<T: NativePType + simd::SimdElement, I: UnsignedPType>(
62+
buffer: &[T],
63+
indices: &[I],
64+
) -> Buffer<T> {
65+
match_each_unsigned_integer_ptype!(I::PTYPE, |IC| {
66+
let indices: &[IC] =
67+
unsafe { std::slice::from_raw_parts(indices.as_ptr() as *const IC, indices.len()) };
68+
take_portable_simd::<T, IC, SIMD_WIDTH>(buffer, indices)
69+
})
4170
}
4271

4372
/// Takes elements from an array using SIMD indexing.
@@ -110,4 +139,126 @@ mod tests {
110139
let result = take_portable_simd::<i32, u32, 64>(&values, &indices);
111140
assert_eq!(result.as_slice(), [0i32; 64]);
112141
}
142+
143+
/// Tests SIMD gather with a mix of sequential, strided, and repeated indices. This exercises
144+
/// irregular access patterns that stress the gather operation.
145+
#[test]
146+
fn test_take_mixed_access_patterns() {
147+
// Create a values array with distinct elements.
148+
let values: Vec<i64> = (0..256).map(|i| i * 100).collect();
149+
150+
// Build indices with mixed patterns:
151+
// - Sequential access (0, 1, 2, ...)
152+
// - Strided access (0, 4, 8, ...)
153+
// - Repeated indices (same index multiple times)
154+
// - Reverse order
155+
let mut indices: Vec<u32> = Vec::with_capacity(200);
156+
157+
// Sequential: indices 0..64.
158+
indices.extend(0u32..64);
159+
// Strided by 4: 0, 4, 8, ..., 252.
160+
indices.extend((0u32..64).map(|i| i * 4));
161+
// Repeated: index 42 repeated 32 times.
162+
indices.extend(std::iter::repeat(42u32).take(32));
163+
// Reverse: 255, 254, ..., 216.
164+
indices.extend((216u32..256).rev());
165+
166+
let result = take_portable_simd::<i64, u32, 64>(&values, &indices);
167+
let result_slice = result.as_slice();
168+
169+
// Verify sequential portion.
170+
for i in 0..64 {
171+
assert_eq!(result_slice[i], (i as i64) * 100, "sequential at index {i}");
172+
}
173+
174+
// Verify strided portion.
175+
for i in 0..64 {
176+
assert_eq!(
177+
result_slice[64 + i],
178+
(i as i64) * 4 * 100,
179+
"strided at index {i}"
180+
);
181+
}
182+
183+
// Verify repeated portion.
184+
for i in 0..32 {
185+
assert_eq!(result_slice[128 + i], 42 * 100, "repeated at index {i}");
186+
}
187+
188+
// Verify reverse portion.
189+
for i in 0..40 {
190+
assert_eq!(
191+
result_slice[160 + i],
192+
(255 - i as i64) * 100,
193+
"reverse at index {i}"
194+
);
195+
}
196+
}
197+
198+
/// Tests that the scalar remainder path works correctly when the number of indices is not
199+
/// evenly divisible by the SIMD lane count.
200+
#[test]
201+
fn test_take_with_remainder() {
202+
let values: Vec<u16> = (0..1000).collect();
203+
204+
// Use 64 + 37 = 101 indices to test both the SIMD loop (64 elements) and the scalar
205+
// remainder (37 elements).
206+
let indices: Vec<u8> = (0u8..101).collect();
207+
208+
let result = take_portable_simd::<u16, u8, 64>(&values, &indices);
209+
let result_slice = result.as_slice();
210+
211+
assert_eq!(result_slice.len(), 101);
212+
213+
// Verify all elements.
214+
for i in 0..101 {
215+
assert_eq!(result_slice[i], i as u16, "mismatch at index {i}");
216+
}
217+
218+
// Also test with exactly 1 remainder element.
219+
let indices_one_remainder: Vec<u8> = (0u8..65).collect();
220+
let result_one = take_portable_simd::<u16, u8, 64>(&values, &indices_one_remainder);
221+
assert_eq!(result_one.as_slice().len(), 65);
222+
assert_eq!(result_one.as_slice()[64], 64);
223+
}
224+
225+
/// Tests gather with large 64-bit values and various index types to ensure no truncation
226+
/// occurs during the operation.
227+
#[test]
228+
fn test_take_large_values_no_truncation() {
229+
// Create values near the edges of i64 range.
230+
let values: Vec<i64> = vec![
231+
i64::MIN,
232+
i64::MIN + 1,
233+
-1_000_000_000_000i64,
234+
-1,
235+
0,
236+
1,
237+
1_000_000_000_000i64,
238+
i64::MAX - 1,
239+
i64::MAX,
240+
];
241+
242+
// Indices that access each value multiple times in different orders.
243+
let indices: Vec<u16> = vec![
244+
0, 8, 1, 7, 2, 6, 3, 5, 4, // Forward-backward interleaved.
245+
8, 8, 8, 0, 0, 0, // Repeated extremes.
246+
4, 4, 4, 4, 4, 4, 4, 4, // Repeated zero.
247+
0, 1, 2, 3, 4, 5, 6, 7, 8, // Sequential.
248+
8, 7, 6, 5, 4, 3, 2, 1, 0, // Reverse.
249+
// Pad to 64 to ensure we hit the SIMD path.
250+
0, 1, 2, 3, 4, 5, 6, 7, 8, 0, 1, 2, 3, 4, 5, 6, 7, 8, 0, 1, 2, 3,
251+
];
252+
253+
let result = take_portable_simd::<i64, u16, 64>(&values, &indices);
254+
let result_slice = result.as_slice();
255+
256+
// Verify each result matches the expected value.
257+
for (i, &idx) in indices.iter().enumerate() {
258+
assert_eq!(
259+
result_slice[i], values[idx as usize],
260+
"mismatch at position {i} for index {idx}"
261+
);
262+
}
263+
}
113264
}

0 commit comments

Comments
 (0)