Skip to content

Commit b5ba195

Browse files
Merge pull request #139 from rust-lang/feat/gather
Add SimdArray trait and safe gather/scatter API (#139) This PR has four parts, without which it doesn't make a lot of sense: - The introduction of the SimdArray trait for abstraction over vectors. - The implementation of private vector-of-pointers types. - Using these to allow constructing vectors with SimdArray::gather_{or, or_default, select}. - Using these to allow writing vectors using SimdArray::scatter{,_select}.
2 parents 3872723 + 1529ed4 commit b5ba195

File tree

5 files changed

+313
-1
lines changed

5 files changed

+313
-1
lines changed

crates/core_simd/src/array.rs

Lines changed: 248 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,248 @@
1+
use crate::intrinsics;
2+
use crate::masks::*;
3+
use crate::vector::ptr::{SimdConstPtr, SimdMutPtr};
4+
use crate::vector::*;
5+
6+
/// A representation of a vector as an "array" with indices, implementing
7+
/// operations applicable to any vector type based solely on "having lanes",
8+
/// and describing relationships between vector and scalar types.
9+
pub trait SimdArray<const LANES: usize>: crate::LanesAtMost32
10+
where
11+
SimdUsize<LANES>: crate::LanesAtMost32,
12+
SimdIsize<LANES>: crate::LanesAtMost32,
13+
MaskSize<LANES>: crate::Mask,
14+
Self: Sized,
15+
{
16+
/// The scalar type in every lane of this vector type.
17+
type Scalar: Copy + Sized;
18+
/// The number of lanes for this vector.
19+
const LANES: usize = LANES;
20+
21+
/// Generates a SIMD vector with the same value in every lane.
22+
#[must_use]
23+
fn splat(val: Self::Scalar) -> Self;
24+
25+
/// SIMD gather: construct a SIMD vector by reading from a slice, using potentially discontiguous indices.
26+
/// If an index is out of bounds, that lane instead selects the value from the "or" vector.
27+
/// ```
28+
/// # use core_simd::*;
29+
/// let vec: Vec<i32> = vec![10, 11, 12, 13, 14, 15, 16, 17, 18];
30+
/// let idxs = SimdUsize::<4>::from_array([9, 3, 0, 5]);
31+
/// let alt = SimdI32::from_array([-5, -4, -3, -2]);
32+
///
33+
/// let result = SimdI32::<4>::gather_or(&vec, idxs, alt); // Note the lane that is out-of-bounds.
34+
/// assert_eq!(result, SimdI32::from_array([-5, 13, 10, 15]));
35+
/// ```
36+
#[must_use]
37+
#[inline]
38+
fn gather_or(slice: &[Self::Scalar], idxs: SimdUsize<LANES>, or: Self) -> Self {
39+
Self::gather_select(slice, MaskSize::splat(true), idxs, or)
40+
}
41+
42+
/// SIMD gather: construct a SIMD vector by reading from a slice, using potentially discontiguous indices.
43+
/// Out-of-bounds indices instead use the default value for that lane (0).
44+
/// ```
45+
/// # use core_simd::*;
46+
/// let vec: Vec<i32> = vec![10, 11, 12, 13, 14, 15, 16, 17, 18];
47+
/// let idxs = SimdUsize::<4>::from_array([9, 3, 0, 5]);
48+
///
49+
/// let result = SimdI32::<4>::gather_or_default(&vec, idxs); // Note the lane that is out-of-bounds.
50+
/// assert_eq!(result, SimdI32::from_array([0, 13, 10, 15]));
51+
/// ```
52+
#[must_use]
53+
#[inline]
54+
fn gather_or_default(slice: &[Self::Scalar], idxs: SimdUsize<LANES>) -> Self
55+
where
56+
Self::Scalar: Default,
57+
{
58+
Self::gather_or(slice, idxs, Self::splat(Self::Scalar::default()))
59+
}
60+
61+
/// SIMD gather: construct a SIMD vector by reading from a slice, using potentially discontiguous indices.
62+
/// Out-of-bounds or masked indices instead select the value from the "or" vector.
63+
/// ```
64+
/// # use core_simd::*;
65+
/// let vec: Vec<i32> = vec![10, 11, 12, 13, 14, 15, 16, 17, 18];
66+
/// let idxs = SimdUsize::<4>::from_array([9, 3, 0, 5]);
67+
/// let alt = SimdI32::from_array([-5, -4, -3, -2]);
68+
/// let mask = MaskSize::from_array([true, true, true, false]); // Note the mask of the last lane.
69+
///
70+
/// let result = SimdI32::<4>::gather_select(&vec, mask, idxs, alt); // Note the lane that is out-of-bounds.
71+
/// assert_eq!(result, SimdI32::from_array([-5, 13, 10, -2]));
72+
/// ```
73+
#[must_use]
74+
#[inline]
75+
fn gather_select(
76+
slice: &[Self::Scalar],
77+
mask: MaskSize<LANES>,
78+
idxs: SimdUsize<LANES>,
79+
or: Self,
80+
) -> Self {
81+
let mask = (mask & idxs.lanes_lt(SimdUsize::splat(slice.len()))).to_int();
82+
let base_ptr = SimdConstPtr::splat(slice.as_ptr());
83+
// Ferris forgive me, I have done pointer arithmetic here.
84+
let ptrs = base_ptr.wrapping_add(idxs);
85+
// SAFETY: The ptrs have been bounds-masked to prevent memory-unsafe reads insha'allah
86+
unsafe { intrinsics::simd_gather(or, ptrs, mask) }
87+
}
88+
89+
/// SIMD scatter: write a SIMD vector's values into a slice, using potentially discontiguous indices.
90+
/// Out-of-bounds indices are not written.
91+
/// `scatter` writes "in order", so if an index receives two writes, only the last is guaranteed.
92+
/// ```
93+
/// # use core_simd::*;
94+
/// let mut vec: Vec<i32> = vec![10, 11, 12, 13, 14, 15, 16, 17, 18];
95+
/// let idxs = SimdUsize::<4>::from_array([9, 3, 0, 0]);
96+
/// let vals = SimdI32::from_array([-27, 82, -41, 124]);
97+
///
98+
/// vals.scatter(&mut vec, idxs); // index 0 receives two writes.
99+
/// assert_eq!(vec, vec![124, 11, 12, 82, 14, 15, 16, 17, 18]);
100+
/// ```
101+
#[inline]
102+
fn scatter(self, slice: &mut [Self::Scalar], idxs: SimdUsize<LANES>) {
103+
self.scatter_select(slice, MaskSize::splat(true), idxs)
104+
}
105+
106+
/// SIMD scatter: write a SIMD vector's values into a slice, using potentially discontiguous indices.
107+
/// Out-of-bounds or masked indices are not written.
108+
/// `scatter_select` writes "in order", so if an index receives two writes, only the last is guaranteed.
109+
/// ```
110+
/// # use core_simd::*;
111+
/// let mut vec: Vec<i32> = vec![10, 11, 12, 13, 14, 15, 16, 17, 18];
112+
/// let idxs = SimdUsize::<4>::from_array([9, 3, 0, 0]);
113+
/// let vals = SimdI32::from_array([-27, 82, -41, 124]);
114+
/// let mask = MaskSize::from_array([true, true, true, false]); // Note the mask of the last lane.
115+
///
116+
/// vals.scatter_select(&mut vec, mask, idxs); // index 0's second write is masked, thus omitted.
117+
/// assert_eq!(vec, vec![-41, 11, 12, 82, 14, 15, 16, 17, 18]);
118+
/// ```
119+
#[inline]
120+
fn scatter_select(
121+
self,
122+
slice: &mut [Self::Scalar],
123+
mask: MaskSize<LANES>,
124+
idxs: SimdUsize<LANES>,
125+
) {
126+
// We must construct our scatter mask before we derive a pointer!
127+
let mask = (mask & idxs.lanes_lt(SimdUsize::splat(slice.len()))).to_int();
128+
// SAFETY: This block works with *mut T derived from &mut 'a [T],
129+
// which means it is delicate in Rust's borrowing model, circa 2021:
130+
// &mut 'a [T] asserts uniqueness, so deriving &'a [T] invalidates live *mut Ts!
131+
// Even though this block is largely safe methods, it must be almost exactly this way
132+
// to prevent invalidating the raw ptrs while they're live.
133+
// Thus, entering this block requires all values to use being already ready:
134+
// 0. idxs we want to write to, which are used to construct the mask.
135+
// 1. mask, which depends on an initial &'a [T] and the idxs.
136+
// 2. actual values to scatter (self).
137+
// 3. &mut [T] which will become our base ptr.
138+
unsafe {
139+
// Now Entering ☢️ *mut T Zone
140+
let base_ptr = SimdMutPtr::splat(slice.as_mut_ptr());
141+
// Ferris forgive me, I have done pointer arithmetic here.
142+
let ptrs = base_ptr.wrapping_add(idxs);
143+
// The ptrs have been bounds-masked to prevent memory-unsafe writes insha'allah
144+
intrinsics::simd_scatter(self, ptrs, mask)
145+
// Cleared ☢️ *mut T Zone
146+
}
147+
}
148+
}
149+
150+
macro_rules! impl_simdarray_for {
151+
($simd:ident {type Scalar = $scalar:ident;}) => {
152+
impl<const LANES: usize> SimdArray<LANES> for $simd<LANES>
153+
where SimdUsize<LANES>: crate::LanesAtMost32,
154+
SimdIsize<LANES>: crate::LanesAtMost32,
155+
MaskSize<LANES>: crate::Mask,
156+
Self: crate::LanesAtMost32,
157+
{
158+
type Scalar = $scalar;
159+
160+
#[must_use]
161+
#[inline]
162+
fn splat(val: Self::Scalar) -> Self {
163+
[val; LANES].into()
164+
}
165+
}
166+
};
167+
168+
($simd:ident $impl:tt) => {
169+
impl<const LANES: usize> SimdArray<LANES> for $simd<LANES>
170+
where SimdUsize<LANES>: crate::LanesAtMost32,
171+
SimdIsize<LANES>: crate::LanesAtMost32,
172+
MaskSize<LANES>: crate::Mask,
173+
Self: crate::LanesAtMost32,
174+
$impl
175+
}
176+
}
177+
178+
impl_simdarray_for! {
179+
SimdUsize {
180+
type Scalar = usize;
181+
}
182+
}
183+
184+
impl_simdarray_for! {
185+
SimdIsize {
186+
type Scalar = isize;
187+
}
188+
}
189+
190+
impl_simdarray_for! {
191+
SimdI8 {
192+
type Scalar = i8;
193+
}
194+
}
195+
196+
impl_simdarray_for! {
197+
SimdI16 {
198+
type Scalar = i16;
199+
}
200+
}
201+
202+
impl_simdarray_for! {
203+
SimdI32 {
204+
type Scalar = i32;
205+
}
206+
}
207+
208+
impl_simdarray_for! {
209+
SimdI64 {
210+
type Scalar = i64;
211+
}
212+
}
213+
214+
impl_simdarray_for! {
215+
SimdU8 {
216+
type Scalar = u8;
217+
}
218+
}
219+
220+
impl_simdarray_for! {
221+
SimdU16 {
222+
type Scalar = u16;
223+
}
224+
}
225+
226+
impl_simdarray_for! {
227+
SimdU32 {
228+
type Scalar = u32;
229+
}
230+
}
231+
232+
impl_simdarray_for! {
233+
SimdU64 {
234+
type Scalar = u64;
235+
}
236+
}
237+
238+
impl_simdarray_for! {
239+
SimdF32 {
240+
type Scalar = f32;
241+
}
242+
}
243+
244+
impl_simdarray_for! {
245+
SimdF64 {
246+
type Scalar = f64;
247+
}
248+
}

crates/core_simd/src/intrinsics.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ extern "platform-intrinsic" {
4545

4646
/// fabs
4747
pub(crate) fn simd_fabs<T>(x: T) -> T;
48-
48+
4949
/// fsqrt
5050
pub(crate) fn simd_fsqrt<T>(x: T) -> T;
5151

@@ -66,6 +66,9 @@ extern "platform-intrinsic" {
6666
pub(crate) fn simd_shuffle16<T, U>(x: T, y: T, idx: [u32; 16]) -> U;
6767
pub(crate) fn simd_shuffle32<T, U>(x: T, y: T, idx: [u32; 32]) -> U;
6868

69+
pub(crate) fn simd_gather<T, U, V>(val: T, ptr: U, mask: V) -> T;
70+
pub(crate) fn simd_scatter<T, U, V>(val: T, ptr: U, mask: V);
71+
6972
// {s,u}add.sat
7073
pub(crate) fn simd_saturating_add<T>(x: T, y: T) -> T;
7174

crates/core_simd/src/lib.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,3 +36,6 @@ pub use masks::*;
3636

3737
mod vector;
3838
pub use vector::*;
39+
40+
mod array;
41+
pub use array::SimdArray;

crates/core_simd/src/vector.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,6 @@ mod uint;
55
pub use float::*;
66
pub use int::*;
77
pub use uint::*;
8+
9+
// Vectors of pointers are not for public use at the current time.
10+
pub(crate) mod ptr;

crates/core_simd/src/vector/ptr.rs

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
//! Private implementation details of public gather/scatter APIs.
2+
use crate::SimdUsize;
3+
use core::mem;
4+
5+
/// A vector of *const T.
6+
#[derive(Debug, Copy, Clone)]
7+
#[repr(simd)]
8+
pub(crate) struct SimdConstPtr<T, const LANES: usize>([*const T; LANES]);
9+
10+
impl<T, const LANES: usize> SimdConstPtr<T, LANES>
11+
where
12+
SimdUsize<LANES>: crate::LanesAtMost32,
13+
T: Sized,
14+
{
15+
#[inline]
16+
#[must_use]
17+
pub fn splat(ptr: *const T) -> Self {
18+
Self([ptr; LANES])
19+
}
20+
21+
#[inline]
22+
#[must_use]
23+
pub fn wrapping_add(self, addend: SimdUsize<LANES>) -> Self {
24+
unsafe {
25+
let x: SimdUsize<LANES> = mem::transmute_copy(&self);
26+
mem::transmute_copy(&{ x + (addend * mem::size_of::<T>()) })
27+
}
28+
}
29+
}
30+
31+
/// A vector of *mut T. Be very careful around potential aliasing.
32+
#[derive(Debug, Copy, Clone)]
33+
#[repr(simd)]
34+
pub(crate) struct SimdMutPtr<T, const LANES: usize>([*mut T; LANES]);
35+
36+
impl<T, const LANES: usize> SimdMutPtr<T, LANES>
37+
where
38+
SimdUsize<LANES>: crate::LanesAtMost32,
39+
T: Sized,
40+
{
41+
#[inline]
42+
#[must_use]
43+
pub fn splat(ptr: *mut T) -> Self {
44+
Self([ptr; LANES])
45+
}
46+
47+
#[inline]
48+
#[must_use]
49+
pub fn wrapping_add(self, addend: SimdUsize<LANES>) -> Self {
50+
unsafe {
51+
let x: SimdUsize<LANES> = mem::transmute_copy(&self);
52+
mem::transmute_copy(&{ x + (addend * mem::size_of::<T>()) })
53+
}
54+
}
55+
}

0 commit comments

Comments
 (0)