@@ -5,7 +5,7 @@ use std::arch::x86_64::*;
55
66use num_traits:: AsPrimitive ;
77use vortex_buffer:: { Alignment , Buffer , BufferMut } ;
8- use vortex_dtype:: { NativePType , Nullability } ;
8+ use vortex_dtype:: { NativePType , Nullability , PType } ;
99
1010use crate :: arrays:: primitive:: PrimitiveArray ;
1111
@@ -20,7 +20,13 @@ pub fn is_avx2_available() -> bool {
2020 false
2121}
2222
23- /// AVX2-optimized take operation dispatch
23+ /// AVX2-optimized take operation dispatch.
24+ ///
25+ /// This returns None if the AVX2 feature is not detected at runtime, signalling to the caller
26+ /// that it should fall back to the scalar implementation.
27+ ///
28+ /// If AVX2 is available, this returns a PrimitiveArray containing the result of the take operation
29+ /// accelerated using AVX2 instructions.
2430#[ cfg( target_arch = "x86_64" ) ]
2531pub fn take_primitive_avx2 < I , V > (
2632 indices : & [ I ] ,
@@ -36,42 +42,42 @@ where
3642 }
3743
3844 // Dispatch to type-specific implementations
39- match ( std :: any :: TypeId :: of :: < I > ( ) , std :: any :: TypeId :: of :: < V > ( ) ) {
45+ match ( I :: PTYPE , V :: PTYPE ) {
4046 // u32 indices, i32 values
41- ( i , v ) if i == std :: any :: TypeId :: of :: < u32 > ( ) && v == std :: any :: TypeId :: of :: < i32 > ( ) => {
47+ ( PType :: U32 , PType :: I32 ) => {
4248 let indices = unsafe { std:: mem:: transmute :: < & [ I ] , & [ u32 ] > ( indices) } ;
4349 let values = unsafe { std:: mem:: transmute :: < & [ V ] , & [ i32 ] > ( values) } ;
44- let result = unsafe { take_i32_u32_avx2 ( indices, values) } ;
50+ let result = unsafe { take_u32_i32_avx2 ( indices, values) } ;
4551 Some ( PrimitiveArray :: new (
4652 unsafe { std:: mem:: transmute :: < Buffer < i32 > , Buffer < V > > ( result) } ,
4753 nullability. into ( ) ,
4854 ) )
4955 }
5056 // u32 indices, f32 values
51- ( i , v ) if i == std :: any :: TypeId :: of :: < u32 > ( ) && v == std :: any :: TypeId :: of :: < f32 > ( ) => {
57+ ( PType :: U32 , PType :: F32 ) => {
5258 let indices = unsafe { std:: mem:: transmute :: < & [ I ] , & [ u32 ] > ( indices) } ;
5359 let values = unsafe { std:: mem:: transmute :: < & [ V ] , & [ f32 ] > ( values) } ;
54- let result = unsafe { take_f32_u32_avx2 ( indices, values) } ;
60+ let result = unsafe { take_u32_f32_avx2 ( indices, values) } ;
5561 Some ( PrimitiveArray :: new (
5662 unsafe { std:: mem:: transmute :: < Buffer < f32 > , Buffer < V > > ( result) } ,
5763 nullability. into ( ) ,
5864 ) )
5965 }
6066 // u64 indices, i64 values
61- ( i , v ) if i == std :: any :: TypeId :: of :: < u64 > ( ) && v == std :: any :: TypeId :: of :: < i64 > ( ) => {
67+ ( PType :: U64 , PType :: I64 ) => {
6268 let indices = unsafe { std:: mem:: transmute :: < & [ I ] , & [ u64 ] > ( indices) } ;
6369 let values = unsafe { std:: mem:: transmute :: < & [ V ] , & [ i64 ] > ( values) } ;
64- let result = unsafe { take_i64_u64_avx2 ( indices, values) } ;
70+ let result = unsafe { take_u64_i64_avx2 ( indices, values) } ;
6571 Some ( PrimitiveArray :: new (
6672 unsafe { std:: mem:: transmute :: < Buffer < i64 > , Buffer < V > > ( result) } ,
6773 nullability. into ( ) ,
6874 ) )
6975 }
7076 // u64 indices, f64 values
71- ( i , v ) if i == std :: any :: TypeId :: of :: < u64 > ( ) && v == std :: any :: TypeId :: of :: < f64 > ( ) => {
77+ ( PType :: U64 , PType :: F64 ) => {
7278 let indices = unsafe { std:: mem:: transmute :: < & [ I ] , & [ u64 ] > ( indices) } ;
7379 let values = unsafe { std:: mem:: transmute :: < & [ V ] , & [ f64 ] > ( values) } ;
74- let result = unsafe { take_f64_u64_avx2 ( indices, values) } ;
80+ let result = unsafe { take_u64_f64_avx2 ( indices, values) } ;
7581 Some ( PrimitiveArray :: new (
7682 unsafe { std:: mem:: transmute :: < Buffer < f64 > , Buffer < V > > ( result) } ,
7783 nullability. into ( ) ,
@@ -95,26 +101,30 @@ where
95101}
96102
97103/// AVX2 implementation for i32 values with u32 indices
104+ ///
105+ /// # Safety:
106+ ///
107+ /// Caller must ensure that all of the indices point to valid elements in the values array.
108+ /// Failure to do so will result in potentially accessing out of bounds memory.
98109#[ cfg( target_arch = "x86_64" ) ]
99110#[ target_feature( enable = "avx2" ) ]
100- unsafe fn take_i32_u32_avx2 ( indices : & [ u32 ] , values : & [ i32 ] ) -> Buffer < i32 > {
111+ unsafe fn take_u32_i32_avx2 ( indices : & [ u32 ] , values : & [ i32 ] ) -> Buffer < i32 > {
101112 const SIMD_WIDTH : usize = 8 ; // 256 bits / 32 bits per element
102113 let indices_len = indices. len ( ) ;
103114
104115 let mut buffer =
105116 BufferMut :: < i32 > :: with_capacity_aligned ( indices_len, Alignment :: of :: < __m256i > ( ) ) ;
106117
107- let output_ptr = buffer. spare_capacity_mut ( ) . as_mut_ptr ( ) as * mut i32 ;
118+ let output_ptr: * mut i32 = buffer. spare_capacity_mut ( ) . as_mut_ptr ( ) . cast ( ) ;
108119 let values_ptr = values. as_ptr ( ) ;
109120
110121 // Process chunks of 8 elements
111122 let chunks = indices_len / SIMD_WIDTH ;
112123 for chunk_idx in 0 ..chunks {
113124 let offset = chunk_idx * SIMD_WIDTH ;
114125
115- // Load 8 u32 indices
116- let indices_vec =
117- unsafe { _mm256_loadu_si256 ( indices. as_ptr ( ) . add ( offset) as * const __m256i ) } ;
126+ // Load the next 8 indices into a vector
127+ let indices_vec = unsafe { _mm256_loadu_si256 ( indices. as_ptr ( ) . add ( offset) . cast ( ) ) } ;
118128
119129 // Gather 8 i32 values using the indices
120130 // Scale of 4 because i32 is 4 bytes
@@ -137,14 +147,14 @@ unsafe fn take_i32_u32_avx2(indices: &[u32], values: &[i32]) -> Buffer<i32> {
137147/// AVX2 implementation for f32 values with u32 indices
138148#[ cfg( target_arch = "x86_64" ) ]
139149#[ target_feature( enable = "avx2" ) ]
140- unsafe fn take_f32_u32_avx2 ( indices : & [ u32 ] , values : & [ f32 ] ) -> Buffer < f32 > {
150+ unsafe fn take_u32_f32_avx2 ( indices : & [ u32 ] , values : & [ f32 ] ) -> Buffer < f32 > {
141151 const SIMD_WIDTH : usize = 8 ; // 256 bits / 32 bits per element
142152 let indices_len = indices. len ( ) ;
143153
144154 let mut buffer =
145155 BufferMut :: < f32 > :: with_capacity_aligned ( indices_len, Alignment :: of :: < __m256 > ( ) ) ;
146156
147- let output_ptr = buffer. spare_capacity_mut ( ) . as_mut_ptr ( ) as * mut f32 ;
157+ let output_ptr: * mut f32 = buffer. spare_capacity_mut ( ) . as_mut_ptr ( ) . cast ( ) ;
148158 let values_ptr = values. as_ptr ( ) ;
149159
150160 // Process chunks of 8 elements
@@ -177,7 +187,7 @@ unsafe fn take_f32_u32_avx2(indices: &[u32], values: &[f32]) -> Buffer<f32> {
177187#[ cfg( target_arch = "x86_64" ) ]
178188#[ target_feature( enable = "avx2" ) ]
179189#[ allow( clippy:: cast_possible_truncation) ]
180- unsafe fn take_i64_u64_avx2 ( indices : & [ u64 ] , values : & [ i64 ] ) -> Buffer < i64 > {
190+ unsafe fn take_u64_i64_avx2 ( indices : & [ u64 ] , values : & [ i64 ] ) -> Buffer < i64 > {
181191 const SIMD_WIDTH : usize = 4 ; // 256 bits / 64 bits per element
182192 let indices_len = indices. len ( ) ;
183193
@@ -218,7 +228,7 @@ unsafe fn take_i64_u64_avx2(indices: &[u64], values: &[i64]) -> Buffer<i64> {
218228#[ cfg( target_arch = "x86_64" ) ]
219229#[ target_feature( enable = "avx2" ) ]
220230#[ allow( clippy:: cast_possible_truncation) ]
221- unsafe fn take_f64_u64_avx2 ( indices : & [ u64 ] , values : & [ f64 ] ) -> Buffer < f64 > {
231+ unsafe fn take_u64_f64_avx2 ( indices : & [ u64 ] , values : & [ f64 ] ) -> Buffer < f64 > {
222232 const SIMD_WIDTH : usize = 4 ; // 256 bits / 64 bits per element
223233 let indices_len = indices. len ( ) ;
224234
@@ -257,7 +267,6 @@ unsafe fn take_f64_u64_avx2(indices: &[u64], values: &[f64]) -> Buffer<f64> {
257267
258268#[ cfg( test) ]
259269mod tests {
260-
261270 use super :: * ;
262271
263272 #[ test]
0 commit comments