Skip to content

Commit b03159d

Browse files
committed
Add alignment parameter to simd_masked_{load,store}
1 parent 4a54b26 commit b03159d

File tree

17 files changed

+405
-107
lines changed

17 files changed

+405
-107
lines changed

compiler/rustc_codegen_llvm/src/intrinsic.rs

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ use rustc_hir::def_id::LOCAL_CRATE;
1313
use rustc_hir::{self as hir};
1414
use rustc_middle::mir::BinOp;
1515
use rustc_middle::ty::layout::{FnAbiOf, HasTyCtxt, HasTypingEnv, LayoutOf};
16-
use rustc_middle::ty::{self, GenericArgsRef, Instance, Ty, TyCtxt, TypingEnv};
16+
use rustc_middle::ty::{self, GenericArgsRef, Instance, SimdAlign, Ty, TyCtxt, TypingEnv};
1717
use rustc_middle::{bug, span_bug};
1818
use rustc_span::{Span, Symbol, sym};
1919
use rustc_symbol_mangling::{mangle_internal_symbol, symbol_name_for_instance_in_crate};
@@ -1828,15 +1828,34 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
18281828
));
18291829
}
18301830

1831+
fn llvm_alignment<'ll, 'tcx>(
1832+
bx: &mut Builder<'_, 'll, 'tcx>,
1833+
alignment: SimdAlign,
1834+
vector_ty: Ty<'tcx>,
1835+
element_ty: Ty<'tcx>,
1836+
) -> &'ll Value {
1837+
let alignment = match alignment {
1838+
SimdAlign::Unaligned => 1,
1839+
SimdAlign::Element => bx.align_of(element_ty).bytes(),
1840+
SimdAlign::Vector => bx.align_of(vector_ty).bytes(),
1841+
};
1842+
1843+
bx.const_i32(alignment as i32)
1844+
}
1845+
18311846
if name == sym::simd_masked_load {
1832-
// simd_masked_load(mask: <N x i{M}>, pointer: *_ T, values: <N x T>) -> <N x T>
1847+
// simd_masked_load<_, _, _, const ALIGN: SimdAlign>(mask: <N x i{M}>, pointer: *_ T, values: <N x T>) -> <N x T>
18331848
// * N: number of elements in the input vectors
18341849
// * T: type of the element to load
18351850
// * M: any integer width is supported, will be truncated to i1
18361851
// Loads contiguous elements from memory behind `pointer`, but only for
18371852
// those lanes whose `mask` bit is enabled.
18381853
// The memory addresses corresponding to the “off” lanes are not accessed.
18391854

1855+
let alignment = fn_args[3].expect_const().to_value().valtree.unwrap_branch()[0]
1856+
.unwrap_leaf()
1857+
.to_simd_alignment();
1858+
18401859
// The element type of the "mask" argument must be a signed integer type of any width
18411860
let mask_ty = in_ty;
18421861
let (mask_len, mask_elem) = (in_len, in_elem);
@@ -1893,7 +1912,7 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
18931912
let mask = vector_mask_to_bitmask(bx, args[0].immediate(), m_elem_bitwidth, mask_len);
18941913

18951914
// Alignment of T, must be a constant integer value:
1896-
let alignment = bx.const_i32(bx.align_of(values_elem).bytes() as i32);
1915+
let alignment = llvm_alignment(bx, alignment, values_ty, values_elem);
18971916

18981917
let llvm_pointer = bx.type_ptr();
18991918

@@ -1908,14 +1927,18 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
19081927
}
19091928

19101929
if name == sym::simd_masked_store {
1911-
// simd_masked_store(mask: <N x i{M}>, pointer: *mut T, values: <N x T>) -> ()
1930+
// simd_masked_store<_, _, _, const ALIGN: SimdAlign>(mask: <N x i{M}>, pointer: *mut T, values: <N x T>) -> ()
19121931
// * N: number of elements in the input vectors
19131932
// * T: type of the element to load
19141933
// * M: any integer width is supported, will be truncated to i1
19151934
// Stores contiguous elements to memory behind `pointer`, but only for
19161935
// those lanes whose `mask` bit is enabled.
19171936
// The memory addresses corresponding to the “off” lanes are not accessed.
19181937

1938+
let alignment = fn_args[3].expect_const().to_value().valtree.unwrap_branch()[0]
1939+
.unwrap_leaf()
1940+
.to_simd_alignment();
1941+
19191942
// The element type of the "mask" argument must be a signed integer type of any width
19201943
let mask_ty = in_ty;
19211944
let (mask_len, mask_elem) = (in_len, in_elem);
@@ -1966,7 +1989,7 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
19661989
let mask = vector_mask_to_bitmask(bx, args[0].immediate(), m_elem_bitwidth, mask_len);
19671990

19681991
// Alignment of T, must be a constant integer value:
1969-
let alignment = bx.const_i32(bx.align_of(values_elem).bytes() as i32);
1992+
let alignment = llvm_alignment(bx, alignment, values_ty, values_elem);
19701993

19711994
let llvm_pointer = bx.type_ptr();
19721995

compiler/rustc_hir_analysis/src/check/intrinsic.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -695,8 +695,8 @@ pub(crate) fn check_intrinsic_type(
695695
(1, 0, vec![param(0), param(0), param(0)], param(0))
696696
}
697697
sym::simd_gather => (3, 0, vec![param(0), param(1), param(2)], param(0)),
698-
sym::simd_masked_load => (3, 0, vec![param(0), param(1), param(2)], param(2)),
699-
sym::simd_masked_store => (3, 0, vec![param(0), param(1), param(2)], tcx.types.unit),
698+
sym::simd_masked_load => (3, 1, vec![param(0), param(1), param(2)], param(2)),
699+
sym::simd_masked_store => (3, 1, vec![param(0), param(1), param(2)], tcx.types.unit),
700700
sym::simd_scatter => (3, 0, vec![param(0), param(1), param(2)], tcx.types.unit),
701701
sym::simd_insert | sym::simd_insert_dyn => {
702702
(2, 0, vec![param(0), tcx.types.u32, param(1)], param(0))

compiler/rustc_middle/src/ty/consts/int.rs

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,15 @@ pub enum AtomicOrdering {
3939
SeqCst = 4,
4040
}
4141

42+
/// An enum to represent the compiler-side view of `intrinsics::simd::SimdAlign`.
43+
#[derive(Debug, Copy, Clone)]
44+
pub enum SimdAlign {
45+
// These values must match `intrinsics::simd::SimdAlign`!
46+
Unaligned = 0,
47+
Element = 1,
48+
Vector = 2,
49+
}
50+
4251
impl std::fmt::Debug for ConstInt {
4352
fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
4453
let Self { int, signed, is_ptr_sized_integral } = *self;
@@ -350,6 +359,21 @@ impl ScalarInt {
350359
}
351360
}
352361

362+
#[inline]
363+
pub fn to_simd_alignment(self) -> SimdAlign {
364+
use SimdAlign::*;
365+
let val = self.to_u32();
366+
if val == Unaligned as u32 {
367+
Unaligned
368+
} else if val == Element as u32 {
369+
Element
370+
} else if val == Vector as u32 {
371+
Vector
372+
} else {
373+
panic!("not a valid simd alignment")
374+
}
375+
}
376+
353377
/// Converts the `ScalarInt` to `bool`.
354378
/// Panics if the `size` of the `ScalarInt` is not equal to 1 byte.
355379
/// Errors if it is not a valid `bool`.

compiler/rustc_middle/src/ty/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ pub use self::closure::{
7575
};
7676
pub use self::consts::{
7777
AnonConstKind, AtomicOrdering, Const, ConstInt, ConstKind, ConstToValTreeResult, Expr,
78-
ExprKind, ScalarInt, UnevaluatedConst, ValTree, ValTreeKind, Value,
78+
ExprKind, ScalarInt, SimdAlign, UnevaluatedConst, ValTree, ValTreeKind, Value,
7979
};
8080
pub use self::context::{
8181
CtxtInterners, CurrentGcx, DeducedParamAttrs, Feed, FreeRegionInfo, GlobalCtxt, Lift, TyCtxt,

library/core/src/intrinsics/simd.rs

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
//!
33
//! In this module, a "vector" is any `repr(simd)` type.
44
5+
use crate::marker::ConstParamTy;
6+
57
/// Inserts an element into a vector, returning the updated vector.
68
///
79
/// `T` must be a vector with element type `U`, and `idx` must be `const`.
@@ -377,6 +379,19 @@ pub unsafe fn simd_gather<T, U, V>(val: T, ptr: U, mask: V) -> T;
377379
#[rustc_nounwind]
378380
pub unsafe fn simd_scatter<T, U, V>(val: T, ptr: U, mask: V);
379381

382+
/// A type for alignment options for SIMD masked load/store intrinsics.
383+
#[derive(Debug, ConstParamTy, PartialEq, Eq)]
384+
pub enum SimdAlign {
385+
// These values must match the compiler's `SimdAlign` defined in
386+
// `rustc_middle/src/ty/consts/int.rs`!
387+
/// No alignment requirements on the pointer
388+
Unaligned = 0,
389+
/// The pointer must be aligned to the element type of the SIMD vector
390+
Element = 1,
391+
/// The pointer must be aligned to the SIMD vector type
392+
Vector = 2,
393+
}
394+
380395
/// Reads a vector of pointers.
381396
///
382397
/// `T` must be a vector.
@@ -392,13 +407,12 @@ pub unsafe fn simd_scatter<T, U, V>(val: T, ptr: U, mask: V);
392407
/// `val`.
393408
///
394409
/// # Safety
395-
/// Unmasked values in `T` must be readable as if by `<ptr>::read` (e.g. aligned to the element
396-
/// type).
410+
/// `ptr` must be aligned according to the `ALIGN` parameter, see [`SimdAlign`] for details.
397411
///
398412
/// `mask` must only contain `0` or `!0` values.
399413
#[rustc_intrinsic]
400414
#[rustc_nounwind]
401-
pub unsafe fn simd_masked_load<V, U, T>(mask: V, ptr: U, val: T) -> T;
415+
pub unsafe fn simd_masked_load<V, U, T, const ALIGN: SimdAlign>(mask: V, ptr: U, val: T) -> T;
402416

403417
/// Writes to a vector of pointers.
404418
///
@@ -414,13 +428,12 @@ pub unsafe fn simd_masked_load<V, U, T>(mask: V, ptr: U, val: T) -> T;
414428
/// Otherwise if the corresponding value in `mask` is `0`, do nothing.
415429
///
416430
/// # Safety
417-
/// Unmasked values in `T` must be writeable as if by `<ptr>::write` (e.g. aligned to the element
418-
/// type).
431+
/// `ptr` must be aligned according to the `ALIGN` parameter, see [`SimdAlign`] for details.
419432
///
420433
/// `mask` must only contain `0` or `!0` values.
421434
#[rustc_intrinsic]
422435
#[rustc_nounwind]
423-
pub unsafe fn simd_masked_store<V, U, T>(mask: V, ptr: U, val: T);
436+
pub unsafe fn simd_masked_store<V, U, T, const ALIGN: SimdAlign>(mask: V, ptr: U, val: T);
424437

425438
/// Adds two simd vectors elementwise, with saturation.
426439
///

library/portable-simd/crates/core_simd/src/vector.rs

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -474,7 +474,14 @@ where
474474
or: Self,
475475
) -> Self {
476476
// SAFETY: The safety of reading elements through `ptr` is ensured by the caller.
477-
unsafe { core::intrinsics::simd::simd_masked_load(enable.to_int(), ptr, or) }
477+
unsafe {
478+
core::intrinsics::simd::simd_masked_load::<
479+
_,
480+
_,
481+
_,
482+
{ core::intrinsics::simd::SimdAlign::Element },
483+
>(enable.to_int(), ptr, or)
484+
}
478485
}
479486

480487
/// Reads from potentially discontiguous indices in `slice` to construct a SIMD vector.
@@ -723,7 +730,14 @@ where
723730
#[inline]
724731
pub unsafe fn store_select_ptr(self, ptr: *mut T, enable: Mask<<T as SimdElement>::Mask, N>) {
725732
// SAFETY: The safety of writing elements through `ptr` is ensured by the caller.
726-
unsafe { core::intrinsics::simd::simd_masked_store(enable.to_int(), ptr, self) }
733+
unsafe {
734+
core::intrinsics::simd::simd_masked_store::<
735+
_,
736+
_,
737+
_,
738+
{ core::intrinsics::simd::SimdAlign::Element },
739+
>(enable.to_int(), ptr, self)
740+
}
727741
}
728742

729743
/// Writes the values in a SIMD vector to potentially discontiguous indices in `slice`.

src/tools/miri/tests/pass/intrinsics/portable-simd.rs

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -680,25 +680,39 @@ fn simd_float_intrinsics() {
680680
}
681681

682682
fn simd_masked_loadstore() {
683+
use intrinsics::*;
684+
683685
// The buffer is deliberarely too short, so reading the last element would be UB.
684686
let buf = [3i32; 3];
685687
let default = i32x4::splat(0);
686688
let mask = i32x4::from_array([!0, !0, !0, 0]);
687-
let vals = unsafe { intrinsics::simd_masked_load(mask, buf.as_ptr(), default) };
689+
let vals =
690+
unsafe { simd_masked_load::<_, _, _, { SimdAlign::Element }>(mask, buf.as_ptr(), default) };
688691
assert_eq!(vals, i32x4::from_array([3, 3, 3, 0]));
689692
// Also read in a way that the *first* element is OOB.
690693
let mask2 = i32x4::from_array([0, !0, !0, !0]);
691-
let vals =
692-
unsafe { intrinsics::simd_masked_load(mask2, buf.as_ptr().wrapping_sub(1), default) };
694+
let vals = unsafe {
695+
simd_masked_load::<_, _, _, { SimdAlign::Element }>(
696+
mask2,
697+
buf.as_ptr().wrapping_sub(1),
698+
default,
699+
)
700+
};
693701
assert_eq!(vals, i32x4::from_array([0, 3, 3, 3]));
694702

695703
// The buffer is deliberarely too short, so writing the last element would be UB.
696704
let mut buf = [42i32; 3];
697705
let vals = i32x4::from_array([1, 2, 3, 4]);
698-
unsafe { intrinsics::simd_masked_store(mask, buf.as_mut_ptr(), vals) };
706+
unsafe { simd_masked_store::<_, _, _, { SimdAlign::Element }>(mask, buf.as_mut_ptr(), vals) };
699707
assert_eq!(buf, [1, 2, 3]);
700708
// Also write in a way that the *first* element is OOB.
701-
unsafe { intrinsics::simd_masked_store(mask2, buf.as_mut_ptr().wrapping_sub(1), vals) };
709+
unsafe {
710+
simd_masked_store::<_, _, _, { SimdAlign::Element }>(
711+
mask2,
712+
buf.as_mut_ptr().wrapping_sub(1),
713+
vals,
714+
)
715+
};
702716
assert_eq!(buf, [2, 3, 4]);
703717
}
704718

tests/assembly-llvm/simd-intrinsic-mask-load.rs

Lines changed: 35 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
//@ assembly-output: emit-asm
1010
//@ compile-flags: --crate-type=lib -Copt-level=3 -C panic=abort
1111

12-
#![feature(no_core, lang_items, repr_simd, intrinsics)]
12+
#![feature(no_core, lang_items, repr_simd, intrinsics, adt_const_params)]
1313
#![no_core]
1414
#![allow(non_camel_case_types)]
1515

@@ -35,7 +35,7 @@ pub struct f64x4([f64; 4]);
3535
pub struct m64x4([i64; 4]);
3636

3737
#[rustc_intrinsic]
38-
unsafe fn simd_masked_load<M, P, T>(mask: M, pointer: P, values: T) -> T;
38+
unsafe fn simd_masked_load<M, P, T, const ALIGN: SimdAlign>(mask: M, pointer: P, values: T) -> T;
3939

4040
// CHECK-LABEL: load_i8x16
4141
#[no_mangle]
@@ -56,7 +56,11 @@ pub unsafe extern "C" fn load_i8x16(mask: m8x16, pointer: *const i8) -> i8x16 {
5656
// x86-avx512-NOT: vpsllw
5757
// x86-avx512: vpmovb2m k1, xmm0
5858
// x86-avx512-NEXT: vmovdqu8 xmm0 {k1} {z}, xmmword ptr [rdi]
59-
simd_masked_load(mask, pointer, i8x16([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]))
59+
simd_masked_load::<_, _, _, { SimdAlign::Element }>(
60+
mask,
61+
pointer,
62+
i8x16([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]),
63+
)
6064
}
6165

6266
// CHECK-LABEL: load_f32x8
@@ -68,7 +72,29 @@ pub unsafe extern "C" fn load_f32x8(mask: m32x8, pointer: *const f32) -> f32x8 {
6872
// x86-avx512-NOT: vpslld
6973
// x86-avx512: vpmovd2m k1, ymm0
7074
// x86-avx512-NEXT: vmovups ymm0 {k1} {z}, ymmword ptr [rdi]
71-
simd_masked_load(mask, pointer, f32x8([0_f32, 0_f32, 0_f32, 0_f32, 0_f32, 0_f32, 0_f32, 0_f32]))
75+
simd_masked_load::<_, _, _, { SimdAlign::Element }>(
76+
mask,
77+
pointer,
78+
f32x8([0_f32, 0_f32, 0_f32, 0_f32, 0_f32, 0_f32, 0_f32, 0_f32]),
79+
)
80+
}
81+
82+
// CHECK-LABEL: load_f32x8_aligned
83+
#[no_mangle]
84+
pub unsafe extern "C" fn load_f32x8_aligned(mask: m32x8, pointer: *const f32) -> f32x8 {
85+
// x86-avx2-NOT: vpslld
86+
// x86-avx2: vmaskmovps ymm0, ymm0, ymmword ptr [rdi]
87+
//
88+
// x86-avx512-NOT: vpslld
89+
// x86-avx512: vpmovd2m k1, ymm0
90+
// x86-avx512-NEXT: vmovaps ymm0 {k1} {z}, ymmword ptr [rdi]
91+
//
92+
// this aligned version should generate `movaps` instead of `movups`
93+
simd_masked_load::<_, _, _, { SimdAlign::Vector }>(
94+
mask,
95+
pointer,
96+
f32x8([0_f32, 0_f32, 0_f32, 0_f32, 0_f32, 0_f32, 0_f32, 0_f32]),
97+
)
7298
}
7399

74100
// CHECK-LABEL: load_f64x4
@@ -79,5 +105,9 @@ pub unsafe extern "C" fn load_f64x4(mask: m64x4, pointer: *const f64) -> f64x4 {
79105
//
80106
// x86-avx512-NOT: vpsllq
81107
// x86-avx512: vpmovq2m k1, ymm0
82-
simd_masked_load(mask, pointer, f64x4([0_f64, 0_f64, 0_f64, 0_f64]))
108+
simd_masked_load::<_, _, _, { SimdAlign::Element }>(
109+
mask,
110+
pointer,
111+
f64x4([0_f64, 0_f64, 0_f64, 0_f64]),
112+
)
83113
}

0 commit comments

Comments
 (0)