diff --git a/test/Conversion/tritongpu_to_llvm_blackwell.mlir b/test/Conversion/tritongpu_to_llvm_blackwell.mlir index 5fd193bff8bf..f123c9d5980a 100644 --- a/test/Conversion/tritongpu_to_llvm_blackwell.mlir +++ b/test/Conversion/tritongpu_to_llvm_blackwell.mlir @@ -107,9 +107,9 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar // CHECK-LABEL: @tensor_memory_ld // CHECK: nvgpu.tensor_memory_base // CHECK: tcgen05.st.sync.aligned.32x32b.x128.b32 - // CHECK: nvvm.tcgen05.wait + // CHECK: tcgen05.wait::st.sync.aligned // CHECK: tcgen05.ld.sync.aligned.32x32b.x128.b32 - // CHECK: nvvm.tcgen05.wait + // CHECK: tcgen05.wait::ld.sync.aligned tt.func public @tensor_memory_ld(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: !tt.ptr) { %cst_0 = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked1> %0 = ttng.tmem_alloc %cst_0 {tensor_memory_col_offset = 0 : i32, tensor_memory_row_offset = 0 : i32} : (tensor<128x128xf32, #blocked1>) -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> @@ -158,10 +158,10 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar // CHECK: nvgpu.tensor_memory_base // CHECK: tcgen05.st.sync.aligned.16x32bx2.x64.b32 // CHECK: tcgen05.st.sync.aligned.16x32bx2.x64.b32 - // CHECK: nvvm.tcgen05.wait + // CHECK: tcgen05.wait::st.sync.aligned // CHECK: tcgen05.ld.sync.aligned.16x32bx2.x64.b32 // CHECK: tcgen05.ld.sync.aligned.16x32bx2.x64.b32 - // CHECK: nvvm.tcgen05.wait + // CHECK: tcgen05.wait::ld.sync.aligned tt.func public @tensor_memory_ld_m64(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: !tt.ptr) { %cst_0 = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked1> %0 = ttng.tmem_alloc %cst_0 {tensor_memory_col_offset = 0 : i32, tensor_memory_row_offset = 0 : i32} : (tensor<128x128xf32, #blocked1>) -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> @@ -179,9 +179,9 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar // CHECK-LABEL: @tensor_memory_unpack_f16 // CHECK: nvgpu.tensor_memory_base // CHECK: tcgen05.st.sync.aligned.32x32b.x64.unpack::16b.b32 - // CHECK: nvvm.tcgen05.wait + // CHECK: tcgen05.wait::st.sync.aligned // CHECK: tcgen05.ld.sync.aligned.32x32b.x64.pack::16b.b32 - // CHECK: nvvm.tcgen05.wait + // CHECK: tcgen05.wait::ld.sync.aligned tt.func public @tensor_memory_unpack_f16() { %cst_0 = arith.constant dense<0.000000e+00> : tensor<128x128xf16, #blocked1> %0 = ttng.tmem_alloc %cst_0 {tensor_memory_col_offset = 0 : i32, tensor_memory_row_offset = 0 : i32} : (tensor<128x128xf16, #blocked1>) -> !ttg.memdesc<128x128xf16, #tmem, #ttng.tensor_memory, mutable> @@ -399,10 +399,10 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar // CHECK-LABEL: @tensor_memory_ld_128x256 // CHECK-COUNT-4: tcgen05.st.sync.aligned.32x32b.x64.b32 // CHECK-NOT: tcgen05.st - // CHECK: nvvm.tcgen05.wait + // CHECK: tcgen05.wait::st.sync.aligned // CHECK-COUNT-4: tcgen05.ld.sync.aligned.32x32b.x64.b32 // CHECK-NOT: tcgen05.ld - // CHECK: nvvm.tcgen05.wait + // CHECK: tcgen05.wait::ld.sync.aligned tt.func public @tensor_memory_ld_128x256(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: !tt.ptr) { %cst_0 = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #blocked> %0 = ttng.tmem_alloc %cst_0 {tensor_memory_col_offset = 0 : i32, tensor_memory_row_offset = 0 : i32} : (tensor<128x256xf32, #blocked>) -> !ttg.memdesc<128x256xf32, #tmem, #ttng.tensor_memory, mutable> @@ -419,9 +419,9 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.shared = 65544 : i32, ttg.target = "cuda:100", ttg.tensor_memory_size = 128 : i32, "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABEL: @tensor_memory_ld_128x256_8_warps // CHECK: tcgen05.st.sync.aligned.32x32b.x128.b32 - // CHECK: nvvm.tcgen05.wait + // CHECK: tcgen05.wait::st.sync.aligned // CHECK: tcgen05.ld.sync.aligned.32x32b.x128.b32 - // CHECK: nvvm.tcgen05.wait + // CHECK: tcgen05.wait::ld.sync.aligned tt.func public @tensor_memory_ld_128x256_8_warps(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: !tt.ptr) { %cst_0 = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #blocked> %0 = ttng.tmem_alloc %cst_0 {tensor_memory_col_offset = 0 : i32, tensor_memory_row_offset = 0 : i32} : (tensor<128x256xf32, #blocked>) -> !ttg.memdesc<128x256xf32, #tmem, #ttng.tensor_memory, mutable> @@ -844,7 +844,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar // CHECK-LABEL: @tensor_memory_st // CHECK: nvgpu.tensor_memory_base // CHECK: tcgen05.st.sync.aligned.32x32b.x128.b32 - // CHECK: nvvm.tcgen05.wait + // CHECK: tcgen05.wait::st.sync.aligned tt.func public @tensor_memory_st(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: !tt.ptr) { %cst_0 = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked1> %0 = ttng.tmem_alloc {tensor_memory_col_offset = 0 : i32, tensor_memory_row_offset = 0 : i32} : () -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TensorMemoryToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TensorMemoryToLLVM.cpp index 8c210aaf6e4b..24ede0415aef 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TensorMemoryToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TensorMemoryToLLVM.cpp @@ -740,6 +740,22 @@ static FailureOr> lowerTMemLdStFromTypes( return resultVals; } +static void createWaitOpSt(Location loc, ConversionPatternRewriter &rewriter) { + PTXBuilder ptxBuilder; + std::string opcode = "tcgen05.wait::st.sync.aligned;"; + auto &wait = *ptxBuilder.create(opcode); + wait({}, /*onlyAttachMLIRArgs=*/true); + ptxBuilder.launch(rewriter, loc, void_ty(rewriter.getContext())); +} + +static void createWaitOpLd(Location loc, ConversionPatternRewriter &rewriter) { + PTXBuilder ptxBuilder; + std::string opcode = "tcgen05.wait::ld.sync.aligned;"; + auto &wait = *ptxBuilder.create(opcode); + wait({}, /*onlyAttachMLIRArgs=*/true); + ptxBuilder.launch(rewriter, loc, void_ty(rewriter.getContext())); +} + struct TensorMemoryLoadOpConversion : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; @@ -767,7 +783,9 @@ struct TensorMemoryLoadOpConversion Value resultStruct = packLLElements(loc, getTypeConverter(), *resultValsOr, rewriter, structTy); // Wait insertion could be moved to the TTGIR level if needed. - rewriter.create(loc, NVVM::Tcgen05WaitKind::LOAD); + // Use inline asm until tcgen05 NVVM ops are supported for sm_103 by NVPTX. + createWaitOpLd(loc, rewriter); + // rewriter.create(loc, NVVM::Tcgen05WaitKind::LOAD); rewriter.replaceOp(op, {resultStruct}); return success(); } @@ -799,7 +817,9 @@ struct TensorMemoryStoreOpConversion maxnreg, pred, llvmElemTy, srcValues); if (failed(lowered)) return failure(); - rewriter.create(loc, NVVM::Tcgen05WaitKind::STORE); + // Use inline asm until tcgen05 NVVM ops are supported for sm_103 by NVPTX. + createWaitOpSt(loc, rewriter); + // rewriter.create(loc, NVVM::Tcgen05WaitKind::STORE); // Emit a barrier to ensure all threads have finished writing to tensor // memory before any use of the tensor memory. @@ -847,7 +867,13 @@ struct TensorMemoryAllocOpConversion b.i1_val(true), llvmElemTy, srcValues); if (failed(lowered)) return failure(); - rewriter.create(loc, NVVM::Tcgen05WaitKind::STORE); + + // Use inline asm until tcgen05 NVVM ops are supported for sm_103 by + // NVPTX. + createWaitOpSt(loc, rewriter); + // rewriter.create(loc, + // NVVM::Tcgen05WaitKind::STORE); + // Emit a barrier to ensure all threads have finished writing to tensor // memory before any use of the tensor memory. b.barrier();