Skip to content

Commit 1146930

Browse files
committed
Guarantee Mask has the same layout as Simd. Implement select as a trait that also supports bitmasks.
1 parent 936d58b commit 1146930

File tree

8 files changed

+167
-61
lines changed

8 files changed

+167
-61
lines changed

crates/core_simd/src/masks.rs

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,10 @@
22
//! Types representing
33
#![allow(non_camel_case_types)]
44

5-
#[cfg_attr(
6-
not(all(target_arch = "x86_64", target_feature = "avx512f")),
7-
path = "masks/full_masks.rs"
8-
)]
9-
#[cfg_attr(
10-
all(target_arch = "x86_64", target_feature = "avx512f"),
11-
path = "masks/bitmask.rs"
12-
)]
5+
#[path = "masks/full_masks.rs"]
136
mod mask_impl;
147

15-
use crate::simd::{LaneCount, Simd, SimdCast, SimdElement, SupportedLaneCount};
8+
use crate::simd::{LaneCount, Select, Simd, SimdCast, SimdElement, SupportedLaneCount};
169
use core::cmp::Ordering;
1710
use core::{fmt, mem};
1811

@@ -105,9 +98,8 @@ impl_element! { isize, usize }
10598
///
10699
/// Masks represent boolean inclusion/exclusion on a per-element basis.
107100
///
108-
/// The layout of this type is unspecified, and may change between platforms
109-
/// and/or Rust versions, and code should not assume that it is equivalent to
110-
/// `[T; N]`.
101+
/// The layout of this type is equivalent to `Simd<T, N>`, but elements
102+
/// are guaranteed to be either 0 or -1.
111103
#[repr(transparent)]
112104
pub struct Mask<T, const N: usize>(mask_impl::Mask<T, N>)
113105
where

crates/core_simd/src/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ pub mod simd {
2929
pub use crate::core_simd::cast::*;
3030
pub use crate::core_simd::lane_count::{LaneCount, SupportedLaneCount};
3131
pub use crate::core_simd::masks::*;
32+
pub use crate::core_simd::select::*;
3233
pub use crate::core_simd::swizzle::*;
3334
pub use crate::core_simd::to_bytes::ToBytes;
3435
pub use crate::core_simd::vector::*;

crates/core_simd/src/ops.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use crate::simd::{LaneCount, Simd, SimdElement, SupportedLaneCount, cmp::SimdPartialEq};
1+
use crate::simd::{LaneCount, Select, Simd, SimdElement, SupportedLaneCount, cmp::SimdPartialEq};
22
use core::ops::{Add, Mul};
33
use core::ops::{BitAnd, BitOr, BitXor};
44
use core::ops::{Div, Rem, Sub};

crates/core_simd/src/select.rs

Lines changed: 155 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,54 +1,167 @@
11
use crate::simd::{LaneCount, Mask, MaskElement, Simd, SimdElement, SupportedLaneCount};
22

3-
impl<T, const N: usize> Mask<T, N>
3+
/// Choose elements from two vectors using a mask.
4+
///
5+
/// For each element in the mask, choose the corresponding element from `true_values` if
6+
/// that element mask is true, and `false_values` if that element mask is false.
7+
///
8+
/// If the mask is `u64`, it's treated as a bitmask with the least significant bit
9+
/// corresponding to the first element.
10+
///
11+
/// # Examples
12+
///
13+
/// ## Selecting values from `Simd`
14+
/// ```
15+
/// # #![feature(portable_simd)]
16+
/// # #[cfg(feature = "as_crate")] use core_simd::simd;
17+
/// # #[cfg(not(feature = "as_crate"))] use core::simd;
18+
/// # use simd::{Simd, Mask, Select};
19+
/// let a = Simd::from_array([0, 1, 2, 3]);
20+
/// let b = Simd::from_array([4, 5, 6, 7]);
21+
/// let mask = Mask::<i32, 4>::from_array([true, false, false, true]);
22+
/// let c = mask.select(a, b);
23+
/// assert_eq!(c.to_array(), [0, 5, 6, 3]);
24+
/// ```
25+
///
26+
/// ## Selecting values from `Mask`
27+
/// ```
28+
/// # #![feature(portable_simd)]
29+
/// # #[cfg(feature = "as_crate")] use core_simd::simd;
30+
/// # #[cfg(not(feature = "as_crate"))] use core::simd;
31+
/// # use simd::{Mask, Select};
32+
/// let a = Mask::<i32, 4>::from_array([true, true, false, false]);
33+
/// let b = Mask::<i32, 4>::from_array([false, false, true, true]);
34+
/// let mask = Mask::<i32, 4>::from_array([true, false, false, true]);
35+
/// let c = mask.select(a, b);
36+
/// assert_eq!(c.to_array(), [true, false, true, false]);
37+
/// ```
38+
///
39+
/// ## Selecting with a bitmask
40+
/// ```
41+
/// # #![feature(portable_simd)]
42+
/// # #[cfg(feature = "as_crate")] use core_simd::simd;
43+
/// # #[cfg(not(feature = "as_crate"))] use core::simd;
44+
/// # use simd::{Mask, Select};
45+
/// let a = Mask::<i32, 4>::from_array([true, true, false, false]);
46+
/// let b = Mask::<i32, 4>::from_array([false, false, true, true]);
47+
/// let mask = 0b1001;
48+
/// let c = mask.select(a, b);
49+
/// assert_eq!(c.to_array(), [true, false, true, false]);
50+
/// ```
51+
pub trait Select<T> {
52+
/// Choose elements
53+
fn select(self, true_values: T, false_values: T) -> T;
54+
}
55+
56+
impl<T, U, const N: usize> Select<Simd<T, N>> for Mask<U, N>
57+
where
58+
T: SimdElement,
59+
U: MaskElement,
60+
LaneCount<N>: SupportedLaneCount,
61+
{
62+
#[inline]
63+
fn select(self, true_values: Simd<T, N>, false_values: Simd<T, N>) -> Simd<T, N> {
64+
// Safety:
65+
// simd_as between masks is always safe (they're vectors of ints).
66+
// simd_select uses a mask that matches the width and number of elements
67+
unsafe {
68+
let mask: Simd<T::Mask, N> = core::intrinsics::simd::simd_as(self.to_simd());
69+
core::intrinsics::simd::simd_select(mask, true_values, false_values)
70+
}
71+
}
72+
}
73+
74+
impl<T, const N: usize> Select<Simd<T, N>> for u64
75+
where
76+
T: SimdElement,
77+
LaneCount<N>: SupportedLaneCount,
78+
{
79+
#[inline]
80+
fn select(self, true_values: Simd<T, N>, false_values: Simd<T, N>) -> Simd<T, N> {
81+
const {
82+
assert!(N <= 64, "number of elements can't be greater than 64");
83+
}
84+
85+
// LLVM assumes bit order should match endianness
86+
let bitmask = if cfg!(target_endian = "big") {
87+
let rev = self.reverse_bits();
88+
if N < 64 {
89+
// Shift things back to the right
90+
rev >> (64 - N)
91+
} else {
92+
rev
93+
}
94+
} else {
95+
self
96+
};
97+
98+
#[inline]
99+
unsafe fn select_impl<T, U, const M: usize, const N: usize>(
100+
bitmask: U,
101+
true_values: Simd<T, N>,
102+
false_values: Simd<T, N>,
103+
) -> Simd<T, N>
104+
where
105+
T: SimdElement,
106+
LaneCount<M>: SupportedLaneCount,
107+
LaneCount<N>: SupportedLaneCount,
108+
{
109+
let default = true_values[0];
110+
let true_values = true_values.resize::<M>(default);
111+
let false_values = false_values.resize::<M>(default);
112+
113+
// Safety: the caller guarantees that the size of U matches M
114+
let selected = unsafe {
115+
core::intrinsics::simd::simd_select_bitmask(bitmask, true_values, false_values)
116+
};
117+
118+
selected.resize::<N>(default)
119+
}
120+
121+
// TODO modify simd_bitmask_select to truncate input, making this unnecessary
122+
if N <= 8 {
123+
// Safety: bitmask matches length
124+
unsafe { select_impl::<T, u8, 8, N>(bitmask as u8, true_values, false_values) }
125+
} else if N <= 16 {
126+
// Safety: bitmask matches length
127+
unsafe { select_impl::<T, u16, 16, N>(bitmask as u16, true_values, false_values) }
128+
} else if N <= 32 {
129+
// Safety: bitmask matches length
130+
unsafe { select_impl::<T, u32, 32, N>(bitmask as u32, true_values, false_values) }
131+
} else {
132+
// Safety: bitmask matches length
133+
unsafe { select_impl::<T, u64, 64, N>(bitmask, true_values, false_values) }
134+
}
135+
}
136+
}
137+
138+
impl<T, U, const N: usize> Select<Mask<T, N>> for Mask<U, N>
4139
where
5140
T: MaskElement,
141+
U: MaskElement,
6142
LaneCount<N>: SupportedLaneCount,
7143
{
8-
/// Choose elements from two vectors.
9-
///
10-
/// For each element in the mask, choose the corresponding element from `true_values` if
11-
/// that element mask is true, and `false_values` if that element mask is false.
12-
///
13-
/// # Examples
14-
/// ```
15-
/// # #![feature(portable_simd)]
16-
/// # use core::simd::{Simd, Mask};
17-
/// let a = Simd::from_array([0, 1, 2, 3]);
18-
/// let b = Simd::from_array([4, 5, 6, 7]);
19-
/// let mask = Mask::from_array([true, false, false, true]);
20-
/// let c = mask.select(a, b);
21-
/// assert_eq!(c.to_array(), [0, 5, 6, 3]);
22-
/// ```
23144
#[inline]
24-
#[must_use = "method returns a new vector and does not mutate the original inputs"]
25-
pub fn select<U>(self, true_values: Simd<U, N>, false_values: Simd<U, N>) -> Simd<U, N>
26-
where
27-
U: SimdElement<Mask = T>,
28-
{
29-
// Safety: The mask has been cast to a vector of integers,
30-
// and the operands to select between are vectors of the same type and length.
31-
unsafe { core::intrinsics::simd::simd_select(self.to_simd(), true_values, false_values) }
145+
fn select(self, true_values: Mask<T, N>, false_values: Mask<T, N>) -> Mask<T, N> {
146+
let selected: Simd<T, N> =
147+
Select::select(self, true_values.to_simd(), false_values.to_simd());
148+
149+
// Safety: all values come from masks
150+
unsafe { Mask::from_simd_unchecked(selected) }
32151
}
152+
}
33153

34-
/// Choose elements from two masks.
35-
///
36-
/// For each element in the mask, choose the corresponding element from `true_values` if
37-
/// that element mask is true, and `false_values` if that element mask is false.
38-
///
39-
/// # Examples
40-
/// ```
41-
/// # #![feature(portable_simd)]
42-
/// # use core::simd::Mask;
43-
/// let a = Mask::<i32, 4>::from_array([true, true, false, false]);
44-
/// let b = Mask::<i32, 4>::from_array([false, false, true, true]);
45-
/// let mask = Mask::<i32, 4>::from_array([true, false, false, true]);
46-
/// let c = mask.select_mask(a, b);
47-
/// assert_eq!(c.to_array(), [true, false, true, false]);
48-
/// ```
154+
impl<T, const N: usize> Select<Mask<T, N>> for u64
155+
where
156+
T: MaskElement,
157+
LaneCount<N>: SupportedLaneCount,
158+
{
49159
#[inline]
50-
#[must_use = "method returns a new mask and does not mutate the original inputs"]
51-
pub fn select_mask(self, true_values: Self, false_values: Self) -> Self {
52-
self & true_values | !self & false_values
160+
fn select(self, true_values: Mask<T, N>, false_values: Mask<T, N>) -> Mask<T, N> {
161+
let selected: Simd<T, N> =
162+
Select::select(self, true_values.to_simd(), false_values.to_simd());
163+
164+
// Safety: all values come from masks
165+
unsafe { Mask::from_simd_unchecked(selected) }
53166
}
54167
}

crates/core_simd/src/simd/cmp/ord.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
use crate::simd::{
2-
LaneCount, Mask, Simd, SupportedLaneCount,
2+
LaneCount, Mask, Select, Simd, SupportedLaneCount,
33
cmp::SimdPartialEq,
44
ptr::{SimdConstPtr, SimdMutPtr},
55
};
@@ -194,12 +194,12 @@ macro_rules! impl_mask {
194194
{
195195
#[inline]
196196
fn simd_max(self, other: Self) -> Self {
197-
self.simd_gt(other).select_mask(other, self)
197+
self.simd_gt(other).select(other, self)
198198
}
199199

200200
#[inline]
201201
fn simd_min(self, other: Self) -> Self {
202-
self.simd_lt(other).select_mask(other, self)
202+
self.simd_lt(other).select(other, self)
203203
}
204204

205205
#[inline]

crates/core_simd/src/simd/num/float.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use super::sealed::Sealed;
22
use crate::simd::{
3-
LaneCount, Mask, Simd, SimdCast, SimdElement, SupportedLaneCount,
3+
LaneCount, Mask, Select, Simd, SimdCast, SimdElement, SupportedLaneCount,
44
cmp::{SimdPartialEq, SimdPartialOrd},
55
};
66

crates/core_simd/src/simd/num/int.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use super::sealed::Sealed;
22
use crate::simd::{
3-
LaneCount, Mask, Simd, SimdCast, SimdElement, SupportedLaneCount, cmp::SimdOrd,
3+
LaneCount, Mask, Select, Simd, SimdCast, SimdElement, SupportedLaneCount, cmp::SimdOrd,
44
cmp::SimdPartialOrd, num::SimdUint,
55
};
66

crates/core_simd/src/swizzle_dyn.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use crate::simd::{LaneCount, Simd, SupportedLaneCount};
1+
use crate::simd::{LaneCount, Select, Simd, SupportedLaneCount};
22
use core::mem;
33

44
impl<const N: usize> Simd<u8, N>

0 commit comments

Comments
 (0)