@@ -15,10 +15,7 @@ use multiversion::multiversion;
1515use vortex_buffer:: Alignment ;
1616use vortex_buffer:: Buffer ;
1717use vortex_buffer:: BufferMut ;
18- use vortex_dtype:: NativePType ;
19- use vortex_dtype:: PType ;
2018use vortex_dtype:: UnsignedPType ;
21- use vortex_dtype:: match_each_native_simd_ptype;
2219use vortex_dtype:: match_each_unsigned_integer_ptype;
2320
2421/// SIMD types larger than the SIMD register size are beneficial for
@@ -27,38 +24,49 @@ pub const SIMD_WIDTH: usize = 64;
2724
2825/// Takes the specified indices into a new [`Buffer`] using portable SIMD.
2926///
30- /// This function handles the type matching required to satisfy `SimdElement` bounds.
31- /// For `f16` values, it reinterprets them as `u16` since `f16` doesn't implement `SimdElement`.
27+ /// This function handles the type matching required to satisfy `SimdElement` bounds by casting
28+ /// to unsigned integers of the same size. Falls back to scalar implementation for unsupported
29+ /// type sizes.
3230#[ inline]
33- pub fn take_portable < T : NativePType , I : UnsignedPType > ( buffer : & [ T ] , indices : & [ I ] ) -> Buffer < T > {
34- if T :: PTYPE == PType :: F16 {
35- assert_eq ! ( size_of:: <half:: f16>( ) , size_of:: <T >( ) ) ;
36-
37- // Since Rust does not actually support 16-bit floats, we first reinterpret the data as
38- // `u16` integers.
39- // SAFETY: We know that f16 has the same bit pattern as u16, so this transmute is fine to
40- // make.
41- let u16_slice: & [ u16 ] =
42- unsafe { std:: slice:: from_raw_parts ( buffer. as_ptr ( ) as * const u16 , buffer. len ( ) ) } ;
43- return take_with_indices ( u16_slice, indices) . cast_into :: < T > ( ) ;
31+ pub fn take_portable < T : Copy , I : UnsignedPType > ( buffer : & [ T ] , indices : & [ I ] ) -> Buffer < T > {
32+ // SIMD gather operations only care about bit patterns, not semantic type. We cast to unsigned
33+ // integers which implement `SimdElement` and then cast back.
34+ //
35+ // SAFETY: The pointer casts below are safe because:
36+ // - `T` and the target type have the same size (matched by `size_of::<T>()`).
37+ // - The alignment of unsigned integers is always <= their size, and `buffer` came from a valid
38+ // `&[T]` which guarantees proper alignment for types of the same size.
39+ match size_of :: < T > ( ) {
40+ 1 => {
41+ let buffer: & [ u8 ] =
42+ unsafe { std:: slice:: from_raw_parts ( buffer. as_ptr ( ) as * const u8 , buffer. len ( ) ) } ;
43+ take_with_indices ( buffer, indices) . cast_into :: < T > ( )
44+ }
45+ 2 => {
46+ let buffer: & [ u16 ] =
47+ unsafe { std:: slice:: from_raw_parts ( buffer. as_ptr ( ) as * const u16 , buffer. len ( ) ) } ;
48+ take_with_indices ( buffer, indices) . cast_into :: < T > ( )
49+ }
50+ 4 => {
51+ let buffer: & [ u32 ] =
52+ unsafe { std:: slice:: from_raw_parts ( buffer. as_ptr ( ) as * const u32 , buffer. len ( ) ) } ;
53+ take_with_indices ( buffer, indices) . cast_into :: < T > ( )
54+ }
55+ 8 => {
56+ let buffer: & [ u64 ] =
57+ unsafe { std:: slice:: from_raw_parts ( buffer. as_ptr ( ) as * const u64 , buffer. len ( ) ) } ;
58+ take_with_indices ( buffer, indices) . cast_into :: < T > ( )
59+ }
60+ // Fall back to scalar implementation for unsupported type sizes.
61+ _ => super :: take_scalar ( buffer, indices) ,
4462 }
45-
46- match_each_native_simd_ptype ! ( T :: PTYPE , |TC | {
47- assert_eq!( size_of:: <TC >( ) , size_of:: <T >( ) ) ;
48-
49- // SAFETY: This is essentially a no-op that tricks the compiler into adding the
50- // `simd::SimdElement` bound we need to call `take_with_indices`.
51- let buffer: & [ TC ] =
52- unsafe { std:: slice:: from_raw_parts( buffer. as_ptr( ) as * const TC , buffer. len( ) ) } ;
53- take_with_indices( buffer, indices) . cast_into:: <T >( )
54- } )
5563}
5664
5765/// Helper that matches on index type and calls `take_portable_simd`.
5866///
5967/// We separate this code out from above to add the [`simd::SimdElement`] constraint.
6068#[ inline]
61- fn take_with_indices < T : NativePType + simd:: SimdElement , I : UnsignedPType > (
69+ fn take_with_indices < T : Copy + Default + simd:: SimdElement , I : UnsignedPType > (
6270 buffer : & [ T ] ,
6371 indices : & [ I ] ,
6472) -> Buffer < T > {
@@ -78,7 +86,7 @@ fn take_with_indices<T: NativePType + simd::SimdElement, I: UnsignedPType>(
7886#[ multiversion( targets( "x86_64+avx2" , "x86_64+avx" , "aarch64+neon" ) ) ]
7987pub fn take_portable_simd < T , I , const LANE_COUNT : usize > ( values : & [ T ] , indices : & [ I ] ) -> Buffer < T >
8088where
81- T : NativePType + simd:: SimdElement ,
89+ T : Copy + Default + simd:: SimdElement ,
8290 I : UnsignedPType + simd:: SimdElement ,
8391 simd:: LaneCount < LANE_COUNT > : simd:: SupportedLaneCount ,
8492 simd:: Simd < I , LANE_COUNT > : SimdUint < Cast < usize > = simd:: Simd < usize , LANE_COUNT > > ,
@@ -128,6 +136,7 @@ where
128136}
129137
130138#[ cfg( test) ]
139+ #[ allow( clippy:: cast_possible_truncation) ]
131140mod tests {
132141 use super :: take_portable_simd;
133142
@@ -159,7 +168,7 @@ mod tests {
159168 // Strided by 4: 0, 4, 8, ..., 252.
160169 indices. extend ( ( 0u32 ..64 ) . map ( |i| i * 4 ) ) ;
161170 // Repeated: index 42 repeated 32 times.
162- indices. extend ( std:: iter:: repeat ( 42u32 ) . take ( 32 ) ) ;
171+ indices. extend ( std:: iter:: repeat_n ( 42u32 , 32 ) ) ;
163172 // Reverse: 255, 254, ..., 216.
164173 indices. extend ( ( 216u32 ..256 ) . rev ( ) ) ;
165174
0 commit comments