Skip to content

Commit 01ccc6e

Browse files
committed
add avx2 take impl back and bound by Copy
Signed-off-by: Connor Tsui <[email protected]>
1 parent 4b3a5c3 commit 01ccc6e

File tree

2 files changed

+47
-177
lines changed

2 files changed

+47
-177
lines changed

vortex-compute/src/take/slice/avx2.rs

Lines changed: 47 additions & 171 deletions
Original file line numberDiff line numberDiff line change
@@ -31,87 +31,62 @@ use std::arch::x86_64::_mm256_set1_epi64x;
3131
use std::arch::x86_64::_mm256_setzero_si256;
3232
use std::arch::x86_64::_mm256_storeu_si256;
3333
use std::convert::identity;
34+
use std::mem::size_of;
3435

3536
use vortex_buffer::Alignment;
3637
use vortex_buffer::Buffer;
3738
use vortex_buffer::BufferMut;
38-
use vortex_dtype::NativePType;
39-
use vortex_dtype::PType;
4039
use vortex_dtype::UnsignedPType;
40+
use vortex_dtype::match_each_unsigned_integer_ptype;
4141

4242
use crate::take::slice::take_scalar;
4343

4444
/// Takes the specified indices into a new [`Buffer`] using AVX2 SIMD.
4545
///
46-
/// This returns None if the AVX2 feature is not detected at runtime, signalling to the caller
47-
/// that it should fall back to the scalar implementation.
48-
///
49-
/// If AVX2 is available, this returns a PrimitiveArray containing the result of the take operation
50-
/// accelerated using AVX2 instructions.
46+
/// This function handles the type matching required to satisfy AVX2 gather instruction requirements
47+
/// by casting to unsigned integers of the same size. Falls back to scalar implementation for
48+
/// unsupported type sizes.
5149
///
5250
/// # Panics
5351
///
54-
/// This function panics if any of the provided `indices` are out of bounds for `values`
52+
/// This function panics if any of the provided `indices` are out of bounds for `values`.
5553
///
5654
/// # Safety
5755
///
5856
/// The caller must ensure the `avx2` feature is enabled.
59-
#[allow(dead_code, unused_variables, reason = "TODO(connor): Implement this")]
6057
#[target_feature(enable = "avx2")]
6158
#[inline]
62-
pub unsafe fn take_avx2<V: NativePType, I: UnsignedPType>(
63-
buffer: &[V],
64-
indices: &[I],
65-
) -> Buffer<V> {
66-
macro_rules! dispatch_avx2 {
67-
($indices:ty, $values:ty) => {
68-
{ let result = dispatch_avx2!($indices, $values, cast: $values); result }
69-
};
70-
($indices:ty, $values:ty, cast: $cast:ty) => {{
71-
let indices = unsafe { std::mem::transmute::<&[I], &[$indices]>(indices) };
72-
let values = unsafe { std::mem::transmute::<&[V], &[$cast]>(buffer) };
73-
74-
let result = exec_take::<$cast, $indices, AVX2Gather>(values, indices);
75-
result.cast_into::<V>()
76-
}};
77-
}
78-
79-
match (I::PTYPE, V::PTYPE) {
80-
// Int value types. Only 32 and 64 bit types are supported.
81-
(PType::U8, PType::I32) => dispatch_avx2!(u8, i32),
82-
(PType::U8, PType::U32) => dispatch_avx2!(u8, u32),
83-
(PType::U8, PType::I64) => dispatch_avx2!(u8, i64),
84-
(PType::U8, PType::U64) => dispatch_avx2!(u8, u64),
85-
(PType::U16, PType::I32) => dispatch_avx2!(u16, i32),
86-
(PType::U16, PType::U32) => dispatch_avx2!(u16, u32),
87-
(PType::U16, PType::I64) => dispatch_avx2!(u16, i64),
88-
(PType::U16, PType::U64) => dispatch_avx2!(u16, u64),
89-
(PType::U32, PType::I32) => dispatch_avx2!(u32, i32),
90-
(PType::U32, PType::U32) => dispatch_avx2!(u32, u32),
91-
(PType::U32, PType::I64) => dispatch_avx2!(u32, i64),
92-
(PType::U32, PType::U64) => dispatch_avx2!(u32, u64),
93-
94-
// Float value types, treat them as if they were corresponding int types.
95-
(PType::U8, PType::F32) => dispatch_avx2!(u8, f32, cast: u32),
96-
(PType::U16, PType::F32) => dispatch_avx2!(u16, f32, cast: u32),
97-
(PType::U32, PType::F32) => dispatch_avx2!(u32, f32, cast: u32),
98-
(PType::U64, PType::F32) => dispatch_avx2!(u64, f32, cast: u32),
99-
100-
(PType::U8, PType::F64) => dispatch_avx2!(u8, f64, cast: u64),
101-
(PType::U16, PType::F64) => dispatch_avx2!(u16, f64, cast: u64),
102-
(PType::U32, PType::F64) => dispatch_avx2!(u32, f64, cast: u64),
103-
(PType::U64, PType::F64) => dispatch_avx2!(u64, f64, cast: u64),
104-
105-
// Scalar fallback for unsupported value types.
106-
_ => {
107-
tracing::trace!(
108-
"take AVX2 kernel missing for indices {} values {}, falling back to scalar",
109-
I::PTYPE,
110-
V::PTYPE
111-
);
112-
113-
take_scalar(buffer, indices)
59+
pub unsafe fn take_avx2<V: Copy, I: UnsignedPType>(buffer: &[V], indices: &[I]) -> Buffer<V> {
60+
// AVX2 gather operations only care about bit patterns, not semantic type. We cast to unsigned
61+
// integers which have the required gather implementations and then cast back.
62+
//
63+
// SAFETY: The pointer casts below are safe because:
64+
// - `V` and the target type have the same size (matched by `size_of::<V>()`)
65+
// - The alignment of unsigned integers is always <= their size, and `buffer` came from a valid
66+
// `&[V]` which guarantees proper alignment for types of the same size.
67+
match size_of::<V>() {
68+
4 => {
69+
let values: &[u32] =
70+
unsafe { std::slice::from_raw_parts(buffer.as_ptr().cast::<u32>(), buffer.len()) };
71+
match_each_unsigned_integer_ptype!(I::PTYPE, |IC| {
72+
let indices: &[IC] = unsafe {
73+
std::slice::from_raw_parts(indices.as_ptr().cast::<IC>(), indices.len())
74+
};
75+
exec_take::<u32, IC, AVX2Gather>(values, indices).cast_into::<V>()
76+
})
77+
}
78+
8 => {
79+
let values: &[u64] =
80+
unsafe { std::slice::from_raw_parts(buffer.as_ptr().cast::<u64>(), buffer.len()) };
81+
match_each_unsigned_integer_ptype!(I::PTYPE, |IC| {
82+
let indices: &[IC] = unsafe {
83+
std::slice::from_raw_parts(indices.as_ptr().cast::<IC>(), indices.len())
84+
};
85+
exec_take::<u64, IC, AVX2Gather>(values, indices).cast_into::<V>()
86+
})
11487
}
88+
// Fall back to scalar implementation for unsupported type sizes (1, 2 byte types).
89+
_ => take_scalar(buffer, indices),
11590
}
11691
}
11792

@@ -182,9 +157,9 @@ macro_rules! impl_gather {
182157
};
183158
}
184159

185-
// kernels for u8 indices
160+
// Kernels for u8 indices.
186161
impl_gather!(u8,
187-
// 32-bit values, loaded 8 at a time
162+
// 32-bit values, loaded 8 at a time.
188163
{ u32 =>
189164
load: _mm_loadu_si128,
190165
extend: _mm256_cvtepu8_epi32,
@@ -196,19 +171,7 @@ impl_gather!(u8,
196171
store: _mm256_storeu_si256,
197172
WIDTH = 8, STRIDE = 16
198173
},
199-
{ i32 =>
200-
load: _mm_loadu_si128,
201-
extend: _mm256_cvtepu8_epi32,
202-
splat: _mm256_set1_epi32,
203-
zero_vec: _mm256_setzero_si256,
204-
mask_indices: _mm256_cmpgt_epi32,
205-
mask_cvt: |x| { x },
206-
gather: _mm256_mask_i32gather_epi32,
207-
store: _mm256_storeu_si256,
208-
WIDTH = 8, STRIDE = 16
209-
},
210-
211-
// 64-bit values, loaded 4 at a time
174+
// 64-bit values, loaded 4 at a time.
212175
{ u64 =>
213176
load: _mm_loadu_si128,
214177
extend: _mm256_cvtepu8_epi64,
@@ -219,23 +182,12 @@ impl_gather!(u8,
219182
gather: _mm256_mask_i64gather_epi64,
220183
store: _mm256_storeu_si256,
221184
WIDTH = 4, STRIDE = 16
222-
},
223-
{ i64 =>
224-
load: _mm_loadu_si128,
225-
extend: _mm256_cvtepu8_epi64,
226-
splat: _mm256_set1_epi64x,
227-
zero_vec: _mm256_setzero_si256,
228-
mask_indices: _mm256_cmpgt_epi64,
229-
mask_cvt: |x| { x },
230-
gather: _mm256_mask_i64gather_epi64,
231-
store: _mm256_storeu_si256,
232-
WIDTH = 4, STRIDE = 16
233185
}
234186
);
235187

236-
// kernels for u16 indices
188+
// Kernels for u16 indices.
237189
impl_gather!(u16,
238-
// 32-bit values. 8x indices loaded at a time and 8x values written at a time
190+
// 32-bit values. 8x indices loaded at a time and 8x values written at a time.
239191
{ u32 =>
240192
load: _mm_loadu_si128,
241193
extend: _mm256_cvtepu16_epi32,
@@ -247,18 +199,6 @@ impl_gather!(u16,
247199
store: _mm256_storeu_si256,
248200
WIDTH = 8, STRIDE = 8
249201
},
250-
{ i32 =>
251-
load: _mm_loadu_si128,
252-
extend: _mm256_cvtepu16_epi32,
253-
splat: _mm256_set1_epi32,
254-
zero_vec: _mm256_setzero_si256,
255-
mask_indices: _mm256_cmpgt_epi32,
256-
mask_cvt: |x| { x },
257-
gather: _mm256_mask_i32gather_epi32,
258-
store: _mm256_storeu_si256,
259-
WIDTH = 8, STRIDE = 8
260-
},
261-
262202
// 64-bit values. 8x indices loaded at a time and 4x values loaded at a time.
263203
{ u64 =>
264204
load: _mm_loadu_si128,
@@ -270,23 +210,12 @@ impl_gather!(u16,
270210
gather: _mm256_mask_i64gather_epi64,
271211
store: _mm256_storeu_si256,
272212
WIDTH = 4, STRIDE = 8
273-
},
274-
{ i64 =>
275-
load: _mm_loadu_si128,
276-
extend: _mm256_cvtepu16_epi64,
277-
splat: _mm256_set1_epi64x,
278-
zero_vec: _mm256_setzero_si256,
279-
mask_indices: _mm256_cmpgt_epi64,
280-
mask_cvt: |x| { x },
281-
gather: _mm256_mask_i64gather_epi64,
282-
store: _mm256_storeu_si256,
283-
WIDTH = 4, STRIDE = 8
284213
}
285214
);
286215

287-
// kernels for u32 indices
216+
// Kernels for u32 indices.
288217
impl_gather!(u32,
289-
// 32-bit values. 8x indices loaded at a time and 8x values written
218+
// 32-bit values. 8x indices loaded at a time and 8x values written.
290219
{ u32 =>
291220
load: _mm256_loadu_si256,
292221
extend: identity,
@@ -298,19 +227,7 @@ impl_gather!(u32,
298227
store: _mm256_storeu_si256,
299228
WIDTH = 8, STRIDE = 8
300229
},
301-
{ i32 =>
302-
load: _mm256_loadu_si256,
303-
extend: identity,
304-
splat: _mm256_set1_epi32,
305-
zero_vec: _mm256_setzero_si256,
306-
mask_indices: _mm256_cmpgt_epi32,
307-
mask_cvt: |x| { x },
308-
gather: _mm256_mask_i32gather_epi32,
309-
store: _mm256_storeu_si256,
310-
WIDTH = 8, STRIDE = 8
311-
},
312-
313-
// 64-bit values
230+
// 64-bit values.
314231
{ u64 =>
315232
load: _mm_loadu_si128,
316233
extend: _mm256_cvtepu32_epi64,
@@ -321,22 +238,12 @@ impl_gather!(u32,
321238
gather: _mm256_mask_i64gather_epi64,
322239
store: _mm256_storeu_si256,
323240
WIDTH = 4, STRIDE = 4
324-
},
325-
{ i64 =>
326-
load: _mm_loadu_si128,
327-
extend: _mm256_cvtepu32_epi64,
328-
splat: _mm256_set1_epi64x,
329-
zero_vec: _mm256_setzero_si256,
330-
mask_indices: _mm256_cmpgt_epi64,
331-
mask_cvt: |x| { x },
332-
gather: _mm256_mask_i64gather_epi64,
333-
store: _mm256_storeu_si256,
334-
WIDTH = 4, STRIDE = 4
335241
}
336242
);
337243

338-
// kernels for u64 indices
244+
// Kernels for u64 indices.
339245
impl_gather!(u64,
246+
// 32-bit values.
340247
{ u32 =>
341248
load: _mm256_loadu_si256,
342249
extend: identity,
@@ -356,27 +263,7 @@ impl_gather!(u64,
356263
store: _mm_storeu_si128,
357264
WIDTH = 4, STRIDE = 4
358265
},
359-
{ i32 =>
360-
load: _mm256_loadu_si256,
361-
extend: identity,
362-
splat: _mm256_set1_epi64x,
363-
zero_vec: _mm_setzero_si128,
364-
mask_indices: _mm256_cmpgt_epi64,
365-
mask_cvt: |m| {
366-
unsafe {
367-
let lo_bits = _mm256_extracti128_si256::<0>(m); // lower half
368-
let hi_bits = _mm256_extracti128_si256::<1>(m); // upper half
369-
let lo_packed = _mm_shuffle_epi32::<0b01_01_01_01>(lo_bits);
370-
let hi_packed = _mm_shuffle_epi32::<0b01_01_01_01>(hi_bits);
371-
_mm_unpacklo_epi64(lo_packed, hi_packed)
372-
}
373-
},
374-
gather: _mm256_mask_i64gather_epi32,
375-
store: _mm_storeu_si128,
376-
WIDTH = 4, STRIDE = 4
377-
},
378-
379-
// 64-bit values
266+
// 64-bit values.
380267
{ u64 =>
381268
load: _mm256_loadu_si256,
382269
extend: identity,
@@ -387,17 +274,6 @@ impl_gather!(u64,
387274
gather: _mm256_mask_i64gather_epi64,
388275
store: _mm256_storeu_si256,
389276
WIDTH = 4, STRIDE = 4
390-
},
391-
{ i64 =>
392-
load: _mm256_loadu_si256,
393-
extend: identity,
394-
splat: _mm256_set1_epi64x,
395-
zero_vec: _mm256_setzero_si256,
396-
mask_indices: _mm256_cmpgt_epi64,
397-
mask_cvt: |x| { x },
398-
gather: _mm256_mask_i64gather_epi64,
399-
store: _mm256_storeu_si256,
400-
WIDTH = 4, STRIDE = 4
401277
}
402278
);
403279

vortex-compute/src/take/slice/mod.rs

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,6 @@ impl<T: Copy, I: UnsignedPType> Take<[I]> for &[T] {
2121
return portable::take_portable(self, indices);
2222
}
2323

24-
// TODO(connor): Make the SIMD implementations bound by `Copy` instead of `NativePType`.
25-
26-
/*
27-
2824
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
2925
{
3026
if is_x86_feature_detected!("avx2") {
@@ -33,8 +29,6 @@ impl<T: Copy, I: UnsignedPType> Take<[I]> for &[T] {
3329
}
3430
}
3531

36-
*/
37-
3832
#[allow(unreachable_code, reason = "`vortex_nightly` path returns early")]
3933
take_scalar(self, indices)
4034
}

0 commit comments

Comments
 (0)