Skip to content

Commit 819e4b9

Browse files
QelxirosRalfJung
andcommitted
unstably constify float mul_add methods
Co-authored-by: Ralf Jung <[email protected]>
1 parent ce4beeb commit 819e4b9

File tree

18 files changed

+118
-135
lines changed

18 files changed

+118
-135
lines changed

compiler/rustc_const_eval/src/interpret/intrinsics.rs

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -636,6 +636,14 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
636636
dest,
637637
rustc_apfloat::Round::NearestTiesToEven,
638638
)?,
639+
sym::fmaf16 => self.fma_intrinsic::<Half>(args, dest)?,
640+
sym::fmaf32 => self.fma_intrinsic::<Single>(args, dest)?,
641+
sym::fmaf64 => self.fma_intrinsic::<Double>(args, dest)?,
642+
sym::fmaf128 => self.fma_intrinsic::<Quad>(args, dest)?,
643+
sym::fmuladdf16 => self.float_muladd_intrinsic::<Half>(args, dest)?,
644+
sym::fmuladdf32 => self.float_muladd_intrinsic::<Single>(args, dest)?,
645+
sym::fmuladdf64 => self.float_muladd_intrinsic::<Double>(args, dest)?,
646+
sym::fmuladdf128 => self.float_muladd_intrinsic::<Quad>(args, dest)?,
639647

640648
// Unsupported intrinsic: skip the return_to_block below.
641649
_ => return interp_ok(false),
@@ -1035,4 +1043,42 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
10351043
self.write_scalar(res, dest)?;
10361044
interp_ok(())
10371045
}
1046+
1047+
fn fma_intrinsic<F>(
1048+
&mut self,
1049+
args: &[OpTy<'tcx, M::Provenance>],
1050+
dest: &PlaceTy<'tcx, M::Provenance>,
1051+
) -> InterpResult<'tcx, ()>
1052+
where
1053+
F: rustc_apfloat::Float + rustc_apfloat::FloatConvert<F> + Into<Scalar<M::Provenance>>,
1054+
{
1055+
let a: F = self.read_scalar(&args[0])?.to_float()?;
1056+
let b: F = self.read_scalar(&args[1])?.to_float()?;
1057+
let c: F = self.read_scalar(&args[2])?.to_float()?;
1058+
1059+
let res = a.mul_add(b, c).value;
1060+
let res = self.adjust_nan(res, &[a, b, c]);
1061+
self.write_scalar(res, dest)?;
1062+
interp_ok(())
1063+
}
1064+
1065+
fn float_muladd_intrinsic<F>(
1066+
&mut self,
1067+
args: &[OpTy<'tcx, M::Provenance>],
1068+
dest: &PlaceTy<'tcx, M::Provenance>,
1069+
) -> InterpResult<'tcx, ()>
1070+
where
1071+
F: rustc_apfloat::Float + rustc_apfloat::FloatConvert<F> + Into<Scalar<M::Provenance>>,
1072+
{
1073+
let a: F = self.read_scalar(&args[0])?.to_float()?;
1074+
let b: F = self.read_scalar(&args[1])?.to_float()?;
1075+
let c: F = self.read_scalar(&args[2])?.to_float()?;
1076+
1077+
let fuse = M::float_fuse_mul_add(self);
1078+
1079+
let res = if fuse { a.mul_add(b, c).value } else { ((a * b).value + c).value };
1080+
let res = self.adjust_nan(res, &[a, b, c]);
1081+
self.write_scalar(res, dest)?;
1082+
interp_ok(())
1083+
}
10381084
}

compiler/rustc_const_eval/src/interpret/machine.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -289,6 +289,9 @@ pub trait Machine<'tcx>: Sized {
289289
a
290290
}
291291

292+
/// Determines whether the `fmuladd` intrinsics fuse the multiply-add or use separate operations.
293+
fn float_fuse_mul_add(_ecx: &mut InterpCx<'tcx, Self>) -> bool;
294+
292295
/// Called before a basic block terminator is executed.
293296
#[inline]
294297
fn before_terminator(_ecx: &mut InterpCx<'tcx, Self>) -> InterpResult<'tcx> {
@@ -672,6 +675,11 @@ pub macro compile_time_machine(<$tcx: lifetime>) {
672675
match fn_val {}
673676
}
674677

678+
#[inline(always)]
679+
fn float_fuse_mul_add(_ecx: &mut InterpCx<$tcx, Self>) -> bool {
680+
true
681+
}
682+
675683
#[inline(always)]
676684
fn ub_checks(_ecx: &InterpCx<$tcx, Self>) -> InterpResult<$tcx, bool> {
677685
// We can't look at `tcx.sess` here as that can differ across crates, which can lead to

library/core/src/intrinsics/mod.rs

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1312,28 +1312,28 @@ pub fn log2f128(x: f128) -> f128;
13121312
/// [`f16::mul_add`](../../std/primitive.f16.html#method.mul_add)
13131313
#[rustc_intrinsic]
13141314
#[rustc_nounwind]
1315-
pub fn fmaf16(a: f16, b: f16, c: f16) -> f16;
1315+
pub const fn fmaf16(a: f16, b: f16, c: f16) -> f16;
13161316
/// Returns `a * b + c` for `f32` values.
13171317
///
13181318
/// The stabilized version of this intrinsic is
13191319
/// [`f32::mul_add`](../../std/primitive.f32.html#method.mul_add)
13201320
#[rustc_intrinsic]
13211321
#[rustc_nounwind]
1322-
pub fn fmaf32(a: f32, b: f32, c: f32) -> f32;
1322+
pub const fn fmaf32(a: f32, b: f32, c: f32) -> f32;
13231323
/// Returns `a * b + c` for `f64` values.
13241324
///
13251325
/// The stabilized version of this intrinsic is
13261326
/// [`f64::mul_add`](../../std/primitive.f64.html#method.mul_add)
13271327
#[rustc_intrinsic]
13281328
#[rustc_nounwind]
1329-
pub fn fmaf64(a: f64, b: f64, c: f64) -> f64;
1329+
pub const fn fmaf64(a: f64, b: f64, c: f64) -> f64;
13301330
/// Returns `a * b + c` for `f128` values.
13311331
///
13321332
/// The stabilized version of this intrinsic is
13331333
/// [`f128::mul_add`](../../std/primitive.f128.html#method.mul_add)
13341334
#[rustc_intrinsic]
13351335
#[rustc_nounwind]
1336-
pub fn fmaf128(a: f128, b: f128, c: f128) -> f128;
1336+
pub const fn fmaf128(a: f128, b: f128, c: f128) -> f128;
13371337

13381338
/// Returns `a * b + c` for `f16` values, non-deterministically executing
13391339
/// either a fused multiply-add or two operations with rounding of the
@@ -1347,7 +1347,7 @@ pub fn fmaf128(a: f128, b: f128, c: f128) -> f128;
13471347
/// example.
13481348
#[rustc_intrinsic]
13491349
#[rustc_nounwind]
1350-
pub fn fmuladdf16(a: f16, b: f16, c: f16) -> f16;
1350+
pub const fn fmuladdf16(a: f16, b: f16, c: f16) -> f16;
13511351
/// Returns `a * b + c` for `f32` values, non-deterministically executing
13521352
/// either a fused multiply-add or two operations with rounding of the
13531353
/// intermediate result.
@@ -1360,7 +1360,7 @@ pub fn fmuladdf16(a: f16, b: f16, c: f16) -> f16;
13601360
/// example.
13611361
#[rustc_intrinsic]
13621362
#[rustc_nounwind]
1363-
pub fn fmuladdf32(a: f32, b: f32, c: f32) -> f32;
1363+
pub const fn fmuladdf32(a: f32, b: f32, c: f32) -> f32;
13641364
/// Returns `a * b + c` for `f64` values, non-deterministically executing
13651365
/// either a fused multiply-add or two operations with rounding of the
13661366
/// intermediate result.
@@ -1373,7 +1373,7 @@ pub fn fmuladdf32(a: f32, b: f32, c: f32) -> f32;
13731373
/// example.
13741374
#[rustc_intrinsic]
13751375
#[rustc_nounwind]
1376-
pub fn fmuladdf64(a: f64, b: f64, c: f64) -> f64;
1376+
pub const fn fmuladdf64(a: f64, b: f64, c: f64) -> f64;
13771377
/// Returns `a * b + c` for `f128` values, non-deterministically executing
13781378
/// either a fused multiply-add or two operations with rounding of the
13791379
/// intermediate result.
@@ -1386,7 +1386,7 @@ pub fn fmuladdf64(a: f64, b: f64, c: f64) -> f64;
13861386
/// example.
13871387
#[rustc_intrinsic]
13881388
#[rustc_nounwind]
1389-
pub fn fmuladdf128(a: f128, b: f128, c: f128) -> f128;
1389+
pub const fn fmuladdf128(a: f128, b: f128, c: f128) -> f128;
13901390

13911391
/// Returns the largest integer less than or equal to an `f16`.
13921392
///

library/core/src/num/f128.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1659,7 +1659,8 @@ impl f128 {
16591659
#[doc(alias = "fmaf128", alias = "fusedMultiplyAdd")]
16601660
#[unstable(feature = "f128", issue = "116909")]
16611661
#[must_use = "method returns a new number and does not mutate the original value"]
1662-
pub fn mul_add(self, a: f128, b: f128) -> f128 {
1662+
#[rustc_const_unstable(feature = "const_mul_add", issue = "146724")]
1663+
pub const fn mul_add(self, a: f128, b: f128) -> f128 {
16631664
intrinsics::fmaf128(self, a, b)
16641665
}
16651666

library/core/src/num/f16.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1634,7 +1634,8 @@ impl f16 {
16341634
#[unstable(feature = "f16", issue = "116909")]
16351635
#[doc(alias = "fmaf16", alias = "fusedMultiplyAdd")]
16361636
#[must_use = "method returns a new number and does not mutate the original value"]
1637-
pub fn mul_add(self, a: f16, b: f16) -> f16 {
1637+
#[rustc_const_unstable(feature = "const_mul_add", issue = "146724")]
1638+
pub const fn mul_add(self, a: f16, b: f16) -> f16 {
16381639
intrinsics::fmaf16(self, a, b)
16391640
}
16401641

library/core/src/num/f32.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1799,7 +1799,8 @@ pub mod math {
17991799
#[doc(alias = "fmaf", alias = "fusedMultiplyAdd")]
18001800
#[must_use = "method returns a new number and does not mutate the original value"]
18011801
#[unstable(feature = "core_float_math", issue = "137578")]
1802-
pub fn mul_add(x: f32, y: f32, z: f32) -> f32 {
1802+
#[rustc_const_unstable(feature = "const_mul_add", issue = "146724")]
1803+
pub const fn mul_add(x: f32, y: f32, z: f32) -> f32 {
18031804
intrinsics::fmaf32(x, y, z)
18041805
}
18051806

library/core/src/num/f64.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1797,7 +1797,8 @@ pub mod math {
17971797
#[doc(alias = "fma", alias = "fusedMultiplyAdd")]
17981798
#[unstable(feature = "core_float_math", issue = "137578")]
17991799
#[must_use = "method returns a new number and does not mutate the original value"]
1800-
pub fn mul_add(x: f64, a: f64, b: f64) -> f64 {
1800+
#[rustc_const_unstable(feature = "const_mul_add", issue = "146724")]
1801+
pub const fn mul_add(x: f64, a: f64, b: f64) -> f64 {
18011802
intrinsics::fmaf64(x, a, b)
18021803
}
18031804

library/coretests/tests/floats/f128.rs

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -20,24 +20,6 @@ const TOL_PRECISE: f128 = 1e-28;
2020
// FIXME(f16_f128,miri): many of these have to be disabled since miri does not yet support
2121
// the intrinsics.
2222

23-
#[test]
24-
#[cfg(not(miri))]
25-
#[cfg(target_has_reliable_f128_math)]
26-
fn test_mul_add() {
27-
let nan: f128 = f128::NAN;
28-
let inf: f128 = f128::INFINITY;
29-
let neg_inf: f128 = f128::NEG_INFINITY;
30-
assert_biteq!(12.3f128.mul_add(4.5, 6.7), 62.0500000000000000000000000000000037);
31-
assert_biteq!((-12.3f128).mul_add(-4.5, -6.7), 48.6500000000000000000000000000000049);
32-
assert_biteq!(0.0f128.mul_add(8.9, 1.2), 1.2);
33-
assert_biteq!(3.4f128.mul_add(-0.0, 5.6), 5.6);
34-
assert!(nan.mul_add(7.8, 9.0).is_nan());
35-
assert_biteq!(inf.mul_add(7.8, 9.0), inf);
36-
assert_biteq!(neg_inf.mul_add(7.8, 9.0), neg_inf);
37-
assert_biteq!(8.9f128.mul_add(inf, 3.2), inf);
38-
assert_biteq!((-3.2f128).mul_add(2.4, neg_inf), neg_inf);
39-
}
40-
4123
#[test]
4224
#[cfg(any(miri, target_has_reliable_f128_math))]
4325
fn test_max_recip() {

library/coretests/tests/floats/f16.rs

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -22,24 +22,6 @@ const TOL_P4: f16 = 10.0;
2222
// FIXME(f16_f128,miri): many of these have to be disabled since miri does not yet support
2323
// the intrinsics.
2424

25-
#[test]
26-
#[cfg(not(miri))]
27-
#[cfg(target_has_reliable_f16_math)]
28-
fn test_mul_add() {
29-
let nan: f16 = f16::NAN;
30-
let inf: f16 = f16::INFINITY;
31-
let neg_inf: f16 = f16::NEG_INFINITY;
32-
assert_biteq!(12.3f16.mul_add(4.5, 6.7), 62.031);
33-
assert_biteq!((-12.3f16).mul_add(-4.5, -6.7), 48.625);
34-
assert_biteq!(0.0f16.mul_add(8.9, 1.2), 1.2);
35-
assert_biteq!(3.4f16.mul_add(-0.0, 5.6), 5.6);
36-
assert!(nan.mul_add(7.8, 9.0).is_nan());
37-
assert_biteq!(inf.mul_add(7.8, 9.0), inf);
38-
assert_biteq!(neg_inf.mul_add(7.8, 9.0), neg_inf);
39-
assert_biteq!(8.9f16.mul_add(inf, 3.2), inf);
40-
assert_biteq!((-3.2f16).mul_add(2.4, neg_inf), neg_inf);
41-
}
42-
4325
#[test]
4426
#[cfg(any(miri, target_has_reliable_f16_math))]
4527
fn test_max_recip() {

library/coretests/tests/floats/f32.rs

Lines changed: 0 additions & 21 deletions
This file was deleted.

0 commit comments

Comments
 (0)