Skip to content

Commit adfc126

Browse files
committed
touch ups
Signed-off-by: Andrew Duffy <andrew@a10y.dev>
1 parent 564df56 commit adfc126

File tree

1 file changed

+30
-21
lines changed
  • vortex-array/src/arrays/primitive/compute/take

1 file changed

+30
-21
lines changed

vortex-array/src/arrays/primitive/compute/take/x86.rs

Lines changed: 30 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ use std::arch::x86_64::*;
55

66
use num_traits::AsPrimitive;
77
use vortex_buffer::{Alignment, Buffer, BufferMut};
8-
use vortex_dtype::{NativePType, Nullability};
8+
use vortex_dtype::{NativePType, Nullability, PType};
99

1010
use crate::arrays::primitive::PrimitiveArray;
1111

@@ -20,7 +20,13 @@ pub fn is_avx2_available() -> bool {
2020
false
2121
}
2222

23-
/// AVX2-optimized take operation dispatch
23+
/// AVX2-optimized take operation dispatch.
24+
///
25+
/// This returns None if the AVX2 feature is not detected at runtime, signalling to the caller
26+
/// that it should fall back to the scalar implementation.
27+
///
28+
/// If AVX2 is available, this returns a PrimitiveArray containing the result of the take operation
29+
/// accelerated using AVX2 instructions.
2430
#[cfg(target_arch = "x86_64")]
2531
pub fn take_primitive_avx2<I, V>(
2632
indices: &[I],
@@ -36,42 +42,42 @@ where
3642
}
3743

3844
// Dispatch to type-specific implementations
39-
match (std::any::TypeId::of::<I>(), std::any::TypeId::of::<V>()) {
45+
match (I::PTYPE, V::PTYPE) {
4046
// u32 indices, i32 values
41-
(i, v) if i == std::any::TypeId::of::<u32>() && v == std::any::TypeId::of::<i32>() => {
47+
(PType::U32, PType::I32) => {
4248
let indices = unsafe { std::mem::transmute::<&[I], &[u32]>(indices) };
4349
let values = unsafe { std::mem::transmute::<&[V], &[i32]>(values) };
44-
let result = unsafe { take_i32_u32_avx2(indices, values) };
50+
let result = unsafe { take_u32_i32_avx2(indices, values) };
4551
Some(PrimitiveArray::new(
4652
unsafe { std::mem::transmute::<Buffer<i32>, Buffer<V>>(result) },
4753
nullability.into(),
4854
))
4955
}
5056
// u32 indices, f32 values
51-
(i, v) if i == std::any::TypeId::of::<u32>() && v == std::any::TypeId::of::<f32>() => {
57+
(PType::U32, PType::F32) => {
5258
let indices = unsafe { std::mem::transmute::<&[I], &[u32]>(indices) };
5359
let values = unsafe { std::mem::transmute::<&[V], &[f32]>(values) };
54-
let result = unsafe { take_f32_u32_avx2(indices, values) };
60+
let result = unsafe { take_u32_f32_avx2(indices, values) };
5561
Some(PrimitiveArray::new(
5662
unsafe { std::mem::transmute::<Buffer<f32>, Buffer<V>>(result) },
5763
nullability.into(),
5864
))
5965
}
6066
// u64 indices, i64 values
61-
(i, v) if i == std::any::TypeId::of::<u64>() && v == std::any::TypeId::of::<i64>() => {
67+
(PType::U64, PType::I64) => {
6268
let indices = unsafe { std::mem::transmute::<&[I], &[u64]>(indices) };
6369
let values = unsafe { std::mem::transmute::<&[V], &[i64]>(values) };
64-
let result = unsafe { take_i64_u64_avx2(indices, values) };
70+
let result = unsafe { take_u64_i64_avx2(indices, values) };
6571
Some(PrimitiveArray::new(
6672
unsafe { std::mem::transmute::<Buffer<i64>, Buffer<V>>(result) },
6773
nullability.into(),
6874
))
6975
}
7076
// u64 indices, f64 values
71-
(i, v) if i == std::any::TypeId::of::<u64>() && v == std::any::TypeId::of::<f64>() => {
77+
(PType::U64, PType::F64) => {
7278
let indices = unsafe { std::mem::transmute::<&[I], &[u64]>(indices) };
7379
let values = unsafe { std::mem::transmute::<&[V], &[f64]>(values) };
74-
let result = unsafe { take_f64_u64_avx2(indices, values) };
80+
let result = unsafe { take_u64_f64_avx2(indices, values) };
7581
Some(PrimitiveArray::new(
7682
unsafe { std::mem::transmute::<Buffer<f64>, Buffer<V>>(result) },
7783
nullability.into(),
@@ -95,26 +101,30 @@ where
95101
}
96102

97103
/// AVX2 implementation for i32 values with u32 indices
104+
///
105+
/// # Safety:
106+
///
107+
/// Caller must ensure that all of the indices point to valid elements in the values array.
108+
/// Failure to do so will result in potentially accessing out of bounds memory.
98109
#[cfg(target_arch = "x86_64")]
99110
#[target_feature(enable = "avx2")]
100-
unsafe fn take_i32_u32_avx2(indices: &[u32], values: &[i32]) -> Buffer<i32> {
111+
unsafe fn take_u32_i32_avx2(indices: &[u32], values: &[i32]) -> Buffer<i32> {
101112
const SIMD_WIDTH: usize = 8; // 256 bits / 32 bits per element
102113
let indices_len = indices.len();
103114

104115
let mut buffer =
105116
BufferMut::<i32>::with_capacity_aligned(indices_len, Alignment::of::<__m256i>());
106117

107-
let output_ptr = buffer.spare_capacity_mut().as_mut_ptr() as *mut i32;
118+
let output_ptr: *mut i32 = buffer.spare_capacity_mut().as_mut_ptr().cast();
108119
let values_ptr = values.as_ptr();
109120

110121
// Process chunks of 8 elements
111122
let chunks = indices_len / SIMD_WIDTH;
112123
for chunk_idx in 0..chunks {
113124
let offset = chunk_idx * SIMD_WIDTH;
114125

115-
// Load 8 u32 indices
116-
let indices_vec =
117-
unsafe { _mm256_loadu_si256(indices.as_ptr().add(offset) as *const __m256i) };
126+
// Load the next 8 indices into a vector
127+
let indices_vec = unsafe { _mm256_loadu_si256(indices.as_ptr().add(offset).cast()) };
118128

119129
// Gather 8 i32 values using the indices
120130
// Scale of 4 because i32 is 4 bytes
@@ -137,14 +147,14 @@ unsafe fn take_i32_u32_avx2(indices: &[u32], values: &[i32]) -> Buffer<i32> {
137147
/// AVX2 implementation for f32 values with u32 indices
138148
#[cfg(target_arch = "x86_64")]
139149
#[target_feature(enable = "avx2")]
140-
unsafe fn take_f32_u32_avx2(indices: &[u32], values: &[f32]) -> Buffer<f32> {
150+
unsafe fn take_u32_f32_avx2(indices: &[u32], values: &[f32]) -> Buffer<f32> {
141151
const SIMD_WIDTH: usize = 8; // 256 bits / 32 bits per element
142152
let indices_len = indices.len();
143153

144154
let mut buffer =
145155
BufferMut::<f32>::with_capacity_aligned(indices_len, Alignment::of::<__m256>());
146156

147-
let output_ptr = buffer.spare_capacity_mut().as_mut_ptr() as *mut f32;
157+
let output_ptr: *mut f32 = buffer.spare_capacity_mut().as_mut_ptr().cast();
148158
let values_ptr = values.as_ptr();
149159

150160
// Process chunks of 8 elements
@@ -177,7 +187,7 @@ unsafe fn take_f32_u32_avx2(indices: &[u32], values: &[f32]) -> Buffer<f32> {
177187
#[cfg(target_arch = "x86_64")]
178188
#[target_feature(enable = "avx2")]
179189
#[allow(clippy::cast_possible_truncation)]
180-
unsafe fn take_i64_u64_avx2(indices: &[u64], values: &[i64]) -> Buffer<i64> {
190+
unsafe fn take_u64_i64_avx2(indices: &[u64], values: &[i64]) -> Buffer<i64> {
181191
const SIMD_WIDTH: usize = 4; // 256 bits / 64 bits per element
182192
let indices_len = indices.len();
183193

@@ -218,7 +228,7 @@ unsafe fn take_i64_u64_avx2(indices: &[u64], values: &[i64]) -> Buffer<i64> {
218228
#[cfg(target_arch = "x86_64")]
219229
#[target_feature(enable = "avx2")]
220230
#[allow(clippy::cast_possible_truncation)]
221-
unsafe fn take_f64_u64_avx2(indices: &[u64], values: &[f64]) -> Buffer<f64> {
231+
unsafe fn take_u64_f64_avx2(indices: &[u64], values: &[f64]) -> Buffer<f64> {
222232
const SIMD_WIDTH: usize = 4; // 256 bits / 64 bits per element
223233
let indices_len = indices.len();
224234

@@ -257,7 +267,6 @@ unsafe fn take_f64_u64_avx2(indices: &[u64], values: &[f64]) -> Buffer<f64> {
257267

258268
#[cfg(test)]
259269
mod tests {
260-
261270
use super::*;
262271

263272
#[test]

0 commit comments

Comments
 (0)