Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
336 changes: 182 additions & 154 deletions crates/core_simd/src/masks.rs

Large diffs are not rendered by default.

23 changes: 13 additions & 10 deletions crates/core_simd/src/select.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
use crate::simd::{
FixEndianness, LaneCount, Mask, MaskElement, Simd, SimdElement, SupportedLaneCount,
};
use crate::simd::{FixEndianness, LaneCount, Mask, Simd, SimdElement, SupportedLaneCount};

/// Choose elements from two vectors using a mask.
///
Expand Down Expand Up @@ -58,7 +56,7 @@ pub trait Select<T> {
impl<T, U, const N: usize> Select<Simd<T, N>> for Mask<U, N>
where
T: SimdElement,
U: MaskElement,
U: SimdElement,
LaneCount<N>: SupportedLaneCount,
{
#[inline]
Expand Down Expand Up @@ -133,14 +131,19 @@ where

impl<T, U, const N: usize> Select<Mask<T, N>> for Mask<U, N>
where
T: MaskElement,
U: MaskElement,
T: SimdElement,
U: SimdElement,
LaneCount<N>: SupportedLaneCount,
{
#[inline]
fn select(self, true_values: Mask<T, N>, false_values: Mask<T, N>) -> Mask<T, N> {
let selected: Simd<T, N> =
Select::select(self, true_values.to_simd(), false_values.to_simd());
// Safety:
// simd_as between masks is always safe (they're vectors of ints).
// simd_select uses a mask that matches the width and number of elements
let selected: Simd<T::Mask, N> = unsafe {
let mask: Simd<T::Mask, N> = core::intrinsics::simd::simd_as(self.to_simd());
core::intrinsics::simd::simd_select(mask, true_values.to_simd(), false_values.to_simd())
};

// Safety: all values come from masks
unsafe { Mask::from_simd_unchecked(selected) }
Expand All @@ -149,12 +152,12 @@ where

impl<T, const N: usize> Select<Mask<T, N>> for u64
where
T: MaskElement,
T: SimdElement,
LaneCount<N>: SupportedLaneCount,
{
#[inline]
fn select(self, true_values: Mask<T, N>, false_values: Mask<T, N>) -> Mask<T, N> {
let selected: Simd<T, N> =
let selected: Simd<T::Mask, N> =
Select::select(self, true_values.to_simd(), false_values.to_simd());

// Safety: all values come from masks
Expand Down
62 changes: 32 additions & 30 deletions crates/core_simd/src/simd/cmp/eq.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ macro_rules! impl_number {
where
LaneCount<N>: SupportedLaneCount,
{
type Mask = Mask<<$number as SimdElement>::Mask, N>;
type Mask = Mask<$number, N>;

#[inline]
fn simd_eq(self, other: Self) -> Self::Mask {
Expand All @@ -46,65 +46,67 @@ macro_rules! impl_number {

impl_number! { f32, f64, u8, u16, u32, u64, usize, i8, i16, i32, i64, isize }

macro_rules! impl_mask {
{ $($integer:ty),* } => {
$(
impl<const N: usize> SimdPartialEq for Mask<$integer, N>
where
LaneCount<N>: SupportedLaneCount,
{
type Mask = Self;
// Masks compare lane-wise by comparing their underlying integer representations
impl<T, const N: usize> SimdPartialEq for Mask<T, N>
where
T: SimdElement,
LaneCount<N>: SupportedLaneCount,
{
type Mask = Self;

#[inline]
fn simd_eq(self, other: Self) -> Self::Mask {
// Safety: `self` is a vector, and the result of the comparison
// is always a valid mask.
unsafe { Self::from_simd_unchecked(core::intrinsics::simd::simd_eq(self.to_simd(), other.to_simd())) }
}
#[inline]
fn simd_eq(self, other: Self) -> Self::Mask {
// Safety: `self` is a vector, and the result of the comparison is always a valid mask.
unsafe {
Self::from_simd_unchecked(core::intrinsics::simd::simd_eq(
self.to_simd(),
other.to_simd(),
))
}
}

#[inline]
fn simd_ne(self, other: Self) -> Self::Mask {
// Safety: `self` is a vector, and the result of the comparison
// is always a valid mask.
unsafe { Self::from_simd_unchecked(core::intrinsics::simd::simd_ne(self.to_simd(), other.to_simd())) }
}
#[inline]
fn simd_ne(self, other: Self) -> Self::Mask {
// Safety: `self` is a vector, and the result of the comparison is always a valid mask.
unsafe {
Self::from_simd_unchecked(core::intrinsics::simd::simd_ne(
self.to_simd(),
other.to_simd(),
))
}
)*
}
}

impl_mask! { i8, i16, i32, i64, isize }

impl<T, const N: usize> SimdPartialEq for Simd<*const T, N>
where
LaneCount<N>: SupportedLaneCount,
{
type Mask = Mask<isize, N>;
type Mask = Mask<*const T, N>;

#[inline]
fn simd_eq(self, other: Self) -> Self::Mask {
self.addr().simd_eq(other.addr())
self.addr().simd_eq(other.addr()).cast::<*const T>()
}

#[inline]
fn simd_ne(self, other: Self) -> Self::Mask {
self.addr().simd_ne(other.addr())
self.addr().simd_ne(other.addr()).cast::<*const T>()
}
}

impl<T, const N: usize> SimdPartialEq for Simd<*mut T, N>
where
LaneCount<N>: SupportedLaneCount,
{
type Mask = Mask<isize, N>;
type Mask = Mask<*mut T, N>;

#[inline]
fn simd_eq(self, other: Self) -> Self::Mask {
self.addr().simd_eq(other.addr())
self.addr().simd_eq(other.addr()).cast::<*mut T>()
}

#[inline]
fn simd_ne(self, other: Self) -> Self::Mask {
self.addr().simd_ne(other.addr())
self.addr().simd_ne(other.addr()).cast::<*mut T>()
}
}
144 changes: 79 additions & 65 deletions crates/core_simd/src/simd/cmp/ord.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::simd::{
LaneCount, Mask, Select, Simd, SupportedLaneCount,
LaneCount, Mask, Select, Simd, SimdElement, SupportedLaneCount,
cmp::SimdPartialEq,
ptr::{SimdConstPtr, SimdMutPtr},
};
Expand Down Expand Up @@ -152,94 +152,108 @@ macro_rules! impl_float {

impl_float! { f32, f64 }

macro_rules! impl_mask {
{ $($integer:ty),* } => {
$(
impl<const N: usize> SimdPartialOrd for Mask<$integer, N>
where
LaneCount<N>: SupportedLaneCount,
{
#[inline]
fn simd_lt(self, other: Self) -> Self::Mask {
// Safety: `self` is a vector, and the result of the comparison
// is always a valid mask.
unsafe { Self::from_simd_unchecked(core::intrinsics::simd::simd_lt(self.to_simd(), other.to_simd())) }
}
impl<T, const N: usize> SimdPartialOrd for Mask<T, N>
where
T: SimdElement,
LaneCount<N>: SupportedLaneCount,
{
#[inline]
fn simd_lt(self, other: Self) -> Self::Mask {
// Use intrinsic to avoid extra bounds on T.
// Safety: `self` is a vector, and the result of the comparison is always a valid mask.
unsafe {
Self::from_simd_unchecked(core::intrinsics::simd::simd_lt(
self.to_simd(),
other.to_simd(),
))
}
}

#[inline]
fn simd_le(self, other: Self) -> Self::Mask {
// Safety: `self` is a vector, and the result of the comparison
// is always a valid mask.
unsafe { Self::from_simd_unchecked(core::intrinsics::simd::simd_le(self.to_simd(), other.to_simd())) }
}
#[inline]
fn simd_le(self, other: Self) -> Self::Mask {
// Use intrinsic to avoid extra bounds on T.
// Safety: `self` is a vector, and the result of the comparison is always a valid mask.
unsafe {
Self::from_simd_unchecked(core::intrinsics::simd::simd_le(
self.to_simd(),
other.to_simd(),
))
}
}

#[inline]
fn simd_gt(self, other: Self) -> Self::Mask {
// Safety: `self` is a vector, and the result of the comparison
// is always a valid mask.
unsafe { Self::from_simd_unchecked(core::intrinsics::simd::simd_gt(self.to_simd(), other.to_simd())) }
}
#[inline]
fn simd_gt(self, other: Self) -> Self::Mask {
// Use intrinsic to avoid extra bounds on T.
// Safety: `self` is a vector, and the result of the comparison is always a valid mask.
unsafe {
Self::from_simd_unchecked(core::intrinsics::simd::simd_gt(
self.to_simd(),
other.to_simd(),
))
}
}

#[inline]
fn simd_ge(self, other: Self) -> Self::Mask {
// Safety: `self` is a vector, and the result of the comparison
// is always a valid mask.
unsafe { Self::from_simd_unchecked(core::intrinsics::simd::simd_ge(self.to_simd(), other.to_simd())) }
}
#[inline]
fn simd_ge(self, other: Self) -> Self::Mask {
// Use intrinsic to avoid extra bounds on T.
// Safety: `self` is a vector, and the result of the comparison is always a valid mask.
unsafe {
Self::from_simd_unchecked(core::intrinsics::simd::simd_ge(
self.to_simd(),
other.to_simd(),
))
}
}
}

impl<const N: usize> SimdOrd for Mask<$integer, N>
where
LaneCount<N>: SupportedLaneCount,
{
#[inline]
fn simd_max(self, other: Self) -> Self {
self.simd_gt(other).select(other, self)
}
impl<T, const N: usize> SimdOrd for Mask<T, N>
where
T: SimdElement,
LaneCount<N>: SupportedLaneCount,
{
#[inline]
fn simd_max(self, other: Self) -> Self {
self.simd_gt(other).select(other, self)
}

#[inline]
fn simd_min(self, other: Self) -> Self {
self.simd_lt(other).select(other, self)
}
#[inline]
fn simd_min(self, other: Self) -> Self {
self.simd_lt(other).select(other, self)
}

#[inline]
#[track_caller]
fn simd_clamp(self, min: Self, max: Self) -> Self {
assert!(
min.simd_le(max).all(),
"each element in `min` must be less than or equal to the corresponding element in `max`",
);
self.simd_max(min).simd_min(max)
}
}
)*
#[inline]
#[track_caller]
fn simd_clamp(self, min: Self, max: Self) -> Self {
assert!(
min.simd_le(max).all(),
"each element in `min` must be less than or equal to the corresponding element in `max`",
);
self.simd_max(min).simd_min(max)
}
}

impl_mask! { i8, i16, i32, i64, isize }

impl<T, const N: usize> SimdPartialOrd for Simd<*const T, N>
where
LaneCount<N>: SupportedLaneCount,
{
#[inline]
fn simd_lt(self, other: Self) -> Self::Mask {
self.addr().simd_lt(other.addr())
self.addr().simd_lt(other.addr()).cast::<*const T>()
}

#[inline]
fn simd_le(self, other: Self) -> Self::Mask {
self.addr().simd_le(other.addr())
self.addr().simd_le(other.addr()).cast::<*const T>()
}

#[inline]
fn simd_gt(self, other: Self) -> Self::Mask {
self.addr().simd_gt(other.addr())
self.addr().simd_gt(other.addr()).cast::<*const T>()
}

#[inline]
fn simd_ge(self, other: Self) -> Self::Mask {
self.addr().simd_ge(other.addr())
self.addr().simd_ge(other.addr()).cast::<*const T>()
}
}

Expand Down Expand Up @@ -274,22 +288,22 @@ where
{
#[inline]
fn simd_lt(self, other: Self) -> Self::Mask {
self.addr().simd_lt(other.addr())
self.addr().simd_lt(other.addr()).cast::<*mut T>()
}

#[inline]
fn simd_le(self, other: Self) -> Self::Mask {
self.addr().simd_le(other.addr())
self.addr().simd_le(other.addr()).cast::<*mut T>()
}

#[inline]
fn simd_gt(self, other: Self) -> Self::Mask {
self.addr().simd_gt(other.addr())
self.addr().simd_gt(other.addr()).cast::<*mut T>()
}

#[inline]
fn simd_ge(self, other: Self) -> Self::Mask {
self.addr().simd_ge(other.addr())
self.addr().simd_ge(other.addr()).cast::<*mut T>()
}
}

Expand Down
11 changes: 7 additions & 4 deletions crates/core_simd/src/simd/num/float.rs
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ macro_rules! impl_trait {
where
LaneCount<N>: SupportedLaneCount,
{
type Mask = Mask<<$mask_ty as SimdElement>::Mask, N>;
type Mask = Mask<$ty, N>;
type Scalar = $ty;
type Bits = Simd<$bits_ty, N>;
type Cast<T: SimdElement> = Simd<T, N>;
Expand Down Expand Up @@ -345,7 +345,7 @@ macro_rules! impl_trait {
#[inline]
fn is_sign_negative(self) -> Self::Mask {
let sign_bits = self.to_bits() & Simd::splat((!0 >> 1) + 1);
sign_bits.simd_gt(Simd::splat(0))
sign_bits.simd_gt(Simd::splat(0)).cast::<$ty>()
}

#[inline]
Expand All @@ -367,8 +367,11 @@ macro_rules! impl_trait {
fn is_subnormal(self) -> Self::Mask {
// On some architectures (e.g. armv7 and some ppc) subnormals are flushed to zero,
// so this comparison must be done with integers.
let not_zero = self.abs().to_bits().simd_ne(Self::splat(0.0).to_bits());
not_zero & (self.to_bits() & Self::splat(Self::Scalar::INFINITY).to_bits()).simd_eq(Simd::splat(0))
let not_zero = self.abs().to_bits().simd_ne(Self::splat(0.0).to_bits()).cast::<$ty>();
let exp_zero = (self.to_bits() & Self::splat(Self::Scalar::INFINITY).to_bits())
.simd_eq(Simd::splat(0))
.cast::<$ty>();
not_zero & exp_zero
}

#[inline]
Expand Down
Loading
Loading