77
88use std:: mem:: MaybeUninit ;
99use std:: mem:: size_of;
10- use std:: mem:: transmute;
1110use std:: simd;
11+ use std:: simd:: cmp:: SimdPartialOrd ;
1212use std:: simd:: num:: SimdUint ;
1313
1414use multiversion:: multiversion;
@@ -83,6 +83,10 @@ fn take_with_indices<T: Copy + Default + simd::SimdElement, I: UnsignedPType>(
8383/// buffer. Uses SIMD instructions to process `LANE_COUNT` indices in parallel.
8484///
8585/// Returns a `Buffer<T>` where each element corresponds to `values[indices[i]]`.
86+ ///
87+ /// # Panics
88+ ///
89+ /// Panics if any index is out of bounds for `values`.
8690#[ multiversion( targets( "x86_64+avx2" , "x86_64+avx" , "aarch64+neon" ) ) ]
8791pub fn take_portable_simd < T , I , const LANE_COUNT : usize > ( values : & [ T ] , indices : & [ I ] ) -> Buffer < T >
8892where
@@ -100,37 +104,67 @@ where
100104
101105 let buf_slice = buffer. spare_capacity_mut ( ) ;
102106
107+ // Set up a vector that we can SIMD compare against for out-of-bounds indices.
108+ let len_vec = simd:: Simd :: < usize , LANE_COUNT > :: splat ( values. len ( ) ) ;
109+ let mut all_valid = simd:: Mask :: < isize , LANE_COUNT > :: splat ( true ) ;
110+
103111 for chunk_idx in 0 ..( indices_len / LANE_COUNT ) {
104112 let offset = chunk_idx * LANE_COUNT ;
105- let mask = simd:: Mask :: from_bitmask ( u64:: MAX ) ;
106113 let codes_chunk = simd:: Simd :: < I , LANE_COUNT > :: from_slice ( & indices[ offset..] ) ;
107-
108- let selection = simd:: Simd :: gather_select (
109- values,
110- mask,
111- codes_chunk. cast :: < usize > ( ) ,
112- simd:: Simd :: < T , LANE_COUNT > :: default ( ) ,
113- ) ;
114-
114+ let codes_usize = codes_chunk. cast :: < usize > ( ) ;
115+
116+ // Accumulate validity and use as gather mask. An out-of-bounds index will turn a bit off.
117+ all_valid &= codes_usize. simd_lt ( len_vec) ;
118+
119+ // SAFETY: We use `all_valid` to mask the gather, preventing OOB memory access. If any
120+ // index is OOB, `all_valid` will have those bits turned off, masking out the invalid
121+ // indices.
122+ // Note that this may also mask out valid indices in subsequent iterations. This is fine
123+ // because we will panic after the loop if **any** index was OOB, so we do not care if the
124+ // resulting gathered data is correct or not.
125+ let selection = unsafe {
126+ simd:: Simd :: gather_select_unchecked (
127+ values,
128+ all_valid,
129+ codes_usize,
130+ simd:: Simd :: < T , LANE_COUNT > :: default ( ) ,
131+ )
132+ } ;
133+
134+ // SAFETY: `MaybeUninit<T>` has the same layout as `T`, and we are about to initialize these
135+ // elements with the store.
136+ let uninit = unsafe {
137+ std:: mem:: transmute :: < & mut [ MaybeUninit < T > ] , & mut [ T ] > (
138+ & mut buf_slice[ offset..] [ ..LANE_COUNT ] ,
139+ )
140+ } ;
141+
142+ // SAFETY: The slice `buf_slice[offset..][..LANE_COUNT]` is guaranteed to have exactly
143+ // `LANE_COUNT` elements since `offset` is a multiple of `LANE_COUNT` and we only iterate
144+ // while `offset + LANE_COUNT <= indices_len`.
115145 unsafe {
116- selection. store_select_unchecked (
117- transmute :: < & mut [ MaybeUninit < T > ] , & mut [ T ] > ( & mut buf_slice[ offset..] [ ..64 ] ) ,
118- mask. cast ( ) ,
119- ) ;
146+ selection. store_select_unchecked ( uninit, simd:: Mask :: splat ( true ) ) ;
120147 }
121148 }
122149
150+ // Check accumulated validity after hot loop. If there are any 0's, then there was an
151+ // out-of-bounds index.
152+ assert ! ( all_valid. all( ) , "index out of bounds in SIMD take" ) ;
153+
154+ // Fall back to scalar iteration for the remainder.
123155 for idx in ( ( indices_len / LANE_COUNT ) * LANE_COUNT ) ..indices_len {
156+ // SAFETY: `idx` is in bounds for `buf_slice` since `idx < indices_len == buf_slice.len()`.
157+ // Note that the `values[...]` access is already bounds-checked and will panic if OOB.
124158 unsafe {
125159 buf_slice
126160 . get_unchecked_mut ( idx)
127161 . write ( values[ indices[ idx] . as_ ( ) ] ) ;
128162 }
129163 }
130164
131- unsafe {
132- buffer . set_len ( indices_len) ;
133- }
165+ // SAFETY: All elements have been initialized: the SIMD loop handles `0..chunks * LANE_COUNT`
166+ // and the scalar loop handles the remainder up to ` indices_len`.
167+ unsafe { buffer . set_len ( indices_len ) } ;
134168
135169 buffer. freeze ( )
136170}
@@ -141,12 +175,12 @@ mod tests {
141175 use super :: take_portable_simd;
142176
143177 #[ test]
178+ #[ should_panic( expected = "index out of bounds" ) ]
144179 fn test_take_out_of_bounds ( ) {
145180 let indices = vec ! [ 2_000_000u32 ; 64 ] ;
146181 let values = vec ! [ 1i32 ] ;
147182
148- let result = take_portable_simd :: < i32 , u32 , 64 > ( & values, & indices) ;
149- assert_eq ! ( result. as_slice( ) , [ 0i32 ; 64 ] ) ;
183+ drop ( take_portable_simd :: < i32 , u32 , 64 > ( & values, & indices) ) ;
150184 }
151185
152186 /// Tests SIMD gather with a mix of sequential, strided, and repeated indices. This exercises
0 commit comments