Skip to content

Commit 9666c78

Browse files
committed
Implement simd_fma and simd_relaxed_fma in const-eval
1 parent acf2437 commit 9666c78

File tree

5 files changed

+88
-90
lines changed

5 files changed

+88
-90
lines changed

compiler/rustc_const_eval/src/interpret/intrinsics.rs

Lines changed: 44 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,15 @@ use super::{
2525
};
2626
use crate::fluent_generated as fluent;
2727

28+
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
29+
enum MulAddType {
30+
/// Used with `fma` and `simd_fma`, always uses fused-multiply-add
31+
Fused,
32+
/// Used with `fmuladd` and `simd_relaxed_fma`, nondeterministically determines whether to use
33+
/// fma or simple multiply-add
34+
Nondeterministic,
35+
}
36+
2837
/// Directly returns an `Allocation` containing an absolute path representation of the given type.
2938
pub(crate) fn alloc_type_name<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> (AllocId, u64) {
3039
let path = crate::util::type_name(tcx, ty);
@@ -612,14 +621,22 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
612621
dest,
613622
rustc_apfloat::Round::NearestTiesToEven,
614623
)?,
615-
sym::fmaf16 => self.fma_intrinsic::<Half>(args, dest)?,
616-
sym::fmaf32 => self.fma_intrinsic::<Single>(args, dest)?,
617-
sym::fmaf64 => self.fma_intrinsic::<Double>(args, dest)?,
618-
sym::fmaf128 => self.fma_intrinsic::<Quad>(args, dest)?,
619-
sym::fmuladdf16 => self.float_muladd_intrinsic::<Half>(args, dest)?,
620-
sym::fmuladdf32 => self.float_muladd_intrinsic::<Single>(args, dest)?,
621-
sym::fmuladdf64 => self.float_muladd_intrinsic::<Double>(args, dest)?,
622-
sym::fmuladdf128 => self.float_muladd_intrinsic::<Quad>(args, dest)?,
624+
sym::fmaf16 => self.float_muladd_intrinsic::<Half>(args, dest, MulAddType::Fused)?,
625+
sym::fmaf32 => self.float_muladd_intrinsic::<Single>(args, dest, MulAddType::Fused)?,
626+
sym::fmaf64 => self.float_muladd_intrinsic::<Double>(args, dest, MulAddType::Fused)?,
627+
sym::fmaf128 => self.float_muladd_intrinsic::<Quad>(args, dest, MulAddType::Fused)?,
628+
sym::fmuladdf16 => {
629+
self.float_muladd_intrinsic::<Half>(args, dest, MulAddType::Nondeterministic)?
630+
}
631+
sym::fmuladdf32 => {
632+
self.float_muladd_intrinsic::<Single>(args, dest, MulAddType::Nondeterministic)?
633+
}
634+
sym::fmuladdf64 => {
635+
self.float_muladd_intrinsic::<Double>(args, dest, MulAddType::Nondeterministic)?
636+
}
637+
sym::fmuladdf128 => {
638+
self.float_muladd_intrinsic::<Quad>(args, dest, MulAddType::Nondeterministic)?
639+
}
623640

624641
// Unsupported intrinsic: skip the return_to_block below.
625642
_ => return interp_ok(false),
@@ -1020,40 +1037,41 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
10201037
interp_ok(())
10211038
}
10221039

1023-
fn fma_intrinsic<F>(
1024-
&mut self,
1025-
args: &[OpTy<'tcx, M::Provenance>],
1026-
dest: &PlaceTy<'tcx, M::Provenance>,
1027-
) -> InterpResult<'tcx, ()>
1040+
fn float_muladd<F>(
1041+
&self,
1042+
a: Scalar<M::Provenance>,
1043+
b: Scalar<M::Provenance>,
1044+
c: Scalar<M::Provenance>,
1045+
typ: MulAddType,
1046+
) -> InterpResult<'tcx, Scalar<M::Provenance>>
10281047
where
10291048
F: rustc_apfloat::Float + rustc_apfloat::FloatConvert<F> + Into<Scalar<M::Provenance>>,
10301049
{
1031-
let a: F = self.read_scalar(&args[0])?.to_float()?;
1032-
let b: F = self.read_scalar(&args[1])?.to_float()?;
1033-
let c: F = self.read_scalar(&args[2])?.to_float()?;
1050+
let a: F = a.to_float()?;
1051+
let b: F = b.to_float()?;
1052+
let c: F = c.to_float()?;
1053+
1054+
let fuse = typ == MulAddType::Fused || M::float_fuse_mul_add(self);
10341055

1035-
let res = a.mul_add(b, c).value;
1056+
let res = if fuse { a.mul_add(b, c).value } else { ((a * b).value + c).value };
10361057
let res = self.adjust_nan(res, &[a, b, c]);
1037-
self.write_scalar(res, dest)?;
1038-
interp_ok(())
1058+
interp_ok(res.into())
10391059
}
10401060

10411061
fn float_muladd_intrinsic<F>(
10421062
&mut self,
10431063
args: &[OpTy<'tcx, M::Provenance>],
10441064
dest: &PlaceTy<'tcx, M::Provenance>,
1065+
typ: MulAddType,
10451066
) -> InterpResult<'tcx, ()>
10461067
where
10471068
F: rustc_apfloat::Float + rustc_apfloat::FloatConvert<F> + Into<Scalar<M::Provenance>>,
10481069
{
1049-
let a: F = self.read_scalar(&args[0])?.to_float()?;
1050-
let b: F = self.read_scalar(&args[1])?.to_float()?;
1051-
let c: F = self.read_scalar(&args[2])?.to_float()?;
1052-
1053-
let fuse = M::float_fuse_mul_add(self);
1070+
let a = self.read_scalar(&args[0])?;
1071+
let b = self.read_scalar(&args[1])?;
1072+
let c = self.read_scalar(&args[2])?;
10541073

1055-
let res = if fuse { a.mul_add(b, c).value } else { ((a * b).value + c).value };
1056-
let res = self.adjust_nan(res, &[a, b, c]);
1074+
let res = self.float_muladd::<F>(a, b, c, typ)?;
10571075
self.write_scalar(res, dest)?;
10581076
interp_ok(())
10591077
}

compiler/rustc_const_eval/src/interpret/intrinsics/simd.rs

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
use either::Either;
22
use rustc_abi::Endian;
3+
use rustc_apfloat::ieee::{Double, Single};
34
use rustc_apfloat::{Float, Round};
45
use rustc_middle::mir::interpret::{InterpErrorKind, UndefinedBehaviorInfo};
56
use rustc_middle::ty::FloatTy;
@@ -8,8 +9,8 @@ use rustc_span::{Symbol, sym};
89
use tracing::trace;
910

1011
use super::{
11-
ImmTy, InterpCx, InterpResult, Machine, OpTy, PlaceTy, Provenance, Scalar, Size, interp_ok,
12-
throw_ub_format,
12+
ImmTy, InterpCx, InterpResult, Machine, MulAddType, OpTy, PlaceTy, Provenance, Scalar, Size,
13+
interp_ok, throw_ub_format,
1314
};
1415
use crate::interpret::Writeable;
1516

@@ -701,6 +702,43 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
701702
};
702703
}
703704
}
705+
sym::simd_fma | sym::simd_relaxed_fma => {
706+
// `simd_fma` should always deterministically use `mul_add`, whereas `relaxed_fma`
707+
// is non-deterministic, and can use either `mul_add` or `a * b + c`
708+
let typ = match intrinsic_name {
709+
sym::simd_fma => MulAddType::Fused,
710+
sym::simd_relaxed_fma => MulAddType::Nondeterministic,
711+
_ => unreachable!(),
712+
};
713+
714+
let (a, a_len) = self.project_to_simd(&args[0])?;
715+
let (b, b_len) = self.project_to_simd(&args[1])?;
716+
let (c, c_len) = self.project_to_simd(&args[2])?;
717+
let (dest, dest_len) = self.project_to_simd(&dest)?;
718+
719+
assert_eq!(dest_len, a_len);
720+
assert_eq!(dest_len, b_len);
721+
assert_eq!(dest_len, c_len);
722+
723+
for i in 0..dest_len {
724+
let a = self.read_scalar(&self.project_index(&a, i)?)?;
725+
let b = self.read_scalar(&self.project_index(&b, i)?)?;
726+
let c = self.read_scalar(&self.project_index(&c, i)?)?;
727+
let dest = self.project_index(&dest, i)?;
728+
729+
let ty::Float(float_ty) = dest.layout.ty.kind() else {
730+
span_bug!(self.cur_span(), "{} operand is not a float", intrinsic_name)
731+
};
732+
733+
let val = match float_ty {
734+
FloatTy::F16 => unimplemented!("f16_f128"),
735+
FloatTy::F32 => self.float_muladd::<Single>(a, b, c, typ)?,
736+
FloatTy::F64 => self.float_muladd::<Double>(a, b, c, typ)?,
737+
FloatTy::F128 => unimplemented!("f16_f128"),
738+
};
739+
self.write_scalar(val, &dest)?;
740+
}
741+
}
704742

705743
// Unsupported intrinsic: skip the return_to_block below.
706744
_ => return interp_ok(false),

compiler/rustc_const_eval/src/interpret/machine.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,7 @@ pub trait Machine<'tcx>: Sized {
290290
}
291291

292292
/// Determines whether the `fmuladd` intrinsics fuse the multiply-add or use separate operations.
293-
fn float_fuse_mul_add(_ecx: &mut InterpCx<'tcx, Self>) -> bool;
293+
fn float_fuse_mul_add(_ecx: &InterpCx<'tcx, Self>) -> bool;
294294

295295
/// Called before a basic block terminator is executed.
296296
#[inline]
@@ -676,7 +676,7 @@ pub macro compile_time_machine(<$tcx: lifetime>) {
676676
}
677677

678678
#[inline(always)]
679-
fn float_fuse_mul_add(_ecx: &mut InterpCx<$tcx, Self>) -> bool {
679+
fn float_fuse_mul_add(_ecx: &InterpCx<$tcx, Self>) -> bool {
680680
true
681681
}
682682

src/tools/miri/src/intrinsics/simd.rs

Lines changed: 0 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
use rand::Rng;
2-
use rustc_apfloat::Float;
31
use rustc_middle::ty::FloatTy;
42
use rustc_middle::ty;
53

@@ -83,62 +81,6 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
8381
this.write_scalar(val, &dest)?;
8482
}
8583
}
86-
"fma" | "relaxed_fma" => {
87-
let [a, b, c] = check_intrinsic_arg_count(args)?;
88-
let (a, a_len) = this.project_to_simd(a)?;
89-
let (b, b_len) = this.project_to_simd(b)?;
90-
let (c, c_len) = this.project_to_simd(c)?;
91-
let (dest, dest_len) = this.project_to_simd(dest)?;
92-
93-
assert_eq!(dest_len, a_len);
94-
assert_eq!(dest_len, b_len);
95-
assert_eq!(dest_len, c_len);
96-
97-
for i in 0..dest_len {
98-
let a = this.read_scalar(&this.project_index(&a, i)?)?;
99-
let b = this.read_scalar(&this.project_index(&b, i)?)?;
100-
let c = this.read_scalar(&this.project_index(&c, i)?)?;
101-
let dest = this.project_index(&dest, i)?;
102-
103-
let fuse: bool = intrinsic_name == "fma"
104-
|| (this.machine.float_nondet && this.machine.rng.get_mut().random());
105-
106-
// Works for f32 and f64.
107-
// FIXME: using host floats to work around https://github.com/rust-lang/miri/issues/2468.
108-
let ty::Float(float_ty) = dest.layout.ty.kind() else {
109-
span_bug!(this.cur_span(), "{} operand is not a float", intrinsic_name)
110-
};
111-
let val = match float_ty {
112-
FloatTy::F16 => unimplemented!("f16_f128"),
113-
FloatTy::F32 => {
114-
let a = a.to_f32()?;
115-
let b = b.to_f32()?;
116-
let c = c.to_f32()?;
117-
let res = if fuse {
118-
a.mul_add(b, c).value
119-
} else {
120-
((a * b).value + c).value
121-
};
122-
let res = this.adjust_nan(res, &[a, b, c]);
123-
Scalar::from(res)
124-
}
125-
FloatTy::F64 => {
126-
let a = a.to_f64()?;
127-
let b = b.to_f64()?;
128-
let c = c.to_f64()?;
129-
let res = if fuse {
130-
a.mul_add(b, c).value
131-
} else {
132-
((a * b).value + c).value
133-
};
134-
let res = this.adjust_nan(res, &[a, b, c]);
135-
Scalar::from(res)
136-
}
137-
FloatTy::F128 => unimplemented!("f16_f128"),
138-
};
139-
this.write_scalar(val, &dest)?;
140-
}
141-
}
14284
"expose_provenance" => {
14385
let [op] = check_intrinsic_arg_count(args)?;
14486
let (op, op_len) = this.project_to_simd(op)?;

src/tools/miri/src/machine.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1294,8 +1294,8 @@ impl<'tcx> Machine<'tcx> for MiriMachine<'tcx> {
12941294
}
12951295

12961296
#[inline(always)]
1297-
fn float_fuse_mul_add(ecx: &mut InterpCx<'tcx, Self>) -> bool {
1298-
ecx.machine.float_nondet && ecx.machine.rng.get_mut().random()
1297+
fn float_fuse_mul_add(ecx: &InterpCx<'tcx, Self>) -> bool {
1298+
ecx.machine.float_nondet && ecx.machine.rng.borrow_mut().random()
12991299
}
13001300

13011301
#[inline(always)]

0 commit comments

Comments
 (0)