Skip to content

Commit 155a726

Browse files
committed
Add AMX autocast tests
1 parent e1cec6b commit 155a726

File tree

1 file changed

+25
-0
lines changed

1 file changed

+25
-0
lines changed

tests/codegen/inject-autocast.rs

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,32 @@
66

77
use std::simd::{f32x4, i16x8, i64x2};
88

9+
#[repr(simd)]
10+
pub struct Tile([i8; 1024]);
11+
912
#[repr(C, packed)]
1013
pub struct Bar(u32, i64x2, i64x2, i64x2, i64x2, i64x2, i64x2);
1114
// CHECK: %Bar = type <{ i32, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64> }>
1215

1316
#[repr(simd)]
1417
pub struct f16x8([f16; 8]);
1518

19+
// CHECK-LABEL: @amx_autocast
20+
#[no_mangle]
21+
pub unsafe fn amx_autocast(m: u16, n: u16, k: u16, a: Tile, b: Tile, c: Tile) -> Tile {
22+
extern "unadjusted" {
23+
#[link_name = "llvm.x86.tdpbuud.internal"]
24+
fn foo(m: u16, n: u16, k: u16, a: Tile, b: Tile, c: Tile) -> Tile;
25+
}
26+
27+
// CHECK: %3 = call x86_amx @llvm.x86.cast.vector.to.tile.v1024i8(<1024 x i8> %0)
28+
// CHECK-NEXT: %4 = call x86_amx @llvm.x86.cast.vector.to.tile.v1024i8(<1024 x i8> %1)
29+
// CHECK-NEXT: %5 = call x86_amx @llvm.x86.cast.vector.to.tile.v1024i8(<1024 x i8> %2)
30+
// CHECK-NEXT: %6 = call x86_amx @llvm.x86.tdpbuud.internal(i16 %m, i16 %n, i16 %k, x86_amx %3, x86_amx %4, x86_amx %5)
31+
// CHECK-NEXT: %7 = call <1024 x i8> @llvm.x86.cast.tile.to.vector.v1024i8(x86_amx %6)
32+
foo(m, n, k, a, b, c)
33+
}
34+
1635
// CHECK-LABEL: @struct_with_i1_vector_autocast
1736
#[no_mangle]
1837
pub unsafe fn struct_with_i1_vector_autocast(a: i64x2, b: i64x2) -> (u8, u8) {
@@ -85,6 +104,12 @@ pub unsafe fn i1_vector_autocast(a: f16x8) -> u8 {
85104
foo(a, 1)
86105
}
87106

107+
// CHECK: declare x86_amx @llvm.x86.tdpbuud.internal(i16, i16, i16, x86_amx, x86_amx, x86_amx)
108+
109+
// CHECK: declare x86_amx @llvm.x86.cast.vector.to.tile.v1024i8(<1024 x i8>)
110+
111+
// CHECK: declare <1024 x i8> @llvm.x86.cast.tile.to.vector.v1024i8(x86_amx)
112+
88113
// CHECK: declare { <2 x i1>, <2 x i1> } @llvm.x86.avx512.vp2intersect.q.128(<2 x i64>, <2 x i64>)
89114

90115
// CHECK: declare <8 x bfloat> @llvm.x86.vcvtneps2bf16128(<4 x float>)

0 commit comments

Comments
 (0)