@@ -31,87 +31,62 @@ use std::arch::x86_64::_mm256_set1_epi64x;
3131use std:: arch:: x86_64:: _mm256_setzero_si256;
3232use std:: arch:: x86_64:: _mm256_storeu_si256;
3333use std:: convert:: identity;
34+ use std:: mem:: size_of;
3435
3536use vortex_buffer:: Alignment ;
3637use vortex_buffer:: Buffer ;
3738use vortex_buffer:: BufferMut ;
38- use vortex_dtype:: NativePType ;
39- use vortex_dtype:: PType ;
4039use vortex_dtype:: UnsignedPType ;
40+ use vortex_dtype:: match_each_unsigned_integer_ptype;
4141
4242use crate :: take:: slice:: take_scalar;
4343
4444/// Takes the specified indices into a new [`Buffer`] using AVX2 SIMD.
4545///
46- /// This returns None if the AVX2 feature is not detected at runtime, signalling to the caller
47- /// that it should fall back to the scalar implementation.
48- ///
49- /// If AVX2 is available, this returns a PrimitiveArray containing the result of the take operation
50- /// accelerated using AVX2 instructions.
46+ /// This function handles the type matching required to satisfy AVX2 gather instruction requirements
47+ /// by casting to unsigned integers of the same size. Falls back to scalar implementation for
48+ /// unsupported type sizes.
5149///
5250/// # Panics
5351///
54- /// This function panics if any of the provided `indices` are out of bounds for `values`
52+ /// This function panics if any of the provided `indices` are out of bounds for `values`.
5553///
5654/// # Safety
5755///
5856/// The caller must ensure the `avx2` feature is enabled.
59- #[ allow( dead_code, unused_variables, reason = "TODO(connor): Implement this" ) ]
6057#[ target_feature( enable = "avx2" ) ]
6158#[ inline]
62- pub unsafe fn take_avx2 < V : NativePType , I : UnsignedPType > (
63- buffer : & [ V ] ,
64- indices : & [ I ] ,
65- ) -> Buffer < V > {
66- macro_rules! dispatch_avx2 {
67- ( $indices: ty, $values: ty) => {
68- { let result = dispatch_avx2!( $indices, $values, cast: $values) ; result }
69- } ;
70- ( $indices: ty, $values: ty, cast: $cast: ty) => { {
71- let indices = unsafe { std:: mem:: transmute:: <& [ I ] , & [ $indices] >( indices) } ;
72- let values = unsafe { std:: mem:: transmute:: <& [ V ] , & [ $cast] >( buffer) } ;
73-
74- let result = exec_take:: <$cast, $indices, AVX2Gather >( values, indices) ;
75- result. cast_into:: <V >( )
76- } } ;
77- }
78-
79- match ( I :: PTYPE , V :: PTYPE ) {
80- // Int value types. Only 32 and 64 bit types are supported.
81- ( PType :: U8 , PType :: I32 ) => dispatch_avx2 ! ( u8 , i32 ) ,
82- ( PType :: U8 , PType :: U32 ) => dispatch_avx2 ! ( u8 , u32 ) ,
83- ( PType :: U8 , PType :: I64 ) => dispatch_avx2 ! ( u8 , i64 ) ,
84- ( PType :: U8 , PType :: U64 ) => dispatch_avx2 ! ( u8 , u64 ) ,
85- ( PType :: U16 , PType :: I32 ) => dispatch_avx2 ! ( u16 , i32 ) ,
86- ( PType :: U16 , PType :: U32 ) => dispatch_avx2 ! ( u16 , u32 ) ,
87- ( PType :: U16 , PType :: I64 ) => dispatch_avx2 ! ( u16 , i64 ) ,
88- ( PType :: U16 , PType :: U64 ) => dispatch_avx2 ! ( u16 , u64 ) ,
89- ( PType :: U32 , PType :: I32 ) => dispatch_avx2 ! ( u32 , i32 ) ,
90- ( PType :: U32 , PType :: U32 ) => dispatch_avx2 ! ( u32 , u32 ) ,
91- ( PType :: U32 , PType :: I64 ) => dispatch_avx2 ! ( u32 , i64 ) ,
92- ( PType :: U32 , PType :: U64 ) => dispatch_avx2 ! ( u32 , u64 ) ,
93-
94- // Float value types, treat them as if they were corresponding int types.
95- ( PType :: U8 , PType :: F32 ) => dispatch_avx2 ! ( u8 , f32 , cast: u32 ) ,
96- ( PType :: U16 , PType :: F32 ) => dispatch_avx2 ! ( u16 , f32 , cast: u32 ) ,
97- ( PType :: U32 , PType :: F32 ) => dispatch_avx2 ! ( u32 , f32 , cast: u32 ) ,
98- ( PType :: U64 , PType :: F32 ) => dispatch_avx2 ! ( u64 , f32 , cast: u32 ) ,
99-
100- ( PType :: U8 , PType :: F64 ) => dispatch_avx2 ! ( u8 , f64 , cast: u64 ) ,
101- ( PType :: U16 , PType :: F64 ) => dispatch_avx2 ! ( u16 , f64 , cast: u64 ) ,
102- ( PType :: U32 , PType :: F64 ) => dispatch_avx2 ! ( u32 , f64 , cast: u64 ) ,
103- ( PType :: U64 , PType :: F64 ) => dispatch_avx2 ! ( u64 , f64 , cast: u64 ) ,
104-
105- // Scalar fallback for unsupported value types.
106- _ => {
107- tracing:: trace!(
108- "take AVX2 kernel missing for indices {} values {}, falling back to scalar" ,
109- I :: PTYPE ,
110- V :: PTYPE
111- ) ;
112-
113- take_scalar ( buffer, indices)
59+ pub unsafe fn take_avx2 < V : Copy , I : UnsignedPType > ( buffer : & [ V ] , indices : & [ I ] ) -> Buffer < V > {
60+ // AVX2 gather operations only care about bit patterns, not semantic type. We cast to unsigned
61+ // integers which have the required gather implementations and then cast back.
62+ //
63+ // SAFETY: The pointer casts below are safe because:
64+ // - `V` and the target type have the same size (matched by `size_of::<V>()`)
65+ // - The alignment of unsigned integers is always <= their size, and `buffer` came from a valid
66+ // `&[V]` which guarantees proper alignment for types of the same size.
67+ match size_of :: < V > ( ) {
68+ 4 => {
69+ let values: & [ u32 ] =
70+ unsafe { std:: slice:: from_raw_parts ( buffer. as_ptr ( ) . cast :: < u32 > ( ) , buffer. len ( ) ) } ;
71+ match_each_unsigned_integer_ptype ! ( I :: PTYPE , |IC | {
72+ let indices: & [ IC ] = unsafe {
73+ std:: slice:: from_raw_parts( indices. as_ptr( ) . cast:: <IC >( ) , indices. len( ) )
74+ } ;
75+ exec_take:: <u32 , IC , AVX2Gather >( values, indices) . cast_into:: <V >( )
76+ } )
77+ }
78+ 8 => {
79+ let values: & [ u64 ] =
80+ unsafe { std:: slice:: from_raw_parts ( buffer. as_ptr ( ) . cast :: < u64 > ( ) , buffer. len ( ) ) } ;
81+ match_each_unsigned_integer_ptype ! ( I :: PTYPE , |IC | {
82+ let indices: & [ IC ] = unsafe {
83+ std:: slice:: from_raw_parts( indices. as_ptr( ) . cast:: <IC >( ) , indices. len( ) )
84+ } ;
85+ exec_take:: <u64 , IC , AVX2Gather >( values, indices) . cast_into:: <V >( )
86+ } )
11487 }
88+ // Fall back to scalar implementation for unsupported type sizes (1, 2 byte types).
89+ _ => take_scalar ( buffer, indices) ,
11590 }
11691}
11792
@@ -182,9 +157,9 @@ macro_rules! impl_gather {
182157 } ;
183158}
184159
185- // kernels for u8 indices
160+ // Kernels for u8 indices.
186161impl_gather ! ( u8 ,
187- // 32-bit values, loaded 8 at a time
162+ // 32-bit values, loaded 8 at a time.
188163 { u32 =>
189164 load: _mm_loadu_si128,
190165 extend: _mm256_cvtepu8_epi32,
@@ -196,19 +171,7 @@ impl_gather!(u8,
196171 store: _mm256_storeu_si256,
197172 WIDTH = 8 , STRIDE = 16
198173 } ,
199- { i32 =>
200- load: _mm_loadu_si128,
201- extend: _mm256_cvtepu8_epi32,
202- splat: _mm256_set1_epi32,
203- zero_vec: _mm256_setzero_si256,
204- mask_indices: _mm256_cmpgt_epi32,
205- mask_cvt: |x| { x } ,
206- gather: _mm256_mask_i32gather_epi32,
207- store: _mm256_storeu_si256,
208- WIDTH = 8 , STRIDE = 16
209- } ,
210-
211- // 64-bit values, loaded 4 at a time
174+ // 64-bit values, loaded 4 at a time.
212175 { u64 =>
213176 load: _mm_loadu_si128,
214177 extend: _mm256_cvtepu8_epi64,
@@ -219,23 +182,12 @@ impl_gather!(u8,
219182 gather: _mm256_mask_i64gather_epi64,
220183 store: _mm256_storeu_si256,
221184 WIDTH = 4 , STRIDE = 16
222- } ,
223- { i64 =>
224- load: _mm_loadu_si128,
225- extend: _mm256_cvtepu8_epi64,
226- splat: _mm256_set1_epi64x,
227- zero_vec: _mm256_setzero_si256,
228- mask_indices: _mm256_cmpgt_epi64,
229- mask_cvt: |x| { x } ,
230- gather: _mm256_mask_i64gather_epi64,
231- store: _mm256_storeu_si256,
232- WIDTH = 4 , STRIDE = 16
233185 }
234186) ;
235187
236- // kernels for u16 indices
188+ // Kernels for u16 indices.
237189impl_gather ! ( u16 ,
238- // 32-bit values. 8x indices loaded at a time and 8x values written at a time
190+ // 32-bit values. 8x indices loaded at a time and 8x values written at a time.
239191 { u32 =>
240192 load: _mm_loadu_si128,
241193 extend: _mm256_cvtepu16_epi32,
@@ -247,18 +199,6 @@ impl_gather!(u16,
247199 store: _mm256_storeu_si256,
248200 WIDTH = 8 , STRIDE = 8
249201 } ,
250- { i32 =>
251- load: _mm_loadu_si128,
252- extend: _mm256_cvtepu16_epi32,
253- splat: _mm256_set1_epi32,
254- zero_vec: _mm256_setzero_si256,
255- mask_indices: _mm256_cmpgt_epi32,
256- mask_cvt: |x| { x } ,
257- gather: _mm256_mask_i32gather_epi32,
258- store: _mm256_storeu_si256,
259- WIDTH = 8 , STRIDE = 8
260- } ,
261-
262202 // 64-bit values. 8x indices loaded at a time and 4x values loaded at a time.
263203 { u64 =>
264204 load: _mm_loadu_si128,
@@ -270,23 +210,12 @@ impl_gather!(u16,
270210 gather: _mm256_mask_i64gather_epi64,
271211 store: _mm256_storeu_si256,
272212 WIDTH = 4 , STRIDE = 8
273- } ,
274- { i64 =>
275- load: _mm_loadu_si128,
276- extend: _mm256_cvtepu16_epi64,
277- splat: _mm256_set1_epi64x,
278- zero_vec: _mm256_setzero_si256,
279- mask_indices: _mm256_cmpgt_epi64,
280- mask_cvt: |x| { x } ,
281- gather: _mm256_mask_i64gather_epi64,
282- store: _mm256_storeu_si256,
283- WIDTH = 4 , STRIDE = 8
284213 }
285214) ;
286215
287- // kernels for u32 indices
216+ // Kernels for u32 indices.
288217impl_gather ! ( u32 ,
289- // 32-bit values. 8x indices loaded at a time and 8x values written
218+ // 32-bit values. 8x indices loaded at a time and 8x values written.
290219 { u32 =>
291220 load: _mm256_loadu_si256,
292221 extend: identity,
@@ -298,19 +227,7 @@ impl_gather!(u32,
298227 store: _mm256_storeu_si256,
299228 WIDTH = 8 , STRIDE = 8
300229 } ,
301- { i32 =>
302- load: _mm256_loadu_si256,
303- extend: identity,
304- splat: _mm256_set1_epi32,
305- zero_vec: _mm256_setzero_si256,
306- mask_indices: _mm256_cmpgt_epi32,
307- mask_cvt: |x| { x } ,
308- gather: _mm256_mask_i32gather_epi32,
309- store: _mm256_storeu_si256,
310- WIDTH = 8 , STRIDE = 8
311- } ,
312-
313- // 64-bit values
230+ // 64-bit values.
314231 { u64 =>
315232 load: _mm_loadu_si128,
316233 extend: _mm256_cvtepu32_epi64,
@@ -321,22 +238,12 @@ impl_gather!(u32,
321238 gather: _mm256_mask_i64gather_epi64,
322239 store: _mm256_storeu_si256,
323240 WIDTH = 4 , STRIDE = 4
324- } ,
325- { i64 =>
326- load: _mm_loadu_si128,
327- extend: _mm256_cvtepu32_epi64,
328- splat: _mm256_set1_epi64x,
329- zero_vec: _mm256_setzero_si256,
330- mask_indices: _mm256_cmpgt_epi64,
331- mask_cvt: |x| { x } ,
332- gather: _mm256_mask_i64gather_epi64,
333- store: _mm256_storeu_si256,
334- WIDTH = 4 , STRIDE = 4
335241 }
336242) ;
337243
338- // kernels for u64 indices
244+ // Kernels for u64 indices.
339245impl_gather ! ( u64 ,
246+ // 32-bit values.
340247 { u32 =>
341248 load: _mm256_loadu_si256,
342249 extend: identity,
@@ -356,27 +263,7 @@ impl_gather!(u64,
356263 store: _mm_storeu_si128,
357264 WIDTH = 4 , STRIDE = 4
358265 } ,
359- { i32 =>
360- load: _mm256_loadu_si256,
361- extend: identity,
362- splat: _mm256_set1_epi64x,
363- zero_vec: _mm_setzero_si128,
364- mask_indices: _mm256_cmpgt_epi64,
365- mask_cvt: |m| {
366- unsafe {
367- let lo_bits = _mm256_extracti128_si256:: <0 >( m) ; // lower half
368- let hi_bits = _mm256_extracti128_si256:: <1 >( m) ; // upper half
369- let lo_packed = _mm_shuffle_epi32:: <0b01_01_01_01 >( lo_bits) ;
370- let hi_packed = _mm_shuffle_epi32:: <0b01_01_01_01 >( hi_bits) ;
371- _mm_unpacklo_epi64( lo_packed, hi_packed)
372- }
373- } ,
374- gather: _mm256_mask_i64gather_epi32,
375- store: _mm_storeu_si128,
376- WIDTH = 4 , STRIDE = 4
377- } ,
378-
379- // 64-bit values
266+ // 64-bit values.
380267 { u64 =>
381268 load: _mm256_loadu_si256,
382269 extend: identity,
@@ -387,17 +274,6 @@ impl_gather!(u64,
387274 gather: _mm256_mask_i64gather_epi64,
388275 store: _mm256_storeu_si256,
389276 WIDTH = 4 , STRIDE = 4
390- } ,
391- { i64 =>
392- load: _mm256_loadu_si256,
393- extend: identity,
394- splat: _mm256_set1_epi64x,
395- zero_vec: _mm256_setzero_si256,
396- mask_indices: _mm256_cmpgt_epi64,
397- mask_cvt: |x| { x } ,
398- gather: _mm256_mask_i64gather_epi64,
399- store: _mm256_storeu_si256,
400- WIDTH = 4 , STRIDE = 4
401277 }
402278) ;
403279
0 commit comments