From b840317a0b12d9dbd580e7a96ee640e2f5bb5b0e Mon Sep 17 00:00:00 2001 From: Gao Yanfeng Date: Fri, 11 Jul 2025 11:37:52 +0800 Subject: [PATCH 1/4] Replace FenceAsyncSharedOp --- test/Conversion/nvgpu_to_llvm.mlir | 10 ---------- .../include/Dialect/NVGPU/IR/NVGPUOps.td | 5 ----- .../lib/NVGPUToLLVM/NVGPUToLLVMPass.cpp | 20 ++++--------------- .../TritonNVIDIAGPUToLLVM/BarrierOpToLLVM.cpp | 9 ++++++--- 4 files changed, 10 insertions(+), 34 deletions(-) diff --git a/test/Conversion/nvgpu_to_llvm.mlir b/test/Conversion/nvgpu_to_llvm.mlir index 00a4eb33b3af..f98164f5965a 100644 --- a/test/Conversion/nvgpu_to_llvm.mlir +++ b/test/Conversion/nvgpu_to_llvm.mlir @@ -1,15 +1,5 @@ // RUN: triton-opt %s --convert-nv-gpu-to-llvm -allow-unregistered-dialect -split-input-file | FileCheck %s -// CHECK-LABEL: @nvvm_syncs -llvm.func @nvvm_syncs() { - // CHECK: fence.proxy.async.shared::cta; - nvgpu.fence_async_shared {bCluster = false} - // CHECK: fence.proxy.async.shared::cluster; - nvgpu.fence_async_shared {bCluster = true} - - llvm.return -} - // CHECK-LABEL: @cluster_id llvm.func @cluster_id() -> i32 { // CHECK: %cluster_ctaid.x; diff --git a/third_party/nvidia/include/Dialect/NVGPU/IR/NVGPUOps.td b/third_party/nvidia/include/Dialect/NVGPU/IR/NVGPUOps.td index ec7ffae8b9f5..00412ad27790 100644 --- a/third_party/nvidia/include/Dialect/NVGPU/IR/NVGPUOps.td +++ b/third_party/nvidia/include/Dialect/NVGPU/IR/NVGPUOps.td @@ -105,11 +105,6 @@ def NVGPU_WGMMAOp : NVGPU_Op<"wgmma", []> { let assemblyFormat = "$opA `,` $opB `,` $useC (`,` $opC^)? attr-dict `:` functional-type(operands, $res)"; } -def NVGPU_FenceAsyncSharedOp : NVGPU_Op<"fence_async_shared", []> { - let arguments = (ins BoolAttr:$bCluster); - let assemblyFormat = "attr-dict"; -} - def NVGPU_StoreMatrixOp : NVGPU_Op<"stmatrix", [MemoryEffects<[MemWrite]>]> { let arguments = ( ins LLVM_PointerShared:$addr, diff --git a/third_party/nvidia/lib/NVGPUToLLVM/NVGPUToLLVMPass.cpp b/third_party/nvidia/lib/NVGPUToLLVM/NVGPUToLLVMPass.cpp index 593ce3ee0847..7010d4930f0f 100644 --- a/third_party/nvidia/lib/NVGPUToLLVM/NVGPUToLLVMPass.cpp +++ b/third_party/nvidia/lib/NVGPUToLLVM/NVGPUToLLVMPass.cpp @@ -206,19 +206,6 @@ class NVGPUOpGenericPattern : public OpRewritePattern { Constraints inputConstraints; }; -class FenceAsyncSharedOpPattern - : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(ttn::FenceAsyncSharedOp op, - PatternRewriter &rewriter) const override { - std::string ptxAsm = op.getBCluster() ? "fence.proxy.async.shared::cluster;" - : "fence.proxy.async.shared::cta;"; - return rewriteAsPtxAsm(op, rewriter, std::move(ptxAsm)); - } -}; - class WarpIdOpPattern : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; @@ -774,9 +761,10 @@ class ConvertNVGPUToLLVM patterns.add>( context, kClusterCtaIdOp, Constraints({"=r"}), Constraints()); - patterns.add(context); + patterns + .add( + context); if (applyPatternsGreedily(mod, std::move(patterns)).failed()) signalPassFailure(); diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/BarrierOpToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/BarrierOpToLLVM.cpp index 145509ef07d9..1cb5791d76f8 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/BarrierOpToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/BarrierOpToLLVM.cpp @@ -42,9 +42,12 @@ struct FenceAsyncSharedOpConversion LogicalResult matchAndRewrite(triton::nvidia_gpu::FenceAsyncSharedOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - Location loc = op->getLoc(); - rewriter.replaceOpWithNewOp( - op, adaptor.getBCluster()); + auto kind = NVVM::ProxyKind::async_shared; + auto space = op.getBCluster() ? NVVM::SharedSpace::shared_cluster + : NVVM::SharedSpace::shared_cta; + auto ctx = rewriter.getContext(); + auto spaceAttr = NVVM::SharedSpaceAttr::get(ctx, space); + rewriter.replaceOpWithNewOp(op, kind, spaceAttr); return success(); } }; From e408e1a2b7e3f94a175731362f73d2ef3522315d Mon Sep 17 00:00:00 2001 From: Gao Yanfeng Date: Fri, 11 Jul 2025 11:38:09 +0800 Subject: [PATCH 2/4] Add ConvertNVVMToLLVMPass to the pipeline --- lib/Target/LLVMIR/CMakeLists.txt | 1 + python/src/passes.cc | 1 + third_party/nvidia/backend/compiler.py | 1 + 3 files changed, 3 insertions(+) diff --git a/lib/Target/LLVMIR/CMakeLists.txt b/lib/Target/LLVMIR/CMakeLists.txt index f2f9adf8f493..d036d5949733 100644 --- a/lib/Target/LLVMIR/CMakeLists.txt +++ b/lib/Target/LLVMIR/CMakeLists.txt @@ -13,6 +13,7 @@ add_triton_library(TritonLLVMIR MLIRIndexToLLVM MLIRIR MLIRLLVMDialect + MLIRNVVMToLLVM MLIRLLVMToLLVMIRTranslation MLIRNVVMToLLVMIRTranslation MLIRROCDLToLLVMIRTranslation diff --git a/python/src/passes.cc b/python/src/passes.cc index 4f8c6d9b2d97..2d240f4b20fa 100644 --- a/python/src/passes.cc +++ b/python/src/passes.cc @@ -105,6 +105,7 @@ void init_triton_passes_convert(py::module &&m) { ADD_PASS_WRAPPER_0("add_cf_to_llvmir", createConvertControlFlowToLLVMPass); ADD_PASS_WRAPPER_0("add_index_to_llvmir", createConvertIndexToLLVMPass); ADD_PASS_WRAPPER_0("add_arith_to_llvmir", createArithToLLVMConversionPass); + ADD_PASS_WRAPPER_0("add_nvvm_to_llvm", createConvertNVVMToLLVMPass); } void init_triton_passes_llvmir(py::module &&m) { diff --git a/third_party/nvidia/backend/compiler.py b/third_party/nvidia/backend/compiler.py index 1e5a72b46f5b..36426a9c116a 100644 --- a/third_party/nvidia/backend/compiler.py +++ b/third_party/nvidia/backend/compiler.py @@ -351,6 +351,7 @@ def make_llir(self, src, metadata, options, capability): passes.common.add_canonicalizer(pm) passes.common.add_cse(pm) passes.common.add_symbol_dce(pm) + passes.convert.add_nvvm_to_llvm(pm) if not knobs.compilation.disable_line_info: passes.llvmir.add_di_scope(pm) pm.run(mod) From 838195c3ef752b0a918a1b4f844f6321834aa386 Mon Sep 17 00:00:00 2001 From: Gao Yanfeng Date: Fri, 11 Jul 2025 15:11:46 +0800 Subject: [PATCH 3/4] Replace LoadMatrixOp and StoreMatrixOp --- test/Conversion/nvgpu_to_llvm.mlir | 24 --- test/Conversion/tritongpu_to_llvm.mlir | 31 ++-- test/Conversion/tritongpu_to_llvm_hopper.mlir | 22 +-- .../include/Dialect/NVGPU/IR/NVGPUOps.td | 18 --- .../lib/NVGPUToLLVM/NVGPUToLLVMPass.cpp | 144 +----------------- .../TritonNVIDIAGPUToLLVM/MemoryOpToLLVM.cpp | 9 +- .../lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp | 2 +- 7 files changed, 34 insertions(+), 216 deletions(-) diff --git a/test/Conversion/nvgpu_to_llvm.mlir b/test/Conversion/nvgpu_to_llvm.mlir index f98164f5965a..ecf224a438c4 100644 --- a/test/Conversion/nvgpu_to_llvm.mlir +++ b/test/Conversion/nvgpu_to_llvm.mlir @@ -13,30 +13,6 @@ llvm.func @cluster_id() -> i32 { // ----- -// CHECK-LABEL: @stmatrix -llvm.func @stmatrix(%i: i32, %ptr: !llvm.ptr<3>) { - // CHECK: stmatrix.sync.aligned.m8n8.x4.shared.b16 [$0], {$1, $2, $3, $4}; - nvgpu.stmatrix %ptr, %i, %i, %i, %i : !llvm.ptr<3>, i32, i32, i32, i32 - // CHECK: stmatrix.sync.aligned.m8n8.x4.trans.shared.b16 [$0], {$1, $2, $3, $4}; - nvgpu.stmatrix %ptr, %i, %i, %i, %i {trans} : !llvm.ptr<3>, i32, i32, i32, i32 - llvm.return -} - -// ----- - -// CHECK-LABEL: @ldmatrix -llvm.func @ldmatrix(%ptr: !llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)> { - // CHECK: ldmatrix.sync.aligned.m8n8.x4.shared.b16 {$0, $1, $2, $3}, [$4]; - %0 = nvgpu.ldmatrix %ptr : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)> - // CHECK: ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 {$0, $1, $2, $3}, [$4]; - %1 = nvgpu.ldmatrix %ptr {trans} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)> - %2 = llvm.extractvalue %1[0] : !llvm.struct<(i32, i32, i32, i32)> - %3 = llvm.insertvalue %2, %0[0] : !llvm.struct<(i32, i32, i32, i32)> - llvm.return %3 : !llvm.struct<(i32, i32, i32, i32)> -} - -// ----- - !struct_128xf32 = !llvm.struct<( f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, diff --git a/test/Conversion/tritongpu_to_llvm.mlir b/test/Conversion/tritongpu_to_llvm.mlir index d005888bb163..e97fd8164be9 100644 --- a/test/Conversion/tritongpu_to_llvm.mlir +++ b/test/Conversion/tritongpu_to_llvm.mlir @@ -880,9 +880,9 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} { tt.func @convert_dot_ldmatrix(%A: tensor<16x16xf16, #blocked0>, %B: tensor<16x16xf16, #blocked0>) { %AA = ttg.local_alloc %A : (tensor<16x16xf16, #blocked0>) -> !ttg.memdesc<16x16xf16, #shared0, #smem> %BB = ttg.local_alloc %B : (tensor<16x16xf16, #blocked0>) -> !ttg.memdesc<16x16xf16, #shared0, #smem> - // CHECK: nvgpu.ldmatrix %{{.*}} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)> - // CHECK: nvgpu.ldmatrix %{{.*}} {trans} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)> - // CHECK-NOT: nvgpu.ldmatrix + // CHECK: nvvm.ldmatrix %{{.*}} {layout = #nvvm.mma_layout, num = 4 : i32} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)> + // CHECK: nvvm.ldmatrix %{{.*}} {layout = #nvvm.mma_layout, num = 4 : i32} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)> + // CHECK-NOT: nvvm.ldmatrix %AA_DOT = ttg.local_load %AA : !ttg.memdesc<16x16xf16, #shared0, #smem> -> tensor<16x16xf16, #dot_operand_a> %BB_DOT = ttg.local_load %BB : !ttg.memdesc<16x16xf16, #shared0, #smem> -> tensor<16x16xf16, #dot_operand_b> %cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf32, #mma0> @@ -910,9 +910,9 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} { tt.func @convert_dot_ldmatrix_swizzle(%A: tensor<16x16xf16, #blocked0>, %B: tensor<16x16xf16, #blocked0>) { %AA = ttg.local_alloc %A : (tensor<16x16xf16, #blocked0>) -> !ttg.memdesc<16x16xf16, #shared0, #smem> %BB = ttg.local_alloc %B : (tensor<16x16xf16, #blocked0>) -> !ttg.memdesc<16x16xf16, #shared0, #smem> - // CHECK: nvgpu.ldmatrix %{{.*}} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)> - // CHECK: nvgpu.ldmatrix %{{.*}} {trans} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)> - // CHECK-NOT: nvgpu.ldmatrix + // CHECK: nvvm.ldmatrix %{{.*}} {layout = #nvvm.mma_layout, num = 4 : i32} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)> + // CHECK: nvvm.ldmatrix %{{.*}} {layout = #nvvm.mma_layout, num = 4 : i32} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)> + // CHECK-NOT: nvvm.ldmatrix %AA_DOT = ttg.local_load %AA : !ttg.memdesc<16x16xf16, #shared0, #smem> -> tensor<16x16xf16, #dot_operand_a> %BB_DOT = ttg.local_load %BB : !ttg.memdesc<16x16xf16, #shared0, #smem> -> tensor<16x16xf16, #dot_operand_b> %cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf32, #mma0> @@ -940,7 +940,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} { tt.func @convert_dot(%A: tensor<16x16xf16, #blocked0>, %B: tensor<16x16xf16, #blocked0>) { %AA = ttg.local_alloc %A : (tensor<16x16xf16, #blocked0>) -> !ttg.memdesc<16x16xf16, #shared0, #smem> %BB = ttg.local_alloc %B : (tensor<16x16xf16, #blocked0>) -> !ttg.memdesc<16x16xf16, #shared0, #smem> - // CHECK-NOT: nvgpu.ldmatrix + // CHECK-NOT: nvvm.ldmatrix %AA_DOT = ttg.local_load %AA : !ttg.memdesc<16x16xf16, #shared0, #smem> -> tensor<16x16xf16, #dot_operand_a> %BB_DOT = ttg.local_load %BB : !ttg.memdesc<16x16xf16, #shared0, #smem> -> tensor<16x16xf16, #dot_operand_b> %cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf32, #mma0> @@ -968,7 +968,8 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} { tt.func @convert_dot_mmav3_shared(%A: tensor<64x64xf16, #blocked0>, %B: tensor<64x64xf16, #blocked0>) { %AA = ttg.local_alloc %A : (tensor<64x64xf16, #blocked0>) -> !ttg.memdesc<64x64xf16, #shared0, #smem> %BB = ttg.local_alloc %B : (tensor<64x64xf16, #blocked0>) -> !ttg.memdesc<64x64xf16, #shared0, #smem> - // CHECK-COUNT-32: nvgpu.ldmatrix %{{.*}} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)> + // CHECK-COUNT-16: nvvm.ldmatrix %{{.*}} {layout = #nvvm.mma_layout, num = 4 : i32} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)> + // CHECK-COUNT-16: nvvm.ldmatrix %{{.*}} {layout = #nvvm.mma_layout, num = 4 : i32} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)> %AA_DOT = ttg.local_load %AA : !ttg.memdesc<64x64xf16, #shared0, #smem> -> tensor<64x64xf16, #dot_operand_a> %BB_DOT = ttg.local_load %BB : !ttg.memdesc<64x64xf16, #shared0, #smem> -> tensor<64x64xf16, #dot_operand_b> %cst0 = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #mma0> @@ -992,8 +993,8 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} { tt.func @convert_dot_fp8(%A: tensor<16x16xf8E5M2, #blocked0>, %B: tensor<16x16xf8E5M2, #blocked0>) { %AA = ttg.local_alloc %A : (tensor<16x16xf8E5M2, #blocked0>) -> !ttg.memdesc<16x16xf8E5M2, #shared0, #smem> %BB = ttg.local_alloc %B : (tensor<16x16xf8E5M2, #blocked0>) -> !ttg.memdesc<16x16xf8E5M2, #shared0, #smem> - // CHECK: nvgpu.ldmatrix %{{.*}} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)> - // CHECK-NOT: nvgpu.ldmatrix + // CHECK: nvvm.ldmatrix %{{.*}} {layout = #nvvm.mma_layout, num = 2 : i32} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)> + // CHECK-NOT: nvvm.ldmatrix %AA_DOT = ttg.local_load %AA : !ttg.memdesc<16x16xf8E5M2, #shared0, #smem> -> tensor<16x16xf8E5M2, #dot_operand_a> %BB_DOT = ttg.local_load %BB : !ttg.memdesc<16x16xf8E5M2, #shared0, #smem> -> tensor<16x16xf8E5M2, #dot_operand_b> %cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf32, #mma0> @@ -1308,7 +1309,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { tt.func @matmul_kernel_dot_operand_layout(%ptr:!tt.ptr {tt.divisibility = 16 : i32}, %a:!ttg.memdesc<128x32xf16, #shared, #smem>, %b:!ttg.memdesc<32x256xf16, #shared, #smem>) { %cst = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #mma> - // CHECK: nvgpu.ldmatrix + // CHECK: nvvm.ldmatrix %a_mat = ttg.local_load %a : !ttg.memdesc<128x32xf16, #shared, #smem> -> tensor<128x32xf16, #dot_operand_a> %b_mat = ttg.local_load %b : !ttg.memdesc<32x256xf16, #shared, #smem> -> tensor<32x256xf16, #dot_operand_b> @@ -1384,9 +1385,9 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { tt.func @matmul_tf32dot(%ptr:!tt.ptr {tt.divisibility = 16 : i32}, %a:!ttg.memdesc<32x16xf32, #shared, #smem>, %b:!ttg.memdesc<16x32xf32, #shared, #smem>) { %cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #mma> - // CHECK: nvgpu.ldmatrix + // CHECK: nvvm.ldmatrix // CHECK-SAME: (i32, i32, i32, i32) - // CHECK: nvgpu.ldmatrix + // CHECK: nvvm.ldmatrix // CHECK-SAME: (i32, i32, i32, i32) %a_mat = ttg.local_load %a : !ttg.memdesc<32x16xf32, #shared, #smem> -> tensor<32x16xf32, #dot_operand_a> %b_mat = ttg.local_load %b : !ttg.memdesc<16x32xf32, #shared, #smem> -> tensor<16x32xf32, #dot_operand_b> @@ -1875,8 +1876,8 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.thr %f16_shared = ttg.local_alloc %f16_inp : (tensor<16x16xf16, #blocked0>) -> !ttg.memdesc<16x16xf16, #shared0, #smem> %i16_shared = ttg.local_alloc %i16_inp : (tensor<16x16xi16, #blocked0>) -> !ttg.memdesc<16x16xi16, #shared0, #smem> - // CHECK: nvgpu.ldmatrix - // CHECK: nvgpu.ldmatrix + // CHECK: nvvm.ldmatrix + // CHECK: nvvm.ldmatrix %f16_dot = ttg.local_load %f16_shared : !ttg.memdesc<16x16xf16, #shared0, #smem> -> tensor<16x16xf16, #dot_operand_a> %i16_dot = ttg.local_load %i16_shared : !ttg.memdesc<16x16xi16, #shared0, #smem> -> tensor<16x16xi16, #dot_operand_b> diff --git a/test/Conversion/tritongpu_to_llvm_hopper.mlir b/test/Conversion/tritongpu_to_llvm_hopper.mlir index c21189082062..7fcec1ca7396 100644 --- a/test/Conversion/tritongpu_to_llvm_hopper.mlir +++ b/test/Conversion/tritongpu_to_llvm_hopper.mlir @@ -203,7 +203,7 @@ module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num- // CHECK-LABEL: convert_mma_to_blocked module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 32 : i32} { tt.func @convert_mma_to_blocked(%a: tensor<128x256xf16, #mma>) { - // CHECK-COUNT-16: nvgpu.stmatrix + // CHECK-COUNT-16: nvvm.stmatrix // CHECK: nvvm.barrier0 %c = ttg.convert_layout %a : tensor<128x256xf16, #mma> -> tensor<128x256xf16, #blocked> tt.return @@ -254,7 +254,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-LABEL: distribute_to_shared_st_matrix module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { tt.func @distribute_to_shared_st_matrix(%a: tensor<128x128xf16, #mma>) { - // CHECK-COUNT-16: nvgpu.stmatrix + // CHECK-COUNT-16: nvvm.stmatrix // CHECK: llvm.return %b = ttg.local_alloc %a {allocation.offset = 0 : i32} : (tensor<128x128xf16, #mma>) -> !ttg.memdesc<128x128xf16, #shared, #smem, mutable> tt.return @@ -269,7 +269,7 @@ module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num- // CHECK-LABEL: distribute_to_shared_st_matrix_local_store module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { tt.func @distribute_to_shared_st_matrix_local_store(%a: tensor<128x128xf16, #mma>) { - // CHECK-COUNT-16: nvgpu.stmatrix + // CHECK-COUNT-16: nvvm.stmatrix // CHECK: llvm.return %b = ttg.local_alloc {allocation.offset = 0 : i32} : () -> !ttg.memdesc<128x128xf16, #shared, #smem, mutable> ttg.local_store %a, %b : tensor<128x128xf16, #mma> -> !ttg.memdesc<128x128xf16, #shared, #smem, mutable> @@ -285,7 +285,7 @@ module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num- // CHECK-LABEL: distribute_to_shared_st_matrix_local_store module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { tt.func @distribute_to_shared_st_matrix_local_store(%a: tensor<64x128xf16, #linear>) { - // CHECK-COUNT-8: nvgpu.stmatrix %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} {trans} + // CHECK-COUNT-8: nvvm.stmatrix %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} {layout = #nvvm.mma_layout} // CHECK: llvm.return %b = ttg.local_alloc {allocation.offset = 0 : i32} : () -> !ttg.memdesc<64x128xf16, #shared, #smem, mutable> ttg.local_store %a, %b : tensor<64x128xf16, #linear> -> !ttg.memdesc<64x128xf16, #shared, #smem, mutable> @@ -301,7 +301,7 @@ module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num- // CHECK-LABEL: distribute_to_swizzled_st_matrix_local_store module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { tt.func @distribute_to_swizzled_st_matrix_local_store(%a: tensor<8x64xf16, #mma>) { - // CHECK-COUNT-2: nvgpu.stmatrix + // CHECK-COUNT-2: nvvm.stmatrix // CHECK: llvm.return %b = ttg.local_alloc {allocation.offset = 0 : i32} : () -> !ttg.memdesc<8x64xf16, #shared, #smem, mutable> ttg.local_store %a, %b : tensor<8x64xf16, #mma> -> !ttg.memdesc<8x64xf16, #shared, #smem, mutable> @@ -317,7 +317,7 @@ module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num- // CHECK-LABEL: linear_to_swizzled_st_matrix_local_store module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { tt.func @linear_to_swizzled_st_matrix_local_store(%a: tensor<64x32xf16, #linear>) { - // CHECK-COUNT-2: nvgpu.stmatrix %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} + // CHECK-COUNT-2: nvvm.stmatrix %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} {layout = #nvvm.mma_layout} // CHECK: llvm.return %b = ttg.local_alloc {allocation.offset = 0 : i32} : () -> !ttg.memdesc<64x32xf16, #shared, #smem, mutable> ttg.local_store %a, %b : tensor<64x32xf16, #linear> -> !ttg.memdesc<64x32xf16, #shared, #smem, mutable> @@ -339,7 +339,7 @@ module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num- // CHECK-LABEL: linear_to_swizzled_st_matrix_local_store module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { tt.func @linear_to_swizzled_st_matrix_local_store(%a: tensor<32x32xf16, #linear>) { - // CHECK-COUNT-2: nvgpu.stmatrix %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} + // CHECK-COUNT-2: nvvm.stmatrix %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} {layout = #nvvm.mma_layout} // CHECK: llvm.return %b = ttg.local_alloc {allocation.offset = 0 : i32} : () -> !ttg.memdesc<32x32xf16, #shared, #smem, mutable> ttg.local_store %a, %b : tensor<32x32xf16, #linear> -> !ttg.memdesc<32x32xf16, #shared, #smem, mutable> @@ -355,7 +355,7 @@ module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num- // CHECK-LABEL: linear_to_swizzled_st_matrix_x2_local_store_fp8 module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { tt.func @linear_to_swizzled_st_matrix_x2_local_store_fp8(%a: tensor<64x16xf8E4M3FNUZ, #linear>) { - // CHECK-COUNT-1: nvgpu.stmatrix %{{.*}}, %{{.*}}, %{{.*}} : + // CHECK-COUNT-1: nvvm.stmatrix %{{.*}}, %{{.*}}, %{{.*}} {layout = #nvvm.mma_layout} : // CHECK: llvm.return %b = ttg.local_alloc {allocation.offset = 0 : i32} : () -> !ttg.memdesc<64x16xf8E4M3FNUZ, #shared, #smem, mutable> ttg.local_store %a, %b : tensor<64x16xf8E4M3FNUZ, #linear> -> !ttg.memdesc<64x16xf8E4M3FNUZ, #shared, #smem, mutable> @@ -371,7 +371,7 @@ module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num- // CHECK-LABEL: linear_to_swizzled_st_matrix_local_store_fp32 module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { tt.func @linear_to_swizzled_st_matrix_local_store_fp32(%a: tensor<64x16xf32, #linear>) { - // CHECK-COUNT-2: nvgpu.stmatrix %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} + // CHECK-COUNT-2: nvvm.stmatrix %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} {layout = #nvvm.mma_layout} // CHECK: llvm.return %b = ttg.local_alloc {allocation.offset = 0 : i32} : () -> !ttg.memdesc<64x16xf32, #shared, #smem, mutable> ttg.local_store %a, %b : tensor<64x16xf32, #linear> -> !ttg.memdesc<64x16xf32, #shared, #smem, mutable> @@ -388,7 +388,7 @@ module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num- // CHECK-LABEL: linear_to_swizzled_st_matrix_trans_local_store module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { tt.func @linear_to_swizzled_st_matrix_trans_local_store(%a: tensor<64x32xf16, #linear>) { - // CHECK-COUNT-2: nvgpu.stmatrix %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} {trans} + // CHECK-COUNT-2: nvvm.stmatrix %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} {layout = #nvvm.mma_layout} // CHECK: llvm.return %b = ttg.local_alloc {allocation.offset = 0 : i32} : () -> !ttg.memdesc<64x32xf16, #shared, #smem, mutable> ttg.local_store %a, %b : tensor<64x32xf16, #linear> -> !ttg.memdesc<64x32xf16, #shared, #smem, mutable> @@ -410,7 +410,7 @@ module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num- // CHECK-LABEL: linear_to_swizzled_st_matrix_trans_local_store module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { tt.func @linear_to_swizzled_st_matrix_trans_local_store(%a: tensor<16x32xf16, #linear>) { - // CHECK-COUNT-2: nvgpu.stmatrix %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} {trans} + // CHECK-COUNT-2: nvvm.stmatrix %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} {layout = #nvvm.mma_layout} // CHECK: llvm.return %b = ttg.local_alloc {allocation.offset = 0 : i32} : () -> !ttg.memdesc<16x32xf16, #shared, #smem, mutable> ttg.local_store %a, %b : tensor<16x32xf16, #linear> -> !ttg.memdesc<16x32xf16, #shared, #smem, mutable> diff --git a/third_party/nvidia/include/Dialect/NVGPU/IR/NVGPUOps.td b/third_party/nvidia/include/Dialect/NVGPU/IR/NVGPUOps.td index 00412ad27790..8eca21375fd1 100644 --- a/third_party/nvidia/include/Dialect/NVGPU/IR/NVGPUOps.td +++ b/third_party/nvidia/include/Dialect/NVGPU/IR/NVGPUOps.td @@ -105,24 +105,6 @@ def NVGPU_WGMMAOp : NVGPU_Op<"wgmma", []> { let assemblyFormat = "$opA `,` $opB `,` $useC (`,` $opC^)? attr-dict `:` functional-type(operands, $res)"; } -def NVGPU_StoreMatrixOp : NVGPU_Op<"stmatrix", [MemoryEffects<[MemWrite]>]> { - let arguments = ( - ins LLVM_PointerShared:$addr, - Variadic:$vals, - UnitAttr:$trans - ); - let assemblyFormat = "operands attr-dict `:` type(operands)"; -} - -def NVGPU_LoadMatrixOp : NVGPU_Op<"ldmatrix", [MemoryEffects<[MemRead]>]> { - let arguments = ( - ins LLVM_PointerShared:$addr, - UnitAttr:$trans - ); - let results = (outs AnyTypeOf<[LLVM_AnyStruct, I32]>:$result); - let assemblyFormat = "$addr attr-dict `:` functional-type($addr, $result)"; -} - def NVGPU_ClusterCTAIdOp : NVGPU_Op<"cluster_id", [Pure]> { let results = (outs I32:$result); let assemblyFormat = "attr-dict"; diff --git a/third_party/nvidia/lib/NVGPUToLLVM/NVGPUToLLVMPass.cpp b/third_party/nvidia/lib/NVGPUToLLVM/NVGPUToLLVMPass.cpp index 7010d4930f0f..3a40a61c81e6 100644 --- a/third_party/nvidia/lib/NVGPUToLLVM/NVGPUToLLVMPass.cpp +++ b/third_party/nvidia/lib/NVGPUToLLVM/NVGPUToLLVMPass.cpp @@ -238,144 +238,6 @@ class WarpIdOpPattern : public OpRewritePattern { } }; -// Base class for Matrix Operation Patterns -template -class MatrixOpPattern : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(MatrixOpType op, - PatternRewriter &rewriter) const override { - unsigned vecSize = getVectorSize(op); - bool trans = op.getTrans(); - // Template method for PTX assembly generation - std::string ptxAsm = - (llvm::Twine(ConcreteMatrixOpPattern::kOpCode) + - getPtxModifiers(vecSize, trans) + " " + getOperands(op, vecSize) + ";") - .str(); - - OperandsAndConstraints operandAndConstraints = - getOperandsAndConstraints(op, vecSize); - Constraints outputConstraints = getOutputConstraints(op, vecSize); - - return rewriteAsPtxAsm(op, rewriter, ptxAsm, operandAndConstraints, - outputConstraints); - } - -protected: - // Shared helper methods - std::string getPtxModifiers(unsigned vecSize, bool trans) const { - auto ptxAsmBase = llvm::Twine(".sync.aligned.m8n8"); - const std::string suffix = trans ? ".trans.shared.b16" : ".shared.b16"; - switch (vecSize) { - case 1: - return (ptxAsmBase + ".x1" + suffix).str(); - case 2: - return (ptxAsmBase + ".x2" + suffix).str(); - case 4: - return (ptxAsmBase + ".x4" + suffix).str(); - default: - llvm_unreachable("Invalid vector size"); - } - } - - std::string getPtxRegOperands(unsigned startIdx, unsigned count) const { - llvm::SmallString<20> regOperands; - llvm::raw_svector_ostream stream(regOperands); - stream << "{"; - for (unsigned i = 0; i < count; i++) { - stream << "$" + llvm::utostr(startIdx + i); - if (i != count - 1) - stream << ", "; - } - stream << "}"; - return std::string(regOperands.str()); - } - - std::string getPtxAddrOperand(unsigned idx) const { - return (llvm::Twine("[$") + llvm::utostr(idx) + "]").str(); - } - - virtual std::string getOperands(MatrixOpType op, unsigned vecSize) const = 0; - virtual OperandsAndConstraints - getOperandsAndConstraints(MatrixOpType op, unsigned vecSize) const = 0; - virtual Constraints getOutputConstraints(MatrixOpType op, - unsigned vecSize) const = 0; - virtual unsigned getVectorSize(MatrixOpType op) const = 0; -}; - -// StoreMatrixOp Pattern -class StoreMatrixOpPattern - : public MatrixOpPattern { -public: - using MatrixOpPattern::MatrixOpPattern; - static constexpr const char *kOpCode = "stmatrix"; - -protected: - unsigned getVectorSize(ttn::StoreMatrixOp op) const override { - return op.getVals().size(); - } - - std::string getOperands(ttn::StoreMatrixOp op, - unsigned vecSize) const override { - return (llvm::Twine(getPtxAddrOperand(0)) + ", " + - getPtxRegOperands(1, vecSize)) - .str(); - } - - OperandsAndConstraints - getOperandsAndConstraints(ttn::StoreMatrixOp op, - unsigned vecSize) const override { - OperandsAndConstraints constraints = {{op.getAddr(), "r"}}; - for (unsigned i = 0; i < vecSize; i++) { - constraints.push_back({op.getVals()[i], "r"}); - } - return constraints; - } - - Constraints getOutputConstraints(ttn::StoreMatrixOp op, - unsigned vecSize) const override { - return {}; // No output constraints for StoreMatrixOp - } -}; - -// LoadMatrixOp Pattern -class LoadMatrixOpPattern - : public MatrixOpPattern { -public: - using MatrixOpPattern::MatrixOpPattern; - static constexpr const char *kOpCode = "ldmatrix"; - -protected: - unsigned getVectorSize(ttn::LoadMatrixOp op) const override { - auto resultType = op.getType(); - if (auto structTy = dyn_cast(resultType)) { - return structTy.getBody().size(); - } - return 1; - } - - std::string getOperands(ttn::LoadMatrixOp op, - unsigned vecSize) const override { - return (llvm::Twine(getPtxRegOperands(0, vecSize)) + ", " + - getPtxAddrOperand(vecSize)) - .str(); - } - - OperandsAndConstraints - getOperandsAndConstraints(ttn::LoadMatrixOp op, - unsigned vecSize) const override { - return {{op.getAddr(), "r"}}; - } - - Constraints getOutputConstraints(ttn::LoadMatrixOp op, - unsigned vecSize) const override { - return Constraints(vecSize, "=r"); - } -}; - class LoadAcquireOpPattern : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; @@ -761,10 +623,8 @@ class ConvertNVGPUToLLVM patterns.add>( context, kClusterCtaIdOp, Constraints({"=r"}), Constraints()); - patterns - .add( - context); + patterns.add(context); if (applyPatternsGreedily(mod, std::move(patterns)).failed()) signalPassFailure(); diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/MemoryOpToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/MemoryOpToLLVM.cpp index 336630009a4b..eaa8fc287161 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/MemoryOpToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/MemoryOpToLLVM.cpp @@ -2,6 +2,7 @@ #include "TargetInfo.h" #include "Utility.h" #include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "mlir/Dialect/LLVMIR/NVVMDialect.h" #include "mlir/IR/PatternMatch.h" #include "triton/Analysis/Allocation.h" #include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" @@ -192,6 +193,7 @@ LogicalResult lowerLdStMatrix( auto vecAddr = b.gep(smemPtrTy, llvmElemTy, smemBase, offset, LLVM::GEPNoWrapFlags::inbounds); Type packedTy = vec_ty(llvmElemTy, 32 / bitwidth); + auto layout = transpose ? NVVM::MMALayout::col : NVVM::MMALayout::row; if (isStore) { // Pack into vector of i32 SmallVector inputs; @@ -203,17 +205,14 @@ LogicalResult lowerLdStMatrix( } inputs.push_back(b.bitcast(input, i32_ty)); } - rewriter.create(loc, vecAddr, inputs, - /*needTrans=*/transpose); + rewriter.create(loc, vecAddr, inputs, layout); } else { Type matTy = nVecs == 1 ? i32_ty : static_cast(LLVM::LLVMStructType::getLiteral( ctx, SmallVector(nVecs, i32_ty))); auto res = - rewriter - .create(loc, matTy, vecAddr, - /*needTrans=*/transpose) + rewriter.create(loc, matTy, vecAddr, nVecs, layout) .getResult(); // Extract result into srcVals for (int j = 0; j < nVecs; j++) { diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp index ea3bf556f20d..857d7d04f76c 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp @@ -534,7 +534,7 @@ void TargetInfo::storeMatrixShared(RewriterBase &rewriter, Location loc, } inputs.push_back(b.bitcast(input, i32_ty)); } - rewriter.create(loc, ptr, inputs); + rewriter.create(loc, ptr, inputs, NVVM::MMALayout::row); } std::string TargetInfo::getMulhiFuncName(Type resultElementTy) const { From be2ce55c4e0d19e5fd4ef130094ad52d1d8788ab Mon Sep 17 00:00:00 2001 From: Gao Yanfeng Date: Sat, 12 Jul 2025 00:26:07 +0800 Subject: [PATCH 4/4] Update the test --- python/test/unit/language/test_tensor_descriptor.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/python/test/unit/language/test_tensor_descriptor.py b/python/test/unit/language/test_tensor_descriptor.py index cdcae2d9123a..1932b7cd2234 100644 --- a/python/test/unit/language/test_tensor_descriptor.py +++ b/python/test/unit/language/test_tensor_descriptor.py @@ -609,7 +609,8 @@ def alloc_fn(size: int, align: int, stream: Optional[int]): if BLOCK_M >= 64 * num_ctas and BLOCK_N >= 64 and is_hopper(): # TODO: The use of stmatrix for Blackwell is currently not supported. # Only a subset of TMEM and stmatrix layout pairs are compatible, for example 16x256bx2 and m8n8x4. - assert "stmatrix.sync.aligned.m8n8.x4.shared.b16" in kernel.asm["ptx"] + assert "stmatrix.sync.aligned.m8n8.x4.shared.b16" in kernel.asm[ + "ptx"] or "stmatrix.sync.aligned.x4.m8n8.shared.b16" in kernel.asm["ptx"] @triton.jit @@ -1668,4 +1669,5 @@ def test_host_tensor_descriptor_matmul(num_stages, num_ctas, BLOCK_M, BLOCK_N, B if BLOCK_M >= 64 * num_ctas and BLOCK_N >= 64 and is_cuda() and is_hopper(): # TODO: The use of stmatrix for Blackwell is currently not supported. # Only a subset of TMEM and stmatrix layout pairs are compatible, for example 16x256bx2 and m8n8x4. - assert "stmatrix.sync.aligned.m8n8.x4.shared.b16" in kernel.asm["ptx"] + assert "stmatrix.sync.aligned.m8n8.x4.shared.b16" in kernel.asm[ + "ptx"] or "stmatrix.sync.aligned.x4.m8n8.shared.b16" in kernel.asm["ptx"]