Skip to content

Commit ba13f84

Browse files
authored
Revert "[BACKEND] Add missing waits in WGMMA rhs in register pipelining" (#8970)
Temporarily revert due to performance regressions #8964
1 parent 6f4f943 commit ba13f84

File tree

2 files changed

+25
-122
lines changed

2 files changed

+25
-122
lines changed

lib/Dialect/TritonGPU/Transforms/Pipeliner/WGMMAPipeline.cpp

Lines changed: 25 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -348,7 +348,7 @@ std::vector<ttng::WarpGroupDotOp> splitRSDot(ttng::WarpGroupDotOp dotOp) {
348348
dotOp.getInputPrecision(), numImpreciseAcc, dotOp.getIsAsync());
349349
dots.push_back(dot);
350350
C = dot.getResult();
351-
useC = {};
351+
useC = mlir::arith::ConstantIntOp::create(builder, loc, 1, 1);
352352
}
353353
dotOp.replaceAllUsesWith(dots.back().getResult());
354354
dotOp.erase();
@@ -588,64 +588,44 @@ static void insertAsyncWarpGroupDotWaitInLoop(
588588
// Insert waits before the users of the properly async dots other than loop
589589
// yield.
590590
for (auto asyncDot : llvm::make_first_range(properlyAsyncDots)) {
591-
DenseMap<Block *, SmallVector<OpOperand *>> blockToUses;
591+
// If the dot takes the LHS on registers i, we add a wait for the number
592+
// of properly async dots in the loop minus one.
593+
// This makes sure that the dot will wait until itself from the previous
594+
// iteration has completed, as to avoid rewriting the registers.
595+
if (rsDotNeedsWait(asyncDot, forOp)) {
596+
OpBuilder builder(asyncDot);
597+
builder.setInsertionPointAfter(asyncDot);
598+
auto newWait = ttng::WarpGroupDotWaitOp::create(
599+
builder, asyncDot->getLoc(), ArrayRef<Value>{},
600+
properlyAsyncDots.size() - 1);
601+
SmallVector<Value> waitOperands = {asyncDot->getResult(0)};
602+
threadValuesThroughWait(newWait, waitOperands);
603+
continue;
604+
}
605+
606+
SmallVector<OpOperand *> uses;
592607
for (auto &use : asyncDot->getUses()) {
593608
if (auto yieldOp = dyn_cast<scf::YieldOp>(use.getOwner())) {
594609
continue;
595610
}
596-
597-
auto block = use.getOwner()->getBlock();
598-
blockToUses[block].push_back(&use);
611+
uses.push_back(&use);
599612
}
600613

601-
for (auto [block, uses] : blockToUses) {
602-
// Insert a wait before the first use in the block
603-
std::sort(uses.begin(), uses.end(), [](OpOperand *lhs, OpOperand *rhs) {
604-
Operation *lhsOp = lhs->getOwner();
605-
Operation *rhsOp = rhs->getOwner();
606-
return lhsOp->isBeforeInBlock(rhsOp);
607-
});
608-
609-
// If a wgmma uses the same accumulator registers, it will be implicitly
610-
// pipelined by the hardware and doesn't need a wait.
611-
auto firstUse =
612-
std::find_if_not(uses.begin(), uses.end(), [](OpOperand *operand) {
613-
return (isa<ttng::WarpGroupDotOp>(operand->getOwner()) &&
614-
operand->getOperandNumber() == 2);
615-
});
616-
if (firstUse == uses.end()) {
617-
continue;
618-
}
614+
DenseMap<Block *, SmallVector<Value>> blockToUsers;
615+
for (auto use : uses) {
616+
auto block = use->getOwner()->getBlock();
617+
blockToUsers[block].push_back(use->get());
618+
}
619619

620-
OpBuilder builder((*firstUse)->getOwner());
620+
for (auto [block, users] : blockToUsers) {
621+
OpBuilder builder(block, block->begin());
621622
auto newWait = ttng::WarpGroupDotWaitOp::create(
622623
builder, asyncDot->getLoc(), ArrayRef<Value>{}, 0);
623624

624-
SmallVector<Value> users;
625-
for (; firstUse != uses.end(); ++firstUse) {
626-
users.push_back((*firstUse)->get());
627-
}
628625
threadValuesThroughWait(newWait, users);
629626
}
630627
}
631628

632-
for (auto asyncDot : llvm::make_first_range(properlyAsyncDots)) {
633-
// If the dot takes the LHS on registers i, we add a wait for the number
634-
// of properly async dots in the loop minus one.
635-
// This makes sure that the dot will wait until itself from the previous
636-
// iteration has completed, as to avoid rewriting the registers.
637-
if (!rsDotNeedsWait(asyncDot, forOp))
638-
continue;
639-
640-
OpBuilder builder(asyncDot);
641-
builder.setInsertionPointAfter(asyncDot);
642-
auto newWait = ttng::WarpGroupDotWaitOp::create(
643-
builder, asyncDot->getLoc(), ArrayRef<Value>{},
644-
properlyAsyncDots.size() - 1);
645-
SmallVector<Value> waitOperands = {asyncDot->getResult(0)};
646-
threadValuesThroughWait(newWait, waitOperands);
647-
}
648-
649629
// Add the wait right after the last properly-async dot. This only needs to
650630
// wait for all properly-async dots from the i-1'th iteration to complete, IOW
651631
// we wait until there are most `asyncDots.size()` dots in flight.

test/TritonGPU/loop-pipeline-hopper.mlir

Lines changed: 0 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -816,83 +816,6 @@ module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-
816816

817817
// -----
818818

819-
#blocked = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [0, 1]}>
820-
#blocked1 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
821-
#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16, 16]}>
822-
#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = true, elementBitWidth = 16}>
823-
#smem = #ttg.shared_memory
824-
module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} {
825-
// CHECK-LABEL: dot_lhs_in_reg_with_epilogue
826-
tt.func @dot_lhs_in_reg_with_epilogue(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg2: i1) -> tensor<128x16xf32, #mma> {
827-
%cst = arith.constant dense<0> : tensor<128x64xi32, #blocked1>
828-
%cst1 = arith.constant dense<0> : tensor<64x16xi32, #blocked>
829-
%c0_i32 = arith.constant 0 : i32
830-
%cst_0 = arith.constant dense<0> : tensor<1x16xi32, #blocked>
831-
%cst_1 = arith.constant dense<0> : tensor<128x1xi32, #blocked1>
832-
%c0_i64 = arith.constant 0 : i64
833-
%cst_2 = arith.constant dense<0.000000e+00> : tensor<128x16xf32, #mma>
834-
%cst_3 = arith.constant dense<0> : tensor<128x64xi32, #blocked1>
835-
%cst_4 = arith.constant dense<2.0> : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>>
836-
%c1_i32 = arith.constant 1 : i32
837-
%c8_i32 = arith.constant 8 : i32
838-
%0 = tt.addptr %arg0, %c0_i64 : !tt.ptr<f16>, i64
839-
%1 = tt.addptr %arg1, %c0_i64 : !tt.ptr<f16>, i64
840-
%2 = tt.splat %1 : !tt.ptr<f16> -> tensor<128x1x!tt.ptr<f16>, #blocked1>
841-
%3 = tt.addptr %2, %cst_1 : tensor<128x1x!tt.ptr<f16>, #blocked1>, tensor<128x1xi32, #blocked1>
842-
%4 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>>
843-
%5 = tt.expand_dims %4 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1>
844-
%6 = tt.broadcast %3 : tensor<128x1x!tt.ptr<f16>, #blocked1> -> tensor<128x64x!tt.ptr<f16>, #blocked1>
845-
%7 = tt.broadcast %5 : tensor<1x64xi32, #blocked1> -> tensor<128x64xi32, #blocked1>
846-
%8 = tt.addptr %6, %7 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1>
847-
%10 = tt.splat %0 : !tt.ptr<f16> -> tensor<1x16x!tt.ptr<f16>, #blocked>
848-
%11 = tt.addptr %10, %cst_0 : tensor<1x16x!tt.ptr<f16>, #blocked>, tensor<1x16xi32, #blocked>
849-
%12 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>>
850-
%13 = tt.expand_dims %12 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked>
851-
%14 = tt.broadcast %11 : tensor<1x16x!tt.ptr<f16>, #blocked> -> tensor<64x16x!tt.ptr<f16>, #blocked>
852-
%15 = tt.broadcast %13 : tensor<64x1xi32, #blocked> -> tensor<64x16xi32, #blocked>
853-
%16 = tt.addptr %14, %15 : tensor<64x16x!tt.ptr<f16>, #blocked>, tensor<64x16xi32, #blocked>
854-
// CHECK: scf.for
855-
// CHECK: ttg.async_wait {{.*}} {num = 2 : i32}
856-
// CHECK: ttng.warp_group_dot
857-
// CHECK-NEXT: ttng.warp_group_dot_wait {{.*}} {pendings = 3 : i32}
858-
// CHECK: ttng.warp_group_dot
859-
// CHECK-NEXT: ttng.warp_group_dot_wait {{.*}} {pendings = 3 : i32}
860-
// CHECK: ttng.warp_group_dot
861-
// CHECK-NEXT: ttng.warp_group_dot_wait {{.*}} {pendings = 3 : i32}
862-
// CHECK: ttng.warp_group_dot
863-
// CHECK-NEXT: ttng.warp_group_dot_wait {{.*}} {pendings = 3 : i32}
864-
// CHECK: ttg.async_copy_global_to_local
865-
// CHECK: ttg.async_copy_global_to_local
866-
// CHECK: ttg.async_commit_group
867-
// CHECK: scf.if
868-
// CHECK-NEXT: ttng.warp_group_dot_wait {{.*}} {pendings = 0 : i32}
869-
// CHECK: } else {
870-
// CHECK-NOT: ttng.warp_group_dot_wait
871-
// CHECK: scf.yield
872-
%17:3 = scf.for %arg3 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg4 = %cst_2, %arg5 = %8, %arg6 = %16) -> (tensor<128x16xf32, #mma>, tensor<128x64x!tt.ptr<f16>, #blocked1>,
873-
tensor<64x16x!tt.ptr<f16>, #blocked>) : i32 {
874-
%a_block = tt.load %arg5 : tensor<128x64x!tt.ptr<f16>, #blocked1>
875-
%b_block = tt.load %arg6 : tensor<64x16x!tt.ptr<f16>, #blocked>
876-
%a_dotop = ttg.convert_layout %a_block : tensor<128x64xf16, #blocked1> -> tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>>
877-
%a_dotop_mul = arith.mulf %a_dotop, %cst_4 : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>>
878-
%b_smem = ttg.local_alloc %b_block : (tensor<64x16xf16, #blocked>) -> !ttg.memdesc<64x16xf16, #shared, #smem>
879-
%25 = ttng.warp_group_dot %a_dotop_mul, %b_smem, %arg4 : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * !ttg.memdesc<64x16xf16, #shared, #smem> -> tensor<128x16xf32, #mma>
880-
%26 = tt.addptr %arg5, %cst : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1>
881-
%27 = tt.addptr %arg6, %cst1 : tensor<64x16x!tt.ptr<f16>, #blocked>, tensor<64x16xi32, #blocked>
882-
%28 = scf.if %arg2 -> tensor<128x16xf32, #mma> {
883-
%29 = arith.addf %25, %25 : tensor<128x16xf32, #mma>
884-
scf.yield %29: tensor<128x16xf32, #mma>
885-
} else {
886-
scf.yield %25: tensor<128x16xf32, #mma>
887-
}
888-
scf.yield %28, %26, %27 : tensor<128x16xf32, #mma>, tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<64x16x!tt.ptr<f16>, #blocked>
889-
}
890-
tt.return %17#0 : tensor<128x16xf32, #mma>
891-
}
892-
}
893-
894-
// -----
895-
896819
#blocked = #ttg.blocked<{sizePerThread = [16, 1], threadsPerWarp = [16, 2], warpsPerCTA = [1, 8], order = [0, 1]}>
897820
#blocked1 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 16], warpsPerCTA = [8, 1], order = [1, 0]}>
898821
#blocked2 = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [8, 1], order = [1, 0]}>

0 commit comments

Comments
 (0)