66#![ cfg( vortex_nightly) ]
77
88use std:: mem:: MaybeUninit ;
9+ use std:: mem:: size_of;
910use std:: mem:: transmute;
1011use std:: simd;
1112use std:: simd:: num:: SimdUint ;
@@ -17,27 +18,55 @@ use vortex_buffer::BufferMut;
1718use vortex_dtype:: NativePType ;
1819use vortex_dtype:: PType ;
1920use vortex_dtype:: UnsignedPType ;
21+ use vortex_dtype:: match_each_native_simd_ptype;
22+ use vortex_dtype:: match_each_unsigned_integer_ptype;
23+
24+ /// SIMD types larger than the SIMD register size are beneficial for
25+ /// performance as this leads to better instruction level parallelism.
26+ pub const SIMD_WIDTH : usize = 64 ;
2027
2128/// Takes the specified indices into a new [`Buffer`] using portable SIMD.
29+ ///
30+ /// This function handles the type matching required to satisfy `SimdElement` bounds.
31+ /// For `f16` values, it reinterprets them as `u16` since `f16` doesn't implement `SimdElement`.
2232#[ inline]
23- pub fn take_portable < T , I > ( buffer : & [ T ] , indices : & [ I ] ) -> Buffer < T >
24- where
25- T : NativePType + simd:: SimdElement ,
26- I : UnsignedPType + simd:: SimdElement ,
27- {
33+ pub fn take_portable < T : NativePType , I : UnsignedPType > ( buffer : & [ T ] , indices : & [ I ] ) -> Buffer < T > {
2834 if T :: PTYPE == PType :: F16 {
35+ assert_eq ! ( size_of:: <half:: f16>( ) , size_of:: <T >( ) ) ;
36+
2937 // Since Rust does not actually support 16-bit floats, we first reinterpret the data as
3038 // `u16` integers.
39+ // SAFETY: We know that f16 has the same bit pattern as u16, so this transmute is fine to
40+ // make.
3141 let u16_slice: & [ u16 ] =
3242 unsafe { std:: slice:: from_raw_parts ( buffer. as_ptr ( ) as * const u16 , buffer. len ( ) ) } ;
43+ return take_with_indices ( u16_slice, indices) . cast_into :: < T > ( ) ;
44+ }
3345
34- let taken_u16 = take_portable_simd :: < u16 , I , SIMD_WIDTH > ( u16_slice , indices ) ;
35- let taken_f16 = taken_u16 . cast_into :: < T > ( ) ;
46+ match_each_native_simd_ptype ! ( T :: PTYPE , | TC | {
47+ assert_eq! ( size_of :: < TC > ( ) , size_of :: <T >( ) ) ;
3648
37- taken_f16
38- } else {
39- take_portable_simd :: < T , I , SIMD_WIDTH > ( buffer, indices)
40- }
49+ // SAFETY: This is essentially a no-op that tricks the compiler into adding the
50+ // `simd::SimdElement` bound we need to call `take_with_indices`.
51+ let buffer: & [ TC ] =
52+ unsafe { std:: slice:: from_raw_parts( buffer. as_ptr( ) as * const TC , buffer. len( ) ) } ;
53+ take_with_indices( buffer, indices) . cast_into:: <T >( )
54+ } )
55+ }
56+
57+ /// Helper that matches on index type and calls `take_portable_simd`.
58+ ///
59+ /// We separate this code out from above to add the [`simd::SimdElement`] constraint.
60+ #[ inline]
61+ fn take_with_indices < T : NativePType + simd:: SimdElement , I : UnsignedPType > (
62+ buffer : & [ T ] ,
63+ indices : & [ I ] ,
64+ ) -> Buffer < T > {
65+ match_each_unsigned_integer_ptype ! ( I :: PTYPE , |IC | {
66+ let indices: & [ IC ] =
67+ unsafe { std:: slice:: from_raw_parts( indices. as_ptr( ) as * const IC , indices. len( ) ) } ;
68+ take_portable_simd:: <T , IC , SIMD_WIDTH >( buffer, indices)
69+ } )
4170}
4271
4372/// Takes elements from an array using SIMD indexing.
@@ -110,4 +139,126 @@ mod tests {
110139 let result = take_portable_simd :: < i32 , u32 , 64 > ( & values, & indices) ;
111140 assert_eq ! ( result. as_slice( ) , [ 0i32 ; 64 ] ) ;
112141 }
142+
143+ /// Tests SIMD gather with a mix of sequential, strided, and repeated indices. This exercises
144+ /// irregular access patterns that stress the gather operation.
145+ #[ test]
146+ fn test_take_mixed_access_patterns ( ) {
147+ // Create a values array with distinct elements.
148+ let values: Vec < i64 > = ( 0 ..256 ) . map ( |i| i * 100 ) . collect ( ) ;
149+
150+ // Build indices with mixed patterns:
151+ // - Sequential access (0, 1, 2, ...)
152+ // - Strided access (0, 4, 8, ...)
153+ // - Repeated indices (same index multiple times)
154+ // - Reverse order
155+ let mut indices: Vec < u32 > = Vec :: with_capacity ( 200 ) ;
156+
157+ // Sequential: indices 0..64.
158+ indices. extend ( 0u32 ..64 ) ;
159+ // Strided by 4: 0, 4, 8, ..., 252.
160+ indices. extend ( ( 0u32 ..64 ) . map ( |i| i * 4 ) ) ;
161+ // Repeated: index 42 repeated 32 times.
162+ indices. extend ( std:: iter:: repeat ( 42u32 ) . take ( 32 ) ) ;
163+ // Reverse: 255, 254, ..., 216.
164+ indices. extend ( ( 216u32 ..256 ) . rev ( ) ) ;
165+
166+ let result = take_portable_simd :: < i64 , u32 , 64 > ( & values, & indices) ;
167+ let result_slice = result. as_slice ( ) ;
168+
169+ // Verify sequential portion.
170+ for i in 0 ..64 {
171+ assert_eq ! ( result_slice[ i] , ( i as i64 ) * 100 , "sequential at index {i}" ) ;
172+ }
173+
174+ // Verify strided portion.
175+ for i in 0 ..64 {
176+ assert_eq ! (
177+ result_slice[ 64 + i] ,
178+ ( i as i64 ) * 4 * 100 ,
179+ "strided at index {i}"
180+ ) ;
181+ }
182+
183+ // Verify repeated portion.
184+ for i in 0 ..32 {
185+ assert_eq ! ( result_slice[ 128 + i] , 42 * 100 , "repeated at index {i}" ) ;
186+ }
187+
188+ // Verify reverse portion.
189+ for i in 0 ..40 {
190+ assert_eq ! (
191+ result_slice[ 160 + i] ,
192+ ( 255 - i as i64 ) * 100 ,
193+ "reverse at index {i}"
194+ ) ;
195+ }
196+ }
197+
198+ /// Tests that the scalar remainder path works correctly when the number of indices is not
199+ /// evenly divisible by the SIMD lane count.
200+ #[ test]
201+ fn test_take_with_remainder ( ) {
202+ let values: Vec < u16 > = ( 0 ..1000 ) . collect ( ) ;
203+
204+ // Use 64 + 37 = 101 indices to test both the SIMD loop (64 elements) and the scalar
205+ // remainder (37 elements).
206+ let indices: Vec < u8 > = ( 0u8 ..101 ) . collect ( ) ;
207+
208+ let result = take_portable_simd :: < u16 , u8 , 64 > ( & values, & indices) ;
209+ let result_slice = result. as_slice ( ) ;
210+
211+ assert_eq ! ( result_slice. len( ) , 101 ) ;
212+
213+ // Verify all elements.
214+ for i in 0 ..101 {
215+ assert_eq ! ( result_slice[ i] , i as u16 , "mismatch at index {i}" ) ;
216+ }
217+
218+ // Also test with exactly 1 remainder element.
219+ let indices_one_remainder: Vec < u8 > = ( 0u8 ..65 ) . collect ( ) ;
220+ let result_one = take_portable_simd :: < u16 , u8 , 64 > ( & values, & indices_one_remainder) ;
221+ assert_eq ! ( result_one. as_slice( ) . len( ) , 65 ) ;
222+ assert_eq ! ( result_one. as_slice( ) [ 64 ] , 64 ) ;
223+ }
224+
225+ /// Tests gather with large 64-bit values and various index types to ensure no truncation
226+ /// occurs during the operation.
227+ #[ test]
228+ fn test_take_large_values_no_truncation ( ) {
229+ // Create values near the edges of i64 range.
230+ let values: Vec < i64 > = vec ! [
231+ i64 :: MIN ,
232+ i64 :: MIN + 1 ,
233+ -1_000_000_000_000i64 ,
234+ -1 ,
235+ 0 ,
236+ 1 ,
237+ 1_000_000_000_000i64 ,
238+ i64 :: MAX - 1 ,
239+ i64 :: MAX ,
240+ ] ;
241+
242+ // Indices that access each value multiple times in different orders.
243+ let indices: Vec < u16 > = vec ! [
244+ 0 , 8 , 1 , 7 , 2 , 6 , 3 , 5 , 4 , // Forward-backward interleaved.
245+ 8 , 8 , 8 , 0 , 0 , 0 , // Repeated extremes.
246+ 4 , 4 , 4 , 4 , 4 , 4 , 4 , 4 , // Repeated zero.
247+ 0 , 1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , // Sequential.
248+ 8 , 7 , 6 , 5 , 4 , 3 , 2 , 1 , 0 , // Reverse.
249+ // Pad to 64 to ensure we hit the SIMD path.
250+ 0 , 1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 0 , 1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 0 , 1 , 2 , 3 ,
251+ ] ;
252+
253+ let result = take_portable_simd :: < i64 , u16 , 64 > ( & values, & indices) ;
254+ let result_slice = result. as_slice ( ) ;
255+
256+ // Verify each result matches the expected value.
257+ for ( i, & idx) in indices. iter ( ) . enumerate ( ) {
258+ assert_eq ! (
259+ result_slice[ i] , values[ idx as usize ] ,
260+ "mismatch at position {i} for index {idx}"
261+ ) ;
262+ }
263+ }
113264}
0 commit comments