Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 11 additions & 11 deletions test/Conversion/tritongpu_to_llvm_blackwell.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -107,9 +107,9 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar
// CHECK-LABEL: @tensor_memory_ld
// CHECK: nvgpu.tensor_memory_base
// CHECK: tcgen05.st.sync.aligned.32x32b.x128.b32
// CHECK: nvvm.tcgen05.wait <store>
// CHECK: tcgen05.wait::st.sync.aligned
// CHECK: tcgen05.ld.sync.aligned.32x32b.x128.b32
// CHECK: nvvm.tcgen05.wait <load>
// CHECK: tcgen05.wait::ld.sync.aligned
tt.func public @tensor_memory_ld(%arg0: !tt.ptr<f16>, %arg1: !tt.ptr<f16>, %arg2: !tt.ptr<f16>) {
%cst_0 = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked1>
%0 = ttng.tmem_alloc %cst_0 {tensor_memory_col_offset = 0 : i32, tensor_memory_row_offset = 0 : i32} : (tensor<128x128xf32, #blocked1>) -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>
Expand Down Expand Up @@ -158,10 +158,10 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar
// CHECK: nvgpu.tensor_memory_base
// CHECK: tcgen05.st.sync.aligned.16x32bx2.x64.b32
// CHECK: tcgen05.st.sync.aligned.16x32bx2.x64.b32
// CHECK: nvvm.tcgen05.wait <store>
// CHECK: tcgen05.wait::st.sync.aligned
// CHECK: tcgen05.ld.sync.aligned.16x32bx2.x64.b32
// CHECK: tcgen05.ld.sync.aligned.16x32bx2.x64.b32
// CHECK: nvvm.tcgen05.wait <load>
// CHECK: tcgen05.wait::ld.sync.aligned
tt.func public @tensor_memory_ld_m64(%arg0: !tt.ptr<f16>, %arg1: !tt.ptr<f16>, %arg2: !tt.ptr<f16>) {
%cst_0 = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked1>
%0 = ttng.tmem_alloc %cst_0 {tensor_memory_col_offset = 0 : i32, tensor_memory_row_offset = 0 : i32} : (tensor<128x128xf32, #blocked1>) -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>
Expand All @@ -179,9 +179,9 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar
// CHECK-LABEL: @tensor_memory_unpack_f16
// CHECK: nvgpu.tensor_memory_base
// CHECK: tcgen05.st.sync.aligned.32x32b.x64.unpack::16b.b32
// CHECK: nvvm.tcgen05.wait <store>
// CHECK: tcgen05.wait::st.sync.aligned
// CHECK: tcgen05.ld.sync.aligned.32x32b.x64.pack::16b.b32
// CHECK: nvvm.tcgen05.wait <load>
// CHECK: tcgen05.wait::ld.sync.aligned
tt.func public @tensor_memory_unpack_f16() {
%cst_0 = arith.constant dense<0.000000e+00> : tensor<128x128xf16, #blocked1>
%0 = ttng.tmem_alloc %cst_0 {tensor_memory_col_offset = 0 : i32, tensor_memory_row_offset = 0 : i32} : (tensor<128x128xf16, #blocked1>) -> !ttg.memdesc<128x128xf16, #tmem, #ttng.tensor_memory, mutable>
Expand Down Expand Up @@ -399,10 +399,10 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar
// CHECK-LABEL: @tensor_memory_ld_128x256
// CHECK-COUNT-4: tcgen05.st.sync.aligned.32x32b.x64.b32
// CHECK-NOT: tcgen05.st
// CHECK: nvvm.tcgen05.wait <store>
// CHECK: tcgen05.wait::st.sync.aligned
// CHECK-COUNT-4: tcgen05.ld.sync.aligned.32x32b.x64.b32
// CHECK-NOT: tcgen05.ld
// CHECK: nvvm.tcgen05.wait <load>
// CHECK: tcgen05.wait::ld.sync.aligned
tt.func public @tensor_memory_ld_128x256(%arg0: !tt.ptr<f16>, %arg1: !tt.ptr<f16>, %arg2: !tt.ptr<f16>) {
%cst_0 = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #blocked>
%0 = ttng.tmem_alloc %cst_0 {tensor_memory_col_offset = 0 : i32, tensor_memory_row_offset = 0 : i32} : (tensor<128x256xf32, #blocked>) -> !ttg.memdesc<128x256xf32, #tmem, #ttng.tensor_memory, mutable>
Expand All @@ -419,9 +419,9 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.shared = 65544 : i32, ttg.target = "cuda:100", ttg.tensor_memory_size = 128 : i32, "ttg.threads-per-warp" = 32 : i32} {
// CHECK-LABEL: @tensor_memory_ld_128x256_8_warps
// CHECK: tcgen05.st.sync.aligned.32x32b.x128.b32
// CHECK: nvvm.tcgen05.wait <store>
// CHECK: tcgen05.wait::st.sync.aligned
// CHECK: tcgen05.ld.sync.aligned.32x32b.x128.b32
// CHECK: nvvm.tcgen05.wait <load>
// CHECK: tcgen05.wait::ld.sync.aligned
tt.func public @tensor_memory_ld_128x256_8_warps(%arg0: !tt.ptr<f16>, %arg1: !tt.ptr<f16>, %arg2: !tt.ptr<f16>) {
%cst_0 = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #blocked>
%0 = ttng.tmem_alloc %cst_0 {tensor_memory_col_offset = 0 : i32, tensor_memory_row_offset = 0 : i32} : (tensor<128x256xf32, #blocked>) -> !ttg.memdesc<128x256xf32, #tmem, #ttng.tensor_memory, mutable>
Expand Down Expand Up @@ -844,7 +844,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar
// CHECK-LABEL: @tensor_memory_st
// CHECK: nvgpu.tensor_memory_base
// CHECK: tcgen05.st.sync.aligned.32x32b.x128.b32
// CHECK: nvvm.tcgen05.wait <store>
// CHECK: tcgen05.wait::st.sync.aligned
tt.func public @tensor_memory_st(%arg0: !tt.ptr<f16>, %arg1: !tt.ptr<f16>, %arg2: !tt.ptr<f16>) {
%cst_0 = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked1>
%0 = ttng.tmem_alloc {tensor_memory_col_offset = 0 : i32, tensor_memory_row_offset = 0 : i32} : () -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -740,6 +740,22 @@ static FailureOr<SmallVector<Value>> lowerTMemLdStFromTypes(
return resultVals;
}

static void createWaitOpSt(Location loc, ConversionPatternRewriter &rewriter) {
PTXBuilder ptxBuilder;
std::string opcode = "tcgen05.wait::st.sync.aligned;";
auto &wait = *ptxBuilder.create<PTXInstr>(opcode);
wait({}, /*onlyAttachMLIRArgs=*/true);
ptxBuilder.launch(rewriter, loc, void_ty(rewriter.getContext()));
}

static void createWaitOpLd(Location loc, ConversionPatternRewriter &rewriter) {
PTXBuilder ptxBuilder;
std::string opcode = "tcgen05.wait::ld.sync.aligned;";
auto &wait = *ptxBuilder.create<PTXInstr>(opcode);
wait({}, /*onlyAttachMLIRArgs=*/true);
ptxBuilder.launch(rewriter, loc, void_ty(rewriter.getContext()));
}

struct TensorMemoryLoadOpConversion
: public ConvertOpToLLVMPattern<triton::nvidia_gpu::TMEMLoadOp> {
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
Expand Down Expand Up @@ -767,7 +783,9 @@ struct TensorMemoryLoadOpConversion
Value resultStruct = packLLElements(loc, getTypeConverter(), *resultValsOr,
rewriter, structTy);
// Wait insertion could be moved to the TTGIR level if needed.
rewriter.create<NVVM::Tcgen05WaitOp>(loc, NVVM::Tcgen05WaitKind::LOAD);
// Use inline asm until tcgen05 NVVM ops are supported for sm_103 by NVPTX.
createWaitOpLd(loc, rewriter);
// rewriter.create<NVVM::Tcgen05WaitOp>(loc, NVVM::Tcgen05WaitKind::LOAD);
rewriter.replaceOp(op, {resultStruct});
return success();
}
Expand Down Expand Up @@ -799,7 +817,9 @@ struct TensorMemoryStoreOpConversion
maxnreg, pred, llvmElemTy, srcValues);
if (failed(lowered))
return failure();
rewriter.create<NVVM::Tcgen05WaitOp>(loc, NVVM::Tcgen05WaitKind::STORE);
// Use inline asm until tcgen05 NVVM ops are supported for sm_103 by NVPTX.
createWaitOpSt(loc, rewriter);
// rewriter.create<NVVM::Tcgen05WaitOp>(loc, NVVM::Tcgen05WaitKind::STORE);

// Emit a barrier to ensure all threads have finished writing to tensor
// memory before any use of the tensor memory.
Expand Down Expand Up @@ -847,7 +867,13 @@ struct TensorMemoryAllocOpConversion
b.i1_val(true), llvmElemTy, srcValues);
if (failed(lowered))
return failure();
rewriter.create<NVVM::Tcgen05WaitOp>(loc, NVVM::Tcgen05WaitKind::STORE);

// Use inline asm until tcgen05 NVVM ops are supported for sm_103 by
// NVPTX.
createWaitOpSt(loc, rewriter);
// rewriter.create<NVVM::Tcgen05WaitOp>(loc,
// NVVM::Tcgen05WaitKind::STORE);

// Emit a barrier to ensure all threads have finished writing to tensor
// memory before any use of the tensor memory.
b.barrier();
Expand Down