Skip to content

Commit 9656971

Browse files
committed
Add autocast for i1 vectors
1 parent 360f534 commit 9656971

File tree

3 files changed

+151
-17
lines changed

3 files changed

+151
-17
lines changed

compiler/rustc_codegen_llvm/src/abi.rs

Lines changed: 22 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -375,26 +375,31 @@ impl<'ll, CX: Borrow<SCx<'ll>>> GenericCx<'ll, CX> {
375375
return true;
376376
}
377377

378-
// Some LLVM intrinsics return **non-packed** structs, but they can't be mimicked from Rust
379-
// due to auto field-alignment in non-packed structs (packed structs are represented in LLVM
380-
// as, well, packed structs, so they won't match with those either)
381-
if self.type_kind(llvm_ty) == TypeKind::Struct
382-
&& self.type_kind(rust_ty) == TypeKind::Struct
383-
{
384-
let rust_element_tys = self.struct_element_types(rust_ty);
385-
let llvm_element_tys = self.struct_element_types(llvm_ty);
378+
match self.type_kind(llvm_ty) {
379+
// Some LLVM intrinsics return **non-packed** structs, but they can't be mimicked from Rust
380+
// due to auto field-alignment in non-packed structs (packed structs are represented in LLVM
381+
// as, well, packed structs, so they won't match with those either)
382+
TypeKind::Struct if self.type_kind(rust_ty) == TypeKind::Struct => {
383+
let rust_element_tys = self.struct_element_types(rust_ty);
384+
let llvm_element_tys = self.struct_element_types(llvm_ty);
385+
386+
if rust_element_tys.len() != llvm_element_tys.len() {
387+
return false;
388+
}
386389

387-
if rust_element_tys.len() != llvm_element_tys.len() {
388-
return false;
390+
iter::zip(rust_element_tys, llvm_element_tys).all(
391+
|(rust_element_ty, llvm_element_ty)| {
392+
self.equate_ty(rust_element_ty, llvm_element_ty)
393+
},
394+
)
389395
}
396+
TypeKind::Vector if self.element_type(llvm_ty) == self.type_i1() => {
397+
let element_count = self.vector_length(llvm_ty) as u64;
398+
let int_width = element_count.next_power_of_two().max(8);
390399

391-
iter::zip(rust_element_tys, llvm_element_tys).all(
392-
|(rust_element_ty, llvm_element_ty)| {
393-
self.equate_ty(rust_element_ty, llvm_element_ty)
394-
},
395-
)
396-
} else {
397-
false
400+
rust_ty == self.type_ix(int_width)
401+
}
402+
_ => false,
398403
}
399404
}
400405
}

compiler/rustc_codegen_llvm/src/builder.rs

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1703,6 +1703,46 @@ impl<'a, 'll, CX: Borrow<SCx<'ll>>> GenericBuilder<'a, 'll, CX> {
17031703
}
17041704
}
17051705
impl<'a, 'll, 'tcx> Builder<'a, 'll, 'tcx> {
1706+
fn trunc_int_to_i1_vector(&mut self, val: &'ll Value, dest_ty: &'ll Type) -> &'ll Value {
1707+
let vector_length = self.vector_length(dest_ty) as u64;
1708+
let int_width = vector_length.next_power_of_two().max(8);
1709+
1710+
let bitcasted = self.bitcast(val, self.type_vector(self.type_i1(), int_width));
1711+
if vector_length == int_width {
1712+
bitcasted
1713+
} else {
1714+
let shuffle_mask =
1715+
(0..vector_length).map(|i| self.const_i32(i as i32)).collect::<Vec<_>>();
1716+
self.shuffle_vector(bitcasted, bitcasted, self.const_vector(&shuffle_mask))
1717+
}
1718+
}
1719+
1720+
fn zext_i1_vector_to_int(
1721+
&mut self,
1722+
mut val: &'ll Value,
1723+
src_ty: &'ll Type,
1724+
dest_ty: &'ll Type,
1725+
) -> &'ll Value {
1726+
let vector_length = self.vector_length(src_ty) as u64;
1727+
let int_width = vector_length.next_power_of_two().max(8);
1728+
1729+
if vector_length != int_width {
1730+
let shuffle_indices = match vector_length {
1731+
0 => unreachable!("zero length vectors are not allowed"),
1732+
1 => vec![0, 1, 1, 1, 1, 1, 1, 1],
1733+
2 => vec![0, 1, 2, 3, 2, 3, 2, 3],
1734+
3 => vec![0, 1, 2, 3, 4, 5, 3, 4],
1735+
4.. => (0..int_width as i32).collect(),
1736+
};
1737+
let shuffle_mask =
1738+
shuffle_indices.into_iter().map(|i| self.const_i32(i)).collect::<Vec<_>>();
1739+
val =
1740+
self.shuffle_vector(val, self.const_null(src_ty), self.const_vector(&shuffle_mask));
1741+
}
1742+
1743+
self.bitcast(val, dest_ty)
1744+
}
1745+
17061746
fn autocast(
17071747
&mut self,
17081748
llfn: &'ll Value,
@@ -1731,6 +1771,13 @@ impl<'a, 'll, 'tcx> Builder<'a, 'll, 'tcx> {
17311771
}
17321772
ret
17331773
}
1774+
TypeKind::Vector if self.element_type(llvm_ty) == self.type_i1() => {
1775+
if is_argument {
1776+
self.trunc_int_to_i1_vector(val, dest_ty)
1777+
} else {
1778+
self.zext_i1_vector_to_int(val, src_ty, dest_ty)
1779+
}
1780+
}
17341781
_ => unreachable!(),
17351782
}
17361783
}
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
//@ compile-flags: -C opt-level=0
2+
//@ only-x86_64
3+
4+
#![feature(link_llvm_intrinsics, abi_unadjusted, repr_simd, simd_ffi, portable_simd, f16)]
5+
#![crate_type = "lib"]
6+
7+
use std::simd::i64x2;
8+
9+
#[repr(simd)]
10+
pub struct Tile([i8; 1024]);
11+
12+
#[repr(C, packed)]
13+
pub struct Bar(u32, i64x2, i64x2, i64x2, i64x2, i64x2, i64x2);
14+
// CHECK: %Bar = type <{ i32, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64> }>
15+
16+
#[repr(simd)]
17+
pub struct f16x8([f16; 8]);
18+
19+
// CHECK-LABEL: @struct_with_i1_vector_autocast
20+
#[no_mangle]
21+
pub unsafe fn struct_with_i1_vector_autocast(a: i64x2, b: i64x2) -> (u8, u8) {
22+
extern "unadjusted" {
23+
#[link_name = "llvm.x86.avx512.vp2intersect.q.128"]
24+
fn foo(a: i64x2, b: i64x2) -> (u8, u8);
25+
}
26+
27+
// CHECK: %2 = call { <2 x i1>, <2 x i1> } @llvm.x86.avx512.vp2intersect.q.128(<2 x i64> %0, <2 x i64> %1)
28+
// CHECK-NEXT: %3 = extractvalue { <2 x i1>, <2 x i1> } %2, 0
29+
// CHECK-NEXT: %4 = shufflevector <2 x i1> %3, <2 x i1> zeroinitializer, <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 2, i32 3, i32 2, i32 3>
30+
// CHECK-NEXT: %5 = bitcast <8 x i1> %4 to i8
31+
// CHECK-NEXT: %6 = insertvalue { i8, i8 } poison, i8 %5, 0
32+
// CHECK-NEXT: %7 = extractvalue { <2 x i1>, <2 x i1> } %2, 1
33+
// CHECK-NEXT: %8 = shufflevector <2 x i1> %7, <2 x i1> zeroinitializer, <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 2, i32 3, i32 2, i32 3>
34+
// CHECK-NEXT: %9 = bitcast <8 x i1> %8 to i8
35+
// CHECK-NEXT: %10 = insertvalue { i8, i8 } %6, i8 %9, 1
36+
foo(a, b)
37+
}
38+
39+
// CHECK-LABEL: @struct_autocast
40+
#[no_mangle]
41+
pub unsafe fn struct_autocast(key_metadata: u32, key: i64x2) -> Bar {
42+
extern "unadjusted" {
43+
#[link_name = "llvm.x86.encodekey128"]
44+
fn foo(key_metadata: u32, key: i64x2) -> Bar;
45+
}
46+
47+
// CHECK: %1 = call { i32, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64> } @llvm.x86.encodekey128(i32 %key_metadata, <2 x i64> %0)
48+
// CHECK-NEXT: %2 = extractvalue { i32, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64> } %1, 0
49+
// CHECK-NEXT: %3 = insertvalue %Bar poison, i32 %2, 0
50+
// CHECK-NEXT: %4 = extractvalue { i32, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64> } %1, 1
51+
// CHECK-NEXT: %5 = insertvalue %Bar %3, <2 x i64> %4, 1
52+
// CHECK-NEXT: %6 = extractvalue { i32, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64> } %1, 2
53+
// CHECK-NEXT: %7 = insertvalue %Bar %5, <2 x i64> %6, 2
54+
// CHECK-NEXT: %8 = extractvalue { i32, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64> } %1, 3
55+
// CHECK-NEXT: %9 = insertvalue %Bar %7, <2 x i64> %8, 3
56+
// CHECK-NEXT: %10 = extractvalue { i32, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64> } %1, 4
57+
// CHECK-NEXT: %11 = insertvalue %Bar %9, <2 x i64> %10, 4
58+
// CHECK-NEXT: %12 = extractvalue { i32, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64> } %1, 5
59+
// CHECK-NEXT: %13 = insertvalue %Bar %11, <2 x i64> %12, 5
60+
// CHECK-NEXT: %14 = extractvalue { i32, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64> } %1, 6
61+
// CHECK-NEXT: %15 = insertvalue %Bar %13, <2 x i64> %14, 6
62+
foo(key_metadata, key)
63+
}
64+
65+
// CHECK-LABEL: @i1_vector_autocast
66+
#[no_mangle]
67+
pub unsafe fn i1_vector_autocast(a: f16x8) -> u8 {
68+
extern "unadjusted" {
69+
#[link_name = "llvm.x86.avx512fp16.fpclass.ph.128"]
70+
fn foo(a: f16x8, b: i32) -> u8;
71+
}
72+
73+
// CHECK: %1 = call <8 x i1> @llvm.x86.avx512fp16.fpclass.ph.128(<8 x half> %0, i32 1)
74+
// CHECK-NEXT: %_0 = bitcast <8 x i1> %1 to i8
75+
foo(a, 1)
76+
}
77+
78+
// CHECK: declare { <2 x i1>, <2 x i1> } @llvm.x86.avx512.vp2intersect.q.128(<2 x i64>, <2 x i64>)
79+
80+
// CHECK: declare { i32, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64> } @llvm.x86.encodekey128(i32, <2 x i64>)
81+
82+
// CHECK: declare <8 x i1> @llvm.x86.avx512fp16.fpclass.ph.128(<8 x half>, i32 immarg)

0 commit comments

Comments
 (0)