Skip to content

Commit ac970b6

Browse files
committed
Make most SIMD intrinsics const-compatible (Port the Miri implementation to rustc_const_eval)
Remaining: Math functions (`fsqrt`, `fsin`, `fcos`, `fexp`, `fexp2`, `flog`, `flog2`, `flog10`), Funnel Shifts, `dyn` extract-inserts and FMA
1 parent 52618eb commit ac970b6

File tree

7 files changed

+1033
-972
lines changed

7 files changed

+1033
-972
lines changed

compiler/rustc_const_eval/src/interpret/intrinsics.rs

Lines changed: 121 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@ use rustc_apfloat::ieee::{Double, Half, Quad, Single};
99
use rustc_middle::mir::interpret::{CTFE_ALLOC_SALT, read_target_uint, write_target_uint};
1010
use rustc_middle::mir::{self, BinOp, ConstValue, NonDivergingIntrinsic};
1111
use rustc_middle::ty::layout::TyAndLayout;
12-
use rustc_middle::ty::{Ty, TyCtxt};
13-
use rustc_middle::{bug, ty};
12+
use rustc_middle::ty::{FloatTy, Ty, TyCtxt};
13+
use rustc_middle::{bug, span_bug, ty};
1414
use rustc_span::{Symbol, sym};
1515
use tracing::trace;
1616

@@ -22,6 +22,15 @@ use super::{
2222
throw_ub_custom, throw_ub_format,
2323
};
2424
use crate::fluent_generated as fluent;
25+
use crate::interpret::Projectable;
26+
27+
#[derive(Copy, Clone)]
28+
pub(super) enum MinMax {
29+
MinNum,
30+
MaxNum,
31+
Minimum,
32+
Maximum,
33+
}
2534

2635
/// Directly returns an `Allocation` containing an absolute path representation of the given type.
2736
pub(crate) fn alloc_type_name<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> (AllocId, u64) {
@@ -123,6 +132,10 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
123132
let intrinsic_name = self.tcx.item_name(instance.def_id());
124133
let tcx = self.tcx.tcx;
125134

135+
if intrinsic_name.as_str().starts_with("simd_") {
136+
return self.eval_simd_intrinsic(intrinsic_name, instance, args, dest, ret);
137+
}
138+
126139
match intrinsic_name {
127140
sym::type_name => {
128141
let tp_ty = instance.args.type_at(0);
@@ -453,38 +466,6 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
453466
let size = ImmTy::from_int(pointee_layout.size.bytes(), ret_layout);
454467
self.exact_div(&val, &size, dest)?;
455468
}
456-
457-
sym::simd_insert => {
458-
let index = u64::from(self.read_scalar(&args[1])?.to_u32()?);
459-
let elem = &args[2];
460-
let (input, input_len) = self.project_to_simd(&args[0])?;
461-
let (dest, dest_len) = self.project_to_simd(dest)?;
462-
assert_eq!(input_len, dest_len, "Return vector length must match input length");
463-
// Bounds are not checked by typeck so we have to do it ourselves.
464-
if index >= input_len {
465-
throw_ub_format!(
466-
"`simd_insert` index {index} is out-of-bounds of vector with length {input_len}"
467-
);
468-
}
469-
470-
for i in 0..dest_len {
471-
let place = self.project_index(&dest, i)?;
472-
let value =
473-
if i == index { elem.clone() } else { self.project_index(&input, i)? };
474-
self.copy_op(&value, &place)?;
475-
}
476-
}
477-
sym::simd_extract => {
478-
let index = u64::from(self.read_scalar(&args[1])?.to_u32()?);
479-
let (input, input_len) = self.project_to_simd(&args[0])?;
480-
// Bounds are not checked by typeck so we have to do it ourselves.
481-
if index >= input_len {
482-
throw_ub_format!(
483-
"`simd_extract` index {index} is out-of-bounds of vector with length {input_len}"
484-
);
485-
}
486-
self.copy_op(&self.project_index(&input, index)?, dest)?;
487-
}
488469
sym::black_box => {
489470
// These just return their argument
490471
self.copy_op(&args[0], dest)?;
@@ -510,25 +491,33 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
510491
self.write_scalar(Scalar::from_target_usize(align.bytes(), self), dest)?;
511492
}
512493

513-
sym::minnumf16 => self.float_min_intrinsic::<Half>(args, dest)?,
514-
sym::minnumf32 => self.float_min_intrinsic::<Single>(args, dest)?,
515-
sym::minnumf64 => self.float_min_intrinsic::<Double>(args, dest)?,
516-
sym::minnumf128 => self.float_min_intrinsic::<Quad>(args, dest)?,
494+
sym::minnumf16 => self.float_minmax_intrinsic::<Half>(args, dest, MinMax::MinNum)?,
495+
sym::minnumf32 => self.float_minmax_intrinsic::<Single>(args, dest, MinMax::MinNum)?,
496+
sym::minnumf64 => self.float_minmax_intrinsic::<Double>(args, dest, MinMax::MinNum)?,
497+
sym::minnumf128 => self.float_minmax_intrinsic::<Quad>(args, dest, MinMax::MinNum)?,
517498

518-
sym::minimumf16 => self.float_minimum_intrinsic::<Half>(args, dest)?,
519-
sym::minimumf32 => self.float_minimum_intrinsic::<Single>(args, dest)?,
520-
sym::minimumf64 => self.float_minimum_intrinsic::<Double>(args, dest)?,
521-
sym::minimumf128 => self.float_minimum_intrinsic::<Quad>(args, dest)?,
499+
sym::minimumf16 => self.float_minmax_intrinsic::<Half>(args, dest, MinMax::Minimum)?,
500+
sym::minimumf32 => {
501+
self.float_minmax_intrinsic::<Single>(args, dest, MinMax::Minimum)?
502+
}
503+
sym::minimumf64 => {
504+
self.float_minmax_intrinsic::<Double>(args, dest, MinMax::Minimum)?
505+
}
506+
sym::minimumf128 => self.float_minmax_intrinsic::<Quad>(args, dest, MinMax::Minimum)?,
522507

523-
sym::maxnumf16 => self.float_max_intrinsic::<Half>(args, dest)?,
524-
sym::maxnumf32 => self.float_max_intrinsic::<Single>(args, dest)?,
525-
sym::maxnumf64 => self.float_max_intrinsic::<Double>(args, dest)?,
526-
sym::maxnumf128 => self.float_max_intrinsic::<Quad>(args, dest)?,
508+
sym::maxnumf16 => self.float_minmax_intrinsic::<Half>(args, dest, MinMax::MaxNum)?,
509+
sym::maxnumf32 => self.float_minmax_intrinsic::<Single>(args, dest, MinMax::MaxNum)?,
510+
sym::maxnumf64 => self.float_minmax_intrinsic::<Double>(args, dest, MinMax::MaxNum)?,
511+
sym::maxnumf128 => self.float_minmax_intrinsic::<Quad>(args, dest, MinMax::MaxNum)?,
527512

528-
sym::maximumf16 => self.float_maximum_intrinsic::<Half>(args, dest)?,
529-
sym::maximumf32 => self.float_maximum_intrinsic::<Single>(args, dest)?,
530-
sym::maximumf64 => self.float_maximum_intrinsic::<Double>(args, dest)?,
531-
sym::maximumf128 => self.float_maximum_intrinsic::<Quad>(args, dest)?,
513+
sym::maximumf16 => self.float_minmax_intrinsic::<Half>(args, dest, MinMax::Maximum)?,
514+
sym::maximumf32 => {
515+
self.float_minmax_intrinsic::<Single>(args, dest, MinMax::Maximum)?
516+
}
517+
sym::maximumf64 => {
518+
self.float_minmax_intrinsic::<Double>(args, dest, MinMax::Maximum)?
519+
}
520+
sym::maximumf128 => self.float_minmax_intrinsic::<Quad>(args, dest, MinMax::Maximum)?,
532521

533522
sym::copysignf16 => self.float_copysign_intrinsic::<Half>(args, dest)?,
534523
sym::copysignf32 => self.float_copysign_intrinsic::<Single>(args, dest)?,
@@ -917,78 +906,45 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
917906
interp_ok(Scalar::from_bool(lhs_bytes == rhs_bytes))
918907
}
919908

920-
fn float_min_intrinsic<F>(
909+
fn float_minmax_intrinsic<F>(
921910
&mut self,
922911
args: &[OpTy<'tcx, M::Provenance>],
923912
dest: &PlaceTy<'tcx, M::Provenance>,
913+
op: MinMax,
924914
) -> InterpResult<'tcx, ()>
925915
where
926916
F: rustc_apfloat::Float + rustc_apfloat::FloatConvert<F> + Into<Scalar<M::Provenance>>,
927917
{
928-
let a: F = self.read_scalar(&args[0])?.to_float()?;
929-
let b: F = self.read_scalar(&args[1])?.to_float()?;
930-
let res = if a == b {
931-
// They are definitely not NaN (those are never equal), but they could be `+0` and `-0`.
932-
// Let the machine decide which one to return.
933-
M::equal_float_min_max(self, a, b)
934-
} else {
935-
self.adjust_nan(a.min(b), &[a, b])
936-
};
918+
let res = self.float_minmax::<F>(&args[0], &args[1], op)?;
937919
self.write_scalar(res, dest)?;
938920
interp_ok(())
939921
}
940922

941-
fn float_max_intrinsic<F>(
942-
&mut self,
943-
args: &[OpTy<'tcx, M::Provenance>],
944-
dest: &PlaceTy<'tcx, M::Provenance>,
945-
) -> InterpResult<'tcx, ()>
923+
pub(super) fn float_minmax<F>(
924+
&self,
925+
a: &impl Projectable<'tcx, M::Provenance>,
926+
b: &impl Projectable<'tcx, M::Provenance>,
927+
op: MinMax,
928+
) -> InterpResult<'tcx, Scalar<M::Provenance>>
946929
where
947930
F: rustc_apfloat::Float + rustc_apfloat::FloatConvert<F> + Into<Scalar<M::Provenance>>,
948931
{
949-
let a: F = self.read_scalar(&args[0])?.to_float()?;
950-
let b: F = self.read_scalar(&args[1])?.to_float()?;
951-
let res = if a == b {
932+
let a: F = self.read_scalar(a)?.to_float()?;
933+
let b: F = self.read_scalar(b)?.to_float()?;
934+
let res = if matches!(op, MinMax::MaxNum | MinMax::MinNum) && a == b {
952935
// They are definitely not NaN (those are never equal), but they could be `+0` and `-0`.
953936
// Let the machine decide which one to return.
954937
M::equal_float_min_max(self, a, b)
955938
} else {
956-
self.adjust_nan(a.max(b), &[a, b])
939+
let res = match op {
940+
MinMax::MinNum => a.min(b),
941+
MinMax::MaxNum => a.max(b),
942+
MinMax::Minimum => a.minimum(b),
943+
MinMax::Maximum => a.maximum(b),
944+
};
945+
self.adjust_nan(res, &[a, b])
957946
};
958-
self.write_scalar(res, dest)?;
959-
interp_ok(())
960-
}
961-
962-
fn float_minimum_intrinsic<F>(
963-
&mut self,
964-
args: &[OpTy<'tcx, M::Provenance>],
965-
dest: &PlaceTy<'tcx, M::Provenance>,
966-
) -> InterpResult<'tcx, ()>
967-
where
968-
F: rustc_apfloat::Float + rustc_apfloat::FloatConvert<F> + Into<Scalar<M::Provenance>>,
969-
{
970-
let a: F = self.read_scalar(&args[0])?.to_float()?;
971-
let b: F = self.read_scalar(&args[1])?.to_float()?;
972-
let res = a.minimum(b);
973-
let res = self.adjust_nan(res, &[a, b]);
974-
self.write_scalar(res, dest)?;
975-
interp_ok(())
976-
}
977-
978-
fn float_maximum_intrinsic<F>(
979-
&mut self,
980-
args: &[OpTy<'tcx, M::Provenance>],
981-
dest: &PlaceTy<'tcx, M::Provenance>,
982-
) -> InterpResult<'tcx, ()>
983-
where
984-
F: rustc_apfloat::Float + rustc_apfloat::FloatConvert<F> + Into<Scalar<M::Provenance>>,
985-
{
986-
let a: F = self.read_scalar(&args[0])?.to_float()?;
987-
let b: F = self.read_scalar(&args[1])?.to_float()?;
988-
let res = a.maximum(b);
989-
let res = self.adjust_nan(res, &[a, b]);
990-
self.write_scalar(res, dest)?;
991-
interp_ok(())
947+
interp_ok(res.into())
992948
}
993949

994950
fn float_copysign_intrinsic<F>(
@@ -1035,4 +991,66 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
1035991
self.write_scalar(res, dest)?;
1036992
interp_ok(())
1037993
}
994+
995+
fn float_to_int_inner<F: rustc_apfloat::Float>(
996+
&self,
997+
src: F,
998+
cast_to: TyAndLayout<'tcx>,
999+
round: rustc_apfloat::Round,
1000+
) -> (Scalar<M::Provenance>, rustc_apfloat::Status) {
1001+
let int_size = cast_to.layout.size;
1002+
match cast_to.ty.kind() {
1003+
// Unsigned
1004+
ty::Uint(_) => {
1005+
let res = src.to_u128_r(int_size.bits_usize(), round, &mut false);
1006+
(Scalar::from_uint(res.value, int_size), res.status)
1007+
}
1008+
// Signed
1009+
ty::Int(_) => {
1010+
let res = src.to_i128_r(int_size.bits_usize(), round, &mut false);
1011+
(Scalar::from_int(res.value, int_size), res.status)
1012+
}
1013+
// Nothing else
1014+
_ => span_bug!(
1015+
self.cur_span(),
1016+
"attempted float-to-int conversion with non-int output type {}",
1017+
cast_to.ty,
1018+
),
1019+
}
1020+
}
1021+
1022+
/// Converts `src` from floating point to integer type `dest_ty`
1023+
/// after rounding with mode `round`.
1024+
/// Returns `None` if `f` is NaN or out of range.
1025+
pub fn float_to_int_checked(
1026+
&self,
1027+
src: &ImmTy<'tcx, M::Provenance>,
1028+
cast_to: TyAndLayout<'tcx>,
1029+
round: rustc_apfloat::Round,
1030+
) -> InterpResult<'tcx, Option<ImmTy<'tcx, M::Provenance>>> {
1031+
let ty::Float(fty) = src.layout.ty.kind() else {
1032+
bug!("float_to_int_checked: non-float input type {}", src.layout.ty)
1033+
};
1034+
1035+
let (val, status) = match fty {
1036+
FloatTy::F16 => self.float_to_int_inner(src.to_scalar().to_f16()?, cast_to, round),
1037+
FloatTy::F32 => self.float_to_int_inner(src.to_scalar().to_f32()?, cast_to, round),
1038+
FloatTy::F64 => self.float_to_int_inner(src.to_scalar().to_f64()?, cast_to, round),
1039+
FloatTy::F128 => self.float_to_int_inner(src.to_scalar().to_f128()?, cast_to, round),
1040+
};
1041+
1042+
if status.intersects(
1043+
rustc_apfloat::Status::INVALID_OP
1044+
| rustc_apfloat::Status::OVERFLOW
1045+
| rustc_apfloat::Status::UNDERFLOW,
1046+
) {
1047+
// Floating point value is NaN (flagged with INVALID_OP) or outside the range
1048+
// of values of the integer type (flagged with OVERFLOW or UNDERFLOW).
1049+
interp_ok(None)
1050+
} else {
1051+
// Floating point value can be represented by the integer type after rounding.
1052+
// The INEXACT flag is ignored on purpose to allow rounding.
1053+
interp_ok(Some(ImmTy::from_scalar(val, cast_to)))
1054+
}
1055+
}
10381056
}

compiler/rustc_const_eval/src/interpret/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ mod operand;
1212
mod operator;
1313
mod place;
1414
mod projection;
15+
mod simd_intrinsics;
1516
mod stack;
1617
mod step;
1718
mod traits;

0 commit comments

Comments
 (0)