Skip to content

Commit 23fdc66

Browse files
committed
Implement simd_fma and simd_relaxed_fma in const-eval
1 parent acf2437 commit 23fdc66

File tree

5 files changed

+68
-88
lines changed

5 files changed

+68
-88
lines changed

compiler/rustc_const_eval/src/interpret/intrinsics.rs

Lines changed: 27 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -612,14 +612,14 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
612612
dest,
613613
rustc_apfloat::Round::NearestTiesToEven,
614614
)?,
615-
sym::fmaf16 => self.fma_intrinsic::<Half>(args, dest)?,
616-
sym::fmaf32 => self.fma_intrinsic::<Single>(args, dest)?,
617-
sym::fmaf64 => self.fma_intrinsic::<Double>(args, dest)?,
618-
sym::fmaf128 => self.fma_intrinsic::<Quad>(args, dest)?,
619-
sym::fmuladdf16 => self.float_muladd_intrinsic::<Half>(args, dest)?,
620-
sym::fmuladdf32 => self.float_muladd_intrinsic::<Single>(args, dest)?,
621-
sym::fmuladdf64 => self.float_muladd_intrinsic::<Double>(args, dest)?,
622-
sym::fmuladdf128 => self.float_muladd_intrinsic::<Quad>(args, dest)?,
615+
sym::fmaf16 => self.float_muladd_intrinsic::<Half>(args, dest, true)?,
616+
sym::fmaf32 => self.float_muladd_intrinsic::<Single>(args, dest, true)?,
617+
sym::fmaf64 => self.float_muladd_intrinsic::<Double>(args, dest, true)?,
618+
sym::fmaf128 => self.float_muladd_intrinsic::<Quad>(args, dest, true)?,
619+
sym::fmuladdf16 => self.float_muladd_intrinsic::<Half>(args, dest, false)?,
620+
sym::fmuladdf32 => self.float_muladd_intrinsic::<Single>(args, dest, false)?,
621+
sym::fmuladdf64 => self.float_muladd_intrinsic::<Double>(args, dest, false)?,
622+
sym::fmuladdf128 => self.float_muladd_intrinsic::<Quad>(args, dest, false)?,
623623

624624
// Unsupported intrinsic: skip the return_to_block below.
625625
_ => return interp_ok(false),
@@ -1020,40 +1020,41 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
10201020
interp_ok(())
10211021
}
10221022

1023-
fn fma_intrinsic<F>(
1024-
&mut self,
1025-
args: &[OpTy<'tcx, M::Provenance>],
1026-
dest: &PlaceTy<'tcx, M::Provenance>,
1027-
) -> InterpResult<'tcx, ()>
1023+
fn float_muladd<F>(
1024+
&self,
1025+
a: Scalar<M::Provenance>,
1026+
b: Scalar<M::Provenance>,
1027+
c: Scalar<M::Provenance>,
1028+
deterministic: bool,
1029+
) -> InterpResult<'tcx, F>
10281030
where
10291031
F: rustc_apfloat::Float + rustc_apfloat::FloatConvert<F> + Into<Scalar<M::Provenance>>,
10301032
{
1031-
let a: F = self.read_scalar(&args[0])?.to_float()?;
1032-
let b: F = self.read_scalar(&args[1])?.to_float()?;
1033-
let c: F = self.read_scalar(&args[2])?.to_float()?;
1033+
let a: F = a.to_float()?;
1034+
let b: F = b.to_float()?;
1035+
let c: F = c.to_float()?;
1036+
1037+
let fuse = deterministic || M::float_fuse_mul_add(self);
10341038

1035-
let res = a.mul_add(b, c).value;
1039+
let res = if fuse { a.mul_add(b, c).value } else { ((a * b).value + c).value };
10361040
let res = self.adjust_nan(res, &[a, b, c]);
1037-
self.write_scalar(res, dest)?;
1038-
interp_ok(())
1041+
interp_ok(res)
10391042
}
10401043

10411044
fn float_muladd_intrinsic<F>(
10421045
&mut self,
10431046
args: &[OpTy<'tcx, M::Provenance>],
10441047
dest: &PlaceTy<'tcx, M::Provenance>,
1048+
deterministic: bool,
10451049
) -> InterpResult<'tcx, ()>
10461050
where
10471051
F: rustc_apfloat::Float + rustc_apfloat::FloatConvert<F> + Into<Scalar<M::Provenance>>,
10481052
{
1049-
let a: F = self.read_scalar(&args[0])?.to_float()?;
1050-
let b: F = self.read_scalar(&args[1])?.to_float()?;
1051-
let c: F = self.read_scalar(&args[2])?.to_float()?;
1052-
1053-
let fuse = M::float_fuse_mul_add(self);
1053+
let a = self.read_scalar(&args[0])?;
1054+
let b = self.read_scalar(&args[1])?;
1055+
let c = self.read_scalar(&args[2])?;
10541056

1055-
let res = if fuse { a.mul_add(b, c).value } else { ((a * b).value + c).value };
1056-
let res = self.adjust_nan(res, &[a, b, c]);
1057+
let res = self.float_muladd::<F>(a, b, c, deterministic)?;
10571058
self.write_scalar(res, dest)?;
10581059
interp_ok(())
10591060
}

compiler/rustc_const_eval/src/interpret/intrinsics/simd.rs

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -701,6 +701,43 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
701701
};
702702
}
703703
}
704+
sym::simd_fma | sym::simd_relaxed_fma => {
705+
// `simd_fma` should always deterministically use `mul_add`, whereas `relaxed_fma`
706+
// is non-deterministic, and can use either `mul_add` or `a * b + c`
707+
let deterministic = intrinsic_name == sym::simd_fma;
708+
709+
let (a, a_len) = self.project_to_simd(&args[0])?;
710+
let (b, b_len) = self.project_to_simd(&args[1])?;
711+
let (c, c_len) = self.project_to_simd(&args[2])?;
712+
let (dest, dest_len) = self.project_to_simd(&dest)?;
713+
714+
assert_eq!(dest_len, a_len);
715+
assert_eq!(dest_len, b_len);
716+
assert_eq!(dest_len, c_len);
717+
718+
for i in 0..dest_len {
719+
let a = self.read_scalar(&self.project_index(&a, i)?)?;
720+
let b = self.read_scalar(&self.project_index(&b, i)?)?;
721+
let c = self.read_scalar(&self.project_index(&c, i)?)?;
722+
let dest = self.project_index(&dest, i)?;
723+
724+
let ty::Float(float_ty) = dest.layout.ty.kind() else {
725+
span_bug!(self.cur_span(), "{} operand is not a float", intrinsic_name)
726+
};
727+
728+
let val = match float_ty {
729+
FloatTy::F16 => unimplemented!("f16_f128"),
730+
FloatTy::F32 => {
731+
Scalar::from_f32(self.float_muladd(a, b, c, deterministic)?)
732+
}
733+
FloatTy::F64 => {
734+
Scalar::from_f64(self.float_muladd(a, b, c, deterministic)?)
735+
}
736+
FloatTy::F128 => unimplemented!("f16_f128"),
737+
};
738+
self.write_scalar(val, &dest)?;
739+
}
740+
}
704741

705742
// Unsupported intrinsic: skip the return_to_block below.
706743
_ => return interp_ok(false),

compiler/rustc_const_eval/src/interpret/machine.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,7 @@ pub trait Machine<'tcx>: Sized {
290290
}
291291

292292
/// Determines whether the `fmuladd` intrinsics fuse the multiply-add or use separate operations.
293-
fn float_fuse_mul_add(_ecx: &mut InterpCx<'tcx, Self>) -> bool;
293+
fn float_fuse_mul_add(_ecx: &InterpCx<'tcx, Self>) -> bool;
294294

295295
/// Called before a basic block terminator is executed.
296296
#[inline]
@@ -676,7 +676,7 @@ pub macro compile_time_machine(<$tcx: lifetime>) {
676676
}
677677

678678
#[inline(always)]
679-
fn float_fuse_mul_add(_ecx: &mut InterpCx<$tcx, Self>) -> bool {
679+
fn float_fuse_mul_add(_ecx: &InterpCx<$tcx, Self>) -> bool {
680680
true
681681
}
682682

src/tools/miri/src/intrinsics/simd.rs

Lines changed: 0 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
use rand::Rng;
2-
use rustc_apfloat::Float;
31
use rustc_middle::ty::FloatTy;
42
use rustc_middle::ty;
53

@@ -83,62 +81,6 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
8381
this.write_scalar(val, &dest)?;
8482
}
8583
}
86-
"fma" | "relaxed_fma" => {
87-
let [a, b, c] = check_intrinsic_arg_count(args)?;
88-
let (a, a_len) = this.project_to_simd(a)?;
89-
let (b, b_len) = this.project_to_simd(b)?;
90-
let (c, c_len) = this.project_to_simd(c)?;
91-
let (dest, dest_len) = this.project_to_simd(dest)?;
92-
93-
assert_eq!(dest_len, a_len);
94-
assert_eq!(dest_len, b_len);
95-
assert_eq!(dest_len, c_len);
96-
97-
for i in 0..dest_len {
98-
let a = this.read_scalar(&this.project_index(&a, i)?)?;
99-
let b = this.read_scalar(&this.project_index(&b, i)?)?;
100-
let c = this.read_scalar(&this.project_index(&c, i)?)?;
101-
let dest = this.project_index(&dest, i)?;
102-
103-
let fuse: bool = intrinsic_name == "fma"
104-
|| (this.machine.float_nondet && this.machine.rng.get_mut().random());
105-
106-
// Works for f32 and f64.
107-
// FIXME: using host floats to work around https://github.com/rust-lang/miri/issues/2468.
108-
let ty::Float(float_ty) = dest.layout.ty.kind() else {
109-
span_bug!(this.cur_span(), "{} operand is not a float", intrinsic_name)
110-
};
111-
let val = match float_ty {
112-
FloatTy::F16 => unimplemented!("f16_f128"),
113-
FloatTy::F32 => {
114-
let a = a.to_f32()?;
115-
let b = b.to_f32()?;
116-
let c = c.to_f32()?;
117-
let res = if fuse {
118-
a.mul_add(b, c).value
119-
} else {
120-
((a * b).value + c).value
121-
};
122-
let res = this.adjust_nan(res, &[a, b, c]);
123-
Scalar::from(res)
124-
}
125-
FloatTy::F64 => {
126-
let a = a.to_f64()?;
127-
let b = b.to_f64()?;
128-
let c = c.to_f64()?;
129-
let res = if fuse {
130-
a.mul_add(b, c).value
131-
} else {
132-
((a * b).value + c).value
133-
};
134-
let res = this.adjust_nan(res, &[a, b, c]);
135-
Scalar::from(res)
136-
}
137-
FloatTy::F128 => unimplemented!("f16_f128"),
138-
};
139-
this.write_scalar(val, &dest)?;
140-
}
141-
}
14284
"expose_provenance" => {
14385
let [op] = check_intrinsic_arg_count(args)?;
14486
let (op, op_len) = this.project_to_simd(op)?;

src/tools/miri/src/machine.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1294,8 +1294,8 @@ impl<'tcx> Machine<'tcx> for MiriMachine<'tcx> {
12941294
}
12951295

12961296
#[inline(always)]
1297-
fn float_fuse_mul_add(ecx: &mut InterpCx<'tcx, Self>) -> bool {
1298-
ecx.machine.float_nondet && ecx.machine.rng.get_mut().random()
1297+
fn float_fuse_mul_add(ecx: &InterpCx<'tcx, Self>) -> bool {
1298+
ecx.machine.float_nondet && ecx.machine.rng.borrow_mut().random()
12991299
}
13001300

13011301
#[inline(always)]

0 commit comments

Comments
 (0)