Skip to content

Commit c15a6fe

Browse files
committed
bring portable simd take back
Signed-off-by: Connor Tsui <[email protected]>
1 parent 5c5f7d1 commit c15a6fe

File tree

3 files changed

+43
-35
lines changed

3 files changed

+43
-35
lines changed

vortex-compute/src/take/slice/mod.rs

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,15 @@ impl<T: Copy, I: UnsignedPType> Take<[I]> for &[T] {
1616
type Output = Buffer<T>;
1717

1818
fn take(self, indices: &[I]) -> Buffer<T> {
19-
// TODO(connor): Make the SIMD implementations bound by `Copy` instead of `NativePType`.
20-
/*
21-
2219
#[cfg(vortex_nightly)]
2320
{
2421
return portable::take_portable(self, indices);
2522
}
2623

24+
// TODO(connor): Make the SIMD implementations bound by `Copy` instead of `NativePType`.
25+
26+
/*
27+
2728
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
2829
{
2930
if is_x86_feature_detected!("avx2") {
@@ -44,6 +45,6 @@ impl<T: Copy, I: UnsignedPType> Take<[I]> for &[T] {
4445
reason = "Compiler may see this as unused based on enabled features"
4546
)]
4647
#[inline]
47-
fn take_scalar<T: Copy, I: UnsignedPType>(buffer: &[T], indices: &[I]) -> Buffer<T> {
48+
pub(crate) fn take_scalar<T: Copy, I: UnsignedPType>(buffer: &[T], indices: &[I]) -> Buffer<T> {
4849
indices.iter().map(|idx| buffer[idx.as_()]).collect()
4950
}

vortex-compute/src/take/slice/portable.rs

Lines changed: 38 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,7 @@ use multiversion::multiversion;
1515
use vortex_buffer::Alignment;
1616
use vortex_buffer::Buffer;
1717
use vortex_buffer::BufferMut;
18-
use vortex_dtype::NativePType;
19-
use vortex_dtype::PType;
2018
use vortex_dtype::UnsignedPType;
21-
use vortex_dtype::match_each_native_simd_ptype;
2219
use vortex_dtype::match_each_unsigned_integer_ptype;
2320

2421
/// SIMD types larger than the SIMD register size are beneficial for
@@ -27,38 +24,49 @@ pub const SIMD_WIDTH: usize = 64;
2724

2825
/// Takes the specified indices into a new [`Buffer`] using portable SIMD.
2926
///
30-
/// This function handles the type matching required to satisfy `SimdElement` bounds.
31-
/// For `f16` values, it reinterprets them as `u16` since `f16` doesn't implement `SimdElement`.
27+
/// This function handles the type matching required to satisfy `SimdElement` bounds by casting
28+
/// to unsigned integers of the same size. Falls back to scalar implementation for unsupported
29+
/// type sizes.
3230
#[inline]
33-
pub fn take_portable<T: NativePType, I: UnsignedPType>(buffer: &[T], indices: &[I]) -> Buffer<T> {
34-
if T::PTYPE == PType::F16 {
35-
assert_eq!(size_of::<half::f16>(), size_of::<T>());
36-
37-
// Since Rust does not actually support 16-bit floats, we first reinterpret the data as
38-
// `u16` integers.
39-
// SAFETY: We know that f16 has the same bit pattern as u16, so this transmute is fine to
40-
// make.
41-
let u16_slice: &[u16] =
42-
unsafe { std::slice::from_raw_parts(buffer.as_ptr() as *const u16, buffer.len()) };
43-
return take_with_indices(u16_slice, indices).cast_into::<T>();
31+
pub fn take_portable<T: Copy, I: UnsignedPType>(buffer: &[T], indices: &[I]) -> Buffer<T> {
32+
// SIMD gather operations only care about bit patterns, not semantic type. We cast to unsigned
33+
// integers which implement `SimdElement` and then cast back.
34+
//
35+
// SAFETY: The pointer casts below are safe because:
36+
// - `T` and the target type have the same size (matched by `size_of::<T>()`).
37+
// - The alignment of unsigned integers is always <= their size, and `buffer` came from a valid
38+
// `&[T]` which guarantees proper alignment for types of the same size.
39+
match size_of::<T>() {
40+
1 => {
41+
let buffer: &[u8] =
42+
unsafe { std::slice::from_raw_parts(buffer.as_ptr() as *const u8, buffer.len()) };
43+
take_with_indices(buffer, indices).cast_into::<T>()
44+
}
45+
2 => {
46+
let buffer: &[u16] =
47+
unsafe { std::slice::from_raw_parts(buffer.as_ptr() as *const u16, buffer.len()) };
48+
take_with_indices(buffer, indices).cast_into::<T>()
49+
}
50+
4 => {
51+
let buffer: &[u32] =
52+
unsafe { std::slice::from_raw_parts(buffer.as_ptr() as *const u32, buffer.len()) };
53+
take_with_indices(buffer, indices).cast_into::<T>()
54+
}
55+
8 => {
56+
let buffer: &[u64] =
57+
unsafe { std::slice::from_raw_parts(buffer.as_ptr() as *const u64, buffer.len()) };
58+
take_with_indices(buffer, indices).cast_into::<T>()
59+
}
60+
// Fall back to scalar implementation for unsupported type sizes.
61+
_ => super::take_scalar(buffer, indices),
4462
}
45-
46-
match_each_native_simd_ptype!(T::PTYPE, |TC| {
47-
assert_eq!(size_of::<TC>(), size_of::<T>());
48-
49-
// SAFETY: This is essentially a no-op that tricks the compiler into adding the
50-
// `simd::SimdElement` bound we need to call `take_with_indices`.
51-
let buffer: &[TC] =
52-
unsafe { std::slice::from_raw_parts(buffer.as_ptr() as *const TC, buffer.len()) };
53-
take_with_indices(buffer, indices).cast_into::<T>()
54-
})
5563
}
5664

5765
/// Helper that matches on index type and calls `take_portable_simd`.
5866
///
5967
/// We separate this code out from above to add the [`simd::SimdElement`] constraint.
6068
#[inline]
61-
fn take_with_indices<T: NativePType + simd::SimdElement, I: UnsignedPType>(
69+
fn take_with_indices<T: Copy + Default + simd::SimdElement, I: UnsignedPType>(
6270
buffer: &[T],
6371
indices: &[I],
6472
) -> Buffer<T> {
@@ -78,7 +86,7 @@ fn take_with_indices<T: NativePType + simd::SimdElement, I: UnsignedPType>(
7886
#[multiversion(targets("x86_64+avx2", "x86_64+avx", "aarch64+neon"))]
7987
pub fn take_portable_simd<T, I, const LANE_COUNT: usize>(values: &[T], indices: &[I]) -> Buffer<T>
8088
where
81-
T: NativePType + simd::SimdElement,
89+
T: Copy + Default + simd::SimdElement,
8290
I: UnsignedPType + simd::SimdElement,
8391
simd::LaneCount<LANE_COUNT>: simd::SupportedLaneCount,
8492
simd::Simd<I, LANE_COUNT>: SimdUint<Cast<usize> = simd::Simd<usize, LANE_COUNT>>,
@@ -128,6 +136,7 @@ where
128136
}
129137

130138
#[cfg(test)]
139+
#[allow(clippy::cast_possible_truncation)]
131140
mod tests {
132141
use super::take_portable_simd;
133142

@@ -159,7 +168,7 @@ mod tests {
159168
// Strided by 4: 0, 4, 8, ..., 252.
160169
indices.extend((0u32..64).map(|i| i * 4));
161170
// Repeated: index 42 repeated 32 times.
162-
indices.extend(std::iter::repeat(42u32).take(32));
171+
indices.extend(std::iter::repeat_n(42u32, 32));
163172
// Reverse: 255, 254, ..., 216.
164173
indices.extend((216u32..256).rev());
165174

vortex-compute/src/take/vector/tests.rs

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,6 @@ fn test_null_vector_take() {
120120
assert!(result.validity().all_false());
121121
}
122122

123-
#[ignore = "TODO(connor): Implement `DecimalVector::take`."]
124123
#[test]
125124
fn test_dvector_take() {
126125
use vortex_buffer::buffer;
@@ -154,7 +153,6 @@ fn test_dvector_take() {
154153
assert_eq!(validity, vec![true, false, true, false]);
155154
}
156155

157-
#[ignore = "TODO(connor): Implement `DecimalVector::take`."]
158156
#[test]
159157
fn test_decimal_vector_take() {
160158
use vortex_buffer::buffer;

0 commit comments

Comments
 (0)