66//! Only enabled for x86_64 hosts and it is gated at runtime behind feature detection to ensure AVX2
77//! instructions are available.
88
9+ #![ allow(
10+ unused,
11+ reason = "Compiler may see things in this module as unused based on enabled features"
12+ ) ]
913#![ cfg( any( target_arch = "x86_64" , target_arch = "x86" ) ) ]
1014
1115use std:: arch:: x86_64:: __m256i;
1216use std:: arch:: x86_64:: _mm_loadu_si128;
17+ use std:: arch:: x86_64:: _mm_movemask_epi8;
1318use std:: arch:: x86_64:: _mm_setzero_si128;
1419use std:: arch:: x86_64:: _mm_shuffle_epi32;
1520use std:: arch:: x86_64:: _mm_storeu_si128;
@@ -26,6 +31,7 @@ use std::arch::x86_64::_mm256_loadu_si256;
2631use std:: arch:: x86_64:: _mm256_mask_i32gather_epi32;
2732use std:: arch:: x86_64:: _mm256_mask_i64gather_epi32;
2833use std:: arch:: x86_64:: _mm256_mask_i64gather_epi64;
34+ use std:: arch:: x86_64:: _mm256_movemask_epi8;
2935use std:: arch:: x86_64:: _mm256_set1_epi32;
3036use std:: arch:: x86_64:: _mm256_set1_epi64x;
3137use std:: arch:: x86_64:: _mm256_setzero_si256;
@@ -102,56 +108,75 @@ pub(crate) trait GatherFn<Idx, Values> {
102108 /// Gather values from `src` into the `dst` using the `indices`, optionally using
103109 /// SIMD instructions.
104110 ///
111+ /// Returns `true` if all indices in this batch were valid (less than `max_idx`), `false`
112+ /// otherwise. Invalid indices are masked out during the gather (substituting zeros).
113+ ///
105114 /// # Safety
106115 ///
107116 /// This function can read up to `STRIDE` elements through `indices`, and read/write up to
108117 /// `WIDTH` elements through `src` and `dst` respectively.
109- unsafe fn gather ( indices : * const Idx , max_idx : Idx , src : * const Values , dst : * mut Values ) ;
118+ unsafe fn gather (
119+ indices : * const Idx ,
120+ max_idx : Idx ,
121+ src : * const Values ,
122+ dst : * mut Values ,
123+ ) -> bool ;
110124}
111125
112126/// AVX2 version of GatherFn defined for 32- and 64-bit value types.
113127enum AVX2Gather { }
114128
115129macro_rules! impl_gather {
116- ( $idx: ty, $( { $value: ty => load: $load: ident, extend: $extend: ident, splat: $splat: ident, zero_vec: $zero_vec: ident, mask_indices: $mask_indices: ident, mask_cvt: |$mask_var: ident| $mask_cvt: block, gather: $masked_gather: ident, store: $store: ident, WIDTH = $WIDTH: literal, STRIDE = $STRIDE: literal } ) ,+) => {
130+ ( $idx: ty, $( { $value: ty => load: $load: ident, extend: $extend: ident, splat: $splat: ident, zero_vec: $zero_vec: ident, mask_indices: $mask_indices: ident, mask_cvt: |$mask_var: ident| $mask_cvt: block, movemask : $movemask : ident , all_valid_mask : $all_valid_mask : expr , gather: $masked_gather: ident, store: $store: ident, WIDTH = $WIDTH: literal, STRIDE = $STRIDE: literal } ) ,+) => {
117131 $(
118- impl_gather!( single; $idx, $value, load: $load, extend: $extend, splat: $splat, zero_vec: $zero_vec, mask_indices: $mask_indices, mask_cvt: |$mask_var| $mask_cvt, gather: $masked_gather, store: $store, WIDTH = $WIDTH, STRIDE = $STRIDE) ;
132+ impl_gather!( single; $idx, $value, load: $load, extend: $extend, splat: $splat, zero_vec: $zero_vec, mask_indices: $mask_indices, mask_cvt: |$mask_var| $mask_cvt, movemask : $movemask , all_valid_mask : $all_valid_mask , gather: $masked_gather, store: $store, WIDTH = $WIDTH, STRIDE = $STRIDE) ;
119133 ) *
120134 } ;
121- ( single; $idx: ty, $value: ty, load: $load: ident, extend: $extend: ident, splat: $splat: ident, zero_vec: $zero_vec: ident, mask_indices: $mask_indices: ident, mask_cvt: |$mask_var: ident| $mask_cvt: block, gather: $masked_gather: ident, store: $store: ident, WIDTH = $WIDTH: literal, STRIDE = $STRIDE: literal) => {
135+ ( single; $idx: ty, $value: ty, load: $load: ident, extend: $extend: ident, splat: $splat: ident, zero_vec: $zero_vec: ident, mask_indices: $mask_indices: ident, mask_cvt: |$mask_var: ident| $mask_cvt: block, movemask : $movemask : ident , all_valid_mask : $all_valid_mask : expr , gather: $masked_gather: ident, store: $store: ident, WIDTH = $WIDTH: literal, STRIDE = $STRIDE: literal) => {
122136 impl GatherFn <$idx, $value> for AVX2Gather {
123137 const WIDTH : usize = $WIDTH;
124138 const STRIDE : usize = $STRIDE;
125139
126140 #[ allow( unused_unsafe, clippy:: cast_possible_truncation) ]
127141 #[ inline( always) ]
128- unsafe fn gather( indices: * const $idx, max_idx: $idx, src: * const $value, dst: * mut $value) {
142+ unsafe fn gather(
143+ indices: * const $idx,
144+ max_idx: $idx,
145+ src: * const $value,
146+ dst: * mut $value
147+ ) -> bool {
129148 const {
130149 assert!( $WIDTH <= $STRIDE, "dst cannot advance by more than the stride" ) ;
131150 }
132151
133152 const SCALE : i32 = std:: mem:: size_of:: <$value>( ) as i32 ;
134153
135154 let indices_vec = unsafe { $load( indices. cast( ) ) } ;
136- // Extend indices to fill vector register
155+ // Extend indices to fill vector register.
137156 let indices_vec = unsafe { $extend( indices_vec) } ;
138157
139- // create a vec of the max idx
158+ // Create a vec of the max idx.
140159 let max_idx_vec = unsafe { $splat( max_idx as _) } ;
141- // create a mask for valid indices (where the max_idx > provided index).
142- let invalid_mask = unsafe { $mask_indices( max_idx_vec, indices_vec) } ;
143- let invalid_mask = {
144- let $mask_var = invalid_mask ;
160+ // Create a mask for valid indices (where the max_idx > provided index).
161+ let valid_mask = unsafe { $mask_indices( max_idx_vec, indices_vec) } ;
162+ let valid_mask = {
163+ let $mask_var = valid_mask ;
145164 $mask_cvt
146165 } ;
147166 let zero_vec = unsafe { $zero_vec( ) } ;
148167
149168 // Gather the values into new vector register, for masked positions
150169 // it substitutes zero instead of accessing the src.
151- let values_vec = unsafe { $masked_gather:: <SCALE >( zero_vec, src. cast( ) , indices_vec, invalid_mask) } ;
170+ let values_vec = unsafe {
171+ $masked_gather:: <SCALE >( zero_vec, src. cast( ) , indices_vec, valid_mask)
172+ } ;
152173
153174 // Write the vec out to dst.
154175 unsafe { $store( dst. cast( ) , values_vec) } ;
176+
177+ // Return true if all indices were valid (all mask bits set).
178+ let mask_bits = unsafe { $movemask( valid_mask) } ;
179+ mask_bits == $all_valid_mask
155180 }
156181 }
157182 } ;
@@ -167,6 +192,8 @@ impl_gather!(u8,
167192 zero_vec: _mm256_setzero_si256,
168193 mask_indices: _mm256_cmpgt_epi32,
169194 mask_cvt: |x| { x } ,
195+ movemask: _mm256_movemask_epi8,
196+ all_valid_mask: -1_i32 ,
170197 gather: _mm256_mask_i32gather_epi32,
171198 store: _mm256_storeu_si256,
172199 WIDTH = 8 , STRIDE = 16
@@ -179,6 +206,8 @@ impl_gather!(u8,
179206 zero_vec: _mm256_setzero_si256,
180207 mask_indices: _mm256_cmpgt_epi64,
181208 mask_cvt: |x| { x } ,
209+ movemask: _mm256_movemask_epi8,
210+ all_valid_mask: -1_i32 ,
182211 gather: _mm256_mask_i64gather_epi64,
183212 store: _mm256_storeu_si256,
184213 WIDTH = 4 , STRIDE = 16
@@ -195,6 +224,8 @@ impl_gather!(u16,
195224 zero_vec: _mm256_setzero_si256,
196225 mask_indices: _mm256_cmpgt_epi32,
197226 mask_cvt: |x| { x } ,
227+ movemask: _mm256_movemask_epi8,
228+ all_valid_mask: -1_i32 ,
198229 gather: _mm256_mask_i32gather_epi32,
199230 store: _mm256_storeu_si256,
200231 WIDTH = 8 , STRIDE = 8
@@ -207,6 +238,8 @@ impl_gather!(u16,
207238 zero_vec: _mm256_setzero_si256,
208239 mask_indices: _mm256_cmpgt_epi64,
209240 mask_cvt: |x| { x } ,
241+ movemask: _mm256_movemask_epi8,
242+ all_valid_mask: -1_i32 ,
210243 gather: _mm256_mask_i64gather_epi64,
211244 store: _mm256_storeu_si256,
212245 WIDTH = 4 , STRIDE = 8
@@ -223,6 +256,8 @@ impl_gather!(u32,
223256 zero_vec: _mm256_setzero_si256,
224257 mask_indices: _mm256_cmpgt_epi32,
225258 mask_cvt: |x| { x } ,
259+ movemask: _mm256_movemask_epi8,
260+ all_valid_mask: -1_i32 ,
226261 gather: _mm256_mask_i32gather_epi32,
227262 store: _mm256_storeu_si256,
228263 WIDTH = 8 , STRIDE = 8
@@ -235,6 +270,8 @@ impl_gather!(u32,
235270 zero_vec: _mm256_setzero_si256,
236271 mask_indices: _mm256_cmpgt_epi64,
237272 mask_cvt: |x| { x } ,
273+ movemask: _mm256_movemask_epi8,
274+ all_valid_mask: -1_i32 ,
238275 gather: _mm256_mask_i64gather_epi64,
239276 store: _mm256_storeu_si256,
240277 WIDTH = 4 , STRIDE = 4
@@ -259,6 +296,8 @@ impl_gather!(u64,
259296 _mm_unpacklo_epi64( lo_packed, hi_packed)
260297 }
261298 } ,
299+ movemask: _mm_movemask_epi8,
300+ all_valid_mask: 0xFFFF_i32 ,
262301 gather: _mm256_mask_i64gather_epi32,
263302 store: _mm_storeu_si128,
264303 WIDTH = 4 , STRIDE = 4
@@ -271,6 +310,8 @@ impl_gather!(u64,
271310 zero_vec: _mm256_setzero_si256,
272311 mask_indices: _mm256_cmpgt_epi64,
273312 mask_cvt: |x| { x } ,
313+ movemask: _mm256_movemask_epi8,
314+ all_valid_mask: -1_i32 ,
274315 gather: _mm256_mask_i64gather_epi64,
275316 store: _mm256_storeu_si256,
276317 WIDTH = 4 , STRIDE = 4
@@ -292,25 +333,32 @@ where
292333 let buf_uninit = buffer. spare_capacity_mut ( ) ;
293334
294335 let mut offset = 0 ;
336+ let mut all_valid = true ;
337+
295338 // Loop terminates STRIDE elements before end of the indices array because the GatherFn
296339 // might read up to STRIDE src elements at a time, even though it only advances WIDTH elements
297340 // in the dst.
298341 while offset + Gather :: STRIDE < indices_len {
299342 // SAFETY: gather_simd preconditions satisfied:
300343 // 1. `(indices + offset)..(indices + offset + STRIDE)` is in-bounds for indices allocation
301344 // 2. `buffer` has same len as indices so `buffer + offset + STRIDE` is always valid.
302- unsafe {
345+ let batch_valid = unsafe {
303346 Gather :: gather (
304347 indices. as_ptr ( ) . add ( offset) ,
305348 max_index,
306349 values. as_ptr ( ) ,
307350 buf_uninit. as_mut_ptr ( ) . add ( offset) . cast ( ) ,
308351 )
309352 } ;
353+ all_valid &= batch_valid;
310354 offset += Gather :: WIDTH ;
311355 }
312356
313- // Remainder
357+ // Check accumulated validity after hot loop. If there are any 0's, then there was an
358+ // out-of-bounds index.
359+ assert ! ( all_valid, "index out of bounds in AVX2 take" ) ;
360+
361+ // Fall back to scalar iteration for the remainder.
314362 while offset < indices_len {
315363 buf_uninit[ offset] . write ( values[ indices[ offset] . as_ ( ) ] ) ;
316364 offset += 1 ;
0 commit comments