1010
1111use std:: arch:: x86_64:: __m256i;
1212use std:: arch:: x86_64:: _mm_loadu_si128;
13+ use std:: arch:: x86_64:: _mm_movemask_epi8;
1314use std:: arch:: x86_64:: _mm_setzero_si128;
1415use std:: arch:: x86_64:: _mm_shuffle_epi32;
1516use std:: arch:: x86_64:: _mm_storeu_si128;
@@ -26,6 +27,7 @@ use std::arch::x86_64::_mm256_loadu_si256;
2627use std:: arch:: x86_64:: _mm256_mask_i32gather_epi32;
2728use std:: arch:: x86_64:: _mm256_mask_i64gather_epi32;
2829use std:: arch:: x86_64:: _mm256_mask_i64gather_epi64;
30+ use std:: arch:: x86_64:: _mm256_movemask_epi8;
2931use std:: arch:: x86_64:: _mm256_set1_epi32;
3032use std:: arch:: x86_64:: _mm256_set1_epi64x;
3133use std:: arch:: x86_64:: _mm256_setzero_si256;
@@ -102,56 +104,75 @@ pub(crate) trait GatherFn<Idx, Values> {
102104 /// Gather values from `src` into the `dst` using the `indices`, optionally using
103105 /// SIMD instructions.
104106 ///
107+ /// Returns `true` if all indices in this batch were valid (less than `max_idx`), `false`
108+ /// otherwise. Invalid indices are masked out during the gather (substituting zeros).
109+ ///
105110 /// # Safety
106111 ///
107112 /// This function can read up to `STRIDE` elements through `indices`, and read/write up to
108113 /// `WIDTH` elements through `src` and `dst` respectively.
109- unsafe fn gather ( indices : * const Idx , max_idx : Idx , src : * const Values , dst : * mut Values ) ;
114+ unsafe fn gather (
115+ indices : * const Idx ,
116+ max_idx : Idx ,
117+ src : * const Values ,
118+ dst : * mut Values ,
119+ ) -> bool ;
110120}
111121
112122/// AVX2 version of GatherFn defined for 32- and 64-bit value types.
113123enum AVX2Gather { }
114124
115125macro_rules! impl_gather {
116- ( $idx: ty, $( { $value: ty => load: $load: ident, extend: $extend: ident, splat: $splat: ident, zero_vec: $zero_vec: ident, mask_indices: $mask_indices: ident, mask_cvt: |$mask_var: ident| $mask_cvt: block, gather: $masked_gather: ident, store: $store: ident, WIDTH = $WIDTH: literal, STRIDE = $STRIDE: literal } ) ,+) => {
126+ ( $idx: ty, $( { $value: ty => load: $load: ident, extend: $extend: ident, splat: $splat: ident, zero_vec: $zero_vec: ident, mask_indices: $mask_indices: ident, mask_cvt: |$mask_var: ident| $mask_cvt: block, movemask : $movemask : ident , all_valid_mask : $all_valid_mask : expr , gather: $masked_gather: ident, store: $store: ident, WIDTH = $WIDTH: literal, STRIDE = $STRIDE: literal } ) ,+) => {
117127 $(
118- impl_gather!( single; $idx, $value, load: $load, extend: $extend, splat: $splat, zero_vec: $zero_vec, mask_indices: $mask_indices, mask_cvt: |$mask_var| $mask_cvt, gather: $masked_gather, store: $store, WIDTH = $WIDTH, STRIDE = $STRIDE) ;
128+ impl_gather!( single; $idx, $value, load: $load, extend: $extend, splat: $splat, zero_vec: $zero_vec, mask_indices: $mask_indices, mask_cvt: |$mask_var| $mask_cvt, movemask : $movemask , all_valid_mask : $all_valid_mask , gather: $masked_gather, store: $store, WIDTH = $WIDTH, STRIDE = $STRIDE) ;
119129 ) *
120130 } ;
121- ( single; $idx: ty, $value: ty, load: $load: ident, extend: $extend: ident, splat: $splat: ident, zero_vec: $zero_vec: ident, mask_indices: $mask_indices: ident, mask_cvt: |$mask_var: ident| $mask_cvt: block, gather: $masked_gather: ident, store: $store: ident, WIDTH = $WIDTH: literal, STRIDE = $STRIDE: literal) => {
131+ ( single; $idx: ty, $value: ty, load: $load: ident, extend: $extend: ident, splat: $splat: ident, zero_vec: $zero_vec: ident, mask_indices: $mask_indices: ident, mask_cvt: |$mask_var: ident| $mask_cvt: block, movemask : $movemask : ident , all_valid_mask : $all_valid_mask : expr , gather: $masked_gather: ident, store: $store: ident, WIDTH = $WIDTH: literal, STRIDE = $STRIDE: literal) => {
122132 impl GatherFn <$idx, $value> for AVX2Gather {
123133 const WIDTH : usize = $WIDTH;
124134 const STRIDE : usize = $STRIDE;
125135
126136 #[ allow( unused_unsafe, clippy:: cast_possible_truncation) ]
127137 #[ inline( always) ]
128- unsafe fn gather( indices: * const $idx, max_idx: $idx, src: * const $value, dst: * mut $value) {
138+ unsafe fn gather(
139+ indices: * const $idx,
140+ max_idx: $idx,
141+ src: * const $value,
142+ dst: * mut $value
143+ ) -> bool {
129144 const {
130145 assert!( $WIDTH <= $STRIDE, "dst cannot advance by more than the stride" ) ;
131146 }
132147
133148 const SCALE : i32 = std:: mem:: size_of:: <$value>( ) as i32 ;
134149
135150 let indices_vec = unsafe { $load( indices. cast( ) ) } ;
136- // Extend indices to fill vector register
151+ // Extend indices to fill vector register.
137152 let indices_vec = unsafe { $extend( indices_vec) } ;
138153
139- // create a vec of the max idx
154+ // Create a vec of the max idx.
140155 let max_idx_vec = unsafe { $splat( max_idx as _) } ;
141- // create a mask for valid indices (where the max_idx > provided index).
142- let invalid_mask = unsafe { $mask_indices( max_idx_vec, indices_vec) } ;
143- let invalid_mask = {
144- let $mask_var = invalid_mask ;
156+ // Create a mask for valid indices (where the max_idx > provided index).
157+ let valid_mask = unsafe { $mask_indices( max_idx_vec, indices_vec) } ;
158+ let valid_mask = {
159+ let $mask_var = valid_mask ;
145160 $mask_cvt
146161 } ;
147162 let zero_vec = unsafe { $zero_vec( ) } ;
148163
149164 // Gather the values into new vector register, for masked positions
150165 // it substitutes zero instead of accessing the src.
151- let values_vec = unsafe { $masked_gather:: <SCALE >( zero_vec, src. cast( ) , indices_vec, invalid_mask) } ;
166+ let values_vec = unsafe {
167+ $masked_gather:: <SCALE >( zero_vec, src. cast( ) , indices_vec, valid_mask)
168+ } ;
152169
153170 // Write the vec out to dst.
154171 unsafe { $store( dst. cast( ) , values_vec) } ;
172+
173+ // Return true if all indices were valid (all mask bits set).
174+ let mask_bits = unsafe { $movemask( valid_mask) } ;
175+ mask_bits == $all_valid_mask
155176 }
156177 }
157178 } ;
@@ -167,6 +188,8 @@ impl_gather!(u8,
167188 zero_vec: _mm256_setzero_si256,
168189 mask_indices: _mm256_cmpgt_epi32,
169190 mask_cvt: |x| { x } ,
191+ movemask: _mm256_movemask_epi8,
192+ all_valid_mask: -1_i32 ,
170193 gather: _mm256_mask_i32gather_epi32,
171194 store: _mm256_storeu_si256,
172195 WIDTH = 8 , STRIDE = 16
@@ -179,6 +202,8 @@ impl_gather!(u8,
179202 zero_vec: _mm256_setzero_si256,
180203 mask_indices: _mm256_cmpgt_epi64,
181204 mask_cvt: |x| { x } ,
205+ movemask: _mm256_movemask_epi8,
206+ all_valid_mask: -1_i32 ,
182207 gather: _mm256_mask_i64gather_epi64,
183208 store: _mm256_storeu_si256,
184209 WIDTH = 4 , STRIDE = 16
@@ -195,6 +220,8 @@ impl_gather!(u16,
195220 zero_vec: _mm256_setzero_si256,
196221 mask_indices: _mm256_cmpgt_epi32,
197222 mask_cvt: |x| { x } ,
223+ movemask: _mm256_movemask_epi8,
224+ all_valid_mask: -1_i32 ,
198225 gather: _mm256_mask_i32gather_epi32,
199226 store: _mm256_storeu_si256,
200227 WIDTH = 8 , STRIDE = 8
@@ -207,6 +234,8 @@ impl_gather!(u16,
207234 zero_vec: _mm256_setzero_si256,
208235 mask_indices: _mm256_cmpgt_epi64,
209236 mask_cvt: |x| { x } ,
237+ movemask: _mm256_movemask_epi8,
238+ all_valid_mask: -1_i32 ,
210239 gather: _mm256_mask_i64gather_epi64,
211240 store: _mm256_storeu_si256,
212241 WIDTH = 4 , STRIDE = 8
@@ -223,6 +252,8 @@ impl_gather!(u32,
223252 zero_vec: _mm256_setzero_si256,
224253 mask_indices: _mm256_cmpgt_epi32,
225254 mask_cvt: |x| { x } ,
255+ movemask: _mm256_movemask_epi8,
256+ all_valid_mask: -1_i32 ,
226257 gather: _mm256_mask_i32gather_epi32,
227258 store: _mm256_storeu_si256,
228259 WIDTH = 8 , STRIDE = 8
@@ -235,6 +266,8 @@ impl_gather!(u32,
235266 zero_vec: _mm256_setzero_si256,
236267 mask_indices: _mm256_cmpgt_epi64,
237268 mask_cvt: |x| { x } ,
269+ movemask: _mm256_movemask_epi8,
270+ all_valid_mask: -1_i32 ,
238271 gather: _mm256_mask_i64gather_epi64,
239272 store: _mm256_storeu_si256,
240273 WIDTH = 4 , STRIDE = 4
@@ -259,6 +292,8 @@ impl_gather!(u64,
259292 _mm_unpacklo_epi64( lo_packed, hi_packed)
260293 }
261294 } ,
295+ movemask: _mm_movemask_epi8,
296+ all_valid_mask: 0xFFFF_i32 ,
262297 gather: _mm256_mask_i64gather_epi32,
263298 store: _mm_storeu_si128,
264299 WIDTH = 4 , STRIDE = 4
@@ -271,6 +306,8 @@ impl_gather!(u64,
271306 zero_vec: _mm256_setzero_si256,
272307 mask_indices: _mm256_cmpgt_epi64,
273308 mask_cvt: |x| { x } ,
309+ movemask: _mm256_movemask_epi8,
310+ all_valid_mask: -1_i32 ,
274311 gather: _mm256_mask_i64gather_epi64,
275312 store: _mm256_storeu_si256,
276313 WIDTH = 4 , STRIDE = 4
@@ -292,25 +329,32 @@ where
292329 let buf_uninit = buffer. spare_capacity_mut ( ) ;
293330
294331 let mut offset = 0 ;
332+ let mut all_valid = true ;
333+
295334 // Loop terminates STRIDE elements before end of the indices array because the GatherFn
296335 // might read up to STRIDE src elements at a time, even though it only advances WIDTH elements
297336 // in the dst.
298337 while offset + Gather :: STRIDE < indices_len {
299338 // SAFETY: gather_simd preconditions satisfied:
300339 // 1. `(indices + offset)..(indices + offset + STRIDE)` is in-bounds for indices allocation
301340 // 2. `buffer` has same len as indices so `buffer + offset + STRIDE` is always valid.
302- unsafe {
341+ let batch_valid = unsafe {
303342 Gather :: gather (
304343 indices. as_ptr ( ) . add ( offset) ,
305344 max_index,
306345 values. as_ptr ( ) ,
307346 buf_uninit. as_mut_ptr ( ) . add ( offset) . cast ( ) ,
308347 )
309348 } ;
349+ all_valid &= batch_valid;
310350 offset += Gather :: WIDTH ;
311351 }
312352
313- // Remainder
353+ // Check accumulated validity after hot loop. If there are any 0's, then there was an
354+ // out-of-bounds index.
355+ assert ! ( all_valid, "index out of bounds in AVX2 take" ) ;
356+
357+ // Fall back to scalar iteration for the remainder.
314358 while offset < indices_len {
315359 buf_uninit[ offset] . write ( values[ indices[ offset] . as_ ( ) ] ) ;
316360 offset += 1 ;
0 commit comments