Skip to content

Commit 94447a4

Browse files
committed
check for OOB in avx2 take impl
Signed-off-by: Connor Tsui <[email protected]>
1 parent 01ccc6e commit 94447a4

File tree

2 files changed

+67
-18
lines changed

2 files changed

+67
-18
lines changed

vortex-compute/src/take/slice/avx2.rs

Lines changed: 62 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,15 @@
66
//! Only enabled for x86_64 hosts and it is gated at runtime behind feature detection to ensure AVX2
77
//! instructions are available.
88
9+
#![allow(
10+
unused,
11+
reason = "Compiler may see things in this module as unused based on enabled features"
12+
)]
913
#![cfg(any(target_arch = "x86_64", target_arch = "x86"))]
1014

1115
use std::arch::x86_64::__m256i;
1216
use std::arch::x86_64::_mm_loadu_si128;
17+
use std::arch::x86_64::_mm_movemask_epi8;
1318
use std::arch::x86_64::_mm_setzero_si128;
1419
use std::arch::x86_64::_mm_shuffle_epi32;
1520
use std::arch::x86_64::_mm_storeu_si128;
@@ -26,6 +31,7 @@ use std::arch::x86_64::_mm256_loadu_si256;
2631
use std::arch::x86_64::_mm256_mask_i32gather_epi32;
2732
use std::arch::x86_64::_mm256_mask_i64gather_epi32;
2833
use std::arch::x86_64::_mm256_mask_i64gather_epi64;
34+
use std::arch::x86_64::_mm256_movemask_epi8;
2935
use std::arch::x86_64::_mm256_set1_epi32;
3036
use std::arch::x86_64::_mm256_set1_epi64x;
3137
use std::arch::x86_64::_mm256_setzero_si256;
@@ -102,56 +108,75 @@ pub(crate) trait GatherFn<Idx, Values> {
102108
/// Gather values from `src` into the `dst` using the `indices`, optionally using
103109
/// SIMD instructions.
104110
///
111+
/// Returns `true` if all indices in this batch were valid (less than `max_idx`), `false`
112+
/// otherwise. Invalid indices are masked out during the gather (substituting zeros).
113+
///
105114
/// # Safety
106115
///
107116
/// This function can read up to `STRIDE` elements through `indices`, and read/write up to
108117
/// `WIDTH` elements through `src` and `dst` respectively.
109-
unsafe fn gather(indices: *const Idx, max_idx: Idx, src: *const Values, dst: *mut Values);
118+
unsafe fn gather(
119+
indices: *const Idx,
120+
max_idx: Idx,
121+
src: *const Values,
122+
dst: *mut Values,
123+
) -> bool;
110124
}
111125

112126
/// AVX2 version of GatherFn defined for 32- and 64-bit value types.
113127
enum AVX2Gather {}
114128

115129
macro_rules! impl_gather {
116-
($idx:ty, $({$value:ty => load: $load:ident, extend: $extend:ident, splat: $splat:ident, zero_vec: $zero_vec:ident, mask_indices: $mask_indices:ident, mask_cvt: |$mask_var:ident| $mask_cvt:block, gather: $masked_gather:ident, store: $store:ident, WIDTH = $WIDTH:literal, STRIDE = $STRIDE:literal }),+) => {
130+
($idx:ty, $({$value:ty => load: $load:ident, extend: $extend:ident, splat: $splat:ident, zero_vec: $zero_vec:ident, mask_indices: $mask_indices:ident, mask_cvt: |$mask_var:ident| $mask_cvt:block, movemask: $movemask:ident, all_valid_mask: $all_valid_mask:expr, gather: $masked_gather:ident, store: $store:ident, WIDTH = $WIDTH:literal, STRIDE = $STRIDE:literal }),+) => {
117131
$(
118-
impl_gather!(single; $idx, $value, load: $load, extend: $extend, splat: $splat, zero_vec: $zero_vec, mask_indices: $mask_indices, mask_cvt: |$mask_var| $mask_cvt, gather: $masked_gather, store: $store, WIDTH = $WIDTH, STRIDE = $STRIDE);
132+
impl_gather!(single; $idx, $value, load: $load, extend: $extend, splat: $splat, zero_vec: $zero_vec, mask_indices: $mask_indices, mask_cvt: |$mask_var| $mask_cvt, movemask: $movemask, all_valid_mask: $all_valid_mask, gather: $masked_gather, store: $store, WIDTH = $WIDTH, STRIDE = $STRIDE);
119133
)*
120134
};
121-
(single; $idx:ty, $value:ty, load: $load:ident, extend: $extend:ident, splat: $splat:ident, zero_vec: $zero_vec:ident, mask_indices: $mask_indices:ident, mask_cvt: |$mask_var:ident| $mask_cvt:block, gather: $masked_gather:ident, store: $store:ident, WIDTH = $WIDTH:literal, STRIDE = $STRIDE:literal) => {
135+
(single; $idx:ty, $value:ty, load: $load:ident, extend: $extend:ident, splat: $splat:ident, zero_vec: $zero_vec:ident, mask_indices: $mask_indices:ident, mask_cvt: |$mask_var:ident| $mask_cvt:block, movemask: $movemask:ident, all_valid_mask: $all_valid_mask:expr, gather: $masked_gather:ident, store: $store:ident, WIDTH = $WIDTH:literal, STRIDE = $STRIDE:literal) => {
122136
impl GatherFn<$idx, $value> for AVX2Gather {
123137
const WIDTH: usize = $WIDTH;
124138
const STRIDE: usize = $STRIDE;
125139

126140
#[allow(unused_unsafe, clippy::cast_possible_truncation)]
127141
#[inline(always)]
128-
unsafe fn gather(indices: *const $idx, max_idx: $idx, src: *const $value, dst: *mut $value) {
142+
unsafe fn gather(
143+
indices: *const $idx,
144+
max_idx: $idx,
145+
src: *const $value,
146+
dst: *mut $value
147+
) -> bool {
129148
const {
130149
assert!($WIDTH <= $STRIDE, "dst cannot advance by more than the stride");
131150
}
132151

133152
const SCALE: i32 = std::mem::size_of::<$value>() as i32;
134153

135154
let indices_vec = unsafe { $load(indices.cast()) };
136-
// Extend indices to fill vector register
155+
// Extend indices to fill vector register.
137156
let indices_vec = unsafe { $extend(indices_vec) };
138157

139-
// create a vec of the max idx
158+
// Create a vec of the max idx.
140159
let max_idx_vec = unsafe { $splat(max_idx as _) };
141-
// create a mask for valid indices (where the max_idx > provided index).
142-
let invalid_mask = unsafe { $mask_indices(max_idx_vec, indices_vec) };
143-
let invalid_mask = {
144-
let $mask_var = invalid_mask;
160+
// Create a mask for valid indices (where the max_idx > provided index).
161+
let valid_mask = unsafe { $mask_indices(max_idx_vec, indices_vec) };
162+
let valid_mask = {
163+
let $mask_var = valid_mask;
145164
$mask_cvt
146165
};
147166
let zero_vec = unsafe { $zero_vec() };
148167

149168
// Gather the values into new vector register, for masked positions
150169
// it substitutes zero instead of accessing the src.
151-
let values_vec = unsafe { $masked_gather::<SCALE>(zero_vec, src.cast(), indices_vec, invalid_mask) };
170+
let values_vec = unsafe {
171+
$masked_gather::<SCALE>(zero_vec, src.cast(), indices_vec, valid_mask)
172+
};
152173

153174
// Write the vec out to dst.
154175
unsafe { $store(dst.cast(), values_vec) };
176+
177+
// Return true if all indices were valid (all mask bits set).
178+
let mask_bits = unsafe { $movemask(valid_mask) };
179+
mask_bits == $all_valid_mask
155180
}
156181
}
157182
};
@@ -167,6 +192,8 @@ impl_gather!(u8,
167192
zero_vec: _mm256_setzero_si256,
168193
mask_indices: _mm256_cmpgt_epi32,
169194
mask_cvt: |x| { x },
195+
movemask: _mm256_movemask_epi8,
196+
all_valid_mask: -1_i32,
170197
gather: _mm256_mask_i32gather_epi32,
171198
store: _mm256_storeu_si256,
172199
WIDTH = 8, STRIDE = 16
@@ -179,6 +206,8 @@ impl_gather!(u8,
179206
zero_vec: _mm256_setzero_si256,
180207
mask_indices: _mm256_cmpgt_epi64,
181208
mask_cvt: |x| { x },
209+
movemask: _mm256_movemask_epi8,
210+
all_valid_mask: -1_i32,
182211
gather: _mm256_mask_i64gather_epi64,
183212
store: _mm256_storeu_si256,
184213
WIDTH = 4, STRIDE = 16
@@ -195,6 +224,8 @@ impl_gather!(u16,
195224
zero_vec: _mm256_setzero_si256,
196225
mask_indices: _mm256_cmpgt_epi32,
197226
mask_cvt: |x| { x },
227+
movemask: _mm256_movemask_epi8,
228+
all_valid_mask: -1_i32,
198229
gather: _mm256_mask_i32gather_epi32,
199230
store: _mm256_storeu_si256,
200231
WIDTH = 8, STRIDE = 8
@@ -207,6 +238,8 @@ impl_gather!(u16,
207238
zero_vec: _mm256_setzero_si256,
208239
mask_indices: _mm256_cmpgt_epi64,
209240
mask_cvt: |x| { x },
241+
movemask: _mm256_movemask_epi8,
242+
all_valid_mask: -1_i32,
210243
gather: _mm256_mask_i64gather_epi64,
211244
store: _mm256_storeu_si256,
212245
WIDTH = 4, STRIDE = 8
@@ -223,6 +256,8 @@ impl_gather!(u32,
223256
zero_vec: _mm256_setzero_si256,
224257
mask_indices: _mm256_cmpgt_epi32,
225258
mask_cvt: |x| { x },
259+
movemask: _mm256_movemask_epi8,
260+
all_valid_mask: -1_i32,
226261
gather: _mm256_mask_i32gather_epi32,
227262
store: _mm256_storeu_si256,
228263
WIDTH = 8, STRIDE = 8
@@ -235,6 +270,8 @@ impl_gather!(u32,
235270
zero_vec: _mm256_setzero_si256,
236271
mask_indices: _mm256_cmpgt_epi64,
237272
mask_cvt: |x| { x },
273+
movemask: _mm256_movemask_epi8,
274+
all_valid_mask: -1_i32,
238275
gather: _mm256_mask_i64gather_epi64,
239276
store: _mm256_storeu_si256,
240277
WIDTH = 4, STRIDE = 4
@@ -259,6 +296,8 @@ impl_gather!(u64,
259296
_mm_unpacklo_epi64(lo_packed, hi_packed)
260297
}
261298
},
299+
movemask: _mm_movemask_epi8,
300+
all_valid_mask: 0xFFFF_i32,
262301
gather: _mm256_mask_i64gather_epi32,
263302
store: _mm_storeu_si128,
264303
WIDTH = 4, STRIDE = 4
@@ -271,6 +310,8 @@ impl_gather!(u64,
271310
zero_vec: _mm256_setzero_si256,
272311
mask_indices: _mm256_cmpgt_epi64,
273312
mask_cvt: |x| { x },
313+
movemask: _mm256_movemask_epi8,
314+
all_valid_mask: -1_i32,
274315
gather: _mm256_mask_i64gather_epi64,
275316
store: _mm256_storeu_si256,
276317
WIDTH = 4, STRIDE = 4
@@ -292,25 +333,32 @@ where
292333
let buf_uninit = buffer.spare_capacity_mut();
293334

294335
let mut offset = 0;
336+
let mut all_valid = true;
337+
295338
// Loop terminates STRIDE elements before end of the indices array because the GatherFn
296339
// might read up to STRIDE src elements at a time, even though it only advances WIDTH elements
297340
// in the dst.
298341
while offset + Gather::STRIDE < indices_len {
299342
// SAFETY: gather_simd preconditions satisfied:
300343
// 1. `(indices + offset)..(indices + offset + STRIDE)` is in-bounds for indices allocation
301344
// 2. `buffer` has same len as indices so `buffer + offset + STRIDE` is always valid.
302-
unsafe {
345+
let batch_valid = unsafe {
303346
Gather::gather(
304347
indices.as_ptr().add(offset),
305348
max_index,
306349
values.as_ptr(),
307350
buf_uninit.as_mut_ptr().add(offset).cast(),
308351
)
309352
};
353+
all_valid &= batch_valid;
310354
offset += Gather::WIDTH;
311355
}
312356

313-
// Remainder
357+
// Check accumulated validity after hot loop. If there are any 0's, then there was an
358+
// out-of-bounds index.
359+
assert!(all_valid, "index out of bounds in AVX2 take");
360+
361+
// Fall back to scalar iteration for the remainder.
314362
while offset < indices_len {
315363
buf_uninit[offset].write(values[indices[offset].as_()]);
316364
offset += 1;

vortex-compute/src/take/slice/mod.rs

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,11 @@
33

44
//! Take function implementations on slices.
55
6+
#![allow(
7+
unused,
8+
reason = "Compiler may see things in this module as unused based on enabled features"
9+
)]
10+
611
use vortex_buffer::Buffer;
712
use vortex_dtype::UnsignedPType;
813

@@ -34,10 +39,6 @@ impl<T: Copy, I: UnsignedPType> Take<[I]> for &[T] {
3439
}
3540
}
3641

37-
#[allow(
38-
unused,
39-
reason = "Compiler may see this as unused based on enabled features"
40-
)]
4142
#[inline]
4243
pub(crate) fn take_scalar<T: Copy, I: UnsignedPType>(buffer: &[T], indices: &[I]) -> Buffer<T> {
4344
indices.iter().map(|idx| buffer[idx.as_()]).collect()

0 commit comments

Comments
 (0)