Skip to content

Commit 4b3a5c3

Browse files
committed
add OOB check + safety comments
Signed-off-by: Connor Tsui <[email protected]>
1 parent c15a6fe commit 4b3a5c3

File tree

1 file changed

+53
-19
lines changed

1 file changed

+53
-19
lines changed

vortex-compute/src/take/slice/portable.rs

Lines changed: 53 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77

88
use std::mem::MaybeUninit;
99
use std::mem::size_of;
10-
use std::mem::transmute;
1110
use std::simd;
11+
use std::simd::cmp::SimdPartialOrd;
1212
use std::simd::num::SimdUint;
1313

1414
use multiversion::multiversion;
@@ -83,6 +83,10 @@ fn take_with_indices<T: Copy + Default + simd::SimdElement, I: UnsignedPType>(
8383
/// buffer. Uses SIMD instructions to process `LANE_COUNT` indices in parallel.
8484
///
8585
/// Returns a `Buffer<T>` where each element corresponds to `values[indices[i]]`.
86+
///
87+
/// # Panics
88+
///
89+
/// Panics if any index is out of bounds for `values`.
8690
#[multiversion(targets("x86_64+avx2", "x86_64+avx", "aarch64+neon"))]
8791
pub fn take_portable_simd<T, I, const LANE_COUNT: usize>(values: &[T], indices: &[I]) -> Buffer<T>
8892
where
@@ -100,37 +104,67 @@ where
100104

101105
let buf_slice = buffer.spare_capacity_mut();
102106

107+
// Set up a vector that we can SIMD compare against for out-of-bounds indices.
108+
let len_vec = simd::Simd::<usize, LANE_COUNT>::splat(values.len());
109+
let mut all_valid = simd::Mask::<isize, LANE_COUNT>::splat(true);
110+
103111
for chunk_idx in 0..(indices_len / LANE_COUNT) {
104112
let offset = chunk_idx * LANE_COUNT;
105-
let mask = simd::Mask::from_bitmask(u64::MAX);
106113
let codes_chunk = simd::Simd::<I, LANE_COUNT>::from_slice(&indices[offset..]);
107-
108-
let selection = simd::Simd::gather_select(
109-
values,
110-
mask,
111-
codes_chunk.cast::<usize>(),
112-
simd::Simd::<T, LANE_COUNT>::default(),
113-
);
114-
114+
let codes_usize = codes_chunk.cast::<usize>();
115+
116+
// Accumulate validity and use as gather mask. An out-of-bounds index will turn a bit off.
117+
all_valid &= codes_usize.simd_lt(len_vec);
118+
119+
// SAFETY: We use `all_valid` to mask the gather, preventing OOB memory access. If any
120+
// index is OOB, `all_valid` will have those bits turned off, masking out the invalid
121+
// indices.
122+
// Note that this may also mask out valid indices in subsequent iterations. This is fine
123+
// because we will panic after the loop if **any** index was OOB, so we do not care if the
124+
// resulting gathered data is correct or not.
125+
let selection = unsafe {
126+
simd::Simd::gather_select_unchecked(
127+
values,
128+
all_valid,
129+
codes_usize,
130+
simd::Simd::<T, LANE_COUNT>::default(),
131+
)
132+
};
133+
134+
// SAFETY: `MaybeUninit<T>` has the same layout as `T`, and we are about to initialize these
135+
// elements with the store.
136+
let uninit = unsafe {
137+
std::mem::transmute::<&mut [MaybeUninit<T>], &mut [T]>(
138+
&mut buf_slice[offset..][..LANE_COUNT],
139+
)
140+
};
141+
142+
// SAFETY: The slice `buf_slice[offset..][..LANE_COUNT]` is guaranteed to have exactly
143+
// `LANE_COUNT` elements since `offset` is a multiple of `LANE_COUNT` and we only iterate
144+
// while `offset + LANE_COUNT <= indices_len`.
115145
unsafe {
116-
selection.store_select_unchecked(
117-
transmute::<&mut [MaybeUninit<T>], &mut [T]>(&mut buf_slice[offset..][..64]),
118-
mask.cast(),
119-
);
146+
selection.store_select_unchecked(uninit, simd::Mask::splat(true));
120147
}
121148
}
122149

150+
// Check accumulated validity after hot loop. If there are any 0's, then there was an
151+
// out-of-bounds index.
152+
assert!(all_valid.all(), "index out of bounds in SIMD take");
153+
154+
// Fall back to scalar iteration for the remainder.
123155
for idx in ((indices_len / LANE_COUNT) * LANE_COUNT)..indices_len {
156+
// SAFETY: `idx` is in bounds for `buf_slice` since `idx < indices_len == buf_slice.len()`.
157+
// Note that the `values[...]` access is already bounds-checked and will panic if OOB.
124158
unsafe {
125159
buf_slice
126160
.get_unchecked_mut(idx)
127161
.write(values[indices[idx].as_()]);
128162
}
129163
}
130164

131-
unsafe {
132-
buffer.set_len(indices_len);
133-
}
165+
// SAFETY: All elements have been initialized: the SIMD loop handles `0..chunks * LANE_COUNT`
166+
// and the scalar loop handles the remainder up to `indices_len`.
167+
unsafe { buffer.set_len(indices_len) };
134168

135169
buffer.freeze()
136170
}
@@ -141,12 +175,12 @@ mod tests {
141175
use super::take_portable_simd;
142176

143177
#[test]
178+
#[should_panic(expected = "index out of bounds")]
144179
fn test_take_out_of_bounds() {
145180
let indices = vec![2_000_000u32; 64];
146181
let values = vec![1i32];
147182

148-
let result = take_portable_simd::<i32, u32, 64>(&values, &indices);
149-
assert_eq!(result.as_slice(), [0i32; 64]);
183+
drop(take_portable_simd::<i32, u32, 64>(&values, &indices));
150184
}
151185

152186
/// Tests SIMD gather with a mix of sequential, strided, and repeated indices. This exercises

0 commit comments

Comments
 (0)