Skip to content

Commit 74e6262

Browse files
committed
Add min/max/clamp
1 parent b936f34 commit 74e6262

File tree

3 files changed

+132
-0
lines changed

3 files changed

+132
-0
lines changed

crates/core_simd/src/vector/float.rs

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,47 @@ macro_rules! impl_float_vector {
136136
let magnitude = self.to_bits() & !Self::splat(-0.).to_bits();
137137
Self::from_bits(sign_bit | magnitude)
138138
}
139+
140+
/// Returns the minimum of each lane.
141+
///
142+
/// If one of the values is `NAN`, then the other value is returned.
143+
#[inline]
144+
pub fn min(self, other: Self) -> Self {
145+
// TODO consider using an intrinsic
146+
self.is_nan().select(
147+
other,
148+
self.lanes_ge(other).select(other, self)
149+
)
150+
}
151+
152+
/// Returns the maximum of each lane.
153+
///
154+
/// If one of the values is `NAN`, then the other value is returned.
155+
#[inline]
156+
pub fn max(self, other: Self) -> Self {
157+
// TODO consider using an intrinsic
158+
self.is_nan().select(
159+
other,
160+
self.lanes_le(other).select(other, self)
161+
)
162+
}
163+
164+
/// Restrict each lane to a certain interval unless it is NaN.
165+
///
166+
/// For each lane in `self`, returns the corresponding lane in `max` if the lane is
167+
/// greater than `max`, and the corresponding lane in `min` if the lane is less
168+
/// than `min`. Otherwise returns the lane in `self`.
169+
#[inline]
170+
pub fn clamp(self, min: Self, max: Self) -> Self {
171+
assert!(
172+
min.lanes_le(max).all(),
173+
"each lane in `min` must be less than or equal to the corresponding lane in `max`",
174+
);
175+
let mut x = self;
176+
x = x.lanes_lt(min).select(min, x);
177+
x = x.lanes_gt(max).select(max, x);
178+
x
179+
}
139180
}
140181
};
141182
}

crates/core_simd/tests/ops_macros.rs

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -483,6 +483,76 @@ macro_rules! impl_float_tests {
483483
)
484484
}
485485

486+
fn min<const LANES: usize>() {
487+
// Regular conditions (both values aren't zero)
488+
test_helpers::test_binary_elementwise(
489+
&Vector::<LANES>::min,
490+
&Scalar::min,
491+
// Reject the case where both values are zero with different signs
492+
&|a, b| {
493+
for (a, b) in a.iter().zip(b.iter()) {
494+
if *a == 0. && *b == 0. && a.signum() != b.signum() {
495+
return false;
496+
}
497+
}
498+
true
499+
}
500+
);
501+
502+
// Special case where both values are zero
503+
let p_zero = Vector::<LANES>::splat(0.);
504+
let n_zero = Vector::<LANES>::splat(-0.);
505+
assert!(p_zero.min(n_zero).to_array().iter().all(|x| *x == 0.));
506+
assert!(n_zero.min(p_zero).to_array().iter().all(|x| *x == 0.));
507+
}
508+
509+
fn max<const LANES: usize>() {
510+
// Regular conditions (both values aren't zero)
511+
test_helpers::test_binary_elementwise(
512+
&Vector::<LANES>::max,
513+
&Scalar::max,
514+
// Reject the case where both values are zero with different signs
515+
&|a, b| {
516+
for (a, b) in a.iter().zip(b.iter()) {
517+
if *a == 0. && *b == 0. && a.signum() != b.signum() {
518+
return false;
519+
}
520+
}
521+
true
522+
}
523+
);
524+
525+
// Special case where both values are zero
526+
let p_zero = Vector::<LANES>::splat(0.);
527+
let n_zero = Vector::<LANES>::splat(-0.);
528+
assert!(p_zero.min(n_zero).to_array().iter().all(|x| *x == 0.));
529+
assert!(n_zero.min(p_zero).to_array().iter().all(|x| *x == 0.));
530+
}
531+
532+
fn clamp<const LANES: usize>() {
533+
test_helpers::test_3(&|value: [Scalar; LANES], mut min: [Scalar; LANES], mut max: [Scalar; LANES]| {
534+
for (min, max) in min.iter_mut().zip(max.iter_mut()) {
535+
if max < min {
536+
core::mem::swap(min, max);
537+
}
538+
if min.is_nan() {
539+
*min = Scalar::NEG_INFINITY;
540+
}
541+
if max.is_nan() {
542+
*max = Scalar::INFINITY;
543+
}
544+
}
545+
546+
let mut result_scalar = [Scalar::default(); LANES];
547+
for i in 0..LANES {
548+
result_scalar[i] = value[i].clamp(min[i], max[i]);
549+
}
550+
let result_vector = Vector::from_array(value).clamp(min.into(), max.into()).to_array();
551+
test_helpers::prop_assert_biteq!(result_scalar, result_vector);
552+
Ok(())
553+
})
554+
}
555+
486556
fn horizontal_sum<const LANES: usize>() {
487557
test_helpers::test_1(&|x| {
488558
test_helpers::prop_assert_biteq! (

crates/test_helpers/src/lib.rs

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,27 @@ pub fn test_2<A: core::fmt::Debug + DefaultStrategy, B: core::fmt::Debug + Defau
9797
.unwrap();
9898
}
9999

100+
/// Test a function that takes two values.
101+
pub fn test_3<
102+
A: core::fmt::Debug + DefaultStrategy,
103+
B: core::fmt::Debug + DefaultStrategy,
104+
C: core::fmt::Debug + DefaultStrategy,
105+
>(
106+
f: &dyn Fn(A, B, C) -> proptest::test_runner::TestCaseResult,
107+
) {
108+
let mut runner = proptest::test_runner::TestRunner::default();
109+
runner
110+
.run(
111+
&(
112+
A::default_strategy(),
113+
B::default_strategy(),
114+
C::default_strategy(),
115+
),
116+
|(a, b, c)| f(a, b, c),
117+
)
118+
.unwrap();
119+
}
120+
100121
/// Test a unary vector function against a unary scalar function, applied elementwise.
101122
#[inline(never)]
102123
pub fn test_unary_elementwise<Scalar, ScalarResult, Vector, VectorResult, const LANES: usize>(

0 commit comments

Comments
 (0)