Skip to content

Commit 3872723

Browse files
Merge pull request #138 from rust-lang/feature/various-fns
Add various fns - Sum/Product traits - recip/to_degrees/to_radians/min/max/clamp/signum/copysign; #14 - mul_add: #14, fixes #102
2 parents 15b4e28 + b0a9fe5 commit 3872723

File tree

7 files changed

+353
-1
lines changed

7 files changed

+353
-1
lines changed

crates/core_simd/src/intrinsics.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,9 @@ extern "platform-intrinsic" {
4949
/// fsqrt
5050
pub(crate) fn simd_fsqrt<T>(x: T) -> T;
5151

52+
/// fma
53+
pub(crate) fn simd_fma<T>(x: T, y: T, z: T) -> T;
54+
5255
pub(crate) fn simd_eq<T, U>(x: T, y: T) -> U;
5356
pub(crate) fn simd_ne<T, U>(x: T, y: T) -> U;
5457
pub(crate) fn simd_lt<T, U>(x: T, y: T) -> U;

crates/core_simd/src/iter.rs

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
macro_rules! impl_traits {
2+
{ $type:ident } => {
3+
impl<const LANES: usize> core::iter::Sum<Self> for crate::$type<LANES>
4+
where
5+
Self: crate::LanesAtMost32,
6+
{
7+
fn sum<I: core::iter::Iterator<Item = Self>>(iter: I) -> Self {
8+
iter.fold(Default::default(), core::ops::Add::add)
9+
}
10+
}
11+
12+
impl<const LANES: usize> core::iter::Product<Self> for crate::$type<LANES>
13+
where
14+
Self: crate::LanesAtMost32,
15+
{
16+
fn product<I: core::iter::Iterator<Item = Self>>(iter: I) -> Self {
17+
iter.fold(Default::default(), core::ops::Mul::mul)
18+
}
19+
}
20+
21+
impl<'a, const LANES: usize> core::iter::Sum<&'a Self> for crate::$type<LANES>
22+
where
23+
Self: crate::LanesAtMost32,
24+
{
25+
fn sum<I: core::iter::Iterator<Item = &'a Self>>(iter: I) -> Self {
26+
iter.fold(Default::default(), core::ops::Add::add)
27+
}
28+
}
29+
30+
impl<'a, const LANES: usize> core::iter::Product<&'a Self> for crate::$type<LANES>
31+
where
32+
Self: crate::LanesAtMost32,
33+
{
34+
fn product<I: core::iter::Iterator<Item = &'a Self>>(iter: I) -> Self {
35+
iter.fold(Default::default(), core::ops::Mul::mul)
36+
}
37+
}
38+
}
39+
}
40+
41+
impl_traits! { SimdF32 }
42+
impl_traits! { SimdF64 }
43+
impl_traits! { SimdU8 }
44+
impl_traits! { SimdU16 }
45+
impl_traits! { SimdU32 }
46+
impl_traits! { SimdU64 }
47+
impl_traits! { SimdUsize }
48+
impl_traits! { SimdI8 }
49+
impl_traits! { SimdI16 }
50+
impl_traits! { SimdI32 }
51+
impl_traits! { SimdI64 }
52+
impl_traits! { SimdIsize }

crates/core_simd/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ pub use to_bytes::ToBytes;
2222
mod comparisons;
2323
mod fmt;
2424
mod intrinsics;
25+
mod iter;
2526
mod ops;
2627
mod round;
2728

crates/core_simd/src/vector/float.rs

Lines changed: 93 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
/// `$lanes` of float `$type`, which uses `$bits_ty` as its binary
55
/// representation. Called from `define_float_vector!`.
66
macro_rules! impl_float_vector {
7-
{ $name:ident, $type:ty, $bits_ty:ident, $mask_ty:ident, $mask_impl_ty:ident } => {
7+
{ $name:ident, $type:ident, $bits_ty:ident, $mask_ty:ident, $mask_impl_ty:ident } => {
88
impl_vector! { $name, $type }
99
impl_float_reductions! { $name, $type }
1010

@@ -36,13 +36,44 @@ macro_rules! impl_float_vector {
3636
unsafe { crate::intrinsics::simd_fabs(self) }
3737
}
3838

39+
/// Fused multiply-add. Computes `(self * a) + b` with only one rounding error,
40+
/// yielding a more accurate result than an unfused multiply-add.
41+
///
42+
/// Using `mul_add` *may* be more performant than an unfused multiply-add if the target
43+
/// architecture has a dedicated `fma` CPU instruction. However, this is not always
44+
/// true, and will be heavily dependent on designing algorithms with specific target
45+
/// hardware in mind.
46+
#[inline]
47+
pub fn mul_add(self, a: Self, b: Self) -> Self {
48+
unsafe { crate::intrinsics::simd_fma(self, a, b) }
49+
}
50+
3951
/// Produces a vector where every lane has the square root value
4052
/// of the equivalently-indexed lane in `self`
4153
#[inline]
4254
#[cfg(feature = "std")]
4355
pub fn sqrt(self) -> Self {
4456
unsafe { crate::intrinsics::simd_fsqrt(self) }
4557
}
58+
59+
/// Takes the reciprocal (inverse) of each lane, `1/x`.
60+
#[inline]
61+
pub fn recip(self) -> Self {
62+
Self::splat(1.0) / self
63+
}
64+
65+
/// Converts each lane from radians to degrees.
66+
#[inline]
67+
pub fn to_degrees(self) -> Self {
68+
// to_degrees uses a special constant for better precision, so extract that constant
69+
self * Self::splat($type::to_degrees(1.))
70+
}
71+
72+
/// Converts each lane from degrees to radians.
73+
#[inline]
74+
pub fn to_radians(self) -> Self {
75+
self * Self::splat($type::to_radians(1.))
76+
}
4677
}
4778

4879
impl<const LANES: usize> $name<LANES>
@@ -97,6 +128,67 @@ macro_rules! impl_float_vector {
97128
pub fn is_normal(self) -> crate::$mask_ty<LANES> {
98129
!(self.abs().lanes_eq(Self::splat(0.0)) | self.is_nan() | self.is_subnormal() | self.is_infinite())
99130
}
131+
132+
/// Replaces each lane with a number that represents its sign.
133+
///
134+
/// * `1.0` if the number is positive, `+0.0`, or `INFINITY`
135+
/// * `-1.0` if the number is negative, `-0.0`, or `NEG_INFINITY`
136+
/// * `NAN` if the number is `NAN`
137+
#[inline]
138+
pub fn signum(self) -> Self {
139+
self.is_nan().select(Self::splat($type::NAN), Self::splat(1.0).copysign(self))
140+
}
141+
142+
/// Returns each lane with the magnitude of `self` and the sign of `sign`.
143+
///
144+
/// If any lane is a `NAN`, then a `NAN` with the sign of `sign` is returned.
145+
#[inline]
146+
pub fn copysign(self, sign: Self) -> Self {
147+
let sign_bit = sign.to_bits() & Self::splat(-0.).to_bits();
148+
let magnitude = self.to_bits() & !Self::splat(-0.).to_bits();
149+
Self::from_bits(sign_bit | magnitude)
150+
}
151+
152+
/// Returns the minimum of each lane.
153+
///
154+
/// If one of the values is `NAN`, then the other value is returned.
155+
#[inline]
156+
pub fn min(self, other: Self) -> Self {
157+
// TODO consider using an intrinsic
158+
self.is_nan().select(
159+
other,
160+
self.lanes_ge(other).select(other, self)
161+
)
162+
}
163+
164+
/// Returns the maximum of each lane.
165+
///
166+
/// If one of the values is `NAN`, then the other value is returned.
167+
#[inline]
168+
pub fn max(self, other: Self) -> Self {
169+
// TODO consider using an intrinsic
170+
self.is_nan().select(
171+
other,
172+
self.lanes_le(other).select(other, self)
173+
)
174+
}
175+
176+
/// Restrict each lane to a certain interval unless it is NaN.
177+
///
178+
/// For each lane in `self`, returns the corresponding lane in `max` if the lane is
179+
/// greater than `max`, and the corresponding lane in `min` if the lane is less
180+
/// than `min`. Otherwise returns the lane in `self`.
181+
#[inline]
182+
pub fn clamp(self, min: Self, max: Self) -> Self {
183+
assert!(
184+
min.lanes_le(max).all(),
185+
"each lane in `min` must be less than or equal to the corresponding lane in `max`",
186+
);
187+
let mut x = self;
188+
x = x.lanes_lt(min).select(min, x);
189+
x = x.lanes_gt(max).select(max, x);
190+
x
191+
}
100192
}
101193
};
102194
}

crates/core_simd/src/vector/int.rs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,14 +33,28 @@ macro_rules! impl_integer_vector {
3333
crate::$mask_ty<LANES>: crate::Mask,
3434
{
3535
/// Returns true for each positive lane and false if it is zero or negative.
36+
#[inline]
3637
pub fn is_positive(self) -> crate::$mask_ty<LANES> {
3738
self.lanes_gt(Self::splat(0))
3839
}
3940

4041
/// Returns true for each negative lane and false if it is zero or positive.
42+
#[inline]
4143
pub fn is_negative(self) -> crate::$mask_ty<LANES> {
4244
self.lanes_lt(Self::splat(0))
4345
}
46+
47+
/// Returns numbers representing the sign of each lane.
48+
/// * `0` if the number is zero
49+
/// * `1` if the number is positive
50+
/// * `-1` if the number is negative
51+
#[inline]
52+
pub fn signum(self) -> Self {
53+
self.is_positive().select(
54+
Self::splat(1),
55+
self.is_negative().select(Self::splat(-1), Self::splat(0))
56+
)
57+
}
4458
}
4559
}
4660
}

crates/core_simd/tests/ops_macros.rs

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,15 @@ macro_rules! impl_signed_tests {
247247
&|_| true,
248248
);
249249
}
250+
251+
fn signum<const LANES: usize>() {
252+
test_helpers::test_unary_elementwise(
253+
&Vector::<LANES>::signum,
254+
&Scalar::signum,
255+
&|_| true,
256+
)
257+
}
258+
250259
}
251260

252261
test_helpers::test_lanes_panic! {
@@ -426,13 +435,132 @@ macro_rules! impl_float_tests {
426435
)
427436
}
428437

438+
fn mul_add<const LANES: usize>() {
439+
test_helpers::test_ternary_elementwise(
440+
&Vector::<LANES>::mul_add,
441+
&Scalar::mul_add,
442+
&|_, _, _| true,
443+
)
444+
}
445+
429446
fn sqrt<const LANES: usize>() {
430447
test_helpers::test_unary_elementwise(
431448
&Vector::<LANES>::sqrt,
432449
&Scalar::sqrt,
433450
&|_| true,
434451
)
435452
}
453+
454+
fn recip<const LANES: usize>() {
455+
test_helpers::test_unary_elementwise(
456+
&Vector::<LANES>::recip,
457+
&Scalar::recip,
458+
&|_| true,
459+
)
460+
}
461+
462+
fn to_degrees<const LANES: usize>() {
463+
test_helpers::test_unary_elementwise(
464+
&Vector::<LANES>::to_degrees,
465+
&Scalar::to_degrees,
466+
&|_| true,
467+
)
468+
}
469+
470+
fn to_radians<const LANES: usize>() {
471+
test_helpers::test_unary_elementwise(
472+
&Vector::<LANES>::to_radians,
473+
&Scalar::to_radians,
474+
&|_| true,
475+
)
476+
}
477+
478+
fn signum<const LANES: usize>() {
479+
test_helpers::test_unary_elementwise(
480+
&Vector::<LANES>::signum,
481+
&Scalar::signum,
482+
&|_| true,
483+
)
484+
}
485+
486+
fn copysign<const LANES: usize>() {
487+
test_helpers::test_binary_elementwise(
488+
&Vector::<LANES>::copysign,
489+
&Scalar::copysign,
490+
&|_, _| true,
491+
)
492+
}
493+
494+
fn min<const LANES: usize>() {
495+
// Regular conditions (both values aren't zero)
496+
test_helpers::test_binary_elementwise(
497+
&Vector::<LANES>::min,
498+
&Scalar::min,
499+
// Reject the case where both values are zero with different signs
500+
&|a, b| {
501+
for (a, b) in a.iter().zip(b.iter()) {
502+
if *a == 0. && *b == 0. && a.signum() != b.signum() {
503+
return false;
504+
}
505+
}
506+
true
507+
}
508+
);
509+
510+
// Special case where both values are zero
511+
let p_zero = Vector::<LANES>::splat(0.);
512+
let n_zero = Vector::<LANES>::splat(-0.);
513+
assert!(p_zero.min(n_zero).to_array().iter().all(|x| *x == 0.));
514+
assert!(n_zero.min(p_zero).to_array().iter().all(|x| *x == 0.));
515+
}
516+
517+
fn max<const LANES: usize>() {
518+
// Regular conditions (both values aren't zero)
519+
test_helpers::test_binary_elementwise(
520+
&Vector::<LANES>::max,
521+
&Scalar::max,
522+
// Reject the case where both values are zero with different signs
523+
&|a, b| {
524+
for (a, b) in a.iter().zip(b.iter()) {
525+
if *a == 0. && *b == 0. && a.signum() != b.signum() {
526+
return false;
527+
}
528+
}
529+
true
530+
}
531+
);
532+
533+
// Special case where both values are zero
534+
let p_zero = Vector::<LANES>::splat(0.);
535+
let n_zero = Vector::<LANES>::splat(-0.);
536+
assert!(p_zero.max(n_zero).to_array().iter().all(|x| *x == 0.));
537+
assert!(n_zero.max(p_zero).to_array().iter().all(|x| *x == 0.));
538+
}
539+
540+
fn clamp<const LANES: usize>() {
541+
test_helpers::test_3(&|value: [Scalar; LANES], mut min: [Scalar; LANES], mut max: [Scalar; LANES]| {
542+
for (min, max) in min.iter_mut().zip(max.iter_mut()) {
543+
if max < min {
544+
core::mem::swap(min, max);
545+
}
546+
if min.is_nan() {
547+
*min = Scalar::NEG_INFINITY;
548+
}
549+
if max.is_nan() {
550+
*max = Scalar::INFINITY;
551+
}
552+
}
553+
554+
let mut result_scalar = [Scalar::default(); LANES];
555+
for i in 0..LANES {
556+
result_scalar[i] = value[i].clamp(min[i], max[i]);
557+
}
558+
let result_vector = Vector::from_array(value).clamp(min.into(), max.into()).to_array();
559+
test_helpers::prop_assert_biteq!(result_scalar, result_vector);
560+
Ok(())
561+
})
562+
}
563+
436564
fn horizontal_sum<const LANES: usize>() {
437565
test_helpers::test_1(&|x| {
438566
test_helpers::prop_assert_biteq! (

0 commit comments

Comments
 (0)