Skip to content

Commit 6f4f943

Browse files
authored
[BACKEND] Add missing waits in WGMMA rhs in register pipelining (#8964)
We have a special case logic for rsDotNeedsWait which skipped the normal code path to add wait before accessing the accumulator. This is necessary because we split one `warp_group_dot` into many that act on the same accumulator, and so we don't want to add wait 0 between the wgmma ops. However, it breaks cases where we genuinely need to access the accumulator e.g. in the epilogue of a persistent matmul. Instead, this PR makes the rsDotNeedsWait logic completely separate and adds a condition to the generic code path to not emit waits for wgmmas that access the accumulator of another wgmma, since these are pipelined in hardware.
1 parent 5784490 commit 6f4f943

File tree

2 files changed

+122
-25
lines changed

2 files changed

+122
-25
lines changed

lib/Dialect/TritonGPU/Transforms/Pipeliner/WGMMAPipeline.cpp

Lines changed: 45 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -348,7 +348,7 @@ std::vector<ttng::WarpGroupDotOp> splitRSDot(ttng::WarpGroupDotOp dotOp) {
348348
dotOp.getInputPrecision(), numImpreciseAcc, dotOp.getIsAsync());
349349
dots.push_back(dot);
350350
C = dot.getResult();
351-
useC = mlir::arith::ConstantIntOp::create(builder, loc, 1, 1);
351+
useC = {};
352352
}
353353
dotOp.replaceAllUsesWith(dots.back().getResult());
354354
dotOp.erase();
@@ -588,44 +588,64 @@ static void insertAsyncWarpGroupDotWaitInLoop(
588588
// Insert waits before the users of the properly async dots other than loop
589589
// yield.
590590
for (auto asyncDot : llvm::make_first_range(properlyAsyncDots)) {
591-
// If the dot takes the LHS on registers i, we add a wait for the number
592-
// of properly async dots in the loop minus one.
593-
// This makes sure that the dot will wait until itself from the previous
594-
// iteration has completed, as to avoid rewriting the registers.
595-
if (rsDotNeedsWait(asyncDot, forOp)) {
596-
OpBuilder builder(asyncDot);
597-
builder.setInsertionPointAfter(asyncDot);
598-
auto newWait = ttng::WarpGroupDotWaitOp::create(
599-
builder, asyncDot->getLoc(), ArrayRef<Value>{},
600-
properlyAsyncDots.size() - 1);
601-
SmallVector<Value> waitOperands = {asyncDot->getResult(0)};
602-
threadValuesThroughWait(newWait, waitOperands);
603-
continue;
604-
}
605-
606-
SmallVector<OpOperand *> uses;
591+
DenseMap<Block *, SmallVector<OpOperand *>> blockToUses;
607592
for (auto &use : asyncDot->getUses()) {
608593
if (auto yieldOp = dyn_cast<scf::YieldOp>(use.getOwner())) {
609594
continue;
610595
}
611-
uses.push_back(&use);
612-
}
613596

614-
DenseMap<Block *, SmallVector<Value>> blockToUsers;
615-
for (auto use : uses) {
616-
auto block = use->getOwner()->getBlock();
617-
blockToUsers[block].push_back(use->get());
597+
auto block = use.getOwner()->getBlock();
598+
blockToUses[block].push_back(&use);
618599
}
619600

620-
for (auto [block, users] : blockToUsers) {
621-
OpBuilder builder(block, block->begin());
601+
for (auto [block, uses] : blockToUses) {
602+
// Insert a wait before the first use in the block
603+
std::sort(uses.begin(), uses.end(), [](OpOperand *lhs, OpOperand *rhs) {
604+
Operation *lhsOp = lhs->getOwner();
605+
Operation *rhsOp = rhs->getOwner();
606+
return lhsOp->isBeforeInBlock(rhsOp);
607+
});
608+
609+
// If a wgmma uses the same accumulator registers, it will be implicitly
610+
// pipelined by the hardware and doesn't need a wait.
611+
auto firstUse =
612+
std::find_if_not(uses.begin(), uses.end(), [](OpOperand *operand) {
613+
return (isa<ttng::WarpGroupDotOp>(operand->getOwner()) &&
614+
operand->getOperandNumber() == 2);
615+
});
616+
if (firstUse == uses.end()) {
617+
continue;
618+
}
619+
620+
OpBuilder builder((*firstUse)->getOwner());
622621
auto newWait = ttng::WarpGroupDotWaitOp::create(
623622
builder, asyncDot->getLoc(), ArrayRef<Value>{}, 0);
624623

624+
SmallVector<Value> users;
625+
for (; firstUse != uses.end(); ++firstUse) {
626+
users.push_back((*firstUse)->get());
627+
}
625628
threadValuesThroughWait(newWait, users);
626629
}
627630
}
628631

632+
for (auto asyncDot : llvm::make_first_range(properlyAsyncDots)) {
633+
// If the dot takes the LHS on registers i, we add a wait for the number
634+
// of properly async dots in the loop minus one.
635+
// This makes sure that the dot will wait until itself from the previous
636+
// iteration has completed, as to avoid rewriting the registers.
637+
if (!rsDotNeedsWait(asyncDot, forOp))
638+
continue;
639+
640+
OpBuilder builder(asyncDot);
641+
builder.setInsertionPointAfter(asyncDot);
642+
auto newWait = ttng::WarpGroupDotWaitOp::create(
643+
builder, asyncDot->getLoc(), ArrayRef<Value>{},
644+
properlyAsyncDots.size() - 1);
645+
SmallVector<Value> waitOperands = {asyncDot->getResult(0)};
646+
threadValuesThroughWait(newWait, waitOperands);
647+
}
648+
629649
// Add the wait right after the last properly-async dot. This only needs to
630650
// wait for all properly-async dots from the i-1'th iteration to complete, IOW
631651
// we wait until there are most `asyncDots.size()` dots in flight.

test/TritonGPU/loop-pipeline-hopper.mlir

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -816,6 +816,83 @@ module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-
816816

817817
// -----
818818

819+
#blocked = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [0, 1]}>
820+
#blocked1 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
821+
#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16, 16]}>
822+
#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = true, elementBitWidth = 16}>
823+
#smem = #ttg.shared_memory
824+
module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} {
825+
// CHECK-LABEL: dot_lhs_in_reg_with_epilogue
826+
tt.func @dot_lhs_in_reg_with_epilogue(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg2: i1) -> tensor<128x16xf32, #mma> {
827+
%cst = arith.constant dense<0> : tensor<128x64xi32, #blocked1>
828+
%cst1 = arith.constant dense<0> : tensor<64x16xi32, #blocked>
829+
%c0_i32 = arith.constant 0 : i32
830+
%cst_0 = arith.constant dense<0> : tensor<1x16xi32, #blocked>
831+
%cst_1 = arith.constant dense<0> : tensor<128x1xi32, #blocked1>
832+
%c0_i64 = arith.constant 0 : i64
833+
%cst_2 = arith.constant dense<0.000000e+00> : tensor<128x16xf32, #mma>
834+
%cst_3 = arith.constant dense<0> : tensor<128x64xi32, #blocked1>
835+
%cst_4 = arith.constant dense<2.0> : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>>
836+
%c1_i32 = arith.constant 1 : i32
837+
%c8_i32 = arith.constant 8 : i32
838+
%0 = tt.addptr %arg0, %c0_i64 : !tt.ptr<f16>, i64
839+
%1 = tt.addptr %arg1, %c0_i64 : !tt.ptr<f16>, i64
840+
%2 = tt.splat %1 : !tt.ptr<f16> -> tensor<128x1x!tt.ptr<f16>, #blocked1>
841+
%3 = tt.addptr %2, %cst_1 : tensor<128x1x!tt.ptr<f16>, #blocked1>, tensor<128x1xi32, #blocked1>
842+
%4 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>>
843+
%5 = tt.expand_dims %4 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1>
844+
%6 = tt.broadcast %3 : tensor<128x1x!tt.ptr<f16>, #blocked1> -> tensor<128x64x!tt.ptr<f16>, #blocked1>
845+
%7 = tt.broadcast %5 : tensor<1x64xi32, #blocked1> -> tensor<128x64xi32, #blocked1>
846+
%8 = tt.addptr %6, %7 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1>
847+
%10 = tt.splat %0 : !tt.ptr<f16> -> tensor<1x16x!tt.ptr<f16>, #blocked>
848+
%11 = tt.addptr %10, %cst_0 : tensor<1x16x!tt.ptr<f16>, #blocked>, tensor<1x16xi32, #blocked>
849+
%12 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>>
850+
%13 = tt.expand_dims %12 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked>
851+
%14 = tt.broadcast %11 : tensor<1x16x!tt.ptr<f16>, #blocked> -> tensor<64x16x!tt.ptr<f16>, #blocked>
852+
%15 = tt.broadcast %13 : tensor<64x1xi32, #blocked> -> tensor<64x16xi32, #blocked>
853+
%16 = tt.addptr %14, %15 : tensor<64x16x!tt.ptr<f16>, #blocked>, tensor<64x16xi32, #blocked>
854+
// CHECK: scf.for
855+
// CHECK: ttg.async_wait {{.*}} {num = 2 : i32}
856+
// CHECK: ttng.warp_group_dot
857+
// CHECK-NEXT: ttng.warp_group_dot_wait {{.*}} {pendings = 3 : i32}
858+
// CHECK: ttng.warp_group_dot
859+
// CHECK-NEXT: ttng.warp_group_dot_wait {{.*}} {pendings = 3 : i32}
860+
// CHECK: ttng.warp_group_dot
861+
// CHECK-NEXT: ttng.warp_group_dot_wait {{.*}} {pendings = 3 : i32}
862+
// CHECK: ttng.warp_group_dot
863+
// CHECK-NEXT: ttng.warp_group_dot_wait {{.*}} {pendings = 3 : i32}
864+
// CHECK: ttg.async_copy_global_to_local
865+
// CHECK: ttg.async_copy_global_to_local
866+
// CHECK: ttg.async_commit_group
867+
// CHECK: scf.if
868+
// CHECK-NEXT: ttng.warp_group_dot_wait {{.*}} {pendings = 0 : i32}
869+
// CHECK: } else {
870+
// CHECK-NOT: ttng.warp_group_dot_wait
871+
// CHECK: scf.yield
872+
%17:3 = scf.for %arg3 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg4 = %cst_2, %arg5 = %8, %arg6 = %16) -> (tensor<128x16xf32, #mma>, tensor<128x64x!tt.ptr<f16>, #blocked1>,
873+
tensor<64x16x!tt.ptr<f16>, #blocked>) : i32 {
874+
%a_block = tt.load %arg5 : tensor<128x64x!tt.ptr<f16>, #blocked1>
875+
%b_block = tt.load %arg6 : tensor<64x16x!tt.ptr<f16>, #blocked>
876+
%a_dotop = ttg.convert_layout %a_block : tensor<128x64xf16, #blocked1> -> tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>>
877+
%a_dotop_mul = arith.mulf %a_dotop, %cst_4 : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>>
878+
%b_smem = ttg.local_alloc %b_block : (tensor<64x16xf16, #blocked>) -> !ttg.memdesc<64x16xf16, #shared, #smem>
879+
%25 = ttng.warp_group_dot %a_dotop_mul, %b_smem, %arg4 : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * !ttg.memdesc<64x16xf16, #shared, #smem> -> tensor<128x16xf32, #mma>
880+
%26 = tt.addptr %arg5, %cst : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1>
881+
%27 = tt.addptr %arg6, %cst1 : tensor<64x16x!tt.ptr<f16>, #blocked>, tensor<64x16xi32, #blocked>
882+
%28 = scf.if %arg2 -> tensor<128x16xf32, #mma> {
883+
%29 = arith.addf %25, %25 : tensor<128x16xf32, #mma>
884+
scf.yield %29: tensor<128x16xf32, #mma>
885+
} else {
886+
scf.yield %25: tensor<128x16xf32, #mma>
887+
}
888+
scf.yield %28, %26, %27 : tensor<128x16xf32, #mma>, tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<64x16x!tt.ptr<f16>, #blocked>
889+
}
890+
tt.return %17#0 : tensor<128x16xf32, #mma>
891+
}
892+
}
893+
894+
// -----
895+
819896
#blocked = #ttg.blocked<{sizePerThread = [16, 1], threadsPerWarp = [16, 2], warpsPerCTA = [1, 8], order = [0, 1]}>
820897
#blocked1 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 16], warpsPerCTA = [8, 1], order = [1, 0]}>
821898
#blocked2 = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [8, 1], order = [1, 0]}>

0 commit comments

Comments
 (0)