Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions lib/Target/LLVMIR/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ add_triton_library(TritonLLVMIR
MLIRIndexToLLVM
MLIRIR
MLIRLLVMDialect
MLIRNVVMToLLVM
MLIRLLVMToLLVMIRTranslation
MLIRNVVMToLLVMIRTranslation
MLIRROCDLToLLVMIRTranslation
Expand Down
1 change: 1 addition & 0 deletions python/src/passes.cc
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ void init_triton_passes_convert(py::module &&m) {
ADD_PASS_WRAPPER_0("add_cf_to_llvmir", createConvertControlFlowToLLVMPass);
ADD_PASS_WRAPPER_0("add_index_to_llvmir", createConvertIndexToLLVMPass);
ADD_PASS_WRAPPER_0("add_arith_to_llvmir", createArithToLLVMConversionPass);
ADD_PASS_WRAPPER_0("add_nvvm_to_llvm", createConvertNVVMToLLVMPass);
}

void init_triton_passes_llvmir(py::module &&m) {
Expand Down
6 changes: 4 additions & 2 deletions python/test/unit/language/test_tensor_descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -609,7 +609,8 @@ def alloc_fn(size: int, align: int, stream: Optional[int]):
if BLOCK_M >= 64 * num_ctas and BLOCK_N >= 64 and is_hopper():
# TODO: The use of stmatrix for Blackwell is currently not supported.
# Only a subset of TMEM and stmatrix layout pairs are compatible, for example 16x256bx2 and m8n8x4.
assert "stmatrix.sync.aligned.m8n8.x4.shared.b16" in kernel.asm["ptx"]
assert "stmatrix.sync.aligned.m8n8.x4.shared.b16" in kernel.asm[
"ptx"] or "stmatrix.sync.aligned.x4.m8n8.shared.b16" in kernel.asm["ptx"]


@triton.jit
Expand Down Expand Up @@ -1668,4 +1669,5 @@ def test_host_tensor_descriptor_matmul(num_stages, num_ctas, BLOCK_M, BLOCK_N, B
if BLOCK_M >= 64 * num_ctas and BLOCK_N >= 64 and is_cuda() and is_hopper():
# TODO: The use of stmatrix for Blackwell is currently not supported.
# Only a subset of TMEM and stmatrix layout pairs are compatible, for example 16x256bx2 and m8n8x4.
assert "stmatrix.sync.aligned.m8n8.x4.shared.b16" in kernel.asm["ptx"]
assert "stmatrix.sync.aligned.m8n8.x4.shared.b16" in kernel.asm[
"ptx"] or "stmatrix.sync.aligned.x4.m8n8.shared.b16" in kernel.asm["ptx"]
34 changes: 0 additions & 34 deletions test/Conversion/nvgpu_to_llvm.mlir
Original file line number Diff line number Diff line change
@@ -1,15 +1,5 @@
// RUN: triton-opt %s --convert-nv-gpu-to-llvm -allow-unregistered-dialect -split-input-file | FileCheck %s

// CHECK-LABEL: @nvvm_syncs
llvm.func @nvvm_syncs() {
// CHECK: fence.proxy.async.shared::cta;
nvgpu.fence_async_shared {bCluster = false}
// CHECK: fence.proxy.async.shared::cluster;
nvgpu.fence_async_shared {bCluster = true}

llvm.return
}

// CHECK-LABEL: @cluster_id
llvm.func @cluster_id() -> i32 {
// CHECK: %cluster_ctaid.x;
Expand All @@ -23,30 +13,6 @@ llvm.func @cluster_id() -> i32 {

// -----

// CHECK-LABEL: @stmatrix
llvm.func @stmatrix(%i: i32, %ptr: !llvm.ptr<3>) {
// CHECK: stmatrix.sync.aligned.m8n8.x4.shared.b16 [$0], {$1, $2, $3, $4};
nvgpu.stmatrix %ptr, %i, %i, %i, %i : !llvm.ptr<3>, i32, i32, i32, i32
// CHECK: stmatrix.sync.aligned.m8n8.x4.trans.shared.b16 [$0], {$1, $2, $3, $4};
nvgpu.stmatrix %ptr, %i, %i, %i, %i {trans} : !llvm.ptr<3>, i32, i32, i32, i32
llvm.return
}

// -----

// CHECK-LABEL: @ldmatrix
llvm.func @ldmatrix(%ptr: !llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)> {
// CHECK: ldmatrix.sync.aligned.m8n8.x4.shared.b16 {$0, $1, $2, $3}, [$4];
%0 = nvgpu.ldmatrix %ptr : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
// CHECK: ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 {$0, $1, $2, $3}, [$4];
%1 = nvgpu.ldmatrix %ptr {trans} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
%2 = llvm.extractvalue %1[0] : !llvm.struct<(i32, i32, i32, i32)>
%3 = llvm.insertvalue %2, %0[0] : !llvm.struct<(i32, i32, i32, i32)>
llvm.return %3 : !llvm.struct<(i32, i32, i32, i32)>
}

// -----

!struct_128xf32 = !llvm.struct<(
f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32,
f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32,
Expand Down
31 changes: 16 additions & 15 deletions test/Conversion/tritongpu_to_llvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -880,9 +880,9 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} {
tt.func @convert_dot_ldmatrix(%A: tensor<16x16xf16, #blocked0>, %B: tensor<16x16xf16, #blocked0>) {
%AA = ttg.local_alloc %A : (tensor<16x16xf16, #blocked0>) -> !ttg.memdesc<16x16xf16, #shared0, #smem>
%BB = ttg.local_alloc %B : (tensor<16x16xf16, #blocked0>) -> !ttg.memdesc<16x16xf16, #shared0, #smem>
// CHECK: nvgpu.ldmatrix %{{.*}} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
// CHECK: nvgpu.ldmatrix %{{.*}} {trans} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
// CHECK-NOT: nvgpu.ldmatrix
// CHECK: nvvm.ldmatrix %{{.*}} {layout = #nvvm.mma_layout<row>, num = 4 : i32} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
// CHECK: nvvm.ldmatrix %{{.*}} {layout = #nvvm.mma_layout<col>, num = 4 : i32} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
// CHECK-NOT: nvvm.ldmatrix
%AA_DOT = ttg.local_load %AA : !ttg.memdesc<16x16xf16, #shared0, #smem> -> tensor<16x16xf16, #dot_operand_a>
%BB_DOT = ttg.local_load %BB : !ttg.memdesc<16x16xf16, #shared0, #smem> -> tensor<16x16xf16, #dot_operand_b>
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf32, #mma0>
Expand Down Expand Up @@ -910,9 +910,9 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} {
tt.func @convert_dot_ldmatrix_swizzle(%A: tensor<16x16xf16, #blocked0>, %B: tensor<16x16xf16, #blocked0>) {
%AA = ttg.local_alloc %A : (tensor<16x16xf16, #blocked0>) -> !ttg.memdesc<16x16xf16, #shared0, #smem>
%BB = ttg.local_alloc %B : (tensor<16x16xf16, #blocked0>) -> !ttg.memdesc<16x16xf16, #shared0, #smem>
// CHECK: nvgpu.ldmatrix %{{.*}} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
// CHECK: nvgpu.ldmatrix %{{.*}} {trans} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
// CHECK-NOT: nvgpu.ldmatrix
// CHECK: nvvm.ldmatrix %{{.*}} {layout = #nvvm.mma_layout<row>, num = 4 : i32} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
// CHECK: nvvm.ldmatrix %{{.*}} {layout = #nvvm.mma_layout<col>, num = 4 : i32} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
// CHECK-NOT: nvvm.ldmatrix
%AA_DOT = ttg.local_load %AA : !ttg.memdesc<16x16xf16, #shared0, #smem> -> tensor<16x16xf16, #dot_operand_a>
%BB_DOT = ttg.local_load %BB : !ttg.memdesc<16x16xf16, #shared0, #smem> -> tensor<16x16xf16, #dot_operand_b>
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf32, #mma0>
Expand Down Expand Up @@ -940,7 +940,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} {
tt.func @convert_dot(%A: tensor<16x16xf16, #blocked0>, %B: tensor<16x16xf16, #blocked0>) {
%AA = ttg.local_alloc %A : (tensor<16x16xf16, #blocked0>) -> !ttg.memdesc<16x16xf16, #shared0, #smem>
%BB = ttg.local_alloc %B : (tensor<16x16xf16, #blocked0>) -> !ttg.memdesc<16x16xf16, #shared0, #smem>
// CHECK-NOT: nvgpu.ldmatrix
// CHECK-NOT: nvvm.ldmatrix
%AA_DOT = ttg.local_load %AA : !ttg.memdesc<16x16xf16, #shared0, #smem> -> tensor<16x16xf16, #dot_operand_a>
%BB_DOT = ttg.local_load %BB : !ttg.memdesc<16x16xf16, #shared0, #smem> -> tensor<16x16xf16, #dot_operand_b>
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf32, #mma0>
Expand Down Expand Up @@ -968,7 +968,8 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} {
tt.func @convert_dot_mmav3_shared(%A: tensor<64x64xf16, #blocked0>, %B: tensor<64x64xf16, #blocked0>) {
%AA = ttg.local_alloc %A : (tensor<64x64xf16, #blocked0>) -> !ttg.memdesc<64x64xf16, #shared0, #smem>
%BB = ttg.local_alloc %B : (tensor<64x64xf16, #blocked0>) -> !ttg.memdesc<64x64xf16, #shared0, #smem>
// CHECK-COUNT-32: nvgpu.ldmatrix %{{.*}} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
// CHECK-COUNT-16: nvvm.ldmatrix %{{.*}} {layout = #nvvm.mma_layout<row>, num = 4 : i32} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
// CHECK-COUNT-16: nvvm.ldmatrix %{{.*}} {layout = #nvvm.mma_layout<col>, num = 4 : i32} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
%AA_DOT = ttg.local_load %AA : !ttg.memdesc<64x64xf16, #shared0, #smem> -> tensor<64x64xf16, #dot_operand_a>
%BB_DOT = ttg.local_load %BB : !ttg.memdesc<64x64xf16, #shared0, #smem> -> tensor<64x64xf16, #dot_operand_b>
%cst0 = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #mma0>
Expand All @@ -992,8 +993,8 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} {
tt.func @convert_dot_fp8(%A: tensor<16x16xf8E5M2, #blocked0>, %B: tensor<16x16xf8E5M2, #blocked0>) {
%AA = ttg.local_alloc %A : (tensor<16x16xf8E5M2, #blocked0>) -> !ttg.memdesc<16x16xf8E5M2, #shared0, #smem>
%BB = ttg.local_alloc %B : (tensor<16x16xf8E5M2, #blocked0>) -> !ttg.memdesc<16x16xf8E5M2, #shared0, #smem>
// CHECK: nvgpu.ldmatrix %{{.*}} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
// CHECK-NOT: nvgpu.ldmatrix
// CHECK: nvvm.ldmatrix %{{.*}} {layout = #nvvm.mma_layout<row>, num = 2 : i32} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
// CHECK-NOT: nvvm.ldmatrix
%AA_DOT = ttg.local_load %AA : !ttg.memdesc<16x16xf8E5M2, #shared0, #smem> -> tensor<16x16xf8E5M2, #dot_operand_a>
%BB_DOT = ttg.local_load %BB : !ttg.memdesc<16x16xf8E5M2, #shared0, #smem> -> tensor<16x16xf8E5M2, #dot_operand_b>
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf32, #mma0>
Expand Down Expand Up @@ -1308,7 +1309,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
tt.func @matmul_kernel_dot_operand_layout(%ptr:!tt.ptr<f32> {tt.divisibility = 16 : i32},
%a:!ttg.memdesc<128x32xf16, #shared, #smem>, %b:!ttg.memdesc<32x256xf16, #shared, #smem>) {
%cst = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #mma>
// CHECK: nvgpu.ldmatrix
// CHECK: nvvm.ldmatrix
%a_mat = ttg.local_load %a : !ttg.memdesc<128x32xf16, #shared, #smem> -> tensor<128x32xf16, #dot_operand_a>
%b_mat = ttg.local_load %b : !ttg.memdesc<32x256xf16, #shared, #smem> -> tensor<32x256xf16, #dot_operand_b>

Expand Down Expand Up @@ -1384,9 +1385,9 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
tt.func @matmul_tf32dot(%ptr:!tt.ptr<f32> {tt.divisibility = 16 : i32},
%a:!ttg.memdesc<32x16xf32, #shared, #smem>, %b:!ttg.memdesc<16x32xf32, #shared, #smem>) {
%cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #mma>
// CHECK: nvgpu.ldmatrix
// CHECK: nvvm.ldmatrix
// CHECK-SAME: (i32, i32, i32, i32)
// CHECK: nvgpu.ldmatrix
// CHECK: nvvm.ldmatrix
// CHECK-SAME: (i32, i32, i32, i32)
%a_mat = ttg.local_load %a : !ttg.memdesc<32x16xf32, #shared, #smem> -> tensor<32x16xf32, #dot_operand_a>
%b_mat = ttg.local_load %b : !ttg.memdesc<16x32xf32, #shared, #smem> -> tensor<16x32xf32, #dot_operand_b>
Expand Down Expand Up @@ -1875,8 +1876,8 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.thr
%f16_shared = ttg.local_alloc %f16_inp : (tensor<16x16xf16, #blocked0>) -> !ttg.memdesc<16x16xf16, #shared0, #smem>
%i16_shared = ttg.local_alloc %i16_inp : (tensor<16x16xi16, #blocked0>) -> !ttg.memdesc<16x16xi16, #shared0, #smem>

// CHECK: nvgpu.ldmatrix
// CHECK: nvgpu.ldmatrix
// CHECK: nvvm.ldmatrix
// CHECK: nvvm.ldmatrix

%f16_dot = ttg.local_load %f16_shared : !ttg.memdesc<16x16xf16, #shared0, #smem> -> tensor<16x16xf16, #dot_operand_a>
%i16_dot = ttg.local_load %i16_shared : !ttg.memdesc<16x16xi16, #shared0, #smem> -> tensor<16x16xi16, #dot_operand_b>
Expand Down
22 changes: 11 additions & 11 deletions test/Conversion/tritongpu_to_llvm_hopper.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-
// CHECK-LABEL: convert_mma_to_blocked
module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 32 : i32} {
tt.func @convert_mma_to_blocked(%a: tensor<128x256xf16, #mma>) {
// CHECK-COUNT-16: nvgpu.stmatrix
// CHECK-COUNT-16: nvvm.stmatrix
// CHECK: nvvm.barrier0
%c = ttg.convert_layout %a : tensor<128x256xf16, #mma> -> tensor<128x256xf16, #blocked>
tt.return
Expand Down Expand Up @@ -254,7 +254,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
// CHECK-LABEL: distribute_to_shared_st_matrix
module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} {
tt.func @distribute_to_shared_st_matrix(%a: tensor<128x128xf16, #mma>) {
// CHECK-COUNT-16: nvgpu.stmatrix
// CHECK-COUNT-16: nvvm.stmatrix
// CHECK: llvm.return
%b = ttg.local_alloc %a {allocation.offset = 0 : i32} : (tensor<128x128xf16, #mma>) -> !ttg.memdesc<128x128xf16, #shared, #smem, mutable>
tt.return
Expand All @@ -269,7 +269,7 @@ module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-
// CHECK-LABEL: distribute_to_shared_st_matrix_local_store
module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} {
tt.func @distribute_to_shared_st_matrix_local_store(%a: tensor<128x128xf16, #mma>) {
// CHECK-COUNT-16: nvgpu.stmatrix
// CHECK-COUNT-16: nvvm.stmatrix
// CHECK: llvm.return
%b = ttg.local_alloc {allocation.offset = 0 : i32} : () -> !ttg.memdesc<128x128xf16, #shared, #smem, mutable>
ttg.local_store %a, %b : tensor<128x128xf16, #mma> -> !ttg.memdesc<128x128xf16, #shared, #smem, mutable>
Expand All @@ -285,7 +285,7 @@ module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-
// CHECK-LABEL: distribute_to_shared_st_matrix_local_store
module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} {
tt.func @distribute_to_shared_st_matrix_local_store(%a: tensor<64x128xf16, #linear>) {
// CHECK-COUNT-8: nvgpu.stmatrix %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} {trans}
// CHECK-COUNT-8: nvvm.stmatrix %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} {layout = #nvvm.mma_layout<col>}
// CHECK: llvm.return
%b = ttg.local_alloc {allocation.offset = 0 : i32} : () -> !ttg.memdesc<64x128xf16, #shared, #smem, mutable>
ttg.local_store %a, %b : tensor<64x128xf16, #linear> -> !ttg.memdesc<64x128xf16, #shared, #smem, mutable>
Expand All @@ -301,7 +301,7 @@ module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-
// CHECK-LABEL: distribute_to_swizzled_st_matrix_local_store
module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} {
tt.func @distribute_to_swizzled_st_matrix_local_store(%a: tensor<8x64xf16, #mma>) {
// CHECK-COUNT-2: nvgpu.stmatrix
// CHECK-COUNT-2: nvvm.stmatrix
// CHECK: llvm.return
%b = ttg.local_alloc {allocation.offset = 0 : i32} : () -> !ttg.memdesc<8x64xf16, #shared, #smem, mutable>
ttg.local_store %a, %b : tensor<8x64xf16, #mma> -> !ttg.memdesc<8x64xf16, #shared, #smem, mutable>
Expand All @@ -317,7 +317,7 @@ module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-
// CHECK-LABEL: linear_to_swizzled_st_matrix_local_store
module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} {
tt.func @linear_to_swizzled_st_matrix_local_store(%a: tensor<64x32xf16, #linear>) {
// CHECK-COUNT-2: nvgpu.stmatrix %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}
// CHECK-COUNT-2: nvvm.stmatrix %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} {layout = #nvvm.mma_layout<row>}
// CHECK: llvm.return
%b = ttg.local_alloc {allocation.offset = 0 : i32} : () -> !ttg.memdesc<64x32xf16, #shared, #smem, mutable>
ttg.local_store %a, %b : tensor<64x32xf16, #linear> -> !ttg.memdesc<64x32xf16, #shared, #smem, mutable>
Expand All @@ -339,7 +339,7 @@ module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-
// CHECK-LABEL: linear_to_swizzled_st_matrix_local_store
module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} {
tt.func @linear_to_swizzled_st_matrix_local_store(%a: tensor<32x32xf16, #linear>) {
// CHECK-COUNT-2: nvgpu.stmatrix %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}
// CHECK-COUNT-2: nvvm.stmatrix %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} {layout = #nvvm.mma_layout<row>}
// CHECK: llvm.return
%b = ttg.local_alloc {allocation.offset = 0 : i32} : () -> !ttg.memdesc<32x32xf16, #shared, #smem, mutable>
ttg.local_store %a, %b : tensor<32x32xf16, #linear> -> !ttg.memdesc<32x32xf16, #shared, #smem, mutable>
Expand All @@ -355,7 +355,7 @@ module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-
// CHECK-LABEL: linear_to_swizzled_st_matrix_x2_local_store_fp8
module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} {
tt.func @linear_to_swizzled_st_matrix_x2_local_store_fp8(%a: tensor<64x16xf8E4M3FNUZ, #linear>) {
// CHECK-COUNT-1: nvgpu.stmatrix %{{.*}}, %{{.*}}, %{{.*}} :
// CHECK-COUNT-1: nvvm.stmatrix %{{.*}}, %{{.*}}, %{{.*}} {layout = #nvvm.mma_layout<row>} :
// CHECK: llvm.return
%b = ttg.local_alloc {allocation.offset = 0 : i32} : () -> !ttg.memdesc<64x16xf8E4M3FNUZ, #shared, #smem, mutable>
ttg.local_store %a, %b : tensor<64x16xf8E4M3FNUZ, #linear> -> !ttg.memdesc<64x16xf8E4M3FNUZ, #shared, #smem, mutable>
Expand All @@ -371,7 +371,7 @@ module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-
// CHECK-LABEL: linear_to_swizzled_st_matrix_local_store_fp32
module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} {
tt.func @linear_to_swizzled_st_matrix_local_store_fp32(%a: tensor<64x16xf32, #linear>) {
// CHECK-COUNT-2: nvgpu.stmatrix %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}
// CHECK-COUNT-2: nvvm.stmatrix %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} {layout = #nvvm.mma_layout<row>}
// CHECK: llvm.return
%b = ttg.local_alloc {allocation.offset = 0 : i32} : () -> !ttg.memdesc<64x16xf32, #shared, #smem, mutable>
ttg.local_store %a, %b : tensor<64x16xf32, #linear> -> !ttg.memdesc<64x16xf32, #shared, #smem, mutable>
Expand All @@ -388,7 +388,7 @@ module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-
// CHECK-LABEL: linear_to_swizzled_st_matrix_trans_local_store
module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} {
tt.func @linear_to_swizzled_st_matrix_trans_local_store(%a: tensor<64x32xf16, #linear>) {
// CHECK-COUNT-2: nvgpu.stmatrix %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} {trans}
// CHECK-COUNT-2: nvvm.stmatrix %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} {layout = #nvvm.mma_layout<col>}
// CHECK: llvm.return
%b = ttg.local_alloc {allocation.offset = 0 : i32} : () -> !ttg.memdesc<64x32xf16, #shared, #smem, mutable>
ttg.local_store %a, %b : tensor<64x32xf16, #linear> -> !ttg.memdesc<64x32xf16, #shared, #smem, mutable>
Expand All @@ -410,7 +410,7 @@ module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-
// CHECK-LABEL: linear_to_swizzled_st_matrix_trans_local_store
module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} {
tt.func @linear_to_swizzled_st_matrix_trans_local_store(%a: tensor<16x32xf16, #linear>) {
// CHECK-COUNT-2: nvgpu.stmatrix %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} {trans}
// CHECK-COUNT-2: nvvm.stmatrix %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} {layout = #nvvm.mma_layout<col>}
// CHECK: llvm.return
%b = ttg.local_alloc {allocation.offset = 0 : i32} : () -> !ttg.memdesc<16x32xf16, #shared, #smem, mutable>
ttg.local_store %a, %b : tensor<16x32xf16, #linear> -> !ttg.memdesc<16x32xf16, #shared, #smem, mutable>
Expand Down
1 change: 1 addition & 0 deletions third_party/nvidia/backend/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,7 @@ def make_llir(self, src, metadata, options, capability):
passes.common.add_canonicalizer(pm)
passes.common.add_cse(pm)
passes.common.add_symbol_dce(pm)
passes.convert.add_nvvm_to_llvm(pm)
if not knobs.compilation.disable_line_info:
passes.llvmir.add_di_scope(pm)
pm.run(mod)
Expand Down
Loading
Loading