diff --git a/compiler/rustc_const_eval/src/interpret/intrinsics.rs b/compiler/rustc_const_eval/src/interpret/intrinsics.rs index 785978b4d7111..968b5fc64c336 100644 --- a/compiler/rustc_const_eval/src/interpret/intrinsics.rs +++ b/compiler/rustc_const_eval/src/interpret/intrinsics.rs @@ -2,6 +2,8 @@ //! looking at their MIR. Intrinsics/functions supported here are shared by CTFE //! and miri. +mod simd; + use std::assert_matches::assert_matches; use rustc_abi::{FieldIdx, HasDataLayout, Size}; @@ -9,8 +11,8 @@ use rustc_apfloat::ieee::{Double, Half, Quad, Single}; use rustc_middle::mir::interpret::{CTFE_ALLOC_SALT, read_target_uint, write_target_uint}; use rustc_middle::mir::{self, BinOp, ConstValue, NonDivergingIntrinsic}; use rustc_middle::ty::layout::TyAndLayout; -use rustc_middle::ty::{Ty, TyCtxt}; -use rustc_middle::{bug, ty}; +use rustc_middle::ty::{FloatTy, Ty, TyCtxt}; +use rustc_middle::{bug, span_bug, ty}; use rustc_span::{Symbol, sym}; use tracing::trace; @@ -121,6 +123,11 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> { ) -> InterpResult<'tcx, bool> { let instance_args = instance.args; let intrinsic_name = self.tcx.item_name(instance.def_id()); + + if intrinsic_name.as_str().starts_with("simd_") { + return self.eval_simd_intrinsic(intrinsic_name, instance_args, args, dest, ret); + } + let tcx = self.tcx.tcx; match intrinsic_name { @@ -454,37 +461,6 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> { self.exact_div(&val, &size, dest)?; } - sym::simd_insert => { - let index = u64::from(self.read_scalar(&args[1])?.to_u32()?); - let elem = &args[2]; - let (input, input_len) = self.project_to_simd(&args[0])?; - let (dest, dest_len) = self.project_to_simd(dest)?; - assert_eq!(input_len, dest_len, "Return vector length must match input length"); - // Bounds are not checked by typeck so we have to do it ourselves. - if index >= input_len { - throw_ub_format!( - "`simd_insert` index {index} is out-of-bounds of vector with length {input_len}" - ); - } - - for i in 0..dest_len { - let place = self.project_index(&dest, i)?; - let value = - if i == index { elem.clone() } else { self.project_index(&input, i)? }; - self.copy_op(&value, &place)?; - } - } - sym::simd_extract => { - let index = u64::from(self.read_scalar(&args[1])?.to_u32()?); - let (input, input_len) = self.project_to_simd(&args[0])?; - // Bounds are not checked by typeck so we have to do it ourselves. - if index >= input_len { - throw_ub_format!( - "`simd_extract` index {index} is out-of-bounds of vector with length {input_len}" - ); - } - self.copy_op(&self.project_index(&input, index)?, dest)?; - } sym::black_box => { // These just return their argument self.copy_op(&args[0], dest)?; @@ -1081,4 +1057,66 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> { self.write_scalar(res, dest)?; interp_ok(()) } + + /// Converts `src` from floating point to integer type `dest_ty` + /// after rounding with mode `round`. + /// Returns `None` if `f` is NaN or out of range. + pub fn float_to_int_checked( + &self, + src: &ImmTy<'tcx, M::Provenance>, + cast_to: TyAndLayout<'tcx>, + round: rustc_apfloat::Round, + ) -> InterpResult<'tcx, Option>> { + fn float_to_int_inner<'tcx, F: rustc_apfloat::Float, M: Machine<'tcx>>( + ecx: &InterpCx<'tcx, M>, + src: F, + cast_to: TyAndLayout<'tcx>, + round: rustc_apfloat::Round, + ) -> (Scalar, rustc_apfloat::Status) { + let int_size = cast_to.layout.size; + match cast_to.ty.kind() { + // Unsigned + ty::Uint(_) => { + let res = src.to_u128_r(int_size.bits_usize(), round, &mut false); + (Scalar::from_uint(res.value, int_size), res.status) + } + // Signed + ty::Int(_) => { + let res = src.to_i128_r(int_size.bits_usize(), round, &mut false); + (Scalar::from_int(res.value, int_size), res.status) + } + // Nothing else + _ => span_bug!( + ecx.cur_span(), + "attempted float-to-int conversion with non-int output type {}", + cast_to.ty, + ), + } + } + + let ty::Float(fty) = src.layout.ty.kind() else { + bug!("float_to_int_checked: non-float input type {}", src.layout.ty) + }; + + let (val, status) = match fty { + FloatTy::F16 => float_to_int_inner(self, src.to_scalar().to_f16()?, cast_to, round), + FloatTy::F32 => float_to_int_inner(self, src.to_scalar().to_f32()?, cast_to, round), + FloatTy::F64 => float_to_int_inner(self, src.to_scalar().to_f64()?, cast_to, round), + FloatTy::F128 => float_to_int_inner(self, src.to_scalar().to_f128()?, cast_to, round), + }; + + if status.intersects( + rustc_apfloat::Status::INVALID_OP + | rustc_apfloat::Status::OVERFLOW + | rustc_apfloat::Status::UNDERFLOW, + ) { + // Floating point value is NaN (flagged with INVALID_OP) or outside the range + // of values of the integer type (flagged with OVERFLOW or UNDERFLOW). + interp_ok(None) + } else { + // Floating point value can be represented by the integer type after rounding. + // The INEXACT flag is ignored on purpose to allow rounding. + interp_ok(Some(ImmTy::from_scalar(val, cast_to))) + } + } } diff --git a/compiler/rustc_const_eval/src/interpret/intrinsics/simd.rs b/compiler/rustc_const_eval/src/interpret/intrinsics/simd.rs new file mode 100644 index 0000000000000..0dba66ae93721 --- /dev/null +++ b/compiler/rustc_const_eval/src/interpret/intrinsics/simd.rs @@ -0,0 +1,782 @@ +use either::Either; +use rustc_abi::Endian; +use rustc_apfloat::{Float, Round}; +use rustc_middle::mir::interpret::{InterpErrorKind, UndefinedBehaviorInfo}; +use rustc_middle::ty::FloatTy; +use rustc_middle::{bug, err_ub_format, mir, span_bug, throw_unsup_format, ty}; +use rustc_span::{Symbol, sym}; +use tracing::trace; + +use super::{ + ImmTy, InterpCx, InterpResult, Machine, OpTy, PlaceTy, Provenance, Scalar, Size, interp_ok, + throw_ub_format, +}; +use crate::interpret::Writeable; + +#[derive(Copy, Clone)] +pub(crate) enum MinMax { + Min, + Max, +} + +impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> { + /// Returns `true` if emulation happened. + /// Here we implement the intrinsics that are common to all CTFE instances; individual machines can add their own + /// intrinsic handling. + pub fn eval_simd_intrinsic( + &mut self, + intrinsic_name: Symbol, + generic_args: ty::GenericArgsRef<'tcx>, + args: &[OpTy<'tcx, M::Provenance>], + dest: &PlaceTy<'tcx, M::Provenance>, + ret: Option, + ) -> InterpResult<'tcx, bool> { + let dest = dest.force_mplace(self)?; + + match intrinsic_name { + sym::simd_insert => { + let index = u64::from(self.read_scalar(&args[1])?.to_u32()?); + let elem = &args[2]; + let (input, input_len) = self.project_to_simd(&args[0])?; + let (dest, dest_len) = self.project_to_simd(&dest)?; + assert_eq!(input_len, dest_len, "Return vector length must match input length"); + // Bounds are not checked by typeck so we have to do it ourselves. + if index >= input_len { + throw_ub_format!( + "`simd_insert` index {index} is out-of-bounds of vector with length {input_len}" + ); + } + + for i in 0..dest_len { + let place = self.project_index(&dest, i)?; + let value = + if i == index { elem.clone() } else { self.project_index(&input, i)? }; + self.copy_op(&value, &place)?; + } + } + sym::simd_extract => { + let index = u64::from(self.read_scalar(&args[1])?.to_u32()?); + let (input, input_len) = self.project_to_simd(&args[0])?; + // Bounds are not checked by typeck so we have to do it ourselves. + if index >= input_len { + throw_ub_format!( + "`simd_extract` index {index} is out-of-bounds of vector with length {input_len}" + ); + } + self.copy_op(&self.project_index(&input, index)?, &dest)?; + } + sym::simd_neg + | sym::simd_fabs + | sym::simd_ceil + | sym::simd_floor + | sym::simd_round + | sym::simd_round_ties_even + | sym::simd_trunc + | sym::simd_ctlz + | sym::simd_ctpop + | sym::simd_cttz + | sym::simd_bswap + | sym::simd_bitreverse => { + let (op, op_len) = self.project_to_simd(&args[0])?; + let (dest, dest_len) = self.project_to_simd(&dest)?; + + assert_eq!(dest_len, op_len); + + #[derive(Copy, Clone)] + enum Op { + MirOp(mir::UnOp), + Abs, + Round(rustc_apfloat::Round), + Numeric(Symbol), + } + let which = match intrinsic_name { + sym::simd_neg => Op::MirOp(mir::UnOp::Neg), + sym::simd_fabs => Op::Abs, + sym::simd_ceil => Op::Round(rustc_apfloat::Round::TowardPositive), + sym::simd_floor => Op::Round(rustc_apfloat::Round::TowardNegative), + sym::simd_round => Op::Round(rustc_apfloat::Round::NearestTiesToAway), + sym::simd_round_ties_even => Op::Round(rustc_apfloat::Round::NearestTiesToEven), + sym::simd_trunc => Op::Round(rustc_apfloat::Round::TowardZero), + sym::simd_ctlz => Op::Numeric(sym::ctlz), + sym::simd_ctpop => Op::Numeric(sym::ctpop), + sym::simd_cttz => Op::Numeric(sym::cttz), + sym::simd_bswap => Op::Numeric(sym::bswap), + sym::simd_bitreverse => Op::Numeric(sym::bitreverse), + _ => unreachable!(), + }; + + for i in 0..dest_len { + let op = self.read_immediate(&self.project_index(&op, i)?)?; + let dest = self.project_index(&dest, i)?; + let val = match which { + Op::MirOp(mir_op) => { + // this already does NaN adjustments + self.unary_op(mir_op, &op)?.to_scalar() + } + Op::Abs => { + // Works for f32 and f64. + let ty::Float(float_ty) = op.layout.ty.kind() else { + span_bug!( + self.cur_span(), + "{} operand is not a float", + intrinsic_name + ) + }; + let op = op.to_scalar(); + // "Bitwise" operation, no NaN adjustments + match float_ty { + FloatTy::F16 => unimplemented!("f16_f128"), + FloatTy::F32 => Scalar::from_f32(op.to_f32()?.abs()), + FloatTy::F64 => Scalar::from_f64(op.to_f64()?.abs()), + FloatTy::F128 => unimplemented!("f16_f128"), + } + } + Op::Round(rounding) => { + let ty::Float(float_ty) = op.layout.ty.kind() else { + span_bug!( + self.cur_span(), + "{} operand is not a float", + intrinsic_name + ) + }; + match float_ty { + FloatTy::F16 => unimplemented!("f16_f128"), + FloatTy::F32 => { + let f = op.to_scalar().to_f32()?; + let res = f.round_to_integral(rounding).value; + let res = self.adjust_nan(res, &[f]); + Scalar::from_f32(res) + } + FloatTy::F64 => { + let f = op.to_scalar().to_f64()?; + let res = f.round_to_integral(rounding).value; + let res = self.adjust_nan(res, &[f]); + Scalar::from_f64(res) + } + FloatTy::F128 => unimplemented!("f16_f128"), + } + } + Op::Numeric(name) => { + self.numeric_intrinsic(name, op.to_scalar(), op.layout, op.layout)? + } + }; + self.write_scalar(val, &dest)?; + } + } + sym::simd_add + | sym::simd_sub + | sym::simd_mul + | sym::simd_div + | sym::simd_rem + | sym::simd_shl + | sym::simd_shr + | sym::simd_and + | sym::simd_or + | sym::simd_xor + | sym::simd_eq + | sym::simd_ne + | sym::simd_lt + | sym::simd_le + | sym::simd_gt + | sym::simd_ge + | sym::simd_fmax + | sym::simd_fmin + | sym::simd_saturating_add + | sym::simd_saturating_sub + | sym::simd_arith_offset => { + use mir::BinOp; + + let (left, left_len) = self.project_to_simd(&args[0])?; + let (right, right_len) = self.project_to_simd(&args[1])?; + let (dest, dest_len) = self.project_to_simd(&dest)?; + + assert_eq!(dest_len, left_len); + assert_eq!(dest_len, right_len); + + enum Op { + MirOp(BinOp), + SaturatingOp(BinOp), + FMinMax(MinMax), + WrappingOffset, + } + let which = match intrinsic_name { + sym::simd_add => Op::MirOp(BinOp::Add), + sym::simd_sub => Op::MirOp(BinOp::Sub), + sym::simd_mul => Op::MirOp(BinOp::Mul), + sym::simd_div => Op::MirOp(BinOp::Div), + sym::simd_rem => Op::MirOp(BinOp::Rem), + sym::simd_shl => Op::MirOp(BinOp::ShlUnchecked), + sym::simd_shr => Op::MirOp(BinOp::ShrUnchecked), + sym::simd_and => Op::MirOp(BinOp::BitAnd), + sym::simd_or => Op::MirOp(BinOp::BitOr), + sym::simd_xor => Op::MirOp(BinOp::BitXor), + sym::simd_eq => Op::MirOp(BinOp::Eq), + sym::simd_ne => Op::MirOp(BinOp::Ne), + sym::simd_lt => Op::MirOp(BinOp::Lt), + sym::simd_le => Op::MirOp(BinOp::Le), + sym::simd_gt => Op::MirOp(BinOp::Gt), + sym::simd_ge => Op::MirOp(BinOp::Ge), + sym::simd_fmax => Op::FMinMax(MinMax::Max), + sym::simd_fmin => Op::FMinMax(MinMax::Min), + sym::simd_saturating_add => Op::SaturatingOp(BinOp::Add), + sym::simd_saturating_sub => Op::SaturatingOp(BinOp::Sub), + sym::simd_arith_offset => Op::WrappingOffset, + _ => unreachable!(), + }; + + for i in 0..dest_len { + let left = self.read_immediate(&self.project_index(&left, i)?)?; + let right = self.read_immediate(&self.project_index(&right, i)?)?; + let dest = self.project_index(&dest, i)?; + let val = match which { + Op::MirOp(mir_op) => { + // this does NaN adjustments. + let val = self.binary_op(mir_op, &left, &right).map_err_kind(|kind| { + match kind { + InterpErrorKind::UndefinedBehavior(UndefinedBehaviorInfo::ShiftOverflow { shift_amount, .. }) => { + // this resets the interpreter backtrace, but it's not worth avoiding that. + let shift_amount = match shift_amount { + Either::Left(v) => v.to_string(), + Either::Right(v) => v.to_string(), + }; + err_ub_format!("overflowing shift by {shift_amount} in `{intrinsic_name}` in lane {i}") + } + kind => kind + } + })?; + if matches!( + mir_op, + BinOp::Eq + | BinOp::Ne + | BinOp::Lt + | BinOp::Le + | BinOp::Gt + | BinOp::Ge + ) { + // Special handling for boolean-returning operations + assert_eq!(val.layout.ty, self.tcx.types.bool); + let val = val.to_scalar().to_bool().unwrap(); + bool_to_simd_element(val, dest.layout.size) + } else { + assert_ne!(val.layout.ty, self.tcx.types.bool); + assert_eq!(val.layout.ty, dest.layout.ty); + val.to_scalar() + } + } + Op::SaturatingOp(mir_op) => self.saturating_arith(mir_op, &left, &right)?, + Op::WrappingOffset => { + let ptr = left.to_scalar().to_pointer(self)?; + let offset_count = right.to_scalar().to_target_isize(self)?; + let pointee_ty = left.layout.ty.builtin_deref(true).unwrap(); + + let pointee_size = + i64::try_from(self.layout_of(pointee_ty)?.size.bytes()).unwrap(); + let offset_bytes = offset_count.wrapping_mul(pointee_size); + let offset_ptr = ptr.wrapping_signed_offset(offset_bytes, self); + Scalar::from_maybe_pointer(offset_ptr, self) + } + Op::FMinMax(op) => self.fminmax_op(op, &left, &right)?, + }; + self.write_scalar(val, &dest)?; + } + } + sym::simd_reduce_and + | sym::simd_reduce_or + | sym::simd_reduce_xor + | sym::simd_reduce_any + | sym::simd_reduce_all + | sym::simd_reduce_max + | sym::simd_reduce_min => { + use mir::BinOp; + + let (op, op_len) = self.project_to_simd(&args[0])?; + + let imm_from_bool = |b| { + ImmTy::from_scalar( + Scalar::from_bool(b), + self.layout_of(self.tcx.types.bool).unwrap(), + ) + }; + + enum Op { + MirOp(BinOp), + MirOpBool(BinOp), + MinMax(MinMax), + } + let which = match intrinsic_name { + sym::simd_reduce_and => Op::MirOp(BinOp::BitAnd), + sym::simd_reduce_or => Op::MirOp(BinOp::BitOr), + sym::simd_reduce_xor => Op::MirOp(BinOp::BitXor), + sym::simd_reduce_any => Op::MirOpBool(BinOp::BitOr), + sym::simd_reduce_all => Op::MirOpBool(BinOp::BitAnd), + sym::simd_reduce_max => Op::MinMax(MinMax::Max), + sym::simd_reduce_min => Op::MinMax(MinMax::Min), + _ => unreachable!(), + }; + + // Initialize with first lane, then proceed with the rest. + let mut res = self.read_immediate(&self.project_index(&op, 0)?)?; + if matches!(which, Op::MirOpBool(_)) { + // Convert to `bool` scalar. + res = imm_from_bool(simd_element_to_bool(res)?); + } + for i in 1..op_len { + let op = self.read_immediate(&self.project_index(&op, i)?)?; + res = match which { + Op::MirOp(mir_op) => self.binary_op(mir_op, &res, &op)?, + Op::MirOpBool(mir_op) => { + let op = imm_from_bool(simd_element_to_bool(op)?); + self.binary_op(mir_op, &res, &op)? + } + Op::MinMax(mmop) => { + if matches!(res.layout.ty.kind(), ty::Float(_)) { + ImmTy::from_scalar(self.fminmax_op(mmop, &res, &op)?, res.layout) + } else { + // Just boring integers, so NaNs to worry about + let mirop = match mmop { + MinMax::Min => BinOp::Le, + MinMax::Max => BinOp::Ge, + }; + if self.binary_op(mirop, &res, &op)?.to_scalar().to_bool()? { + res + } else { + op + } + } + } + }; + } + self.write_immediate(*res, &dest)?; + } + sym::simd_reduce_add_ordered | sym::simd_reduce_mul_ordered => { + use mir::BinOp; + + let (op, op_len) = self.project_to_simd(&args[0])?; + let init = self.read_immediate(&args[1])?; + + let mir_op = match intrinsic_name { + sym::simd_reduce_add_ordered => BinOp::Add, + sym::simd_reduce_mul_ordered => BinOp::Mul, + _ => unreachable!(), + }; + + let mut res = init; + for i in 0..op_len { + let op = self.read_immediate(&self.project_index(&op, i)?)?; + res = self.binary_op(mir_op, &res, &op)?; + } + self.write_immediate(*res, &dest)?; + } + sym::simd_select => { + let (mask, mask_len) = self.project_to_simd(&args[0])?; + let (yes, yes_len) = self.project_to_simd(&args[1])?; + let (no, no_len) = self.project_to_simd(&args[2])?; + let (dest, dest_len) = self.project_to_simd(&dest)?; + + assert_eq!(dest_len, mask_len); + assert_eq!(dest_len, yes_len); + assert_eq!(dest_len, no_len); + + for i in 0..dest_len { + let mask = self.read_immediate(&self.project_index(&mask, i)?)?; + let yes = self.read_immediate(&self.project_index(&yes, i)?)?; + let no = self.read_immediate(&self.project_index(&no, i)?)?; + let dest = self.project_index(&dest, i)?; + + let val = if simd_element_to_bool(mask)? { yes } else { no }; + self.write_immediate(*val, &dest)?; + } + } + // Variant of `select` that takes a bitmask rather than a "vector of bool". + sym::simd_select_bitmask => { + let mask = &args[0]; + let (yes, yes_len) = self.project_to_simd(&args[1])?; + let (no, no_len) = self.project_to_simd(&args[2])?; + let (dest, dest_len) = self.project_to_simd(&dest)?; + let bitmask_len = dest_len.next_multiple_of(8); + if bitmask_len > 64 { + throw_unsup_format!( + "simd_select_bitmask: vectors larger than 64 elements are currently not supported" + ); + } + + assert_eq!(dest_len, yes_len); + assert_eq!(dest_len, no_len); + + // Read the mask, either as an integer or as an array. + let mask: u64 = match mask.layout.ty.kind() { + ty::Uint(_) => { + // Any larger integer type is fine. + assert!(mask.layout.size.bits() >= bitmask_len); + self.read_scalar(mask)?.to_bits(mask.layout.size)?.try_into().unwrap() + } + ty::Array(elem, _len) if elem == &self.tcx.types.u8 => { + // The array must have exactly the right size. + assert_eq!(mask.layout.size.bits(), bitmask_len); + // Read the raw bytes. + let mask = mask.assert_mem_place(); // arrays cannot be immediate + let mask_bytes = + self.read_bytes_ptr_strip_provenance(mask.ptr(), mask.layout.size)?; + // Turn them into a `u64` in the right way. + let mask_size = mask.layout.size.bytes_usize(); + let mut mask_arr = [0u8; 8]; + match self.tcx.data_layout.endian { + Endian::Little => { + // Fill the first N bytes. + mask_arr[..mask_size].copy_from_slice(mask_bytes); + u64::from_le_bytes(mask_arr) + } + Endian::Big => { + // Fill the last N bytes. + let i = mask_arr.len().strict_sub(mask_size); + mask_arr[i..].copy_from_slice(mask_bytes); + u64::from_be_bytes(mask_arr) + } + } + } + _ => bug!("simd_select_bitmask: invalid mask type {}", mask.layout.ty), + }; + + let dest_len = u32::try_from(dest_len).unwrap(); + for i in 0..dest_len { + let bit_i = simd_bitmask_index(i, dest_len, self.tcx.data_layout.endian); + let mask = mask & 1u64.strict_shl(bit_i); + let yes = self.read_immediate(&self.project_index(&yes, i.into())?)?; + let no = self.read_immediate(&self.project_index(&no, i.into())?)?; + let dest = self.project_index(&dest, i.into())?; + + let val = if mask != 0 { yes } else { no }; + self.write_immediate(*val, &dest)?; + } + // The remaining bits of the mask are ignored. + } + // Converts a "vector of bool" into a bitmask. + sym::simd_bitmask => { + let (op, op_len) = self.project_to_simd(&args[0])?; + let bitmask_len = op_len.next_multiple_of(8); + if bitmask_len > 64 { + throw_unsup_format!( + "simd_bitmask: vectors larger than 64 elements are currently not supported" + ); + } + + let op_len = u32::try_from(op_len).unwrap(); + let mut res = 0u64; + for i in 0..op_len { + let op = self.read_immediate(&self.project_index(&op, i.into())?)?; + if simd_element_to_bool(op)? { + let bit_i = simd_bitmask_index(i, op_len, self.tcx.data_layout.endian); + res |= 1u64.strict_shl(bit_i); + } + } + // Write the result, depending on the `dest` type. + // Returns either an unsigned integer or array of `u8`. + match dest.layout.ty.kind() { + ty::Uint(_) => { + // Any larger integer type is fine, it will be zero-extended. + assert!(dest.layout.size.bits() >= bitmask_len); + self.write_scalar(Scalar::from_uint(res, dest.layout.size), &dest)?; + } + ty::Array(elem, _len) if elem == &self.tcx.types.u8 => { + // The array must have exactly the right size. + assert_eq!(dest.layout.size.bits(), bitmask_len); + // We have to write the result byte-for-byte. + let res_size = dest.layout.size.bytes_usize(); + let res_bytes; + let res_bytes_slice = match self.tcx.data_layout.endian { + Endian::Little => { + res_bytes = res.to_le_bytes(); + &res_bytes[..res_size] // take the first N bytes + } + Endian::Big => { + res_bytes = res.to_be_bytes(); + &res_bytes[res_bytes.len().strict_sub(res_size)..] // take the last N bytes + } + }; + self.write_bytes_ptr(dest.ptr(), res_bytes_slice.iter().cloned())?; + } + _ => bug!("simd_bitmask: invalid return type {}", dest.layout.ty), + } + } + sym::simd_cast + | sym::simd_as + | sym::simd_cast_ptr + | sym::simd_with_exposed_provenance => { + let (op, op_len) = self.project_to_simd(&args[0])?; + let (dest, dest_len) = self.project_to_simd(&dest)?; + + assert_eq!(dest_len, op_len); + + let unsafe_cast = intrinsic_name == sym::simd_cast; + let safe_cast = intrinsic_name == sym::simd_as; + let ptr_cast = intrinsic_name == sym::simd_cast_ptr; + let from_exposed_cast = intrinsic_name == sym::simd_with_exposed_provenance; + + for i in 0..dest_len { + let op = self.read_immediate(&self.project_index(&op, i)?)?; + let dest = self.project_index(&dest, i)?; + + let val = match (op.layout.ty.kind(), dest.layout.ty.kind()) { + // Int-to-(int|float): always safe + (ty::Int(_) | ty::Uint(_), ty::Int(_) | ty::Uint(_) | ty::Float(_)) + if safe_cast || unsafe_cast => + self.int_to_int_or_float(&op, dest.layout)?, + // Float-to-float: always safe + (ty::Float(_), ty::Float(_)) if safe_cast || unsafe_cast => + self.float_to_float_or_int(&op, dest.layout)?, + // Float-to-int in safe mode + (ty::Float(_), ty::Int(_) | ty::Uint(_)) if safe_cast => + self.float_to_float_or_int(&op, dest.layout)?, + // Float-to-int in unchecked mode + (ty::Float(_), ty::Int(_) | ty::Uint(_)) if unsafe_cast => { + self.float_to_int_checked(&op, dest.layout, Round::TowardZero)? + .ok_or_else(|| { + err_ub_format!( + "`simd_cast` intrinsic called on {op} which cannot be represented in target type `{:?}`", + dest.layout.ty + ) + })? + } + // Ptr-to-ptr cast + (ty::RawPtr(..), ty::RawPtr(..)) if ptr_cast => + self.ptr_to_ptr(&op, dest.layout)?, + // Int->Ptr casts + (ty::Int(_) | ty::Uint(_), ty::RawPtr(..)) if from_exposed_cast => + self.pointer_with_exposed_provenance_cast(&op, dest.layout)?, + // Error otherwise + _ => + throw_unsup_format!( + "Unsupported SIMD cast from element type {from_ty} to {to_ty}", + from_ty = op.layout.ty, + to_ty = dest.layout.ty, + ), + }; + self.write_immediate(*val, &dest)?; + } + } + sym::simd_shuffle_const_generic => { + let (left, left_len) = self.project_to_simd(&args[0])?; + let (right, right_len) = self.project_to_simd(&args[1])?; + let (dest, dest_len) = self.project_to_simd(&dest)?; + + let index = generic_args[2].expect_const().to_value().valtree.unwrap_branch(); + let index_len = index.len(); + + assert_eq!(left_len, right_len); + assert_eq!(u64::try_from(index_len).unwrap(), dest_len); + + for i in 0..dest_len { + let src_index: u64 = + index[usize::try_from(i).unwrap()].unwrap_leaf().to_u32().into(); + let dest = self.project_index(&dest, i)?; + + let val = if src_index < left_len { + self.read_immediate(&self.project_index(&left, src_index)?)? + } else if src_index < left_len.strict_add(right_len) { + let right_idx = src_index.strict_sub(left_len); + self.read_immediate(&self.project_index(&right, right_idx)?)? + } else { + throw_ub_format!( + "`simd_shuffle_const_generic` index {src_index} is out-of-bounds for 2 vectors with length {dest_len}" + ); + }; + self.write_immediate(*val, &dest)?; + } + } + sym::simd_shuffle => { + let (left, left_len) = self.project_to_simd(&args[0])?; + let (right, right_len) = self.project_to_simd(&args[1])?; + let (index, index_len) = self.project_to_simd(&args[2])?; + let (dest, dest_len) = self.project_to_simd(&dest)?; + + assert_eq!(left_len, right_len); + assert_eq!(index_len, dest_len); + + for i in 0..dest_len { + let src_index: u64 = self + .read_immediate(&self.project_index(&index, i)?)? + .to_scalar() + .to_u32()? + .into(); + let dest = self.project_index(&dest, i)?; + + let val = if src_index < left_len { + self.read_immediate(&self.project_index(&left, src_index)?)? + } else if src_index < left_len.strict_add(right_len) { + let right_idx = src_index.strict_sub(left_len); + self.read_immediate(&self.project_index(&right, right_idx)?)? + } else { + throw_ub_format!( + "`simd_shuffle` index {src_index} is out-of-bounds for 2 vectors with length {dest_len}" + ); + }; + self.write_immediate(*val, &dest)?; + } + } + sym::simd_gather => { + let (passthru, passthru_len) = self.project_to_simd(&args[0])?; + let (ptrs, ptrs_len) = self.project_to_simd(&args[1])?; + let (mask, mask_len) = self.project_to_simd(&args[2])?; + let (dest, dest_len) = self.project_to_simd(&dest)?; + + assert_eq!(dest_len, passthru_len); + assert_eq!(dest_len, ptrs_len); + assert_eq!(dest_len, mask_len); + + for i in 0..dest_len { + let passthru = self.read_immediate(&self.project_index(&passthru, i)?)?; + let ptr = self.read_immediate(&self.project_index(&ptrs, i)?)?; + let mask = self.read_immediate(&self.project_index(&mask, i)?)?; + let dest = self.project_index(&dest, i)?; + + let val = if simd_element_to_bool(mask)? { + let place = self.deref_pointer(&ptr)?; + self.read_immediate(&place)? + } else { + passthru + }; + self.write_immediate(*val, &dest)?; + } + } + sym::simd_scatter => { + let (value, value_len) = self.project_to_simd(&args[0])?; + let (ptrs, ptrs_len) = self.project_to_simd(&args[1])?; + let (mask, mask_len) = self.project_to_simd(&args[2])?; + + assert_eq!(ptrs_len, value_len); + assert_eq!(ptrs_len, mask_len); + + for i in 0..ptrs_len { + let value = self.read_immediate(&self.project_index(&value, i)?)?; + let ptr = self.read_immediate(&self.project_index(&ptrs, i)?)?; + let mask = self.read_immediate(&self.project_index(&mask, i)?)?; + + if simd_element_to_bool(mask)? { + let place = self.deref_pointer(&ptr)?; + self.write_immediate(*value, &place)?; + } + } + } + sym::simd_masked_load => { + let (mask, mask_len) = self.project_to_simd(&args[0])?; + let ptr = self.read_pointer(&args[1])?; + let (default, default_len) = self.project_to_simd(&args[2])?; + let (dest, dest_len) = self.project_to_simd(&dest)?; + + assert_eq!(dest_len, mask_len); + assert_eq!(dest_len, default_len); + + for i in 0..dest_len { + let mask = self.read_immediate(&self.project_index(&mask, i)?)?; + let default = self.read_immediate(&self.project_index(&default, i)?)?; + let dest = self.project_index(&dest, i)?; + + let val = if simd_element_to_bool(mask)? { + // Size * u64 is implemented as always checked + let ptr = ptr.wrapping_offset(dest.layout.size * i, self); + let place = self.ptr_to_mplace(ptr, dest.layout); + self.read_immediate(&place)? + } else { + default + }; + self.write_immediate(*val, &dest)?; + } + } + sym::simd_masked_store => { + let (mask, mask_len) = self.project_to_simd(&args[0])?; + let ptr = self.read_pointer(&args[1])?; + let (vals, vals_len) = self.project_to_simd(&args[2])?; + + assert_eq!(mask_len, vals_len); + + for i in 0..vals_len { + let mask = self.read_immediate(&self.project_index(&mask, i)?)?; + let val = self.read_immediate(&self.project_index(&vals, i)?)?; + + if simd_element_to_bool(mask)? { + // Size * u64 is implemented as always checked + let ptr = ptr.wrapping_offset(val.layout.size * i, self); + let place = self.ptr_to_mplace(ptr, val.layout); + self.write_immediate(*val, &place)? + }; + } + } + + // Unsupported intrinsic: skip the return_to_block below. + _ => return interp_ok(false), + } + + trace!("{:?}", self.dump_place(&dest.clone().into())); + self.return_to_block(ret)?; + interp_ok(true) + } + + fn fminmax_op( + &self, + op: MinMax, + left: &ImmTy<'tcx, Prov>, + right: &ImmTy<'tcx, Prov>, + ) -> InterpResult<'tcx, Scalar> { + assert_eq!(left.layout.ty, right.layout.ty); + let ty::Float(float_ty) = left.layout.ty.kind() else { + bug!("fmax operand is not a float") + }; + let left = left.to_scalar(); + let right = right.to_scalar(); + interp_ok(match float_ty { + FloatTy::F16 => unimplemented!("f16_f128"), + FloatTy::F32 => { + let left = left.to_f32()?; + let right = right.to_f32()?; + let res = match op { + MinMax::Min => left.min(right), + MinMax::Max => left.max(right), + }; + let res = self.adjust_nan(res, &[left, right]); + Scalar::from_f32(res) + } + FloatTy::F64 => { + let left = left.to_f64()?; + let right = right.to_f64()?; + let res = match op { + MinMax::Min => left.min(right), + MinMax::Max => left.max(right), + }; + let res = self.adjust_nan(res, &[left, right]); + Scalar::from_f64(res) + } + FloatTy::F128 => unimplemented!("f16_f128"), + }) + } +} + +fn simd_bitmask_index(idx: u32, vec_len: u32, endianness: Endian) -> u32 { + assert!(idx < vec_len); + match endianness { + Endian::Little => idx, + #[expect(clippy::arithmetic_side_effects)] // idx < vec_len + Endian::Big => vec_len - 1 - idx, // reverse order of bits + } +} + +fn bool_to_simd_element(b: bool, size: Size) -> Scalar { + // SIMD uses all-1 as pattern for "true". In two's complement, + // -1 has all its bits set to one and `from_int` will truncate or + // sign-extend it to `size` as required. + let val = if b { -1 } else { 0 }; + Scalar::from_int(val, size) +} + +fn simd_element_to_bool(elem: ImmTy<'_, Prov>) -> InterpResult<'_, bool> { + assert!( + matches!(elem.layout.ty.kind(), ty::Int(_) | ty::Uint(_)), + "SIMD mask element type must be an integer, but this is `{}`", + elem.layout.ty + ); + let val = elem.to_scalar().to_int(elem.layout.size)?; + interp_ok(match val { + 0 => false, + -1 => true, + _ => throw_ub_format!("each element of a SIMD mask must be all-0-bits or all-1-bits"), + }) +} diff --git a/src/tools/miri/src/helpers.rs b/src/tools/miri/src/helpers.rs index e0c077e99319a..d6646f9586aa6 100644 --- a/src/tools/miri/src/helpers.rs +++ b/src/tools/miri/src/helpers.rs @@ -5,7 +5,6 @@ use std::{cmp, iter}; use rand::RngCore; use rustc_abi::{Align, ExternAbi, FieldIdx, FieldsShape, Size, Variants}; use rustc_apfloat::Float; -use rustc_apfloat::ieee::{Double, Half, Quad, Single}; use rustc_hir::Safety; use rustc_hir::def::{DefKind, Namespace}; use rustc_hir::def_id::{CRATE_DEF_INDEX, CrateNum, DefId, LOCAL_CRATE}; @@ -14,7 +13,7 @@ use rustc_middle::middle::codegen_fn_attrs::CodegenFnAttrFlags; use rustc_middle::middle::dependency_format::Linkage; use rustc_middle::middle::exported_symbols::ExportedSymbol; use rustc_middle::ty::layout::{LayoutOf, MaybeResult, TyAndLayout}; -use rustc_middle::ty::{self, FloatTy, IntTy, Ty, TyCtxt, UintTy}; +use rustc_middle::ty::{self, IntTy, Ty, TyCtxt, UintTy}; use rustc_session::config::CrateType; use rustc_span::{Span, Symbol}; use rustc_symbol_mangling::mangle_internal_symbol; @@ -961,75 +960,6 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> { this.alloc_mark_immutable(provenance.get_alloc_id().unwrap()).unwrap(); } - /// Converts `src` from floating point to integer type `dest_ty` - /// after rounding with mode `round`. - /// Returns `None` if `f` is NaN or out of range. - fn float_to_int_checked( - &self, - src: &ImmTy<'tcx>, - cast_to: TyAndLayout<'tcx>, - round: rustc_apfloat::Round, - ) -> InterpResult<'tcx, Option>> { - let this = self.eval_context_ref(); - - fn float_to_int_inner<'tcx, F: rustc_apfloat::Float>( - ecx: &MiriInterpCx<'tcx>, - src: F, - cast_to: TyAndLayout<'tcx>, - round: rustc_apfloat::Round, - ) -> (Scalar, rustc_apfloat::Status) { - let int_size = cast_to.layout.size; - match cast_to.ty.kind() { - // Unsigned - ty::Uint(_) => { - let res = src.to_u128_r(int_size.bits_usize(), round, &mut false); - (Scalar::from_uint(res.value, int_size), res.status) - } - // Signed - ty::Int(_) => { - let res = src.to_i128_r(int_size.bits_usize(), round, &mut false); - (Scalar::from_int(res.value, int_size), res.status) - } - // Nothing else - _ => - span_bug!( - ecx.cur_span(), - "attempted float-to-int conversion with non-int output type {}", - cast_to.ty, - ), - } - } - - let ty::Float(fty) = src.layout.ty.kind() else { - bug!("float_to_int_checked: non-float input type {}", src.layout.ty) - }; - - let (val, status) = match fty { - FloatTy::F16 => - float_to_int_inner::(this, src.to_scalar().to_f16()?, cast_to, round), - FloatTy::F32 => - float_to_int_inner::(this, src.to_scalar().to_f32()?, cast_to, round), - FloatTy::F64 => - float_to_int_inner::(this, src.to_scalar().to_f64()?, cast_to, round), - FloatTy::F128 => - float_to_int_inner::(this, src.to_scalar().to_f128()?, cast_to, round), - }; - - if status.intersects( - rustc_apfloat::Status::INVALID_OP - | rustc_apfloat::Status::OVERFLOW - | rustc_apfloat::Status::UNDERFLOW, - ) { - // Floating point value is NaN (flagged with INVALID_OP) or outside the range - // of values of the integer type (flagged with OVERFLOW or UNDERFLOW). - interp_ok(None) - } else { - // Floating point value can be represented by the integer type after rounding. - // The INEXACT flag is ignored on purpose to allow rounding. - interp_ok(Some(ImmTy::from_scalar(val, cast_to))) - } - } - /// Returns an integer type that is twice wide as `ty` fn get_twice_wide_int_ty(&self, ty: Ty<'tcx>) -> Ty<'tcx> { let this = self.eval_context_ref(); @@ -1194,20 +1124,6 @@ pub(crate) fn bool_to_simd_element(b: bool, size: Size) -> Scalar { Scalar::from_int(val, size) } -pub(crate) fn simd_element_to_bool(elem: ImmTy<'_>) -> InterpResult<'_, bool> { - assert!( - matches!(elem.layout.ty.kind(), ty::Int(_) | ty::Uint(_)), - "SIMD mask element type must be an integer, but this is `{}`", - elem.layout.ty - ); - let val = elem.to_scalar().to_int(elem.layout.size)?; - interp_ok(match val { - 0 => false, - -1 => true, - _ => throw_ub_format!("each element of a SIMD mask must be all-0-bits or all-1-bits"), - }) -} - /// Check whether an operation that writes to a target buffer was successful. /// Accordingly select return value. /// Local helper function to be used in Windows shims. diff --git a/src/tools/miri/src/intrinsics/mod.rs b/src/tools/miri/src/intrinsics/mod.rs index a80b939d84ea9..f09fc6c187896 100644 --- a/src/tools/miri/src/intrinsics/mod.rs +++ b/src/tools/miri/src/intrinsics/mod.rs @@ -118,7 +118,7 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> { return this.emulate_atomic_intrinsic(name, generic_args, args, dest); } if let Some(name) = intrinsic_name.strip_prefix("simd_") { - return this.emulate_simd_intrinsic(name, generic_args, args, dest); + return this.emulate_simd_intrinsic(name, args, dest); } match intrinsic_name { diff --git a/src/tools/miri/src/intrinsics/simd.rs b/src/tools/miri/src/intrinsics/simd.rs index b26516c0ff0e8..5f75657e0a220 100644 --- a/src/tools/miri/src/intrinsics/simd.rs +++ b/src/tools/miri/src/intrinsics/simd.rs @@ -1,21 +1,12 @@ -use either::Either; use rand::Rng; -use rustc_abi::{Endian, HasDataLayout}; -use rustc_apfloat::{Float, Round}; +use rustc_apfloat::Float; use rustc_middle::ty::FloatTy; -use rustc_middle::{mir, ty}; -use rustc_span::{Symbol, sym}; +use rustc_middle::ty; use super::check_intrinsic_arg_count; -use crate::helpers::{ToHost, ToSoft, bool_to_simd_element, simd_element_to_bool}; +use crate::helpers::{ToHost, ToSoft}; use crate::*; -#[derive(Copy, Clone)] -pub(crate) enum MinMax { - Min, - Max, -} - impl<'tcx> EvalContextExt<'tcx> for crate::MiriInterpCx<'tcx> {} pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> { /// Calls the simd intrinsic `intrinsic`; the `simd_` prefix has already been removed. @@ -23,20 +14,12 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> { fn emulate_simd_intrinsic( &mut self, intrinsic_name: &str, - generic_args: ty::GenericArgsRef<'tcx>, args: &[OpTy<'tcx>], dest: &MPlaceTy<'tcx>, ) -> InterpResult<'tcx, EmulateItemResult> { let this = self.eval_context_mut(); match intrinsic_name { #[rustfmt::skip] - | "neg" - | "fabs" - | "ceil" - | "floor" - | "round" - | "round_ties_even" - | "trunc" | "fsqrt" | "fsin" | "fcos" @@ -45,11 +28,6 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> { | "flog" | "flog2" | "flog10" - | "ctlz" - | "ctpop" - | "cttz" - | "bswap" - | "bitreverse" => { let [op] = check_intrinsic_arg_count(args)?; let (op, op_len) = this.project_to_simd(op)?; @@ -57,235 +35,51 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> { assert_eq!(dest_len, op_len); - #[derive(Copy, Clone)] - enum Op<'a> { - MirOp(mir::UnOp), - Abs, - Round(rustc_apfloat::Round), - Numeric(Symbol), - HostOp(&'a str), - } - let which = match intrinsic_name { - "neg" => Op::MirOp(mir::UnOp::Neg), - "fabs" => Op::Abs, - "ceil" => Op::Round(rustc_apfloat::Round::TowardPositive), - "floor" => Op::Round(rustc_apfloat::Round::TowardNegative), - "round" => Op::Round(rustc_apfloat::Round::NearestTiesToAway), - "round_ties_even" => Op::Round(rustc_apfloat::Round::NearestTiesToEven), - "trunc" => Op::Round(rustc_apfloat::Round::TowardZero), - "ctlz" => Op::Numeric(sym::ctlz), - "ctpop" => Op::Numeric(sym::ctpop), - "cttz" => Op::Numeric(sym::cttz), - "bswap" => Op::Numeric(sym::bswap), - "bitreverse" => Op::Numeric(sym::bitreverse), - _ => Op::HostOp(intrinsic_name), - }; - for i in 0..dest_len { let op = this.read_immediate(&this.project_index(&op, i)?)?; let dest = this.project_index(&dest, i)?; - let val = match which { - Op::MirOp(mir_op) => { - // This already does NaN adjustments - this.unary_op(mir_op, &op)?.to_scalar() - } - Op::Abs => { - // Works for f32 and f64. - let ty::Float(float_ty) = op.layout.ty.kind() else { - span_bug!(this.cur_span(), "{} operand is not a float", intrinsic_name) - }; - let op = op.to_scalar(); - // "Bitwise" operation, no NaN adjustments - match float_ty { - FloatTy::F16 => unimplemented!("f16_f128"), - FloatTy::F32 => Scalar::from_f32(op.to_f32()?.abs()), - FloatTy::F64 => Scalar::from_f64(op.to_f64()?.abs()), - FloatTy::F128 => unimplemented!("f16_f128"), - } - } - Op::HostOp(host_op) => { - let ty::Float(float_ty) = op.layout.ty.kind() else { - span_bug!(this.cur_span(), "{} operand is not a float", intrinsic_name) + let ty::Float(float_ty) = op.layout.ty.kind() else { + span_bug!(this.cur_span(), "{} operand is not a float", intrinsic_name) + }; + // Using host floats except for sqrt (but it's fine, these operations do not + // have guaranteed precision). + let val = match float_ty { + FloatTy::F16 => unimplemented!("f16_f128"), + FloatTy::F32 => { + let f = op.to_scalar().to_f32()?; + let res = match intrinsic_name { + "fsqrt" => math::sqrt(f), + "fsin" => f.to_host().sin().to_soft(), + "fcos" => f.to_host().cos().to_soft(), + "fexp" => f.to_host().exp().to_soft(), + "fexp2" => f.to_host().exp2().to_soft(), + "flog" => f.to_host().ln().to_soft(), + "flog2" => f.to_host().log2().to_soft(), + "flog10" => f.to_host().log10().to_soft(), + _ => bug!(), }; - // Using host floats except for sqrt (but it's fine, these operations do not - // have guaranteed precision). - match float_ty { - FloatTy::F16 => unimplemented!("f16_f128"), - FloatTy::F32 => { - let f = op.to_scalar().to_f32()?; - let res = match host_op { - "fsqrt" => math::sqrt(f), - "fsin" => f.to_host().sin().to_soft(), - "fcos" => f.to_host().cos().to_soft(), - "fexp" => f.to_host().exp().to_soft(), - "fexp2" => f.to_host().exp2().to_soft(), - "flog" => f.to_host().ln().to_soft(), - "flog2" => f.to_host().log2().to_soft(), - "flog10" => f.to_host().log10().to_soft(), - _ => bug!(), - }; - let res = this.adjust_nan(res, &[f]); - Scalar::from(res) - } - FloatTy::F64 => { - let f = op.to_scalar().to_f64()?; - let res = match host_op { - "fsqrt" => math::sqrt(f), - "fsin" => f.to_host().sin().to_soft(), - "fcos" => f.to_host().cos().to_soft(), - "fexp" => f.to_host().exp().to_soft(), - "fexp2" => f.to_host().exp2().to_soft(), - "flog" => f.to_host().ln().to_soft(), - "flog2" => f.to_host().log2().to_soft(), - "flog10" => f.to_host().log10().to_soft(), - _ => bug!(), - }; - let res = this.adjust_nan(res, &[f]); - Scalar::from(res) - } - FloatTy::F128 => unimplemented!("f16_f128"), - } + let res = this.adjust_nan(res, &[f]); + Scalar::from(res) } - Op::Round(rounding) => { - let ty::Float(float_ty) = op.layout.ty.kind() else { - span_bug!(this.cur_span(), "{} operand is not a float", intrinsic_name) + FloatTy::F64 => { + let f = op.to_scalar().to_f64()?; + let res = match intrinsic_name { + "fsqrt" => math::sqrt(f), + "fsin" => f.to_host().sin().to_soft(), + "fcos" => f.to_host().cos().to_soft(), + "fexp" => f.to_host().exp().to_soft(), + "fexp2" => f.to_host().exp2().to_soft(), + "flog" => f.to_host().ln().to_soft(), + "flog2" => f.to_host().log2().to_soft(), + "flog10" => f.to_host().log10().to_soft(), + _ => bug!(), }; - match float_ty { - FloatTy::F16 => unimplemented!("f16_f128"), - FloatTy::F32 => { - let f = op.to_scalar().to_f32()?; - let res = f.round_to_integral(rounding).value; - let res = this.adjust_nan(res, &[f]); - Scalar::from_f32(res) - } - FloatTy::F64 => { - let f = op.to_scalar().to_f64()?; - let res = f.round_to_integral(rounding).value; - let res = this.adjust_nan(res, &[f]); - Scalar::from_f64(res) - } - FloatTy::F128 => unimplemented!("f16_f128"), - } - } - Op::Numeric(name) => { - this.numeric_intrinsic(name, op.to_scalar(), op.layout, op.layout)? - } - }; - this.write_scalar(val, &dest)?; - } - } - #[rustfmt::skip] - | "add" - | "sub" - | "mul" - | "div" - | "rem" - | "shl" - | "shr" - | "and" - | "or" - | "xor" - | "eq" - | "ne" - | "lt" - | "le" - | "gt" - | "ge" - | "fmax" - | "fmin" - | "saturating_add" - | "saturating_sub" - | "arith_offset" - => { - use mir::BinOp; - - let [left, right] = check_intrinsic_arg_count(args)?; - let (left, left_len) = this.project_to_simd(left)?; - let (right, right_len) = this.project_to_simd(right)?; - let (dest, dest_len) = this.project_to_simd(dest)?; - - assert_eq!(dest_len, left_len); - assert_eq!(dest_len, right_len); - - enum Op { - MirOp(BinOp), - SaturatingOp(BinOp), - FMinMax(MinMax), - WrappingOffset, - } - let which = match intrinsic_name { - "add" => Op::MirOp(BinOp::Add), - "sub" => Op::MirOp(BinOp::Sub), - "mul" => Op::MirOp(BinOp::Mul), - "div" => Op::MirOp(BinOp::Div), - "rem" => Op::MirOp(BinOp::Rem), - "shl" => Op::MirOp(BinOp::ShlUnchecked), - "shr" => Op::MirOp(BinOp::ShrUnchecked), - "and" => Op::MirOp(BinOp::BitAnd), - "or" => Op::MirOp(BinOp::BitOr), - "xor" => Op::MirOp(BinOp::BitXor), - "eq" => Op::MirOp(BinOp::Eq), - "ne" => Op::MirOp(BinOp::Ne), - "lt" => Op::MirOp(BinOp::Lt), - "le" => Op::MirOp(BinOp::Le), - "gt" => Op::MirOp(BinOp::Gt), - "ge" => Op::MirOp(BinOp::Ge), - "fmax" => Op::FMinMax(MinMax::Max), - "fmin" => Op::FMinMax(MinMax::Min), - "saturating_add" => Op::SaturatingOp(BinOp::Add), - "saturating_sub" => Op::SaturatingOp(BinOp::Sub), - "arith_offset" => Op::WrappingOffset, - _ => unreachable!(), - }; - - for i in 0..dest_len { - let left = this.read_immediate(&this.project_index(&left, i)?)?; - let right = this.read_immediate(&this.project_index(&right, i)?)?; - let dest = this.project_index(&dest, i)?; - let val = match which { - Op::MirOp(mir_op) => { - // This does NaN adjustments. - let val = this.binary_op(mir_op, &left, &right).map_err_kind(|kind| { - match kind { - InterpErrorKind::UndefinedBehavior(UndefinedBehaviorInfo::ShiftOverflow { shift_amount, .. }) => { - // This resets the interpreter backtrace, but it's not worth avoiding that. - let shift_amount = match shift_amount { - Either::Left(v) => v.to_string(), - Either::Right(v) => v.to_string(), - }; - err_ub_format!("overflowing shift by {shift_amount} in `simd_{intrinsic_name}` in lane {i}") - } - kind => kind - } - })?; - if matches!(mir_op, BinOp::Eq | BinOp::Ne | BinOp::Lt | BinOp::Le | BinOp::Gt | BinOp::Ge) { - // Special handling for boolean-returning operations - assert_eq!(val.layout.ty, this.tcx.types.bool); - let val = val.to_scalar().to_bool().unwrap(); - bool_to_simd_element(val, dest.layout.size) - } else { - assert_ne!(val.layout.ty, this.tcx.types.bool); - assert_eq!(val.layout.ty, dest.layout.ty); - val.to_scalar() - } - } - Op::SaturatingOp(mir_op) => { - this.saturating_arith(mir_op, &left, &right)? - } - Op::WrappingOffset => { - let ptr = left.to_scalar().to_pointer(this)?; - let offset_count = right.to_scalar().to_target_isize(this)?; - let pointee_ty = left.layout.ty.builtin_deref(true).unwrap(); - - let pointee_size = i64::try_from(this.layout_of(pointee_ty)?.size.bytes()).unwrap(); - let offset_bytes = offset_count.wrapping_mul(pointee_size); - let offset_ptr = ptr.wrapping_signed_offset(offset_bytes, this); - Scalar::from_maybe_pointer(offset_ptr, this) - } - Op::FMinMax(op) => { - this.fminmax_op(op, &left, &right)? + let res = this.adjust_nan(res, &[f]); + Scalar::from(res) } + FloatTy::F128 => unimplemented!("f16_f128"), }; + this.write_scalar(val, &dest)?; } } @@ -345,279 +139,25 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> { this.write_scalar(val, &dest)?; } } - #[rustfmt::skip] - | "reduce_and" - | "reduce_or" - | "reduce_xor" - | "reduce_any" - | "reduce_all" - | "reduce_max" - | "reduce_min" => { - use mir::BinOp; - - let [op] = check_intrinsic_arg_count(args)?; - let (op, op_len) = this.project_to_simd(op)?; - - let imm_from_bool = - |b| ImmTy::from_scalar(Scalar::from_bool(b), this.machine.layouts.bool); - - enum Op { - MirOp(BinOp), - MirOpBool(BinOp), - MinMax(MinMax), - } - let which = match intrinsic_name { - "reduce_and" => Op::MirOp(BinOp::BitAnd), - "reduce_or" => Op::MirOp(BinOp::BitOr), - "reduce_xor" => Op::MirOp(BinOp::BitXor), - "reduce_any" => Op::MirOpBool(BinOp::BitOr), - "reduce_all" => Op::MirOpBool(BinOp::BitAnd), - "reduce_max" => Op::MinMax(MinMax::Max), - "reduce_min" => Op::MinMax(MinMax::Min), - _ => unreachable!(), - }; - - // Initialize with first lane, then proceed with the rest. - let mut res = this.read_immediate(&this.project_index(&op, 0)?)?; - if matches!(which, Op::MirOpBool(_)) { - // Convert to `bool` scalar. - res = imm_from_bool(simd_element_to_bool(res)?); - } - for i in 1..op_len { - let op = this.read_immediate(&this.project_index(&op, i)?)?; - res = match which { - Op::MirOp(mir_op) => { - this.binary_op(mir_op, &res, &op)? - } - Op::MirOpBool(mir_op) => { - let op = imm_from_bool(simd_element_to_bool(op)?); - this.binary_op(mir_op, &res, &op)? - } - Op::MinMax(mmop) => { - if matches!(res.layout.ty.kind(), ty::Float(_)) { - ImmTy::from_scalar(this.fminmax_op(mmop, &res, &op)?, res.layout) - } else { - // Just boring integers, so NaNs to worry about - let mirop = match mmop { - MinMax::Min => BinOp::Le, - MinMax::Max => BinOp::Ge, - }; - if this.binary_op(mirop, &res, &op)?.to_scalar().to_bool()? { - res - } else { - op - } - } - } - }; - } - this.write_immediate(*res, dest)?; - } - #[rustfmt::skip] - | "reduce_add_ordered" - | "reduce_mul_ordered" => { - use mir::BinOp; - - let [op, init] = check_intrinsic_arg_count(args)?; - let (op, op_len) = this.project_to_simd(op)?; - let init = this.read_immediate(init)?; - - let mir_op = match intrinsic_name { - "reduce_add_ordered" => BinOp::Add, - "reduce_mul_ordered" => BinOp::Mul, - _ => unreachable!(), - }; - - let mut res = init; - for i in 0..op_len { - let op = this.read_immediate(&this.project_index(&op, i)?)?; - res = this.binary_op(mir_op, &res, &op)?; - } - this.write_immediate(*res, dest)?; - } - "select" => { - let [mask, yes, no] = check_intrinsic_arg_count(args)?; - let (mask, mask_len) = this.project_to_simd(mask)?; - let (yes, yes_len) = this.project_to_simd(yes)?; - let (no, no_len) = this.project_to_simd(no)?; - let (dest, dest_len) = this.project_to_simd(dest)?; - - assert_eq!(dest_len, mask_len); - assert_eq!(dest_len, yes_len); - assert_eq!(dest_len, no_len); - - for i in 0..dest_len { - let mask = this.read_immediate(&this.project_index(&mask, i)?)?; - let yes = this.read_immediate(&this.project_index(&yes, i)?)?; - let no = this.read_immediate(&this.project_index(&no, i)?)?; - let dest = this.project_index(&dest, i)?; - - let val = if simd_element_to_bool(mask)? { yes } else { no }; - this.write_immediate(*val, &dest)?; - } - } - // Variant of `select` that takes a bitmask rather than a "vector of bool". - "select_bitmask" => { - let [mask, yes, no] = check_intrinsic_arg_count(args)?; - let (yes, yes_len) = this.project_to_simd(yes)?; - let (no, no_len) = this.project_to_simd(no)?; - let (dest, dest_len) = this.project_to_simd(dest)?; - let bitmask_len = dest_len.next_multiple_of(8); - if bitmask_len > 64 { - throw_unsup_format!( - "simd_select_bitmask: vectors larger than 64 elements are currently not supported" - ); - } - - assert_eq!(dest_len, yes_len); - assert_eq!(dest_len, no_len); - - // Read the mask, either as an integer or as an array. - let mask: u64 = match mask.layout.ty.kind() { - ty::Uint(_) => { - // Any larger integer type is fine. - assert!(mask.layout.size.bits() >= bitmask_len); - this.read_scalar(mask)?.to_bits(mask.layout.size)?.try_into().unwrap() - } - ty::Array(elem, _len) if elem == &this.tcx.types.u8 => { - // The array must have exactly the right size. - assert_eq!(mask.layout.size.bits(), bitmask_len); - // Read the raw bytes. - let mask = mask.assert_mem_place(); // arrays cannot be immediate - let mask_bytes = - this.read_bytes_ptr_strip_provenance(mask.ptr(), mask.layout.size)?; - // Turn them into a `u64` in the right way. - let mask_size = mask.layout.size.bytes_usize(); - let mut mask_arr = [0u8; 8]; - match this.data_layout().endian { - Endian::Little => { - // Fill the first N bytes. - mask_arr[..mask_size].copy_from_slice(mask_bytes); - u64::from_le_bytes(mask_arr) - } - Endian::Big => { - // Fill the last N bytes. - let i = mask_arr.len().strict_sub(mask_size); - mask_arr[i..].copy_from_slice(mask_bytes); - u64::from_be_bytes(mask_arr) - } - } - } - _ => bug!("simd_select_bitmask: invalid mask type {}", mask.layout.ty), - }; - - let dest_len = u32::try_from(dest_len).unwrap(); - for i in 0..dest_len { - let bit_i = simd_bitmask_index(i, dest_len, this.data_layout().endian); - let mask = mask & 1u64.strict_shl(bit_i); - let yes = this.read_immediate(&this.project_index(&yes, i.into())?)?; - let no = this.read_immediate(&this.project_index(&no, i.into())?)?; - let dest = this.project_index(&dest, i.into())?; - - let val = if mask != 0 { yes } else { no }; - this.write_immediate(*val, &dest)?; - } - // The remaining bits of the mask are ignored. - } - // Converts a "vector of bool" into a bitmask. - "bitmask" => { - let [op] = check_intrinsic_arg_count(args)?; - let (op, op_len) = this.project_to_simd(op)?; - let bitmask_len = op_len.next_multiple_of(8); - if bitmask_len > 64 { - throw_unsup_format!( - "simd_bitmask: vectors larger than 64 elements are currently not supported" - ); - } - - let op_len = u32::try_from(op_len).unwrap(); - let mut res = 0u64; - for i in 0..op_len { - let op = this.read_immediate(&this.project_index(&op, i.into())?)?; - if simd_element_to_bool(op)? { - let bit_i = simd_bitmask_index(i, op_len, this.data_layout().endian); - res |= 1u64.strict_shl(bit_i); - } - } - // Write the result, depending on the `dest` type. - // Returns either an unsigned integer or array of `u8`. - match dest.layout.ty.kind() { - ty::Uint(_) => { - // Any larger integer type is fine, it will be zero-extended. - assert!(dest.layout.size.bits() >= bitmask_len); - this.write_int(res, dest)?; - } - ty::Array(elem, _len) if elem == &this.tcx.types.u8 => { - // The array must have exactly the right size. - assert_eq!(dest.layout.size.bits(), bitmask_len); - // We have to write the result byte-for-byte. - let res_size = dest.layout.size.bytes_usize(); - let res_bytes; - let res_bytes_slice = match this.data_layout().endian { - Endian::Little => { - res_bytes = res.to_le_bytes(); - &res_bytes[..res_size] // take the first N bytes - } - Endian::Big => { - res_bytes = res.to_be_bytes(); - &res_bytes[res_bytes.len().strict_sub(res_size)..] // take the last N bytes - } - }; - this.write_bytes_ptr(dest.ptr(), res_bytes_slice.iter().cloned())?; - } - _ => bug!("simd_bitmask: invalid return type {}", dest.layout.ty), - } - } - "cast" | "as" | "cast_ptr" | "expose_provenance" | "with_exposed_provenance" => { + "expose_provenance" => { let [op] = check_intrinsic_arg_count(args)?; let (op, op_len) = this.project_to_simd(op)?; let (dest, dest_len) = this.project_to_simd(dest)?; assert_eq!(dest_len, op_len); - let unsafe_cast = intrinsic_name == "cast"; - let safe_cast = intrinsic_name == "as"; - let ptr_cast = intrinsic_name == "cast_ptr"; - let expose_cast = intrinsic_name == "expose_provenance"; - let from_exposed_cast = intrinsic_name == "with_exposed_provenance"; - for i in 0..dest_len { let op = this.read_immediate(&this.project_index(&op, i)?)?; let dest = this.project_index(&dest, i)?; let val = match (op.layout.ty.kind(), dest.layout.ty.kind()) { - // Int-to-(int|float): always safe - (ty::Int(_) | ty::Uint(_), ty::Int(_) | ty::Uint(_) | ty::Float(_)) - if safe_cast || unsafe_cast => - this.int_to_int_or_float(&op, dest.layout)?, - // Float-to-float: always safe - (ty::Float(_), ty::Float(_)) if safe_cast || unsafe_cast => - this.float_to_float_or_int(&op, dest.layout)?, - // Float-to-int in safe mode - (ty::Float(_), ty::Int(_) | ty::Uint(_)) if safe_cast => - this.float_to_float_or_int(&op, dest.layout)?, - // Float-to-int in unchecked mode - (ty::Float(_), ty::Int(_) | ty::Uint(_)) if unsafe_cast => { - this.float_to_int_checked(&op, dest.layout, Round::TowardZero)? - .ok_or_else(|| { - err_ub_format!( - "`simd_cast` intrinsic called on {op} which cannot be represented in target type `{:?}`", - dest.layout.ty - ) - })? - } - // Ptr-to-ptr cast - (ty::RawPtr(..), ty::RawPtr(..)) if ptr_cast => - this.ptr_to_ptr(&op, dest.layout)?, // Ptr/Int casts - (ty::RawPtr(..), ty::Int(_) | ty::Uint(_)) if expose_cast => + (ty::RawPtr(..), ty::Int(_) | ty::Uint(_)) => this.pointer_expose_provenance_cast(&op, dest.layout)?, - (ty::Int(_) | ty::Uint(_), ty::RawPtr(..)) if from_exposed_cast => - this.pointer_with_exposed_provenance_cast(&op, dest.layout)?, // Error otherwise _ => throw_unsup_format!( - "Unsupported SIMD cast from element type {from_ty} to {to_ty}", + "Unsupported `simd_expose_provenance` from element type {from_ty} to {to_ty}", from_ty = op.layout.ty, to_ty = dest.layout.ty, ), @@ -625,210 +165,9 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> { this.write_immediate(*val, &dest)?; } } - "shuffle_const_generic" => { - let [left, right] = check_intrinsic_arg_count(args)?; - let (left, left_len) = this.project_to_simd(left)?; - let (right, right_len) = this.project_to_simd(right)?; - let (dest, dest_len) = this.project_to_simd(dest)?; - - let index = generic_args[2].expect_const().to_value().valtree.unwrap_branch(); - let index_len = index.len(); - - assert_eq!(left_len, right_len); - assert_eq!(u64::try_from(index_len).unwrap(), dest_len); - - for i in 0..dest_len { - let src_index: u64 = - index[usize::try_from(i).unwrap()].unwrap_leaf().to_u32().into(); - let dest = this.project_index(&dest, i)?; - - let val = if src_index < left_len { - this.read_immediate(&this.project_index(&left, src_index)?)? - } else if src_index < left_len.strict_add(right_len) { - let right_idx = src_index.strict_sub(left_len); - this.read_immediate(&this.project_index(&right, right_idx)?)? - } else { - throw_ub_format!( - "`simd_shuffle_const_generic` index {src_index} is out-of-bounds for 2 vectors with length {dest_len}" - ); - }; - this.write_immediate(*val, &dest)?; - } - } - "shuffle" => { - let [left, right, index] = check_intrinsic_arg_count(args)?; - let (left, left_len) = this.project_to_simd(left)?; - let (right, right_len) = this.project_to_simd(right)?; - let (index, index_len) = this.project_to_simd(index)?; - let (dest, dest_len) = this.project_to_simd(dest)?; - - assert_eq!(left_len, right_len); - assert_eq!(index_len, dest_len); - - for i in 0..dest_len { - let src_index: u64 = this - .read_immediate(&this.project_index(&index, i)?)? - .to_scalar() - .to_u32()? - .into(); - let dest = this.project_index(&dest, i)?; - - let val = if src_index < left_len { - this.read_immediate(&this.project_index(&left, src_index)?)? - } else if src_index < left_len.strict_add(right_len) { - let right_idx = src_index.strict_sub(left_len); - this.read_immediate(&this.project_index(&right, right_idx)?)? - } else { - throw_ub_format!( - "`simd_shuffle` index {src_index} is out-of-bounds for 2 vectors with length {dest_len}" - ); - }; - this.write_immediate(*val, &dest)?; - } - } - "gather" => { - let [passthru, ptrs, mask] = check_intrinsic_arg_count(args)?; - let (passthru, passthru_len) = this.project_to_simd(passthru)?; - let (ptrs, ptrs_len) = this.project_to_simd(ptrs)?; - let (mask, mask_len) = this.project_to_simd(mask)?; - let (dest, dest_len) = this.project_to_simd(dest)?; - - assert_eq!(dest_len, passthru_len); - assert_eq!(dest_len, ptrs_len); - assert_eq!(dest_len, mask_len); - - for i in 0..dest_len { - let passthru = this.read_immediate(&this.project_index(&passthru, i)?)?; - let ptr = this.read_immediate(&this.project_index(&ptrs, i)?)?; - let mask = this.read_immediate(&this.project_index(&mask, i)?)?; - let dest = this.project_index(&dest, i)?; - - let val = if simd_element_to_bool(mask)? { - let place = this.deref_pointer(&ptr)?; - this.read_immediate(&place)? - } else { - passthru - }; - this.write_immediate(*val, &dest)?; - } - } - "scatter" => { - let [value, ptrs, mask] = check_intrinsic_arg_count(args)?; - let (value, value_len) = this.project_to_simd(value)?; - let (ptrs, ptrs_len) = this.project_to_simd(ptrs)?; - let (mask, mask_len) = this.project_to_simd(mask)?; - - assert_eq!(ptrs_len, value_len); - assert_eq!(ptrs_len, mask_len); - - for i in 0..ptrs_len { - let value = this.read_immediate(&this.project_index(&value, i)?)?; - let ptr = this.read_immediate(&this.project_index(&ptrs, i)?)?; - let mask = this.read_immediate(&this.project_index(&mask, i)?)?; - - if simd_element_to_bool(mask)? { - let place = this.deref_pointer(&ptr)?; - this.write_immediate(*value, &place)?; - } - } - } - "masked_load" => { - let [mask, ptr, default] = check_intrinsic_arg_count(args)?; - let (mask, mask_len) = this.project_to_simd(mask)?; - let ptr = this.read_pointer(ptr)?; - let (default, default_len) = this.project_to_simd(default)?; - let (dest, dest_len) = this.project_to_simd(dest)?; - - assert_eq!(dest_len, mask_len); - assert_eq!(dest_len, default_len); - - for i in 0..dest_len { - let mask = this.read_immediate(&this.project_index(&mask, i)?)?; - let default = this.read_immediate(&this.project_index(&default, i)?)?; - let dest = this.project_index(&dest, i)?; - - let val = if simd_element_to_bool(mask)? { - // Size * u64 is implemented as always checked - let ptr = ptr.wrapping_offset(dest.layout.size * i, this); - let place = this.ptr_to_mplace(ptr, dest.layout); - this.read_immediate(&place)? - } else { - default - }; - this.write_immediate(*val, &dest)?; - } - } - "masked_store" => { - let [mask, ptr, vals] = check_intrinsic_arg_count(args)?; - let (mask, mask_len) = this.project_to_simd(mask)?; - let ptr = this.read_pointer(ptr)?; - let (vals, vals_len) = this.project_to_simd(vals)?; - - assert_eq!(mask_len, vals_len); - - for i in 0..vals_len { - let mask = this.read_immediate(&this.project_index(&mask, i)?)?; - let val = this.read_immediate(&this.project_index(&vals, i)?)?; - - if simd_element_to_bool(mask)? { - // Size * u64 is implemented as always checked - let ptr = ptr.wrapping_offset(val.layout.size * i, this); - let place = this.ptr_to_mplace(ptr, val.layout); - this.write_immediate(*val, &place)? - }; - } - } _ => return interp_ok(EmulateItemResult::NotSupported), } interp_ok(EmulateItemResult::NeedsReturn) } - - fn fminmax_op( - &self, - op: MinMax, - left: &ImmTy<'tcx>, - right: &ImmTy<'tcx>, - ) -> InterpResult<'tcx, Scalar> { - let this = self.eval_context_ref(); - assert_eq!(left.layout.ty, right.layout.ty); - let ty::Float(float_ty) = left.layout.ty.kind() else { - bug!("fmax operand is not a float") - }; - let left = left.to_scalar(); - let right = right.to_scalar(); - interp_ok(match float_ty { - FloatTy::F16 => unimplemented!("f16_f128"), - FloatTy::F32 => { - let left = left.to_f32()?; - let right = right.to_f32()?; - let res = match op { - MinMax::Min => left.min(right), - MinMax::Max => left.max(right), - }; - let res = this.adjust_nan(res, &[left, right]); - Scalar::from_f32(res) - } - FloatTy::F64 => { - let left = left.to_f64()?; - let right = right.to_f64()?; - let res = match op { - MinMax::Min => left.min(right), - MinMax::Max => left.max(right), - }; - let res = this.adjust_nan(res, &[left, right]); - Scalar::from_f64(res) - } - FloatTy::F128 => unimplemented!("f16_f128"), - }) - } -} - -fn simd_bitmask_index(idx: u32, vec_len: u32, endianness: Endian) -> u32 { - assert!(idx < vec_len); - match endianness { - Endian::Little => idx, - #[expect(clippy::arithmetic_side_effects)] // idx < vec_len - Endian::Big => vec_len - 1 - idx, // reverse order of bits - } }