Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 73 additions & 23 deletions compiler/rustc_codegen_llvm/src/intrinsic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1325,6 +1325,8 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
};
}

let llvm_version = crate::llvm_util::get_version();

/// Converts a vector mask, where each element has a bit width equal to the data elements it is used with,
/// down to an i1 based mask that can be used by llvm intrinsics.
///
Expand Down Expand Up @@ -1808,7 +1810,7 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
);

// Alignment of T, must be a constant integer value:
let alignment = bx.const_i32(bx.align_of(in_elem).bytes() as i32);
let alignment = bx.align_of(in_elem).bytes();

// Truncate the mask vector to a vector of i1s:
let mask = vector_mask_to_bitmask(bx, args[2].immediate(), mask_elem_bitwidth, in_len);
Expand All @@ -1819,11 +1821,23 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
// Type of the vector of elements:
let llvm_elem_vec_ty = llvm_vector_ty(bx, element_ty0, in_len);

return Ok(bx.call_intrinsic(
"llvm.masked.gather",
&[llvm_elem_vec_ty, llvm_pointer_vec_ty],
&[args[1].immediate(), alignment, mask, args[0].immediate()],
));
let args: &[&'ll Value] = if llvm_version < (22, 0, 0) {
let alignment = bx.const_i32(alignment as i32);
&[args[1].immediate(), alignment, mask, args[0].immediate()]
} else {
&[args[1].immediate(), mask, args[0].immediate()]
};

let call =
bx.call_intrinsic("llvm.masked.gather", &[llvm_elem_vec_ty, llvm_pointer_vec_ty], args);
if llvm_version >= (22, 0, 0) {
crate::attributes::apply_to_callsite(
call,
crate::llvm::AttributePlace::Argument(0),
&[crate::llvm::CreateAlignmentAttr(bx.llcx, alignment)],
)
}
return Ok(call);
}

if name == sym::simd_masked_load {
Expand Down Expand Up @@ -1891,18 +1905,30 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
let mask = vector_mask_to_bitmask(bx, args[0].immediate(), m_elem_bitwidth, mask_len);

// Alignment of T, must be a constant integer value:
let alignment = bx.const_i32(bx.align_of(values_elem).bytes() as i32);
let alignment = bx.align_of(values_elem).bytes();

let llvm_pointer = bx.type_ptr();

// Type of the vector of elements:
let llvm_elem_vec_ty = llvm_vector_ty(bx, values_elem, values_len);

return Ok(bx.call_intrinsic(
"llvm.masked.load",
&[llvm_elem_vec_ty, llvm_pointer],
&[args[1].immediate(), alignment, mask, args[2].immediate()],
));
let args: &[&'ll Value] = if llvm_version < (22, 0, 0) {
let alignment = bx.const_i32(alignment as i32);

&[args[1].immediate(), alignment, mask, args[2].immediate()]
} else {
&[args[1].immediate(), mask, args[2].immediate()]
};

let call = bx.call_intrinsic("llvm.masked.load", &[llvm_elem_vec_ty, llvm_pointer], args);
if llvm_version >= (22, 0, 0) {
crate::attributes::apply_to_callsite(
call,
crate::llvm::AttributePlace::Argument(0),
&[crate::llvm::CreateAlignmentAttr(bx.llcx, alignment)],
)
}
return Ok(call);
}

if name == sym::simd_masked_store {
Expand Down Expand Up @@ -1964,18 +1990,29 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
let mask = vector_mask_to_bitmask(bx, args[0].immediate(), m_elem_bitwidth, mask_len);

// Alignment of T, must be a constant integer value:
let alignment = bx.const_i32(bx.align_of(values_elem).bytes() as i32);
let alignment = bx.align_of(values_elem).bytes();

let llvm_pointer = bx.type_ptr();

// Type of the vector of elements:
let llvm_elem_vec_ty = llvm_vector_ty(bx, values_elem, values_len);

return Ok(bx.call_intrinsic(
"llvm.masked.store",
&[llvm_elem_vec_ty, llvm_pointer],
&[args[2].immediate(), args[1].immediate(), alignment, mask],
));
let args: &[&'ll Value] = if llvm_version < (22, 0, 0) {
let alignment = bx.const_i32(alignment as i32);
&[args[2].immediate(), args[1].immediate(), alignment, mask]
} else {
&[args[2].immediate(), args[1].immediate(), mask]
};

let call = bx.call_intrinsic("llvm.masked.store", &[llvm_elem_vec_ty, llvm_pointer], args);
if llvm_version >= (22, 0, 0) {
crate::attributes::apply_to_callsite(
call,
crate::llvm::AttributePlace::Argument(1),
&[crate::llvm::CreateAlignmentAttr(bx.llcx, alignment)],
)
}
return Ok(call);
}

if name == sym::simd_scatter {
Expand Down Expand Up @@ -2040,7 +2077,7 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
);

// Alignment of T, must be a constant integer value:
let alignment = bx.const_i32(bx.align_of(in_elem).bytes() as i32);
let alignment = bx.align_of(in_elem).bytes();

// Truncate the mask vector to a vector of i1s:
let mask = vector_mask_to_bitmask(bx, args[2].immediate(), mask_elem_bitwidth, in_len);
Expand All @@ -2050,12 +2087,25 @@ fn generic_simd_intrinsic<'ll, 'tcx>(

// Type of the vector of elements:
let llvm_elem_vec_ty = llvm_vector_ty(bx, element_ty0, in_len);

return Ok(bx.call_intrinsic(
let args: &[&'ll Value] = if llvm_version < (22, 0, 0) {
let alignment = bx.const_i32(alignment as i32);
&[args[0].immediate(), args[1].immediate(), alignment, mask]
} else {
&[args[0].immediate(), args[1].immediate(), mask]
};
let call = bx.call_intrinsic(
"llvm.masked.scatter",
&[llvm_elem_vec_ty, llvm_pointer_vec_ty],
&[args[0].immediate(), args[1].immediate(), alignment, mask],
));
args,
);
if llvm_version >= (22, 0, 0) {
crate::attributes::apply_to_callsite(
call,
crate::llvm::AttributePlace::Argument(1),
&[crate::llvm::CreateAlignmentAttr(bx.llcx, alignment)],
)
}
return Ok(call);
}

macro_rules! arith_red {
Expand Down
13 changes: 10 additions & 3 deletions tests/codegen-llvm/simd-intrinsic/simd-intrinsic-generic-gather.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
//

//@ compile-flags: -C no-prepopulate-passes
//@ revisions: LLVM21 LLVM22
//@ [LLVM22] min-llvm-version: 22
//@ [LLVM21] max-llvm-major-version: 21
// ignore-tidy-linelength

#![crate_type = "lib"]
#![feature(repr_simd, core_intrinsics)]
Expand All @@ -24,7 +28,8 @@ pub unsafe fn gather_f32x2(
) -> Vec2<f32> {
// CHECK: [[A:%[0-9]+]] = lshr <2 x i32> {{.*}}, {{<i32 31, i32 31>|splat \(i32 31\)}}
// CHECK: [[B:%[0-9]+]] = trunc <2 x i32> [[A]] to <2 x i1>
// CHECK: call <2 x float> @llvm.masked.gather.v2f32.v2p0(<2 x ptr> {{.*}}, i32 {{.*}}, <2 x i1> [[B]], <2 x float> {{.*}})
// LLVM21: call <2 x float> @llvm.masked.gather.v2f32.v2p0(<2 x ptr> {{.*}}, i32 {{.*}}, <2 x i1> [[B]], <2 x float> {{.*}})
// LLVM22: call <2 x float> @llvm.masked.gather.v2f32.v2p0(<2 x ptr> align {{.*}} {{.*}}, <2 x i1> [[B]], <2 x float> {{.*}})
simd_gather(values, pointers, mask)
}

Expand All @@ -37,7 +42,8 @@ pub unsafe fn gather_f32x2_unsigned(
) -> Vec2<f32> {
// CHECK: [[A:%[0-9]+]] = lshr <2 x i32> {{.*}}, {{<i32 31, i32 31>|splat \(i32 31\)}}
// CHECK: [[B:%[0-9]+]] = trunc <2 x i32> [[A]] to <2 x i1>
// CHECK: call <2 x float> @llvm.masked.gather.v2f32.v2p0(<2 x ptr> {{.*}}, i32 {{.*}}, <2 x i1> [[B]], <2 x float> {{.*}})
// LLVM21: call <2 x float> @llvm.masked.gather.v2f32.v2p0(<2 x ptr> {{.*}}, i32 {{.*}}, <2 x i1> [[B]], <2 x float> {{.*}})
// LLVM22: call <2 x float> @llvm.masked.gather.v2f32.v2p0(<2 x ptr> align {{.*}} {{.*}}, <2 x i1> [[B]], <2 x float> {{.*}})
simd_gather(values, pointers, mask)
}

Expand All @@ -50,6 +56,7 @@ pub unsafe fn gather_pf32x2(
) -> Vec2<*const f32> {
// CHECK: [[A:%[0-9]+]] = lshr <2 x i32> {{.*}}, {{<i32 31, i32 31>|splat \(i32 31\)}}
// CHECK: [[B:%[0-9]+]] = trunc <2 x i32> [[A]] to <2 x i1>
// CHECK: call <2 x ptr> @llvm.masked.gather.v2p0.v2p0(<2 x ptr> {{.*}}, i32 {{.*}}, <2 x i1> [[B]], <2 x ptr> {{.*}})
// LLVM21: call <2 x ptr> @llvm.masked.gather.v2p0.v2p0(<2 x ptr> {{.*}}, i32 {{.*}}, <2 x i1> [[B]], <2 x ptr> {{.*}})
// LLVM22: call <2 x ptr> @llvm.masked.gather.v2p0.v2p0(<2 x ptr> align {{.*}} {{.*}}, <2 x i1> [[B]], <2 x ptr> {{.*}})
simd_gather(values, pointers, mask)
}
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
//@ compile-flags: -C no-prepopulate-passes
//@ revisions: LLVM21 LLVM22
//@ [LLVM22] min-llvm-version: 22
//@ [LLVM21] max-llvm-major-version: 21
// ignore-tidy-linelength

#![crate_type = "lib"]
#![feature(repr_simd, core_intrinsics)]
Expand All @@ -18,7 +22,8 @@ pub type Vec4<T> = Simd<T, 4>;
pub unsafe fn load_f32x2(mask: Vec2<i32>, pointer: *const f32, values: Vec2<f32>) -> Vec2<f32> {
// CHECK: [[A:%[0-9]+]] = lshr <2 x i32> {{.*}}, {{<i32 31, i32 31>|splat \(i32 31\)}}
// CHECK: [[B:%[0-9]+]] = trunc <2 x i32> [[A]] to <2 x i1>
// CHECK: call <2 x float> @llvm.masked.load.v2f32.p0(ptr {{.*}}, i32 4, <2 x i1> [[B]], <2 x float> {{.*}})
// LLVM21: call <2 x float> @llvm.masked.load.v2f32.p0(ptr {{.*}}, i32 4, <2 x i1> [[B]], <2 x float> {{.*}})
// LLVM22: call <2 x float> @llvm.masked.load.v2f32.p0(ptr align 4 {{.*}}, <2 x i1> [[B]], <2 x float> {{.*}})
simd_masked_load(mask, pointer, values)
}

Expand All @@ -31,7 +36,8 @@ pub unsafe fn load_f32x2_unsigned(
) -> Vec2<f32> {
// CHECK: [[A:%[0-9]+]] = lshr <2 x i32> {{.*}}, {{<i32 31, i32 31>|splat \(i32 31\)}}
// CHECK: [[B:%[0-9]+]] = trunc <2 x i32> [[A]] to <2 x i1>
// CHECK: call <2 x float> @llvm.masked.load.v2f32.p0(ptr {{.*}}, i32 4, <2 x i1> [[B]], <2 x float> {{.*}})
// LLVM21: call <2 x float> @llvm.masked.load.v2f32.p0(ptr {{.*}}, i32 4, <2 x i1> [[B]], <2 x float> {{.*}})
// LLVM22: call <2 x float> @llvm.masked.load.v2f32.p0(ptr align 4 {{.*}}, <2 x i1> [[B]], <2 x float> {{.*}})
simd_masked_load(mask, pointer, values)
}

Expand All @@ -44,6 +50,7 @@ pub unsafe fn load_pf32x4(
) -> Vec4<*const f32> {
// CHECK: [[A:%[0-9]+]] = lshr <4 x i32> {{.*}}, {{<i32 31, i32 31, i32 31, i32 31>|splat \(i32 31\)}}
// CHECK: [[B:%[0-9]+]] = trunc <4 x i32> [[A]] to <4 x i1>
// CHECK: call <4 x ptr> @llvm.masked.load.v4p0.p0(ptr {{.*}}, i32 {{.*}}, <4 x i1> [[B]], <4 x ptr> {{.*}})
// LLVM21: call <4 x ptr> @llvm.masked.load.v4p0.p0(ptr {{.*}}, i32 {{.*}}, <4 x i1> [[B]], <4 x ptr> {{.*}})
// LLVM22: call <4 x ptr> @llvm.masked.load.v4p0.p0(ptr align {{.*}} {{.*}}, <4 x i1> [[B]], <4 x ptr> {{.*}})
simd_masked_load(mask, pointer, values)
}
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
//@ compile-flags: -C no-prepopulate-passes
//@ revisions: LLVM21 LLVM22
//@ [LLVM22] min-llvm-version: 22
//@ [LLVM21] max-llvm-major-version: 21
// ignore-tidy-linelength

#![crate_type = "lib"]
#![feature(repr_simd, core_intrinsics)]
Expand All @@ -18,7 +22,8 @@ pub type Vec4<T> = Simd<T, 4>;
pub unsafe fn store_f32x2(mask: Vec2<i32>, pointer: *mut f32, values: Vec2<f32>) {
// CHECK: [[A:%[0-9]+]] = lshr <2 x i32> {{.*}}, {{<i32 31, i32 31>|splat \(i32 31\)}}
// CHECK: [[B:%[0-9]+]] = trunc <2 x i32> [[A]] to <2 x i1>
// CHECK: call void @llvm.masked.store.v2f32.p0(<2 x float> {{.*}}, ptr {{.*}}, i32 4, <2 x i1> [[B]])
// LLVM21: call void @llvm.masked.store.v2f32.p0(<2 x float> {{.*}}, ptr {{.*}}, i32 4, <2 x i1> [[B]])
// LLVM22: call void @llvm.masked.store.v2f32.p0(<2 x float> {{.*}}, ptr align 4 {{.*}}, <2 x i1> [[B]])
simd_masked_store(mask, pointer, values)
}

Expand All @@ -27,7 +32,8 @@ pub unsafe fn store_f32x2(mask: Vec2<i32>, pointer: *mut f32, values: Vec2<f32>)
pub unsafe fn store_f32x2_unsigned(mask: Vec2<u32>, pointer: *mut f32, values: Vec2<f32>) {
// CHECK: [[A:%[0-9]+]] = lshr <2 x i32> {{.*}}, {{<i32 31, i32 31>|splat \(i32 31\)}}
// CHECK: [[B:%[0-9]+]] = trunc <2 x i32> [[A]] to <2 x i1>
// CHECK: call void @llvm.masked.store.v2f32.p0(<2 x float> {{.*}}, ptr {{.*}}, i32 4, <2 x i1> [[B]])
// LLVM21: call void @llvm.masked.store.v2f32.p0(<2 x float> {{.*}}, ptr {{.*}}, i32 4, <2 x i1> [[B]])
// LLVM22: call void @llvm.masked.store.v2f32.p0(<2 x float> {{.*}}, ptr align 4 {{.*}}, <2 x i1> [[B]])
simd_masked_store(mask, pointer, values)
}

Expand All @@ -36,6 +42,7 @@ pub unsafe fn store_f32x2_unsigned(mask: Vec2<u32>, pointer: *mut f32, values: V
pub unsafe fn store_pf32x4(mask: Vec4<i32>, pointer: *mut *const f32, values: Vec4<*const f32>) {
// CHECK: [[A:%[0-9]+]] = lshr <4 x i32> {{.*}}, {{<i32 31, i32 31, i32 31, i32 31>|splat \(i32 31\)}}
// CHECK: [[B:%[0-9]+]] = trunc <4 x i32> [[A]] to <4 x i1>
// CHECK: call void @llvm.masked.store.v4p0.p0(<4 x ptr> {{.*}}, ptr {{.*}}, i32 {{.*}}, <4 x i1> [[B]])
// LLVM21: call void @llvm.masked.store.v4p0.p0(<4 x ptr> {{.*}}, ptr {{.*}}, i32 {{.*}}, <4 x i1> [[B]])
// LLVM22: call void @llvm.masked.store.v4p0.p0(<4 x ptr> {{.*}}, ptr align {{.*}} {{.*}}, <4 x i1> [[B]])
simd_masked_store(mask, pointer, values)
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
//

//@ compile-flags: -C no-prepopulate-passes
//@ revisions: LLVM21 LLVM22
//@ [LLVM22] min-llvm-version: 22
//@ [LLVM21] max-llvm-major-version: 21
// ignore-tidy-linelength

#![crate_type = "lib"]
#![feature(repr_simd, core_intrinsics)]
Expand All @@ -20,7 +24,8 @@ pub type Vec4<T> = Simd<T, 4>;
pub unsafe fn scatter_f32x2(pointers: Vec2<*mut f32>, mask: Vec2<i32>, values: Vec2<f32>) {
// CHECK: [[A:%[0-9]+]] = lshr <2 x i32> {{.*}}, {{<i32 31, i32 31>|splat \(i32 31\)}}
// CHECK: [[B:%[0-9]+]] = trunc <2 x i32> [[A]] to <2 x i1>
// CHECK: call void @llvm.masked.scatter.v2f32.v2p0(<2 x float> {{.*}}, <2 x ptr> {{.*}}, i32 {{.*}}, <2 x i1> [[B]]
// LLVM21: call void @llvm.masked.scatter.v2f32.v2p0(<2 x float> {{.*}}, <2 x ptr> {{.*}}, i32 {{.*}}, <2 x i1> [[B]]
// LLVM22: call void @llvm.masked.scatter.v2f32.v2p0(<2 x float> {{.*}}, <2 x ptr> align {{.*}} {{.*}}, <2 x i1> [[B]]
simd_scatter(values, pointers, mask)
}

Expand All @@ -29,7 +34,8 @@ pub unsafe fn scatter_f32x2(pointers: Vec2<*mut f32>, mask: Vec2<i32>, values: V
pub unsafe fn scatter_f32x2_unsigned(pointers: Vec2<*mut f32>, mask: Vec2<u32>, values: Vec2<f32>) {
// CHECK: [[A:%[0-9]+]] = lshr <2 x i32> {{.*}}, {{<i32 31, i32 31>|splat \(i32 31\)}}
// CHECK: [[B:%[0-9]+]] = trunc <2 x i32> [[A]] to <2 x i1>
// CHECK: call void @llvm.masked.scatter.v2f32.v2p0(<2 x float> {{.*}}, <2 x ptr> {{.*}}, i32 {{.*}}, <2 x i1> [[B]]
// LLVM21: call void @llvm.masked.scatter.v2f32.v2p0(<2 x float> {{.*}}, <2 x ptr> {{.*}}, i32 {{.*}}, <2 x i1> [[B]]
// LLVM22: call void @llvm.masked.scatter.v2f32.v2p0(<2 x float> {{.*}}, <2 x ptr> align {{.*}} {{.*}}, <2 x i1> [[B]]
simd_scatter(values, pointers, mask)
}

Expand All @@ -42,6 +48,7 @@ pub unsafe fn scatter_pf32x2(
) {
// CHECK: [[A:%[0-9]+]] = lshr <2 x i32> {{.*}}, {{<i32 31, i32 31>|splat \(i32 31\)}}
// CHECK: [[B:%[0-9]+]] = trunc <2 x i32> [[A]] to <2 x i1>
// CHECK: call void @llvm.masked.scatter.v2p0.v2p0(<2 x ptr> {{.*}}, <2 x ptr> {{.*}}, i32 {{.*}}, <2 x i1> [[B]]
// LLVM21: call void @llvm.masked.scatter.v2p0.v2p0(<2 x ptr> {{.*}}, <2 x ptr> {{.*}}, i32 {{.*}}, <2 x i1> [[B]]
// LLVM22: call void @llvm.masked.scatter.v2p0.v2p0(<2 x ptr> {{.*}}, <2 x ptr> align {{.*}} {{.*}}, <2 x i1> [[B]]
simd_scatter(values, pointers, mask)
}
Loading