diff --git a/crates/core_simd/src/masks.rs b/crates/core_simd/src/masks.rs index 7baa9647591..beb01b733ac 100644 --- a/crates/core_simd/src/masks.rs +++ b/crates/core_simd/src/masks.rs @@ -2,7 +2,8 @@ //! Types representing #![allow(non_camel_case_types)] -use crate::simd::{LaneCount, Select, Simd, SimdCast, SimdElement, SupportedLaneCount}; +use crate::core_simd::vector::sealed::MaskElement; +use crate::simd::{LaneCount, Select, Simd, SimdElement, SupportedLaneCount}; use core::cmp::Ordering; use core::{fmt, mem}; @@ -29,92 +30,7 @@ macro_rules! impl_fix_endianness { impl_fix_endianness! { u8, u16, u32, u64 } -mod sealed { - use super::*; - - /// Not only does this seal the `MaskElement` trait, but these functions prevent other traits - /// from bleeding into the parent bounds. - /// - /// For example, `eq` could be provided by requiring `MaskElement: PartialEq`, but that would - /// prevent us from ever removing that bound, or from implementing `MaskElement` on - /// non-`PartialEq` types in the future. - pub trait Sealed { - fn valid(values: Simd) -> bool - where - LaneCount: SupportedLaneCount, - Self: SimdElement; - - fn eq(self, other: Self) -> bool; - - fn to_usize(self) -> usize; - fn max_unsigned() -> u64; - - type Unsigned: SimdElement; - - const TRUE: Self; - - const FALSE: Self; - } -} -use sealed::Sealed; - -/// Marker trait for types that may be used as SIMD mask elements. -/// -/// # Safety -/// Type must be a signed integer. -pub unsafe trait MaskElement: SimdElement + SimdCast + Sealed {} - -macro_rules! impl_element { - { $ty:ty, $unsigned:ty } => { - impl Sealed for $ty { - #[inline] - fn valid(value: Simd) -> bool - where - LaneCount: SupportedLaneCount, - { - // We can't use `Simd` directly, because `Simd`'s functions call this function and - // we will end up with an infinite loop. - // Safety: `value` is an integer vector - unsafe { - use core::intrinsics::simd; - let falses: Simd = simd::simd_eq(value, Simd::splat(0 as _)); - let trues: Simd = simd::simd_eq(value, Simd::splat(-1 as _)); - let valid: Simd = simd::simd_or(falses, trues); - simd::simd_reduce_all(valid) - } - } - - #[inline] - fn eq(self, other: Self) -> bool { self == other } - - #[inline] - fn to_usize(self) -> usize { - self as usize - } - - #[inline] - fn max_unsigned() -> u64 { - <$unsigned>::MAX as u64 - } - - type Unsigned = $unsigned; - - const TRUE: Self = -1; - const FALSE: Self = 0; - } - - // Safety: this is a valid mask element type - unsafe impl MaskElement for $ty {} - } -} - -impl_element! { i8, u8 } -impl_element! { i16, u16 } -impl_element! { i32, u32 } -impl_element! { i64, u64 } -impl_element! { isize, usize } - -/// A SIMD vector mask for `N` elements of width specified by `Element`. +/// A SIMD vector mask for `N` elements matching the element type `T`. /// /// Masks represent boolean inclusion/exclusion on a per-element basis. /// @@ -122,21 +38,21 @@ impl_element! { isize, usize } /// and/or Rust versions, and code should not assume that it is equivalent to /// `[T; N]`. #[repr(transparent)] -pub struct Mask(Simd) +pub struct Mask(Simd) where - T: MaskElement, + T: SimdElement, LaneCount: SupportedLaneCount; impl Copy for Mask where - T: MaskElement, + T: SimdElement, LaneCount: SupportedLaneCount, { } impl Clone for Mask where - T: MaskElement, + T: SimdElement, LaneCount: SupportedLaneCount, { #[inline] @@ -147,14 +63,18 @@ where impl Mask where - T: MaskElement, + T: SimdElement, LaneCount: SupportedLaneCount, { /// Constructs a mask by setting all elements to the given value. #[inline] #[rustc_const_unstable(feature = "portable_simd", issue = "86656")] pub const fn splat(value: bool) -> Self { - Self(Simd::splat(if value { T::TRUE } else { T::FALSE })) + Self(Simd::splat(if value { + ::TRUE + } else { + ::FALSE + })) } /// Converts an array of bools to a SIMD mask. @@ -201,10 +121,10 @@ where /// All elements must be either 0 or -1. #[inline] #[must_use = "method returns a new mask and does not mutate the original value"] - pub unsafe fn from_simd_unchecked(value: Simd) -> Self { + pub unsafe fn from_simd_unchecked(value: Simd) -> Self { // Safety: the caller must confirm this invariant unsafe { - core::intrinsics::assume(::valid(value)); + core::intrinsics::assume(::valid(value)); } Self(value) } @@ -217,8 +137,11 @@ where #[inline] #[must_use = "method returns a new mask and does not mutate the original value"] #[track_caller] - pub fn from_simd(value: Simd) -> Self { - assert!(T::valid(value), "all values must be either 0 or -1",); + pub fn from_simd(value: Simd) -> Self { + assert!( + ::valid(value), + "all values must be either 0 or -1", + ); // Safety: the validity has been checked unsafe { Self::from_simd_unchecked(value) } } @@ -227,16 +150,17 @@ where /// represents `true`. #[inline] #[must_use = "method returns a new vector and does not mutate the original value"] - pub fn to_simd(self) -> Simd { + pub fn to_simd(self) -> Simd { self.0 } /// Converts the mask to a mask of any other element size. #[inline] #[must_use = "method returns a new mask and does not mutate the original value"] - pub fn cast(self) -> Mask { - // Safety: mask elements are integers - unsafe { Mask(core::intrinsics::simd::simd_as(self.0)) } + pub fn cast(self) -> Mask { + // Safety: mask elements are integers; cast between underlying mask widths + let ints: Simd = unsafe { core::intrinsics::simd::simd_as(self.0) }; + Mask(ints) } /// Tests the value of the specified element. @@ -247,7 +171,12 @@ where #[must_use = "method returns a new bool and does not mutate the original value"] pub unsafe fn test_unchecked(&self, index: usize) -> bool { // Safety: the caller must confirm this invariant - unsafe { T::eq(*self.0.as_array().get_unchecked(index), T::TRUE) } + unsafe { + ::eq( + *self.0.as_array().get_unchecked(index), + ::TRUE, + ) + } } /// Tests the value of the specified element. @@ -258,7 +187,7 @@ where #[must_use = "method returns a new bool and does not mutate the original value"] #[track_caller] pub fn test(&self, index: usize) -> bool { - T::eq(self.0[index], T::TRUE) + ::eq(self.0[index], ::TRUE) } /// Sets the value of the specified element. @@ -269,7 +198,11 @@ where pub unsafe fn set_unchecked(&mut self, index: usize, value: bool) { // Safety: the caller must confirm this invariant unsafe { - *self.0.as_mut_array().get_unchecked_mut(index) = if value { T::TRUE } else { T::FALSE } + *self.0.as_mut_array().get_unchecked_mut(index) = if value { + ::TRUE + } else { + ::FALSE + } } } @@ -280,7 +213,11 @@ where #[inline] #[track_caller] pub fn set(&mut self, index: usize, value: bool) { - self.0[index] = if value { T::TRUE } else { T::FALSE } + self.0[index] = if value { + ::TRUE + } else { + ::FALSE + } } /// Returns true if any element is set, or false otherwise. @@ -314,7 +251,7 @@ where mask: Mask, ) -> U where - T: MaskElement, + T: SimdElement, LaneCount: SupportedLaneCount, LaneCount: SupportedLaneCount, { @@ -350,7 +287,10 @@ where #[inline] #[must_use = "method returns a new mask and does not mutate the original value"] pub fn from_bitmask(bitmask: u64) -> Self { - Self(bitmask.select(Simd::splat(T::TRUE), Simd::splat(T::FALSE))) + Self(bitmask.select( + Simd::splat(::TRUE), + Simd::splat(::FALSE), + )) } /// Finds the index of the first set element. @@ -394,25 +334,28 @@ where ); // Safety: the input and output are integer vectors - let index: Simd = unsafe { core::intrinsics::simd::simd_cast(index) }; + let index: Simd = unsafe { core::intrinsics::simd::simd_cast(index) }; - let masked_index = self.select(index, Self::splat(true).to_simd()); + // Safety: the mask and inputs are integer vectors + let masked_index = unsafe { + core::intrinsics::simd::simd_select(self.to_simd(), index, Self::splat(true).to_simd()) + }; // Safety: the input and output are integer vectors - let masked_index: Simd = + let masked_index: Simd<::Unsigned, N> = unsafe { core::intrinsics::simd::simd_cast(masked_index) }; - // Safety: the input is an integer vector - let min_index: T::Unsigned = + // Safety: the input is an integer vectors + let min_index: ::Unsigned = unsafe { core::intrinsics::simd::simd_reduce_min(masked_index) }; // Safety: the return value is the unsigned version of T - let min_index: T = unsafe { core::mem::transmute_copy(&min_index) }; + let min_index: T::Mask = unsafe { core::mem::transmute_copy(&min_index) }; - if min_index.eq(T::TRUE) { + if ::eq(min_index, ::TRUE) { None } else { - Some(min_index.to_usize()) + Some(::to_usize(min_index)) } } } @@ -420,7 +363,7 @@ where // vector/array conversion impl From<[bool; N]> for Mask where - T: MaskElement, + T: SimdElement, LaneCount: SupportedLaneCount, { #[inline] @@ -431,7 +374,7 @@ where impl From> for [bool; N] where - T: MaskElement, + T: SimdElement, LaneCount: SupportedLaneCount, { #[inline] @@ -442,7 +385,7 @@ where impl Default for Mask where - T: MaskElement, + T: SimdElement, LaneCount: SupportedLaneCount, { #[inline] @@ -453,29 +396,32 @@ where impl PartialEq for Mask where - T: MaskElement + PartialEq, + T: SimdElement, LaneCount: SupportedLaneCount, { #[inline] fn eq(&self, other: &Self) -> bool { - self.0 == other.0 + // Use intrinsic to avoid additional bound on T + // Safety: `self.0` is an integer vector + unsafe { Self(core::intrinsics::simd::simd_eq(self.0, other.0)).all() } } } impl PartialOrd for Mask where - T: MaskElement + PartialOrd, + T: SimdElement, LaneCount: SupportedLaneCount, { #[inline] fn partial_cmp(&self, other: &Self) -> Option { - self.0.partial_cmp(&other.0) + // TODO use SIMD equality + self.to_array().partial_cmp(&other.to_array()) } } impl fmt::Debug for Mask where - T: MaskElement + fmt::Debug, + T: SimdElement, LaneCount: SupportedLaneCount, { #[inline] @@ -488,12 +434,13 @@ where impl core::ops::BitAnd for Mask where - T: MaskElement, + T: SimdElement, LaneCount: SupportedLaneCount, { type Output = Self; #[inline] fn bitand(self, rhs: Self) -> Self { + // Use intrinsic to avoid additional bound on T // Safety: `self` is an integer vector unsafe { Self(core::intrinsics::simd::simd_and(self.0, rhs.0)) } } @@ -501,7 +448,7 @@ where impl core::ops::BitAnd for Mask where - T: MaskElement, + T: SimdElement, LaneCount: SupportedLaneCount, { type Output = Self; @@ -513,7 +460,7 @@ where impl core::ops::BitAnd> for bool where - T: MaskElement, + T: SimdElement, LaneCount: SupportedLaneCount, { type Output = Mask; @@ -525,12 +472,13 @@ where impl core::ops::BitOr for Mask where - T: MaskElement, + T: SimdElement, LaneCount: SupportedLaneCount, { type Output = Self; #[inline] fn bitor(self, rhs: Self) -> Self { + // Use intrinsic to avoid additional bound on T // Safety: `self` is an integer vector unsafe { Self(core::intrinsics::simd::simd_or(self.0, rhs.0)) } } @@ -538,7 +486,7 @@ where impl core::ops::BitOr for Mask where - T: MaskElement, + T: SimdElement, LaneCount: SupportedLaneCount, { type Output = Self; @@ -550,7 +498,7 @@ where impl core::ops::BitOr> for bool where - T: MaskElement, + T: SimdElement, LaneCount: SupportedLaneCount, { type Output = Mask; @@ -562,12 +510,13 @@ where impl core::ops::BitXor for Mask where - T: MaskElement, + T: SimdElement, LaneCount: SupportedLaneCount, { type Output = Self; #[inline] fn bitxor(self, rhs: Self) -> Self::Output { + // Use intrinsic to avoid additional bound on T // Safety: `self` is an integer vector unsafe { Self(core::intrinsics::simd::simd_xor(self.0, rhs.0)) } } @@ -575,7 +524,7 @@ where impl core::ops::BitXor for Mask where - T: MaskElement, + T: SimdElement, LaneCount: SupportedLaneCount, { type Output = Self; @@ -587,7 +536,7 @@ where impl core::ops::BitXor> for bool where - T: MaskElement, + T: SimdElement, LaneCount: SupportedLaneCount, { type Output = Mask; @@ -599,7 +548,7 @@ where impl core::ops::Not for Mask where - T: MaskElement, + T: SimdElement, LaneCount: SupportedLaneCount, { type Output = Mask; @@ -611,7 +560,7 @@ where impl core::ops::BitAndAssign for Mask where - T: MaskElement, + T: SimdElement, LaneCount: SupportedLaneCount, { #[inline] @@ -622,7 +571,7 @@ where impl core::ops::BitAndAssign for Mask where - T: MaskElement, + T: SimdElement, LaneCount: SupportedLaneCount, { #[inline] @@ -633,7 +582,7 @@ where impl core::ops::BitOrAssign for Mask where - T: MaskElement, + T: SimdElement, LaneCount: SupportedLaneCount, { #[inline] @@ -644,7 +593,7 @@ where impl core::ops::BitOrAssign for Mask where - T: MaskElement, + T: SimdElement, LaneCount: SupportedLaneCount, { #[inline] @@ -655,7 +604,7 @@ where impl core::ops::BitXorAssign for Mask where - T: MaskElement, + T: SimdElement, LaneCount: SupportedLaneCount, { #[inline] @@ -666,7 +615,7 @@ where impl core::ops::BitXorAssign for Mask where - T: MaskElement, + T: SimdElement, LaneCount: SupportedLaneCount, { #[inline] @@ -675,23 +624,102 @@ where } } -macro_rules! impl_from { - { $from:ty => $($to:ty),* } => { +macro_rules! impl_scalar_from_scalar { + () => {}; + ($head:ty $(, $tail:ty)*) => { $( - impl From> for Mask<$to, N> - where - LaneCount: SupportedLaneCount, - { - #[inline] - fn from(value: Mask<$from, N>) -> Self { - value.cast() + impl From> for Mask<$tail, N> + where + LaneCount: SupportedLaneCount, + { + #[inline] + fn from(value: Mask<$head, N>) -> Self { value.cast() } + } + impl From> for Mask<$head, N> + where + LaneCount: SupportedLaneCount, + { + #[inline] + fn from(value: Mask<$tail, N>) -> Self { value.cast() } + } + )* + impl_scalar_from_scalar! { $( $tail ),* } + }; +} + +macro_rules! impl_scalar_from_ptr { + ( $( $scalar:ty ),* $(,)? ) => { + $( + // From pointer to scalar + impl From> for Mask<$scalar, N> + where + P: core::ptr::Pointee, + LaneCount: SupportedLaneCount, + { + #[inline] + fn from(value: Mask<*const P, N>) -> Self { value.cast() } + } + + impl From> for Mask<$scalar, N> + where + P: core::ptr::Pointee, + LaneCount: SupportedLaneCount, + { + #[inline] + fn from(value: Mask<*mut P, N>) -> Self { value.cast() } + } + + // From scalar to pointer + impl From> for Mask<*const P, N> + where + P: core::ptr::Pointee, + LaneCount: SupportedLaneCount, + { + #[inline] + fn from(value: Mask<$scalar, N>) -> Self { value.cast() } + } + + impl From> for Mask<*mut P, N> + where + P: core::ptr::Pointee, + LaneCount: SupportedLaneCount, + { + #[inline] + fn from(value: Mask<$scalar, N>) -> Self { value.cast() } } - } )* + }; +} + +macro_rules! impl_scalar_from { + ( $( $scalar:ty ),* $(,)? ) => { + impl_scalar_from_scalar! { $( $scalar ),* } + impl_scalar_from_ptr! { $( $scalar ),* } + }; +} + +impl_scalar_from! { i8, i16, i32, i64, isize, u8, u16, u32, u64, usize, f32, f64 } + +impl From> for Mask<*mut U, N> +where + T: core::ptr::Pointee, + U: core::ptr::Pointee, + LaneCount: SupportedLaneCount, +{ + #[inline] + fn from(value: Mask<*const T, N>) -> Self { + value.cast() + } +} + +impl From> for Mask<*const U, N> +where + T: core::ptr::Pointee, + U: core::ptr::Pointee, + LaneCount: SupportedLaneCount, +{ + #[inline] + fn from(value: Mask<*mut T, N>) -> Self { + value.cast() } } -impl_from! { i8 => i16, i32, i64, isize } -impl_from! { i16 => i32, i64, isize, i8 } -impl_from! { i32 => i64, isize, i8, i16 } -impl_from! { i64 => isize, i8, i16, i32 } -impl_from! { isize => i8, i16, i32, i64 } diff --git a/crates/core_simd/src/select.rs b/crates/core_simd/src/select.rs index 5240b9b0c71..eb479582c6c 100644 --- a/crates/core_simd/src/select.rs +++ b/crates/core_simd/src/select.rs @@ -1,6 +1,4 @@ -use crate::simd::{ - FixEndianness, LaneCount, Mask, MaskElement, Simd, SimdElement, SupportedLaneCount, -}; +use crate::simd::{FixEndianness, LaneCount, Mask, Simd, SimdElement, SupportedLaneCount}; /// Choose elements from two vectors using a mask. /// @@ -58,7 +56,7 @@ pub trait Select { impl Select> for Mask where T: SimdElement, - U: MaskElement, + U: SimdElement, LaneCount: SupportedLaneCount, { #[inline] @@ -133,14 +131,19 @@ where impl Select> for Mask where - T: MaskElement, - U: MaskElement, + T: SimdElement, + U: SimdElement, LaneCount: SupportedLaneCount, { #[inline] fn select(self, true_values: Mask, false_values: Mask) -> Mask { - let selected: Simd = - Select::select(self, true_values.to_simd(), false_values.to_simd()); + // Safety: + // simd_as between masks is always safe (they're vectors of ints). + // simd_select uses a mask that matches the width and number of elements + let selected: Simd = unsafe { + let mask: Simd = core::intrinsics::simd::simd_as(self.to_simd()); + core::intrinsics::simd::simd_select(mask, true_values.to_simd(), false_values.to_simd()) + }; // Safety: all values come from masks unsafe { Mask::from_simd_unchecked(selected) } @@ -149,12 +152,12 @@ where impl Select> for u64 where - T: MaskElement, + T: SimdElement, LaneCount: SupportedLaneCount, { #[inline] fn select(self, true_values: Mask, false_values: Mask) -> Mask { - let selected: Simd = + let selected: Simd = Select::select(self, true_values.to_simd(), false_values.to_simd()); // Safety: all values come from masks diff --git a/crates/core_simd/src/simd/cmp/eq.rs b/crates/core_simd/src/simd/cmp/eq.rs index 789fc0bb942..c5f05d4acaa 100644 --- a/crates/core_simd/src/simd/cmp/eq.rs +++ b/crates/core_simd/src/simd/cmp/eq.rs @@ -24,7 +24,7 @@ macro_rules! impl_number { where LaneCount: SupportedLaneCount, { - type Mask = Mask<<$number as SimdElement>::Mask, N>; + type Mask = Mask<$number, N>; #[inline] fn simd_eq(self, other: Self) -> Self::Mask { @@ -46,49 +46,51 @@ macro_rules! impl_number { impl_number! { f32, f64, u8, u16, u32, u64, usize, i8, i16, i32, i64, isize } -macro_rules! impl_mask { - { $($integer:ty),* } => { - $( - impl SimdPartialEq for Mask<$integer, N> - where - LaneCount: SupportedLaneCount, - { - type Mask = Self; +// Masks compare lane-wise by comparing their underlying integer representations +impl SimdPartialEq for Mask +where + T: SimdElement, + LaneCount: SupportedLaneCount, +{ + type Mask = Self; - #[inline] - fn simd_eq(self, other: Self) -> Self::Mask { - // Safety: `self` is a vector, and the result of the comparison - // is always a valid mask. - unsafe { Self::from_simd_unchecked(core::intrinsics::simd::simd_eq(self.to_simd(), other.to_simd())) } - } + #[inline] + fn simd_eq(self, other: Self) -> Self::Mask { + // Safety: `self` is a vector, and the result of the comparison is always a valid mask. + unsafe { + Self::from_simd_unchecked(core::intrinsics::simd::simd_eq( + self.to_simd(), + other.to_simd(), + )) + } + } - #[inline] - fn simd_ne(self, other: Self) -> Self::Mask { - // Safety: `self` is a vector, and the result of the comparison - // is always a valid mask. - unsafe { Self::from_simd_unchecked(core::intrinsics::simd::simd_ne(self.to_simd(), other.to_simd())) } - } + #[inline] + fn simd_ne(self, other: Self) -> Self::Mask { + // Safety: `self` is a vector, and the result of the comparison is always a valid mask. + unsafe { + Self::from_simd_unchecked(core::intrinsics::simd::simd_ne( + self.to_simd(), + other.to_simd(), + )) } - )* } } -impl_mask! { i8, i16, i32, i64, isize } - impl SimdPartialEq for Simd<*const T, N> where LaneCount: SupportedLaneCount, { - type Mask = Mask; + type Mask = Mask<*const T, N>; #[inline] fn simd_eq(self, other: Self) -> Self::Mask { - self.addr().simd_eq(other.addr()) + self.addr().simd_eq(other.addr()).cast::<*const T>() } #[inline] fn simd_ne(self, other: Self) -> Self::Mask { - self.addr().simd_ne(other.addr()) + self.addr().simd_ne(other.addr()).cast::<*const T>() } } @@ -96,15 +98,15 @@ impl SimdPartialEq for Simd<*mut T, N> where LaneCount: SupportedLaneCount, { - type Mask = Mask; + type Mask = Mask<*mut T, N>; #[inline] fn simd_eq(self, other: Self) -> Self::Mask { - self.addr().simd_eq(other.addr()) + self.addr().simd_eq(other.addr()).cast::<*mut T>() } #[inline] fn simd_ne(self, other: Self) -> Self::Mask { - self.addr().simd_ne(other.addr()) + self.addr().simd_ne(other.addr()).cast::<*mut T>() } } diff --git a/crates/core_simd/src/simd/cmp/ord.rs b/crates/core_simd/src/simd/cmp/ord.rs index 1b1c689ad45..a9d5faf7e60 100644 --- a/crates/core_simd/src/simd/cmp/ord.rs +++ b/crates/core_simd/src/simd/cmp/ord.rs @@ -1,5 +1,5 @@ use crate::simd::{ - LaneCount, Mask, Select, Simd, SupportedLaneCount, + LaneCount, Mask, Select, Simd, SimdElement, SupportedLaneCount, cmp::SimdPartialEq, ptr::{SimdConstPtr, SimdMutPtr}, }; @@ -152,94 +152,108 @@ macro_rules! impl_float { impl_float! { f32, f64 } -macro_rules! impl_mask { - { $($integer:ty),* } => { - $( - impl SimdPartialOrd for Mask<$integer, N> - where - LaneCount: SupportedLaneCount, - { - #[inline] - fn simd_lt(self, other: Self) -> Self::Mask { - // Safety: `self` is a vector, and the result of the comparison - // is always a valid mask. - unsafe { Self::from_simd_unchecked(core::intrinsics::simd::simd_lt(self.to_simd(), other.to_simd())) } - } +impl SimdPartialOrd for Mask +where + T: SimdElement, + LaneCount: SupportedLaneCount, +{ + #[inline] + fn simd_lt(self, other: Self) -> Self::Mask { + // Use intrinsic to avoid extra bounds on T. + // Safety: `self` is a vector, and the result of the comparison is always a valid mask. + unsafe { + Self::from_simd_unchecked(core::intrinsics::simd::simd_lt( + self.to_simd(), + other.to_simd(), + )) + } + } - #[inline] - fn simd_le(self, other: Self) -> Self::Mask { - // Safety: `self` is a vector, and the result of the comparison - // is always a valid mask. - unsafe { Self::from_simd_unchecked(core::intrinsics::simd::simd_le(self.to_simd(), other.to_simd())) } - } + #[inline] + fn simd_le(self, other: Self) -> Self::Mask { + // Use intrinsic to avoid extra bounds on T. + // Safety: `self` is a vector, and the result of the comparison is always a valid mask. + unsafe { + Self::from_simd_unchecked(core::intrinsics::simd::simd_le( + self.to_simd(), + other.to_simd(), + )) + } + } - #[inline] - fn simd_gt(self, other: Self) -> Self::Mask { - // Safety: `self` is a vector, and the result of the comparison - // is always a valid mask. - unsafe { Self::from_simd_unchecked(core::intrinsics::simd::simd_gt(self.to_simd(), other.to_simd())) } - } + #[inline] + fn simd_gt(self, other: Self) -> Self::Mask { + // Use intrinsic to avoid extra bounds on T. + // Safety: `self` is a vector, and the result of the comparison is always a valid mask. + unsafe { + Self::from_simd_unchecked(core::intrinsics::simd::simd_gt( + self.to_simd(), + other.to_simd(), + )) + } + } - #[inline] - fn simd_ge(self, other: Self) -> Self::Mask { - // Safety: `self` is a vector, and the result of the comparison - // is always a valid mask. - unsafe { Self::from_simd_unchecked(core::intrinsics::simd::simd_ge(self.to_simd(), other.to_simd())) } - } + #[inline] + fn simd_ge(self, other: Self) -> Self::Mask { + // Use intrinsic to avoid extra bounds on T. + // Safety: `self` is a vector, and the result of the comparison is always a valid mask. + unsafe { + Self::from_simd_unchecked(core::intrinsics::simd::simd_ge( + self.to_simd(), + other.to_simd(), + )) } + } +} - impl SimdOrd for Mask<$integer, N> - where - LaneCount: SupportedLaneCount, - { - #[inline] - fn simd_max(self, other: Self) -> Self { - self.simd_gt(other).select(other, self) - } +impl SimdOrd for Mask +where + T: SimdElement, + LaneCount: SupportedLaneCount, +{ + #[inline] + fn simd_max(self, other: Self) -> Self { + self.simd_gt(other).select(other, self) + } - #[inline] - fn simd_min(self, other: Self) -> Self { - self.simd_lt(other).select(other, self) - } + #[inline] + fn simd_min(self, other: Self) -> Self { + self.simd_lt(other).select(other, self) + } - #[inline] - #[track_caller] - fn simd_clamp(self, min: Self, max: Self) -> Self { - assert!( - min.simd_le(max).all(), - "each element in `min` must be less than or equal to the corresponding element in `max`", - ); - self.simd_max(min).simd_min(max) - } - } - )* + #[inline] + #[track_caller] + fn simd_clamp(self, min: Self, max: Self) -> Self { + assert!( + min.simd_le(max).all(), + "each element in `min` must be less than or equal to the corresponding element in `max`", + ); + self.simd_max(min).simd_min(max) } } -impl_mask! { i8, i16, i32, i64, isize } - impl SimdPartialOrd for Simd<*const T, N> where LaneCount: SupportedLaneCount, { #[inline] fn simd_lt(self, other: Self) -> Self::Mask { - self.addr().simd_lt(other.addr()) + self.addr().simd_lt(other.addr()).cast::<*const T>() } #[inline] fn simd_le(self, other: Self) -> Self::Mask { - self.addr().simd_le(other.addr()) + self.addr().simd_le(other.addr()).cast::<*const T>() } #[inline] fn simd_gt(self, other: Self) -> Self::Mask { - self.addr().simd_gt(other.addr()) + self.addr().simd_gt(other.addr()).cast::<*const T>() } #[inline] fn simd_ge(self, other: Self) -> Self::Mask { - self.addr().simd_ge(other.addr()) + self.addr().simd_ge(other.addr()).cast::<*const T>() } } @@ -274,22 +288,22 @@ where { #[inline] fn simd_lt(self, other: Self) -> Self::Mask { - self.addr().simd_lt(other.addr()) + self.addr().simd_lt(other.addr()).cast::<*mut T>() } #[inline] fn simd_le(self, other: Self) -> Self::Mask { - self.addr().simd_le(other.addr()) + self.addr().simd_le(other.addr()).cast::<*mut T>() } #[inline] fn simd_gt(self, other: Self) -> Self::Mask { - self.addr().simd_gt(other.addr()) + self.addr().simd_gt(other.addr()).cast::<*mut T>() } #[inline] fn simd_ge(self, other: Self) -> Self::Mask { - self.addr().simd_ge(other.addr()) + self.addr().simd_ge(other.addr()).cast::<*mut T>() } } diff --git a/crates/core_simd/src/simd/num/float.rs b/crates/core_simd/src/simd/num/float.rs index 76ab5748c63..d425f9537f8 100644 --- a/crates/core_simd/src/simd/num/float.rs +++ b/crates/core_simd/src/simd/num/float.rs @@ -250,7 +250,7 @@ macro_rules! impl_trait { where LaneCount: SupportedLaneCount, { - type Mask = Mask<<$mask_ty as SimdElement>::Mask, N>; + type Mask = Mask<$ty, N>; type Scalar = $ty; type Bits = Simd<$bits_ty, N>; type Cast = Simd; @@ -345,7 +345,7 @@ macro_rules! impl_trait { #[inline] fn is_sign_negative(self) -> Self::Mask { let sign_bits = self.to_bits() & Simd::splat((!0 >> 1) + 1); - sign_bits.simd_gt(Simd::splat(0)) + sign_bits.simd_gt(Simd::splat(0)).cast::<$ty>() } #[inline] @@ -367,8 +367,11 @@ macro_rules! impl_trait { fn is_subnormal(self) -> Self::Mask { // On some architectures (e.g. armv7 and some ppc) subnormals are flushed to zero, // so this comparison must be done with integers. - let not_zero = self.abs().to_bits().simd_ne(Self::splat(0.0).to_bits()); - not_zero & (self.to_bits() & Self::splat(Self::Scalar::INFINITY).to_bits()).simd_eq(Simd::splat(0)) + let not_zero = self.abs().to_bits().simd_ne(Self::splat(0.0).to_bits()).cast::<$ty>(); + let exp_zero = (self.to_bits() & Self::splat(Self::Scalar::INFINITY).to_bits()) + .simd_eq(Simd::splat(0)) + .cast::<$ty>(); + not_zero & exp_zero } #[inline] diff --git a/crates/core_simd/src/simd/num/int.rs b/crates/core_simd/src/simd/num/int.rs index 5a292407d05..434599f39af 100644 --- a/crates/core_simd/src/simd/num/int.rs +++ b/crates/core_simd/src/simd/num/int.rs @@ -251,7 +251,7 @@ macro_rules! impl_trait { where LaneCount: SupportedLaneCount, { - type Mask = Mask<<$ty as SimdElement>::Mask, N>; + type Mask = Mask<$ty, N>; type Scalar = $ty; type Unsigned = Simd<$unsigned, N>; type Cast = Simd; diff --git a/crates/core_simd/src/simd/ptr/const_ptr.rs b/crates/core_simd/src/simd/ptr/const_ptr.rs index 36452e7ae92..0dba1cd8770 100644 --- a/crates/core_simd/src/simd/ptr/const_ptr.rs +++ b/crates/core_simd/src/simd/ptr/const_ptr.rs @@ -98,7 +98,7 @@ where type Isize = Simd; type CastPtr = Simd<*const U, N>; type MutPtr = Simd<*mut T, N>; - type Mask = Mask; + type Mask = Mask<*const T, N>; #[inline] fn is_null(self) -> Self::Mask { diff --git a/crates/core_simd/src/simd/ptr/mut_ptr.rs b/crates/core_simd/src/simd/ptr/mut_ptr.rs index c644f390c20..9baa422ffe9 100644 --- a/crates/core_simd/src/simd/ptr/mut_ptr.rs +++ b/crates/core_simd/src/simd/ptr/mut_ptr.rs @@ -95,7 +95,7 @@ where type Isize = Simd; type CastPtr = Simd<*mut U, N>; type ConstPtr = Simd<*const T, N>; - type Mask = Mask; + type Mask = Mask<*mut T, N>; #[inline] fn is_null(self) -> Self::Mask { diff --git a/crates/core_simd/src/swizzle.rs b/crates/core_simd/src/swizzle.rs index 81085a9ee4a..e155980b558 100644 --- a/crates/core_simd/src/swizzle.rs +++ b/crates/core_simd/src/swizzle.rs @@ -1,4 +1,4 @@ -use crate::simd::{LaneCount, Mask, MaskElement, Simd, SimdElement, SupportedLaneCount}; +use crate::simd::{LaneCount, Mask, Simd, SimdElement, SupportedLaneCount}; /// Constructs a new SIMD vector by copying elements from selected elements in other vectors. /// @@ -160,7 +160,7 @@ pub trait Swizzle { #[must_use = "method returns a new mask and does not mutate the original inputs"] fn swizzle_mask(mask: Mask) -> Mask where - T: MaskElement, + T: SimdElement, LaneCount: SupportedLaneCount, LaneCount: SupportedLaneCount, { @@ -176,7 +176,7 @@ pub trait Swizzle { #[must_use = "method returns a new mask and does not mutate the original inputs"] fn concat_swizzle_mask(first: Mask, second: Mask) -> Mask where - T: MaskElement, + T: SimdElement, LaneCount: SupportedLaneCount, LaneCount: SupportedLaneCount, { @@ -518,7 +518,7 @@ where impl Mask where - T: MaskElement, + T: SimdElement, LaneCount: SupportedLaneCount, { /// Reverse the order of the elements in the mask. @@ -556,11 +556,8 @@ where pub fn shift_elements_left(self, padding: bool) -> Self { // Safety: swizzles are safe for masks unsafe { - Self::from_simd_unchecked(self.to_simd().shift_elements_left::(if padding { - T::TRUE - } else { - T::FALSE - })) + let padding = Mask::::splat(padding).to_simd()[0]; + Self::from_simd_unchecked(self.to_simd().shift_elements_left::(padding)) } } @@ -571,11 +568,8 @@ where pub fn shift_elements_right(self, padding: bool) -> Self { // Safety: swizzles are safe for masks unsafe { - Self::from_simd_unchecked(self.to_simd().shift_elements_right::(if padding { - T::TRUE - } else { - T::FALSE - })) + let padding = Mask::::splat(padding).to_simd()[0]; + Self::from_simd_unchecked(self.to_simd().shift_elements_right::(padding)) } } @@ -661,11 +655,8 @@ where { // Safety: swizzles are safe for masks unsafe { - Mask::::from_simd_unchecked(self.to_simd().resize::(if value { - T::TRUE - } else { - T::FALSE - })) + let padding = Mask::::splat(value).to_simd()[0]; + Mask::::from_simd_unchecked(self.to_simd().resize::(padding)) } } diff --git a/crates/core_simd/src/vector.rs b/crates/core_simd/src/vector.rs index c00cfcdd41f..2f9653af48b 100644 --- a/crates/core_simd/src/vector.rs +++ b/crates/core_simd/src/vector.rs @@ -1,5 +1,5 @@ use crate::simd::{ - LaneCount, Mask, MaskElement, SupportedLaneCount, Swizzle, + LaneCount, Mask, SupportedLaneCount, Swizzle, cmp::SimdPartialOrd, num::SimdUint, ptr::{SimdConstPtr, SimdMutPtr}, @@ -399,7 +399,7 @@ where /// ``` #[must_use] #[inline] - pub fn load_select_or_default(slice: &[T], enable: Mask<::Mask, N>) -> Self + pub fn load_select_or_default(slice: &[T], enable: Mask) -> Self where T: Default, { @@ -427,11 +427,7 @@ where /// ``` #[must_use] #[inline] - pub fn load_select( - slice: &[T], - mut enable: Mask<::Mask, N>, - or: Self, - ) -> Self { + pub fn load_select(slice: &[T], mut enable: Mask, or: Self) -> Self { enable &= mask_up_to(slice.len()); // SAFETY: We performed the bounds check by updating the mask. &[T] is properly aligned to // the element. @@ -448,11 +444,7 @@ where /// Enabled loads must not exceed the length of `slice`. #[must_use] #[inline] - pub unsafe fn load_select_unchecked( - slice: &[T], - enable: Mask<::Mask, N>, - or: Self, - ) -> Self { + pub unsafe fn load_select_unchecked(slice: &[T], enable: Mask, or: Self) -> Self { let ptr = slice.as_ptr(); // SAFETY: The safety of reading elements from `slice` is ensured by the caller. unsafe { Self::load_select_ptr(ptr, enable, or) } @@ -468,11 +460,7 @@ where /// Enabled `ptr` elements must be safe to read as if by `std::ptr::read`. #[must_use] #[inline] - pub unsafe fn load_select_ptr( - ptr: *const T, - enable: Mask<::Mask, N>, - or: Self, - ) -> Self { + pub unsafe fn load_select_ptr(ptr: *const T, enable: Mask, or: Self) -> Self { // SAFETY: The safety of reading elements through `ptr` is ensured by the caller. unsafe { core::intrinsics::simd::simd_masked_load(enable.to_simd(), ptr, or) } } @@ -539,11 +527,11 @@ where #[inline] pub fn gather_select( slice: &[T], - enable: Mask, + enable: Mask, idxs: Simd, or: Self, ) -> Self { - let enable: Mask = enable & idxs.simd_lt(Simd::splat(slice.len())); + let enable: Mask = enable & idxs.simd_lt(Simd::splat(slice.len())); // Safety: We have masked-off out-of-bounds indices. unsafe { Self::gather_select_unchecked(slice, enable, idxs, or) } } @@ -580,7 +568,7 @@ where #[cfg_attr(miri, track_caller)] // even without panics, this helps for Miri backtraces pub unsafe fn gather_select_unchecked( slice: &[T], - enable: Mask, + enable: Mask, idxs: Simd, or: Self, ) -> Self { @@ -588,7 +576,7 @@ where // Ferris forgive me, I have done pointer arithmetic here. let ptrs = base_ptr.wrapping_add(idxs); // Safety: The caller is responsible for determining the indices are okay to read - unsafe { Self::gather_select_ptr(ptrs, enable, or) } + unsafe { Self::gather_select_ptr(ptrs, enable.cast::<*const T>(), or) } } /// Reads elementwise from pointers into a SIMD vector. @@ -648,7 +636,7 @@ where #[cfg_attr(miri, track_caller)] // even without panics, this helps for Miri backtraces pub unsafe fn gather_select_ptr( source: Simd<*const T, N>, - enable: Mask, + enable: Mask<*const T, N>, or: Self, ) -> Self { // Safety: The caller is responsible for upholding all invariants @@ -674,7 +662,7 @@ where /// assert_eq!(arr, [0, -4, -3, 0]); /// ``` #[inline] - pub fn store_select(self, slice: &mut [T], mut enable: Mask<::Mask, N>) { + pub fn store_select(self, slice: &mut [T], mut enable: Mask) { enable &= mask_up_to(slice.len()); // SAFETY: We performed the bounds check by updating the mask. &[T] is properly aligned to // the element. @@ -702,11 +690,7 @@ where /// assert_eq!(arr, [0, -4, -3, -2]); /// ``` #[inline] - pub unsafe fn store_select_unchecked( - self, - slice: &mut [T], - enable: Mask<::Mask, N>, - ) { + pub unsafe fn store_select_unchecked(self, slice: &mut [T], enable: Mask) { let ptr = slice.as_mut_ptr(); // SAFETY: The safety of writing elements in `slice` is ensured by the caller. unsafe { self.store_select_ptr(ptr, enable) } @@ -721,7 +705,7 @@ where /// Memory addresses for element are calculated [`pointer::wrapping_offset`] and /// each enabled element must satisfy the same conditions as [`core::ptr::write`]. #[inline] - pub unsafe fn store_select_ptr(self, ptr: *mut T, enable: Mask<::Mask, N>) { + pub unsafe fn store_select_ptr(self, ptr: *mut T, enable: Mask) { // SAFETY: The safety of writing elements through `ptr` is ensured by the caller. unsafe { core::intrinsics::simd::simd_masked_store(enable.to_simd(), ptr, self) } } @@ -768,8 +752,8 @@ where /// assert_eq!(vec, vec![-41, 11, 12, 82, 14, 15, 16, 17, 18]); /// ``` #[inline] - pub fn scatter_select(self, slice: &mut [T], enable: Mask, idxs: Simd) { - let enable: Mask = enable & idxs.simd_lt(Simd::splat(slice.len())); + pub fn scatter_select(self, slice: &mut [T], enable: Mask, idxs: Simd) { + let enable: Mask = enable & idxs.simd_lt(Simd::splat(slice.len())); // Safety: We have masked-off out-of-bounds indices. unsafe { self.scatter_select_unchecked(slice, enable, idxs) } } @@ -808,7 +792,7 @@ where pub unsafe fn scatter_select_unchecked( self, slice: &mut [T], - enable: Mask, + enable: Mask, idxs: Simd, ) { // Safety: This block works with *mut T derived from &mut 'a [T], @@ -827,7 +811,7 @@ where // Ferris forgive me, I have done pointer arithmetic here. let ptrs = base_ptr.wrapping_add(idxs); // The ptrs have been bounds-masked to prevent memory-unsafe writes insha'allah - self.scatter_select_ptr(ptrs, enable); + self.scatter_select_ptr(ptrs, enable.cast::<*mut T>()); // Cleared ☢️ *mut T Zone } } @@ -880,7 +864,7 @@ where /// ``` #[inline] #[cfg_attr(miri, track_caller)] // even without panics, this helps for Miri backtraces - pub unsafe fn scatter_select_ptr(self, dest: Simd<*mut T, N>, enable: Mask) { + pub unsafe fn scatter_select_ptr(self, dest: Simd<*mut T, N>, enable: Mask<*mut T, N>) { // Safety: The caller is responsible for upholding all invariants unsafe { core::intrinsics::simd::simd_scatter(self, dest, enable.to_simd()) } } @@ -926,7 +910,7 @@ where let mask = unsafe { let tfvec: Simd<::Mask, N> = core::intrinsics::simd::simd_eq(*self, *other); - Mask::from_simd_unchecked(tfvec) + Mask::::from_simd_unchecked(tfvec) }; // Two vectors are equal if all elements are equal when compared elementwise @@ -940,7 +924,7 @@ where let mask = unsafe { let tfvec: Simd<::Mask, N> = core::intrinsics::simd::simd_ne(*self, *other); - Mask::from_simd_unchecked(tfvec) + Mask::::from_simd_unchecked(tfvec) }; // Two vectors are non-equal if any elements are non-equal when compared elementwise @@ -1090,8 +1074,32 @@ where } } -mod sealed { +pub(crate) mod sealed { + use super::*; + use crate::simd::SimdCast; + pub trait Sealed {} + + /// These functions prevent other traits from bleeding into the parent bounds. + /// + /// For example, `eq` could be provided by requiring `MaskElement: PartialEq`, but that would + /// prevent us from ever removing that bound, or from implementing `MaskElement` on + /// non-`PartialEq` types in the future. + pub trait MaskElement: SimdCast + Sealed { + fn valid(value: Simd) -> bool + where + LaneCount: SupportedLaneCount, + Self: SimdElement; + + fn eq(self, other: Self) -> bool; + + fn to_usize(self) -> usize; + + type Unsigned: SimdElement; + + const TRUE: Self; + const FALSE: Self; + } } use sealed::Sealed; @@ -1105,7 +1113,7 @@ use sealed::Sealed; /// even when no soundness guarantees are broken by allowing the user to try. pub unsafe trait SimdElement: Sealed + Copy { /// The mask element type corresponding to this element type. - type Mask: MaskElement; + type Mask: sealed::MaskElement; } impl Sealed for u8 {} @@ -1149,6 +1157,54 @@ impl Sealed for i8 {} unsafe impl SimdElement for i8 { type Mask = i8; } +macro_rules! impl_mask_element { + ($ty:ty, $unsigned:ty) => { + impl sealed::MaskElement for $ty { + #[inline] + fn valid(value: Simd) -> bool + where + LaneCount: SupportedLaneCount, + Self: SimdElement, + { + // We can't use `Simd` directly, because `Simd`'s functions call this function and + // we will end up with an infinite loop. + // Safety: `value` is an integer vector + unsafe { + use core::intrinsics::simd; + let falses: Simd = simd::simd_eq(value, Simd::splat(0 as _)); + let trues: Simd = simd::simd_eq(value, Simd::splat(-1 as _)); + let valid: Simd = simd::simd_or(falses, trues); + simd::simd_reduce_all(valid) + } + } + + #[inline] + fn eq(self, other: Self) -> bool { + self == other + } + + #[inline] + fn to_usize(self) -> usize { + self as usize + } + + type Unsigned = $unsigned; + + const TRUE: Self = -1; + const FALSE: Self = 0; + } + }; +} + +impl_mask_element! { i8, u8 } + +impl_mask_element! { i16, u16 } + +impl_mask_element! { i32, u32 } + +impl_mask_element! { i64, u64 } + +impl_mask_element! { isize, usize } impl Sealed for i16 {} @@ -1233,20 +1289,27 @@ where fn mask_up_to(len: usize) -> Mask where LaneCount: SupportedLaneCount, - M: MaskElement, + M: SimdElement, { let index = lane_indices::(); - let max_value: u64 = M::max_unsigned(); - macro_rules! case { - ($ty:ty) => { - if N < <$ty>::MAX as usize && max_value as $ty as u64 == max_value { - return index.cast().simd_lt(Simd::splat(len.min(N) as $ty)).cast(); - } - }; + // Choose the comparison width based on the mask element size of M + match core::mem::size_of::() { + 1 => { + let idx: Simd = index.cast(); + idx.simd_lt(Simd::splat(len.min(N) as u8)).cast() + } + 2 => { + let idx: Simd = index.cast(); + idx.simd_lt(Simd::splat(len.min(N) as u16)).cast() + } + 4 => { + let idx: Simd = index.cast(); + idx.simd_lt(Simd::splat(len.min(N) as u32)).cast() + } + 8 => { + let idx: Simd = index.cast(); + idx.simd_lt(Simd::splat(len.min(N) as u64)).cast() + } + _ => index.simd_lt(Simd::splat(len)).cast(), } - case!(u8); - case!(u16); - case!(u32); - case!(u64); - index.simd_lt(Simd::splat(len)).cast() } diff --git a/crates/core_simd/tests/masks.rs b/crates/core_simd/tests/masks.rs index 53fb2367b60..700505c3cb0 100644 --- a/crates/core_simd/tests/masks.rs +++ b/crates/core_simd/tests/masks.rs @@ -113,7 +113,7 @@ macro_rules! test_mask_api { #[test] fn cast() { - fn cast_impl() + fn cast_impl() where Mask<$type, 8>: Into>, {