diff --git a/crates/core_simd/examples/dot_product.rs b/crates/core_simd/examples/dot_product.rs index 4ef32bfa60b..9071029e61d 100644 --- a/crates/core_simd/examples/dot_product.rs +++ b/crates/core_simd/examples/dot_product.rs @@ -4,6 +4,7 @@ // Add these imports to use the stdsimd library #![feature(portable_simd)] use core_simd::simd::prelude::*; +use std_float::StdFloat; // This is your barebones dot product implementation: // Take 2 vectors, multiply them element wise and *then* @@ -71,7 +72,6 @@ pub fn dot_prod_simd_1(a: &[f32], b: &[f32]) -> f32 { // A lot of knowledgeable use of SIMD comes from knowing specific instructions that are // available - let's try to use the `mul_add` instruction, which is the fused-multiply-add we were looking for. -use std_float::StdFloat; pub fn dot_prod_simd_2(a: &[f32], b: &[f32]) -> f32 { assert_eq!(a.len(), b.len()); // TODO handle remainder when a.len() % 4 != 0 diff --git a/crates/core_simd/examples/spectral_norm.rs b/crates/core_simd/examples/spectral_norm.rs index bc7934c2522..5b40d2c0b92 100644 --- a/crates/core_simd/examples/spectral_norm.rs +++ b/crates/core_simd/examples/spectral_norm.rs @@ -8,7 +8,7 @@ fn a(i: usize, j: usize) -> f64 { fn mult_av(v: &[f64], out: &mut [f64]) { assert!(v.len() == out.len()); - assert!(v.len() % 2 == 0); + assert!(v.len().is_multiple_of(2)); for (i, out) in out.iter_mut().enumerate() { let mut sum = f64x2::splat(0.0); @@ -26,7 +26,7 @@ fn mult_av(v: &[f64], out: &mut [f64]) { fn mult_atv(v: &[f64], out: &mut [f64]) { assert!(v.len() == out.len()); - assert!(v.len() % 2 == 0); + assert!(v.len().is_multiple_of(2)); for (i, out) in out.iter_mut().enumerate() { let mut sum = f64x2::splat(0.0); @@ -48,7 +48,7 @@ fn mult_atav(v: &[f64], out: &mut [f64], tmp: &mut [f64]) { } pub fn spectral_norm(n: usize) -> f64 { - assert!(n % 2 == 0, "only even lengths are accepted"); + assert!(n.is_multiple_of(2), "only even lengths are accepted"); let mut u = vec![1.0; n]; let mut v = u.clone(); diff --git a/crates/std_float/examples/fma.rs b/crates/std_float/examples/fma.rs new file mode 100644 index 00000000000..ab139014ec5 --- /dev/null +++ b/crates/std_float/examples/fma.rs @@ -0,0 +1,54 @@ +//! Demonstrates fused multiply-add (FMA) operations. + +#![feature(portable_simd)] +use core_simd::simd::prelude::*; +use std_float::StdFloat; + +fn main() { + let a = f32x4::from_array([1.0, 2.0, 3.0, 4.0]); + let b = f32x4::from_array([2.0, 3.0, 4.0, 5.0]); + let c = f32x4::from_array([10.0, 10.0, 10.0, 10.0]); + + println!("FMA: a*b + c"); + println!("a = {:?}", a.to_array()); + println!("b = {:?}", b.to_array()); + println!("c = {:?}", c.to_array()); + println!("result = {:?}", a.mul_add(b, c).to_array()); + println!(); + + // Polynomial: p(x) = 2x³ + 3x² + 4x + 5 + // Horner form: ((2x + 3)x + 4)x + 5 + let x = f32x8::from_array([0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0]); + let result = f32x8::splat(2.0) + .mul_add(x, f32x8::splat(3.0)) + .mul_add(x, f32x8::splat(4.0)) + .mul_add(x, f32x8::splat(5.0)); + + println!("Polynomial p(x) = 2x³ + 3x² + 4x + 5"); + println!("x = {:?}", x.to_array()); + println!("p(x) = {:?}", result.to_array()); + println!(); + + let v1 = f32x4::from_array([1.0, 2.0, 3.0, 4.0]); + let v2 = f32x4::from_array([5.0, 6.0, 7.0, 8.0]); + + let mut acc = 0.0; + for i in 0..4 { + acc = v1[i].mul_add(v2[i], acc); + } + + println!("Dot product using FMA:"); + println!("v1 · v2 = {}", acc); + println!(); + + let large = f32x4::splat(1e10); + let small = f32x4::splat(1.0); + + let fma_result = large.mul_add(f32x4::splat(1.0), small); + let separate_result = large * f32x4::splat(1.0) + small; + + println!("Accuracy comparison (1e10 * 1.0 + 1.0):"); + println!("FMA result: {:?}", fma_result.to_array()); + println!("Separate ops: {:?}", separate_result.to_array()); + println!("Both preserve precision in this case"); +} diff --git a/crates/std_float/src/lib.rs b/crates/std_float/src/lib.rs index c3c9b76e50b..bf2ed882f77 100644 --- a/crates/std_float/src/lib.rs +++ b/crates/std_float/src/lib.rs @@ -56,6 +56,19 @@ pub trait StdFloat: Sealed + Sized { unsafe { intrinsics::simd_fma(self, a, b) } } + /// Elementwise fused multiply-subtract. Computes `(self * a) - b` with only one rounding error, + /// yielding a more accurate result than an unfused multiply-subtract. + /// + /// Using `mul_sub` *may* be more performant than an unfused multiply-subtract if the target + /// architecture has a dedicated `fma` CPU instruction. However, this is not always + /// true, and will be heavily dependent on designing algorithms with specific target + /// hardware in mind. + #[inline] + #[must_use = "method returns a new vector and does not mutate the original value"] + fn mul_sub(self, a: Self, b: Self) -> Self { + unsafe { intrinsics::simd_fma(self, a, intrinsics::simd_neg(b)) } + } + /// Produces a vector where every element has the square root value /// of the equivalently-indexed element in `self` #[inline] diff --git a/crates/std_float/tests/fma.rs b/crates/std_float/tests/fma.rs new file mode 100644 index 00000000000..57453897225 --- /dev/null +++ b/crates/std_float/tests/fma.rs @@ -0,0 +1,176 @@ +#![feature(portable_simd)] + +use core_simd::simd::prelude::*; +use std_float::StdFloat; + +#[test] +fn test_mul_add_basic() { + let a = f32x4::from_array([2.0, 3.0, 4.0, 5.0]); + let b = f32x4::from_array([10.0, 10.0, 10.0, 10.0]); + let c = f32x4::from_array([1.0, 2.0, 3.0, 4.0]); + assert_eq!(a.mul_add(b, c), f32x4::from_array([21.0, 32.0, 43.0, 54.0])); +} + +#[test] +fn test_mul_add_f64() { + let a = f64x4::from_array([2.0, 3.0, 4.0, 5.0]); + let b = f64x4::from_array([10.0, 10.0, 10.0, 10.0]); + let c = f64x4::from_array([1.0, 2.0, 3.0, 4.0]); + assert_eq!(a.mul_add(b, c), f64x4::from_array([21.0, 32.0, 43.0, 54.0])); +} + +#[test] +fn test_mul_sub_basic() { + let a = f32x4::from_array([2.0, 3.0, 4.0, 5.0]); + let b = f32x4::from_array([10.0, 10.0, 10.0, 10.0]); + let c = f32x4::from_array([1.0, 2.0, 3.0, 4.0]); + assert_eq!(a.mul_sub(b, c), f32x4::from_array([19.0, 28.0, 37.0, 46.0])); +} + +#[test] +fn test_mul_sub_f64() { + let a = f64x4::from_array([2.0, 3.0, 4.0, 5.0]); + let b = f64x4::from_array([10.0, 10.0, 10.0, 10.0]); + let c = f64x4::from_array([1.0, 2.0, 3.0, 4.0]); + assert_eq!(a.mul_sub(b, c), f64x4::from_array([19.0, 28.0, 37.0, 46.0])); +} + +#[test] +fn test_fma_accuracy_catastrophic_cancellation() { + let epsilon = 1e-4_f32; + let x = 1.0 + epsilon; + let y = 1.0 - epsilon; + + let a = f32x4::splat(x); + let b = f32x4::splat(y); + let c = f32x4::splat(-1.0); + + let fma_result = a.mul_add(b, c); + let separate_result = a * b + c; + + let expected = -epsilon * epsilon; + + let fma_error = (fma_result[0] - expected).abs(); + let sep_error = (separate_result[0] - expected).abs(); + + assert!(fma_error <= sep_error); +} + +#[test] +fn test_fma_accuracy_discriminant() { + let b = f64x2::splat(1e8); + let four_ac = f64x2::splat(1.0); + + let fma_discriminant = b.mul_add(b, -four_ac); + let sep_discriminant = b * b - four_ac; + + let expected = 1e16 - 1.0; + + let fma_error = ((fma_discriminant[0] - expected) / expected).abs(); + let sep_error = ((sep_discriminant[0] - expected) / expected).abs(); + + assert!(fma_error <= sep_error); +} + +#[test] +fn test_fma_accuracy_polynomial() { + let x = f64x2::splat(1.00001); + let a = f64x2::splat(1.0); + let b = f64x2::splat(-2.0); + let c = f64x2::splat(1.0); + + let fma_result = a.mul_add(x, b).mul_add(x, c); + let sep_result = (a * x + b) * x + c; + + let expected = (x[0] - 1.0) * (x[0] - 1.0); + + let fma_error = (fma_result[0] - expected).abs(); + let sep_error = (sep_result[0] - expected).abs(); + + assert!(fma_error < sep_error || (fma_error - sep_error).abs() < 1e-15); +} + +#[test] +fn test_negative_values() { + let a = f32x4::from_array([-2.0, -3.0, -4.0, -5.0]); + let b = f32x4::splat(2.0); + let c = f32x4::splat(1.0); + assert_eq!(a.mul_add(b, c), f32x4::from_array([-3.0, -5.0, -7.0, -9.0])); + assert_eq!( + a.mul_sub(b, c), + f32x4::from_array([-5.0, -7.0, -9.0, -11.0]) + ); +} + +#[test] +fn test_infinity() { + let a = f32x4::from_array([f32::INFINITY, 1.0, 2.0, 3.0]); + let b = f32x4::splat(2.0); + let c = f32x4::splat(1.0); + let result = a.mul_add(b, c); + assert_eq!(result[0], f32::INFINITY); + assert_eq!(result[1], 3.0); +} + +#[test] +fn test_nan_propagation() { + let a = f32x4::from_array([f32::NAN, 2.0, 3.0, 4.0]); + let b = f32x4::splat(2.0); + let c = f32x4::splat(1.0); + let result = a.mul_add(b, c); + assert!(result[0].is_nan()); + assert_eq!(result[1], 5.0); +} + +#[test] +fn test_different_sizes() { + let a2 = f32x2::from_array([3.0, 4.0]); + let b2 = f32x2::from_array([2.0, 2.0]); + let c2 = f32x2::from_array([1.0, 1.0]); + assert_eq!(a2.mul_add(b2, c2), f32x2::from_array([7.0, 9.0])); + + let a8 = f32x8::splat(2.0); + let b8 = f32x8::splat(3.0); + let c8 = f32x8::splat(4.0); + assert_eq!(a8.mul_add(b8, c8), f32x8::splat(10.0)); +} + +#[test] +fn test_polynomial_evaluation() { + let x = f32x4::from_array([1.0, 2.0, 3.0, 4.0]); + let result = f32x4::splat(2.0) + .mul_add(x, f32x4::splat(3.0)) + .mul_add(x, f32x4::splat(5.0)); + assert_eq!(result, f32x4::from_array([10.0, 19.0, 32.0, 49.0])); +} + +#[test] +fn test_max_min_values() { + let a = f32x4::from_array([f32::MAX, f32::MIN, 1.0, -1.0]); + let b = f32x4::splat(1.0); + let c = f32x4::splat(0.0); + let result = a.mul_add(b, c); + assert_eq!(result[0], f32::MAX); + assert_eq!(result[1], f32::MIN); +} + +#[test] +fn test_subnormal_values() { + let subnormal = f32::MIN_POSITIVE / 2.0; + let a = f32x4::splat(subnormal); + let b = f32x4::splat(2.0); + let c = f32x4::splat(0.0); + let result = a.mul_add(b, c); + assert!(result[0].is_finite()); + + // On platforms with flush-to-zero (FTZ) mode (e.g., ARM NEON), subnormal + // values in SIMD operations may be flushed to zero for performance. + // We accept either the mathematically correct result or zero. + let expected = subnormal * 2.0; + assert!( + result[0] == expected || result[0] == 0.0, + "Expected {} (or 0.0 due to FTZ), got {}", + expected, + result[0] + ); +}