Skip to content

Commit ecc19e2

Browse files
robert3005a10y
andauthored
chore: Always use primitive array simd take (#3647)
Fixed the cases where we determined we can't use a simd take. Also noticed that our simd take was unchecked which is not correct as we didn't validate the indices Signed-off-by: Robert Kruszewski <[email protected]> --------- Signed-off-by: Robert Kruszewski <[email protected]> Signed-off-by: Andrew Duffy <[email protected]> Co-authored-by: Andrew Duffy <[email protected]>
1 parent 1703f88 commit ecc19e2

File tree

1 file changed

+61
-53
lines changed
  • vortex-array/src/arrays/primitive/compute

1 file changed

+61
-53
lines changed
Lines changed: 61 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1,67 +1,67 @@
1+
use std::mem::{MaybeUninit, transmute};
12
use std::simd;
23

34
use num_traits::AsPrimitive;
45
use simd::num::SimdUint;
56
use vortex_buffer::{Alignment, Buffer, BufferMut};
67
use vortex_dtype::{
7-
NativePType, Nullability, PType, match_each_integer_ptype, match_each_native_ptype,
8-
match_each_native_simd_ptype, match_each_unsigned_integer_ptype,
8+
DType, NativePType, PType, match_each_native_simd_ptype, match_each_unsigned_integer_ptype,
99
};
10-
use vortex_error::VortexResult;
10+
use vortex_error::{VortexResult, vortex_bail};
1111

1212
use crate::arrays::PrimitiveVTable;
1313
use crate::arrays::primitive::PrimitiveArray;
14-
use crate::compute::{TakeKernel, TakeKernelAdapter};
14+
use crate::compute::{TakeKernel, TakeKernelAdapter, cast};
1515
use crate::vtable::ValidityHelper;
1616
use crate::{Array, ArrayRef, IntoArray, ToCanonical, register_kernel};
1717

18+
// SIMD types larger than the SIMD register size are beneficial for
19+
// performance as this leads to better instruction level parallelism.
20+
const SIMD_WIDTH: usize = 64;
21+
1822
impl TakeKernel for PrimitiveVTable {
1923
#[allow(clippy::cognitive_complexity)]
2024
fn take(&self, array: &PrimitiveArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
21-
let indices = indices.to_primitive()?;
22-
23-
if array.ptype() != PType::F16
24-
&& indices.dtype().is_unsigned_int()
25-
&& indices.all_valid()?
26-
&& array.all_valid()?
27-
{
28-
// TODO(alex): handle nullable codes & values
29-
match_each_unsigned_integer_ptype!(indices.ptype(), |C| {
25+
let unsigned_indices = match indices.dtype() {
26+
DType::Primitive(p, n) => {
27+
if p.is_unsigned_int() {
28+
indices.to_primitive()?
29+
} else {
30+
// This will fail if all values cannot be converted to unsigned
31+
cast(indices, &DType::Primitive(p.to_unsigned(), *n))?.to_primitive()?
32+
}
33+
}
34+
_ => vortex_bail!("Invalid indices dtype: {}", indices.dtype()),
35+
};
36+
37+
let validity = array.validity().take(unsigned_indices.as_ref())?;
38+
if array.ptype() == PType::F16 {
39+
// Special handling for f16 to treat as opaque u16
40+
let decoded = match_each_unsigned_integer_ptype!(unsigned_indices.ptype(), |C| {
41+
take_primitive_simd::<C, u16, SIMD_WIDTH>(
42+
unsigned_indices.as_slice(),
43+
array.reinterpret_cast(PType::U16).as_slice(),
44+
)
45+
});
46+
Ok(PrimitiveArray::new(decoded, validity)
47+
.reinterpret_cast(PType::F16)
48+
.into_array())
49+
} else {
50+
match_each_unsigned_integer_ptype!(unsigned_indices.ptype(), |C| {
3051
match_each_native_simd_ptype!(array.ptype(), |V| {
31-
// SIMD types larger than the SIMD register size are beneficial for
32-
// performance as this leads to better instruction level parallelism.
33-
let decoded = take_primitive_simd::<C, V, 64>(
34-
indices.as_slice(),
52+
let decoded = take_primitive_simd::<C, V, SIMD_WIDTH>(
53+
unsigned_indices.as_slice(),
3554
array.as_slice(),
36-
array.dtype().nullability() | indices.dtype().nullability(),
3755
);
38-
39-
return Ok(decoded.into_array()) as VortexResult<ArrayRef>;
56+
Ok(PrimitiveArray::new(decoded, validity).into_array())
4057
})
41-
});
42-
}
43-
44-
// TODO(joe): if the true count of take indices validity is low, only take array values with
45-
// valid indices.
46-
let validity = array.validity().take(indices.as_ref())?;
47-
match_each_native_ptype!(array.ptype(), |T| {
48-
match_each_integer_ptype!(indices.ptype(), |I| {
49-
let values = take_primitive(array.as_slice::<T>(), indices.as_slice::<I>());
50-
Ok(PrimitiveArray::new(values, validity).into_array())
5158
})
52-
})
59+
}
5360
}
5461
}
5562

5663
register_kernel!(TakeKernelAdapter(PrimitiveVTable).lift());
5764

58-
fn take_primitive<T: NativePType, I: NativePType + AsPrimitive<usize>>(
59-
array: &[T],
60-
indices: &[I],
61-
) -> Buffer<T> {
62-
indices.iter().map(|idx| array[idx.as_()]).collect()
63-
}
64-
6565
/// Takes elements from an array using SIMD indexing.
6666
///
6767
/// # Type Parameters
@@ -77,11 +77,7 @@ fn take_primitive<T: NativePType, I: NativePType + AsPrimitive<usize>>(
7777
/// # Returns
7878
/// A `PrimitiveArray` containing the gathered values where each index has been replaced with
7979
/// the corresponding value from the source array.
80-
fn take_primitive_simd<I, V, const LANE_COUNT: usize>(
81-
indices: &[I],
82-
values: &[V],
83-
nullability: Nullability,
84-
) -> PrimitiveArray
80+
fn take_primitive_simd<I, V, const LANE_COUNT: usize>(indices: &[I], values: &[V]) -> Buffer<V>
8581
where
8682
I: simd::SimdElement + AsPrimitive<usize>,
8783
V: simd::SimdElement + NativePType,
@@ -102,15 +98,18 @@ where
10298
let mask = simd::Mask::from_bitmask(u64::MAX);
10399
let codes_chunk = simd::Simd::<I, LANE_COUNT>::from_slice(&indices[offset..]);
104100

101+
let selection = simd::Simd::gather_select(
102+
values,
103+
mask,
104+
codes_chunk.cast::<usize>(),
105+
simd::Simd::<V, LANE_COUNT>::default(),
106+
);
107+
105108
unsafe {
106-
let selection = simd::Simd::gather_select_unchecked(
107-
values,
108-
mask,
109-
codes_chunk.cast::<usize>(),
110-
simd::Simd::<V, LANE_COUNT>::default(),
109+
selection.store_select_unchecked(
110+
transmute::<&mut [MaybeUninit<V>], &mut [V]>(&mut buf_slice[offset..][..64]),
111+
mask.cast(),
111112
);
112-
113-
selection.store_select_ptr(buf_slice.as_mut_ptr().add(offset) as *mut V, mask.cast());
114113
}
115114
}
116115

@@ -126,15 +125,15 @@ where
126125
buffer.set_len(indices_len);
127126
}
128127

129-
PrimitiveArray::new(buffer.freeze(), nullability.into())
128+
buffer.freeze()
130129
}
131130

132131
#[cfg(test)]
133132
mod test {
134133
use vortex_buffer::buffer;
135134
use vortex_scalar::Scalar;
136135

137-
use crate::arrays::primitive::compute::take::take_primitive;
136+
use crate::arrays::primitive::compute::take::take_primitive_simd;
138137
use crate::arrays::{BoolArray, PrimitiveArray};
139138
use crate::compute::take;
140139
use crate::validity::Validity;
@@ -143,7 +142,7 @@ mod test {
143142
#[test]
144143
fn test_take() {
145144
let a = vec![1i32, 2, 3, 4, 5];
146-
let result = take_primitive(&a, &[0, 0, 4, 2]);
145+
let result = take_primitive_simd::<u8, i32, 64>(&[0, 0, 4, 2], &a);
147146
assert_eq!(result.as_slice(), &[1i32, 1, 5, 3]);
148147
}
149148

@@ -164,4 +163,13 @@ mod test {
164163
// the third index is null
165164
assert_eq!(actual.scalar_at(2).unwrap(), Scalar::null_typed::<i32>());
166165
}
166+
167+
#[test]
168+
fn test_take_out_of_bounds() {
169+
let indices = vec![2_000_000u32; 64];
170+
let values = vec![1i32];
171+
172+
let result = take_primitive_simd::<u32, i32, 64>(&indices, &values);
173+
assert_eq!(result.as_slice(), [0i32; 64]);
174+
}
167175
}

0 commit comments

Comments
 (0)