|
6 | 6 |
|
7 | 7 | use std::simd::{f32x4, i16x8, i64x2}; |
8 | 8 |
|
| 9 | +#[repr(simd)] |
| 10 | +pub struct Tile([i8; 1024]); |
| 11 | + |
9 | 12 | #[repr(C, packed)] |
10 | 13 | pub struct Bar(u32, i64x2, i64x2, i64x2, i64x2, i64x2, i64x2); |
11 | 14 | // CHECK: %Bar = type <{ i32, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64> }> |
12 | 15 |
|
13 | 16 | #[repr(simd)] |
14 | 17 | pub struct f16x8([f16; 8]); |
15 | 18 |
|
| 19 | +// CHECK-LABEL: @amx_autocast |
| 20 | +#[no_mangle] |
| 21 | +pub unsafe fn amx_autocast(m: u16, n: u16, k: u16, a: Tile, b: Tile, c: Tile) -> Tile { |
| 22 | + extern "unadjusted" { |
| 23 | + #[link_name = "llvm.x86.tdpbuud.internal"] |
| 24 | + fn foo(m: u16, n: u16, k: u16, a: Tile, b: Tile, c: Tile) -> Tile; |
| 25 | + } |
| 26 | + |
| 27 | + // CHECK: %3 = call x86_amx @llvm.x86.cast.vector.to.tile.v1024i8(<1024 x i8> %0) |
| 28 | + // CHECK-NEXT: %4 = call x86_amx @llvm.x86.cast.vector.to.tile.v1024i8(<1024 x i8> %1) |
| 29 | + // CHECK-NEXT: %5 = call x86_amx @llvm.x86.cast.vector.to.tile.v1024i8(<1024 x i8> %2) |
| 30 | + // CHECK-NEXT: %6 = call x86_amx @llvm.x86.tdpbuud.internal(i16 %m, i16 %n, i16 %k, x86_amx %3, x86_amx %4, x86_amx %5) |
| 31 | + // CHECK-NEXT: %7 = call <1024 x i8> @llvm.x86.cast.tile.to.vector.v1024i8(x86_amx %6) |
| 32 | + foo(m, n, k, a, b, c) |
| 33 | +} |
| 34 | + |
16 | 35 | // CHECK-LABEL: @struct_with_i1_vector_autocast |
17 | 36 | #[no_mangle] |
18 | 37 | pub unsafe fn struct_with_i1_vector_autocast(a: i64x2, b: i64x2) -> (u8, u8) { |
@@ -85,6 +104,12 @@ pub unsafe fn i1_vector_autocast(a: f16x8) -> u8 { |
85 | 104 | foo(a, 1) |
86 | 105 | } |
87 | 106 |
|
| 107 | +// CHECK: declare x86_amx @llvm.x86.tdpbuud.internal(i16, i16, i16, x86_amx, x86_amx, x86_amx) |
| 108 | + |
| 109 | +// CHECK: declare x86_amx @llvm.x86.cast.vector.to.tile.v1024i8(<1024 x i8>) |
| 110 | + |
| 111 | +// CHECK: declare <1024 x i8> @llvm.x86.cast.tile.to.vector.v1024i8(x86_amx) |
| 112 | + |
88 | 113 | // CHECK: declare { <2 x i1>, <2 x i1> } @llvm.x86.avx512.vp2intersect.q.128(<2 x i64>, <2 x i64>) |
89 | 114 |
|
90 | 115 | // CHECK: declare <8 x bfloat> @llvm.x86.vcvtneps2bf16128(<4 x float>) |
|
0 commit comments