Skip to content

Commit 2b21c62

Browse files
committed
check for OOB in avx2 take impl
Signed-off-by: Connor Tsui <[email protected]>
1 parent 01ccc6e commit 2b21c62

File tree

1 file changed

+58
-14
lines changed
  • vortex-compute/src/take/slice

1 file changed

+58
-14
lines changed

vortex-compute/src/take/slice/avx2.rs

Lines changed: 58 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
use std::arch::x86_64::__m256i;
1212
use std::arch::x86_64::_mm_loadu_si128;
13+
use std::arch::x86_64::_mm_movemask_epi8;
1314
use std::arch::x86_64::_mm_setzero_si128;
1415
use std::arch::x86_64::_mm_shuffle_epi32;
1516
use std::arch::x86_64::_mm_storeu_si128;
@@ -26,6 +27,7 @@ use std::arch::x86_64::_mm256_loadu_si256;
2627
use std::arch::x86_64::_mm256_mask_i32gather_epi32;
2728
use std::arch::x86_64::_mm256_mask_i64gather_epi32;
2829
use std::arch::x86_64::_mm256_mask_i64gather_epi64;
30+
use std::arch::x86_64::_mm256_movemask_epi8;
2931
use std::arch::x86_64::_mm256_set1_epi32;
3032
use std::arch::x86_64::_mm256_set1_epi64x;
3133
use std::arch::x86_64::_mm256_setzero_si256;
@@ -102,56 +104,75 @@ pub(crate) trait GatherFn<Idx, Values> {
102104
/// Gather values from `src` into the `dst` using the `indices`, optionally using
103105
/// SIMD instructions.
104106
///
107+
/// Returns `true` if all indices in this batch were valid (less than `max_idx`), `false`
108+
/// otherwise. Invalid indices are masked out during the gather (substituting zeros).
109+
///
105110
/// # Safety
106111
///
107112
/// This function can read up to `STRIDE` elements through `indices`, and read/write up to
108113
/// `WIDTH` elements through `src` and `dst` respectively.
109-
unsafe fn gather(indices: *const Idx, max_idx: Idx, src: *const Values, dst: *mut Values);
114+
unsafe fn gather(
115+
indices: *const Idx,
116+
max_idx: Idx,
117+
src: *const Values,
118+
dst: *mut Values,
119+
) -> bool;
110120
}
111121

112122
/// AVX2 version of GatherFn defined for 32- and 64-bit value types.
113123
enum AVX2Gather {}
114124

115125
macro_rules! impl_gather {
116-
($idx:ty, $({$value:ty => load: $load:ident, extend: $extend:ident, splat: $splat:ident, zero_vec: $zero_vec:ident, mask_indices: $mask_indices:ident, mask_cvt: |$mask_var:ident| $mask_cvt:block, gather: $masked_gather:ident, store: $store:ident, WIDTH = $WIDTH:literal, STRIDE = $STRIDE:literal }),+) => {
126+
($idx:ty, $({$value:ty => load: $load:ident, extend: $extend:ident, splat: $splat:ident, zero_vec: $zero_vec:ident, mask_indices: $mask_indices:ident, mask_cvt: |$mask_var:ident| $mask_cvt:block, movemask: $movemask:ident, all_valid_mask: $all_valid_mask:expr, gather: $masked_gather:ident, store: $store:ident, WIDTH = $WIDTH:literal, STRIDE = $STRIDE:literal }),+) => {
117127
$(
118-
impl_gather!(single; $idx, $value, load: $load, extend: $extend, splat: $splat, zero_vec: $zero_vec, mask_indices: $mask_indices, mask_cvt: |$mask_var| $mask_cvt, gather: $masked_gather, store: $store, WIDTH = $WIDTH, STRIDE = $STRIDE);
128+
impl_gather!(single; $idx, $value, load: $load, extend: $extend, splat: $splat, zero_vec: $zero_vec, mask_indices: $mask_indices, mask_cvt: |$mask_var| $mask_cvt, movemask: $movemask, all_valid_mask: $all_valid_mask, gather: $masked_gather, store: $store, WIDTH = $WIDTH, STRIDE = $STRIDE);
119129
)*
120130
};
121-
(single; $idx:ty, $value:ty, load: $load:ident, extend: $extend:ident, splat: $splat:ident, zero_vec: $zero_vec:ident, mask_indices: $mask_indices:ident, mask_cvt: |$mask_var:ident| $mask_cvt:block, gather: $masked_gather:ident, store: $store:ident, WIDTH = $WIDTH:literal, STRIDE = $STRIDE:literal) => {
131+
(single; $idx:ty, $value:ty, load: $load:ident, extend: $extend:ident, splat: $splat:ident, zero_vec: $zero_vec:ident, mask_indices: $mask_indices:ident, mask_cvt: |$mask_var:ident| $mask_cvt:block, movemask: $movemask:ident, all_valid_mask: $all_valid_mask:expr, gather: $masked_gather:ident, store: $store:ident, WIDTH = $WIDTH:literal, STRIDE = $STRIDE:literal) => {
122132
impl GatherFn<$idx, $value> for AVX2Gather {
123133
const WIDTH: usize = $WIDTH;
124134
const STRIDE: usize = $STRIDE;
125135

126136
#[allow(unused_unsafe, clippy::cast_possible_truncation)]
127137
#[inline(always)]
128-
unsafe fn gather(indices: *const $idx, max_idx: $idx, src: *const $value, dst: *mut $value) {
138+
unsafe fn gather(
139+
indices: *const $idx,
140+
max_idx: $idx,
141+
src: *const $value,
142+
dst: *mut $value
143+
) -> bool {
129144
const {
130145
assert!($WIDTH <= $STRIDE, "dst cannot advance by more than the stride");
131146
}
132147

133148
const SCALE: i32 = std::mem::size_of::<$value>() as i32;
134149

135150
let indices_vec = unsafe { $load(indices.cast()) };
136-
// Extend indices to fill vector register
151+
// Extend indices to fill vector register.
137152
let indices_vec = unsafe { $extend(indices_vec) };
138153

139-
// create a vec of the max idx
154+
// Create a vec of the max idx.
140155
let max_idx_vec = unsafe { $splat(max_idx as _) };
141-
// create a mask for valid indices (where the max_idx > provided index).
142-
let invalid_mask = unsafe { $mask_indices(max_idx_vec, indices_vec) };
143-
let invalid_mask = {
144-
let $mask_var = invalid_mask;
156+
// Create a mask for valid indices (where the max_idx > provided index).
157+
let valid_mask = unsafe { $mask_indices(max_idx_vec, indices_vec) };
158+
let valid_mask = {
159+
let $mask_var = valid_mask;
145160
$mask_cvt
146161
};
147162
let zero_vec = unsafe { $zero_vec() };
148163

149164
// Gather the values into new vector register, for masked positions
150165
// it substitutes zero instead of accessing the src.
151-
let values_vec = unsafe { $masked_gather::<SCALE>(zero_vec, src.cast(), indices_vec, invalid_mask) };
166+
let values_vec = unsafe {
167+
$masked_gather::<SCALE>(zero_vec, src.cast(), indices_vec, valid_mask)
168+
};
152169

153170
// Write the vec out to dst.
154171
unsafe { $store(dst.cast(), values_vec) };
172+
173+
// Return true if all indices were valid (all mask bits set).
174+
let mask_bits = unsafe { $movemask(valid_mask) };
175+
mask_bits == $all_valid_mask
155176
}
156177
}
157178
};
@@ -167,6 +188,8 @@ impl_gather!(u8,
167188
zero_vec: _mm256_setzero_si256,
168189
mask_indices: _mm256_cmpgt_epi32,
169190
mask_cvt: |x| { x },
191+
movemask: _mm256_movemask_epi8,
192+
all_valid_mask: -1_i32,
170193
gather: _mm256_mask_i32gather_epi32,
171194
store: _mm256_storeu_si256,
172195
WIDTH = 8, STRIDE = 16
@@ -179,6 +202,8 @@ impl_gather!(u8,
179202
zero_vec: _mm256_setzero_si256,
180203
mask_indices: _mm256_cmpgt_epi64,
181204
mask_cvt: |x| { x },
205+
movemask: _mm256_movemask_epi8,
206+
all_valid_mask: -1_i32,
182207
gather: _mm256_mask_i64gather_epi64,
183208
store: _mm256_storeu_si256,
184209
WIDTH = 4, STRIDE = 16
@@ -195,6 +220,8 @@ impl_gather!(u16,
195220
zero_vec: _mm256_setzero_si256,
196221
mask_indices: _mm256_cmpgt_epi32,
197222
mask_cvt: |x| { x },
223+
movemask: _mm256_movemask_epi8,
224+
all_valid_mask: -1_i32,
198225
gather: _mm256_mask_i32gather_epi32,
199226
store: _mm256_storeu_si256,
200227
WIDTH = 8, STRIDE = 8
@@ -207,6 +234,8 @@ impl_gather!(u16,
207234
zero_vec: _mm256_setzero_si256,
208235
mask_indices: _mm256_cmpgt_epi64,
209236
mask_cvt: |x| { x },
237+
movemask: _mm256_movemask_epi8,
238+
all_valid_mask: -1_i32,
210239
gather: _mm256_mask_i64gather_epi64,
211240
store: _mm256_storeu_si256,
212241
WIDTH = 4, STRIDE = 8
@@ -223,6 +252,8 @@ impl_gather!(u32,
223252
zero_vec: _mm256_setzero_si256,
224253
mask_indices: _mm256_cmpgt_epi32,
225254
mask_cvt: |x| { x },
255+
movemask: _mm256_movemask_epi8,
256+
all_valid_mask: -1_i32,
226257
gather: _mm256_mask_i32gather_epi32,
227258
store: _mm256_storeu_si256,
228259
WIDTH = 8, STRIDE = 8
@@ -235,6 +266,8 @@ impl_gather!(u32,
235266
zero_vec: _mm256_setzero_si256,
236267
mask_indices: _mm256_cmpgt_epi64,
237268
mask_cvt: |x| { x },
269+
movemask: _mm256_movemask_epi8,
270+
all_valid_mask: -1_i32,
238271
gather: _mm256_mask_i64gather_epi64,
239272
store: _mm256_storeu_si256,
240273
WIDTH = 4, STRIDE = 4
@@ -259,6 +292,8 @@ impl_gather!(u64,
259292
_mm_unpacklo_epi64(lo_packed, hi_packed)
260293
}
261294
},
295+
movemask: _mm_movemask_epi8,
296+
all_valid_mask: 0xFFFF_i32,
262297
gather: _mm256_mask_i64gather_epi32,
263298
store: _mm_storeu_si128,
264299
WIDTH = 4, STRIDE = 4
@@ -271,6 +306,8 @@ impl_gather!(u64,
271306
zero_vec: _mm256_setzero_si256,
272307
mask_indices: _mm256_cmpgt_epi64,
273308
mask_cvt: |x| { x },
309+
movemask: _mm256_movemask_epi8,
310+
all_valid_mask: -1_i32,
274311
gather: _mm256_mask_i64gather_epi64,
275312
store: _mm256_storeu_si256,
276313
WIDTH = 4, STRIDE = 4
@@ -292,25 +329,32 @@ where
292329
let buf_uninit = buffer.spare_capacity_mut();
293330

294331
let mut offset = 0;
332+
let mut all_valid = true;
333+
295334
// Loop terminates STRIDE elements before end of the indices array because the GatherFn
296335
// might read up to STRIDE src elements at a time, even though it only advances WIDTH elements
297336
// in the dst.
298337
while offset + Gather::STRIDE < indices_len {
299338
// SAFETY: gather_simd preconditions satisfied:
300339
// 1. `(indices + offset)..(indices + offset + STRIDE)` is in-bounds for indices allocation
301340
// 2. `buffer` has same len as indices so `buffer + offset + STRIDE` is always valid.
302-
unsafe {
341+
let batch_valid = unsafe {
303342
Gather::gather(
304343
indices.as_ptr().add(offset),
305344
max_index,
306345
values.as_ptr(),
307346
buf_uninit.as_mut_ptr().add(offset).cast(),
308347
)
309348
};
349+
all_valid &= batch_valid;
310350
offset += Gather::WIDTH;
311351
}
312352

313-
// Remainder
353+
// Check accumulated validity after hot loop. If there are any 0's, then there was an
354+
// out-of-bounds index.
355+
assert!(all_valid, "index out of bounds in AVX2 take");
356+
357+
// Fall back to scalar iteration for the remainder.
314358
while offset < indices_len {
315359
buf_uninit[offset].write(values[indices[offset].as_()]);
316360
offset += 1;

0 commit comments

Comments
 (0)