Skip to content

Commit f102de7

Browse files
committed
Add mul_add
1 parent 74e6262 commit f102de7

File tree

4 files changed

+64
-0
lines changed

4 files changed

+64
-0
lines changed

crates/core_simd/src/intrinsics.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,9 @@ extern "platform-intrinsic" {
4949
/// fsqrt
5050
pub(crate) fn simd_fsqrt<T>(x: T) -> T;
5151

52+
/// fma
53+
pub(crate) fn simd_fma<T>(x: T, y: T, z: T) -> T;
54+
5255
pub(crate) fn simd_eq<T, U>(x: T, y: T) -> U;
5356
pub(crate) fn simd_ne<T, U>(x: T, y: T) -> U;
5457
pub(crate) fn simd_lt<T, U>(x: T, y: T) -> U;

crates/core_simd/src/vector/float.rs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,18 @@ macro_rules! impl_float_vector {
3636
unsafe { crate::intrinsics::simd_fabs(self) }
3737
}
3838

39+
/// Fused multiply-add. Computes `(self * a) + b` with only one rounding error,
40+
/// yielding a more accurate result than an unfused multiply-add.
41+
///
42+
/// Using `mul_add` *may* be more performant than an unfused multiply-add if the target
43+
/// architecture has a dedicated `fma` CPU instruction. However, this is not always
44+
/// true, and will be heavily dependent on designing algorithms with specific target
45+
/// hardware in mind.
46+
#[inline]
47+
pub fn mul_add(self, a: Self, b: Self) -> Self {
48+
unsafe { crate::intrinsics::simd_fma(self, a, b) }
49+
}
50+
3951
/// Produces a vector where every lane has the square root value
4052
/// of the equivalently-indexed lane in `self`
4153
#[inline]

crates/core_simd/tests/ops_macros.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -435,6 +435,14 @@ macro_rules! impl_float_tests {
435435
)
436436
}
437437

438+
fn mul_add<const LANES: usize>() {
439+
test_helpers::test_ternary_elementwise(
440+
&Vector::<LANES>::mul_add,
441+
&Scalar::mul_add,
442+
&|_, _, _| true,
443+
)
444+
}
445+
438446
fn sqrt<const LANES: usize>() {
439447
test_helpers::test_unary_elementwise(
440448
&Vector::<LANES>::sqrt,

crates/test_helpers/src/lib.rs

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,47 @@ pub fn test_binary_scalar_lhs_elementwise<
278278
});
279279
}
280280

281+
/// Test a ternary vector function against a ternary scalar function, applied elementwise.
282+
#[inline(never)]
283+
pub fn test_ternary_elementwise<
284+
Scalar1,
285+
Scalar2,
286+
Scalar3,
287+
ScalarResult,
288+
Vector1,
289+
Vector2,
290+
Vector3,
291+
VectorResult,
292+
const LANES: usize,
293+
>(
294+
fv: &dyn Fn(Vector1, Vector2, Vector3) -> VectorResult,
295+
fs: &dyn Fn(Scalar1, Scalar2, Scalar3) -> ScalarResult,
296+
check: &dyn Fn([Scalar1; LANES], [Scalar2; LANES], [Scalar3; LANES]) -> bool,
297+
) where
298+
Scalar1: Copy + Default + core::fmt::Debug + DefaultStrategy,
299+
Scalar2: Copy + Default + core::fmt::Debug + DefaultStrategy,
300+
Scalar3: Copy + Default + core::fmt::Debug + DefaultStrategy,
301+
ScalarResult: Copy + Default + biteq::BitEq + core::fmt::Debug + DefaultStrategy,
302+
Vector1: Into<[Scalar1; LANES]> + From<[Scalar1; LANES]> + Copy,
303+
Vector2: Into<[Scalar2; LANES]> + From<[Scalar2; LANES]> + Copy,
304+
Vector3: Into<[Scalar3; LANES]> + From<[Scalar3; LANES]> + Copy,
305+
VectorResult: Into<[ScalarResult; LANES]> + From<[ScalarResult; LANES]> + Copy,
306+
{
307+
test_3(&|x: [Scalar1; LANES], y: [Scalar2; LANES], z: [Scalar3; LANES]| {
308+
proptest::prop_assume!(check(x, y, z));
309+
let result_1: [ScalarResult; LANES] = fv(x.into(), y.into(), z.into()).into();
310+
let result_2: [ScalarResult; LANES] = {
311+
let mut result = [ScalarResult::default(); LANES];
312+
for ((i1, (i2, i3)), o) in x.iter().zip(y.iter().zip(z.iter())).zip(result.iter_mut()) {
313+
*o = fs(*i1, *i2, *i3);
314+
}
315+
result
316+
};
317+
crate::prop_assert_biteq!(result_1, result_2);
318+
Ok(())
319+
});
320+
}
321+
281322
/// Expand a const-generic test into separate tests for each possible lane count.
282323
#[macro_export]
283324
macro_rules! test_lanes {

0 commit comments

Comments
 (0)