1+ use std:: mem:: { MaybeUninit , transmute} ;
12use std:: simd;
23
34use num_traits:: AsPrimitive ;
45use simd:: num:: SimdUint ;
56use vortex_buffer:: { Alignment , Buffer , BufferMut } ;
67use vortex_dtype:: {
7- NativePType , Nullability , PType , match_each_integer_ptype, match_each_native_ptype,
8- match_each_native_simd_ptype, match_each_unsigned_integer_ptype,
8+ DType , NativePType , PType , match_each_native_simd_ptype, match_each_unsigned_integer_ptype,
99} ;
10- use vortex_error:: VortexResult ;
10+ use vortex_error:: { VortexResult , vortex_bail } ;
1111
1212use crate :: arrays:: PrimitiveVTable ;
1313use crate :: arrays:: primitive:: PrimitiveArray ;
14- use crate :: compute:: { TakeKernel , TakeKernelAdapter } ;
14+ use crate :: compute:: { TakeKernel , TakeKernelAdapter , cast } ;
1515use crate :: vtable:: ValidityHelper ;
1616use crate :: { Array , ArrayRef , IntoArray , ToCanonical , register_kernel} ;
1717
18+ // SIMD types larger than the SIMD register size are beneficial for
19+ // performance as this leads to better instruction level parallelism.
20+ const SIMD_WIDTH : usize = 64 ;
21+
1822impl TakeKernel for PrimitiveVTable {
1923 #[ allow( clippy:: cognitive_complexity) ]
2024 fn take ( & self , array : & PrimitiveArray , indices : & dyn Array ) -> VortexResult < ArrayRef > {
21- let indices = indices. to_primitive ( ) ?;
22-
23- if array. ptype ( ) != PType :: F16
24- && indices. dtype ( ) . is_unsigned_int ( )
25- && indices. all_valid ( ) ?
26- && array. all_valid ( ) ?
27- {
28- // TODO(alex): handle nullable codes & values
29- match_each_unsigned_integer_ptype ! ( indices. ptype( ) , |C | {
25+ let unsigned_indices = match indices. dtype ( ) {
26+ DType :: Primitive ( p, n) => {
27+ if p. is_unsigned_int ( ) {
28+ indices. to_primitive ( ) ?
29+ } else {
30+ // This will fail if all values cannot be converted to unsigned
31+ cast ( indices, & DType :: Primitive ( p. to_unsigned ( ) , * n) ) ?. to_primitive ( ) ?
32+ }
33+ }
34+ _ => vortex_bail ! ( "Invalid indices dtype: {}" , indices. dtype( ) ) ,
35+ } ;
36+
37+ let validity = array. validity ( ) . take ( unsigned_indices. as_ref ( ) ) ?;
38+ if array. ptype ( ) == PType :: F16 {
39+ // Special handling for f16 to treat as opaque u16
40+ let decoded = match_each_unsigned_integer_ptype ! ( unsigned_indices. ptype( ) , |C | {
41+ take_primitive_simd:: <C , u16 , SIMD_WIDTH >(
42+ unsigned_indices. as_slice( ) ,
43+ array. reinterpret_cast( PType :: U16 ) . as_slice( ) ,
44+ )
45+ } ) ;
46+ Ok ( PrimitiveArray :: new ( decoded, validity)
47+ . reinterpret_cast ( PType :: F16 )
48+ . into_array ( ) )
49+ } else {
50+ match_each_unsigned_integer_ptype ! ( unsigned_indices. ptype( ) , |C | {
3051 match_each_native_simd_ptype!( array. ptype( ) , |V | {
31- // SIMD types larger than the SIMD register size are beneficial for
32- // performance as this leads to better instruction level parallelism.
33- let decoded = take_primitive_simd:: <C , V , 64 >(
34- indices. as_slice( ) ,
52+ let decoded = take_primitive_simd:: <C , V , SIMD_WIDTH >(
53+ unsigned_indices. as_slice( ) ,
3554 array. as_slice( ) ,
36- array. dtype( ) . nullability( ) | indices. dtype( ) . nullability( ) ,
3755 ) ;
38-
39- return Ok ( decoded. into_array( ) ) as VortexResult <ArrayRef >;
56+ Ok ( PrimitiveArray :: new( decoded, validity) . into_array( ) )
4057 } )
41- } ) ;
42- }
43-
44- // TODO(joe): if the true count of take indices validity is low, only take array values with
45- // valid indices.
46- let validity = array. validity ( ) . take ( indices. as_ref ( ) ) ?;
47- match_each_native_ptype ! ( array. ptype( ) , |T | {
48- match_each_integer_ptype!( indices. ptype( ) , |I | {
49- let values = take_primitive( array. as_slice:: <T >( ) , indices. as_slice:: <I >( ) ) ;
50- Ok ( PrimitiveArray :: new( values, validity) . into_array( ) )
5158 } )
52- } )
59+ }
5360 }
5461}
5562
5663register_kernel ! ( TakeKernelAdapter ( PrimitiveVTable ) . lift( ) ) ;
5764
58- fn take_primitive < T : NativePType , I : NativePType + AsPrimitive < usize > > (
59- array : & [ T ] ,
60- indices : & [ I ] ,
61- ) -> Buffer < T > {
62- indices. iter ( ) . map ( |idx| array[ idx. as_ ( ) ] ) . collect ( )
63- }
64-
6565/// Takes elements from an array using SIMD indexing.
6666///
6767/// # Type Parameters
@@ -77,11 +77,7 @@ fn take_primitive<T: NativePType, I: NativePType + AsPrimitive<usize>>(
7777/// # Returns
7878/// A `PrimitiveArray` containing the gathered values where each index has been replaced with
7979/// the corresponding value from the source array.
80- fn take_primitive_simd < I , V , const LANE_COUNT : usize > (
81- indices : & [ I ] ,
82- values : & [ V ] ,
83- nullability : Nullability ,
84- ) -> PrimitiveArray
80+ fn take_primitive_simd < I , V , const LANE_COUNT : usize > ( indices : & [ I ] , values : & [ V ] ) -> Buffer < V >
8581where
8682 I : simd:: SimdElement + AsPrimitive < usize > ,
8783 V : simd:: SimdElement + NativePType ,
@@ -102,15 +98,18 @@ where
10298 let mask = simd:: Mask :: from_bitmask ( u64:: MAX ) ;
10399 let codes_chunk = simd:: Simd :: < I , LANE_COUNT > :: from_slice ( & indices[ offset..] ) ;
104100
101+ let selection = simd:: Simd :: gather_select (
102+ values,
103+ mask,
104+ codes_chunk. cast :: < usize > ( ) ,
105+ simd:: Simd :: < V , LANE_COUNT > :: default ( ) ,
106+ ) ;
107+
105108 unsafe {
106- let selection = simd:: Simd :: gather_select_unchecked (
107- values,
108- mask,
109- codes_chunk. cast :: < usize > ( ) ,
110- simd:: Simd :: < V , LANE_COUNT > :: default ( ) ,
109+ selection. store_select_unchecked (
110+ transmute :: < & mut [ MaybeUninit < V > ] , & mut [ V ] > ( & mut buf_slice[ offset..] [ ..64 ] ) ,
111+ mask. cast ( ) ,
111112 ) ;
112-
113- selection. store_select_ptr ( buf_slice. as_mut_ptr ( ) . add ( offset) as * mut V , mask. cast ( ) ) ;
114113 }
115114 }
116115
@@ -126,15 +125,15 @@ where
126125 buffer. set_len ( indices_len) ;
127126 }
128127
129- PrimitiveArray :: new ( buffer. freeze ( ) , nullability . into ( ) )
128+ buffer. freeze ( )
130129}
131130
132131#[ cfg( test) ]
133132mod test {
134133 use vortex_buffer:: buffer;
135134 use vortex_scalar:: Scalar ;
136135
137- use crate :: arrays:: primitive:: compute:: take:: take_primitive ;
136+ use crate :: arrays:: primitive:: compute:: take:: take_primitive_simd ;
138137 use crate :: arrays:: { BoolArray , PrimitiveArray } ;
139138 use crate :: compute:: take;
140139 use crate :: validity:: Validity ;
@@ -143,7 +142,7 @@ mod test {
143142 #[ test]
144143 fn test_take ( ) {
145144 let a = vec ! [ 1i32 , 2 , 3 , 4 , 5 ] ;
146- let result = take_primitive ( & a , & [ 0 , 0 , 4 , 2 ] ) ;
145+ let result = take_primitive_simd :: < u8 , i32 , 64 > ( & [ 0 , 0 , 4 , 2 ] , & a ) ;
147146 assert_eq ! ( result. as_slice( ) , & [ 1i32 , 1 , 5 , 3 ] ) ;
148147 }
149148
@@ -164,4 +163,13 @@ mod test {
164163 // the third index is null
165164 assert_eq ! ( actual. scalar_at( 2 ) . unwrap( ) , Scalar :: null_typed:: <i32 >( ) ) ;
166165 }
166+
167+ #[ test]
168+ fn test_take_out_of_bounds ( ) {
169+ let indices = vec ! [ 2_000_000u32 ; 64 ] ;
170+ let values = vec ! [ 1i32 ] ;
171+
172+ let result = take_primitive_simd :: < u32 , i32 , 64 > ( & indices, & values) ;
173+ assert_eq ! ( result. as_slice( ) , [ 0i32 ; 64 ] ) ;
174+ }
167175}
0 commit comments