Skip to content

Commit ea02805

Browse files
committed
Implement select generically
1 parent de13b20 commit ea02805

File tree

2 files changed

+75
-77
lines changed

2 files changed

+75
-77
lines changed

crates/core_simd/src/select.rs

Lines changed: 68 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
use crate::{LaneCount, Mask, MaskElement, Simd, SimdElement, SupportedLaneCount};
2+
13
mod sealed {
24
pub trait Sealed {}
35
}
@@ -9,79 +11,75 @@ pub trait Select<Mask>: Sealed {
911
fn select(mask: Mask, true_values: Self, false_values: Self) -> Self;
1012
}
1113

12-
macro_rules! impl_select {
13-
{
14-
$mask:ident ($bits_ty:ident): $($type:ident),*
15-
} => {
16-
$(
17-
impl<const LANES: usize> Sealed for crate::$type<LANES> where crate::LaneCount<LANES>: crate::SupportedLaneCount {}
18-
impl<const LANES: usize> Select<crate::$mask<LANES>> for crate::$type<LANES>
19-
where
20-
crate::LaneCount<LANES>: crate::SupportedLaneCount,
21-
{
22-
#[doc(hidden)]
23-
#[inline]
24-
fn select(mask: crate::$mask<LANES>, true_values: Self, false_values: Self) -> Self {
25-
unsafe { crate::intrinsics::simd_select(mask.to_int(), true_values, false_values) }
26-
}
27-
}
28-
)*
14+
impl<Element, const LANES: usize> Sealed for Simd<Element, LANES>
15+
where
16+
Element: SimdElement,
17+
LaneCount<LANES>: SupportedLaneCount,
18+
{
19+
}
2920

30-
impl<const LANES: usize> Sealed for crate::$mask<LANES>
31-
where
32-
crate::LaneCount<LANES>: crate::SupportedLaneCount,
33-
{}
21+
impl<Element, const LANES: usize> Select<Mask<Element::Mask, LANES>> for Simd<Element, LANES>
22+
where
23+
Element: SimdElement,
24+
LaneCount<LANES>: SupportedLaneCount,
25+
{
26+
#[inline]
27+
fn select(mask: Mask<Element::Mask, LANES>, true_values: Self, false_values: Self) -> Self {
28+
unsafe { crate::intrinsics::simd_select(mask.to_int(), true_values, false_values) }
29+
}
30+
}
3431

35-
impl<const LANES: usize> Select<Self> for crate::$mask<LANES>
36-
where
37-
crate::LaneCount<LANES>: crate::SupportedLaneCount,
38-
{
39-
#[doc(hidden)]
40-
#[inline]
41-
fn select(mask: Self, true_values: Self, false_values: Self) -> Self {
42-
mask & true_values | !mask & false_values
43-
}
44-
}
32+
impl<Element, const LANES: usize> Sealed for Mask<Element, LANES>
33+
where
34+
Element: MaskElement,
35+
LaneCount<LANES>: SupportedLaneCount,
36+
{
37+
}
4538

46-
impl<const LANES: usize> crate::$mask<LANES>
47-
where
48-
crate::LaneCount<LANES>: crate::SupportedLaneCount,
49-
{
50-
/// Choose lanes from two vectors.
51-
///
52-
/// For each lane in the mask, choose the corresponding lane from `true_values` if
53-
/// that lane mask is true, and `false_values` if that lane mask is false.
54-
///
55-
/// ```
56-
/// # #![feature(portable_simd)]
57-
/// # use core_simd::{Mask32, SimdI32};
58-
/// let a = SimdI32::from_array([0, 1, 2, 3]);
59-
/// let b = SimdI32::from_array([4, 5, 6, 7]);
60-
/// let mask = Mask32::from_array([true, false, false, true]);
61-
/// let c = mask.select(a, b);
62-
/// assert_eq!(c.to_array(), [0, 5, 6, 3]);
63-
/// ```
64-
///
65-
/// `select` can also be used on masks:
66-
/// ```
67-
/// # #![feature(portable_simd)]
68-
/// # use core_simd::Mask32;
69-
/// let a = Mask32::from_array([true, true, false, false]);
70-
/// let b = Mask32::from_array([false, false, true, true]);
71-
/// let mask = Mask32::from_array([true, false, false, true]);
72-
/// let c = mask.select(a, b);
73-
/// assert_eq!(c.to_array(), [true, false, true, false]);
74-
/// ```
75-
#[inline]
76-
pub fn select<S: Select<Self>>(self, true_values: S, false_values: S) -> S {
77-
S::select(self, true_values, false_values)
78-
}
79-
}
39+
impl<Element, const LANES: usize> Select<Self> for Mask<Element, LANES>
40+
where
41+
Element: MaskElement,
42+
LaneCount<LANES>: SupportedLaneCount,
43+
{
44+
#[doc(hidden)]
45+
#[inline]
46+
fn select(mask: Self, true_values: Self, false_values: Self) -> Self {
47+
mask & true_values | !mask & false_values
8048
}
8149
}
8250

83-
impl_select! { Mask8 (SimdI8): SimdU8, SimdI8 }
84-
impl_select! { Mask16 (SimdI16): SimdU16, SimdI16 }
85-
impl_select! { Mask32 (SimdI32): SimdU32, SimdI32, SimdF32}
86-
impl_select! { Mask64 (SimdI64): SimdU64, SimdI64, SimdF64}
87-
impl_select! { MaskSize (SimdIsize): SimdUsize, SimdIsize }
51+
impl<Element, const LANES: usize> Mask<Element, LANES>
52+
where
53+
Element: MaskElement,
54+
LaneCount<LANES>: SupportedLaneCount,
55+
{
56+
/// Choose lanes from two vectors.
57+
///
58+
/// For each lane in the mask, choose the corresponding lane from `true_values` if
59+
/// that lane mask is true, and `false_values` if that lane mask is false.
60+
///
61+
/// ```
62+
/// # #![feature(portable_simd)]
63+
/// # use core_simd::{Mask32, SimdI32};
64+
/// let a = SimdI32::from_array([0, 1, 2, 3]);
65+
/// let b = SimdI32::from_array([4, 5, 6, 7]);
66+
/// let mask = Mask32::from_array([true, false, false, true]);
67+
/// let c = mask.select(a, b);
68+
/// assert_eq!(c.to_array(), [0, 5, 6, 3]);
69+
/// ```
70+
///
71+
/// `select` can also be used on masks:
72+
/// ```
73+
/// # #![feature(portable_simd)]
74+
/// # use core_simd::Mask32;
75+
/// let a = Mask32::from_array([true, true, false, false]);
76+
/// let b = Mask32::from_array([false, false, true, true]);
77+
/// let mask = Mask32::from_array([true, false, false, true]);
78+
/// let c = mask.select(a, b);
79+
/// assert_eq!(c.to_array(), [true, false, true, false]);
80+
/// ```
81+
#[inline]
82+
pub fn select<S: Select<Self>>(self, true_values: S, false_values: S) -> S {
83+
S::select(self, true_values, false_values)
84+
}
85+
}

crates/core_simd/src/vector.rs

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ pub use uint::*;
99
// Vectors of pointers are not for public use at the current time.
1010
pub(crate) mod ptr;
1111

12-
use crate::{LaneCount, SupportedLaneCount};
12+
use crate::{LaneCount, MaskElement, SupportedLaneCount};
1313

1414
/// A SIMD vector of `LANES` elements of type `Element`.
1515
#[repr(simd)]
@@ -338,32 +338,32 @@ use sealed::Sealed;
338338
/// Marker trait for types that may be used as SIMD vector elements.
339339
pub unsafe trait SimdElement: Sealed + Copy {
340340
/// The mask element type corresponding to this element type.
341-
type Mask: SimdElement;
341+
type Mask: MaskElement;
342342
}
343343

344344
impl Sealed for u8 {}
345345
unsafe impl SimdElement for u8 {
346-
type Mask = u8;
346+
type Mask = i8;
347347
}
348348

349349
impl Sealed for u16 {}
350350
unsafe impl SimdElement for u16 {
351-
type Mask = u16;
351+
type Mask = i16;
352352
}
353353

354354
impl Sealed for u32 {}
355355
unsafe impl SimdElement for u32 {
356-
type Mask = u32;
356+
type Mask = i32;
357357
}
358358

359359
impl Sealed for u64 {}
360360
unsafe impl SimdElement for u64 {
361-
type Mask = u64;
361+
type Mask = i64;
362362
}
363363

364364
impl Sealed for usize {}
365365
unsafe impl SimdElement for usize {
366-
type Mask = usize;
366+
type Mask = isize;
367367
}
368368

369369
impl Sealed for i8 {}

0 commit comments

Comments
 (0)