Skip to content

Commit 6e41010

Browse files
authored
[Blackwell][Clean up] Introduce interface for MMAv5 ops (triton-lang#5848)
The goal is to let `TCGen5MMAOp` and `TCGen5MMAScaledOp` share an interface so that the rest of code can work generically with them. The MMA pipelining pass gets cleaned up a lot, and the accum init flag optimization is now automatically enabled for `TCGen5MMAScaledOp` as well. --------- Co-authored-by: Masahiro Masuda <[email protected]>
1 parent d691926 commit 6e41010

File tree

12 files changed

+129
-51
lines changed

12 files changed

+129
-51
lines changed

include/triton/Dialect/TritonNvidiaGPU/IR/CMakeLists.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,8 @@ mlir_tablegen(TritonNvidiaGPUAttrDefs.cpp.inc -gen-attrdef-defs)
1515
mlir_tablegen(OpsEnums.h.inc -gen-enum-decls)
1616
mlir_tablegen(OpsEnums.cpp.inc -gen-enum-defs)
1717
add_public_tablegen_target(TritonNvidiaGPUAttrDefsIncGen)
18+
19+
set(LLVM_TARGET_DEFINITIONS TritonNvidiaGPUOpInterfaces.td)
20+
mlir_tablegen(TritonNvidiaGPUOpInterfaces.h.inc -gen-op-interface-decls)
21+
mlir_tablegen(TritonNvidiaGPUOpInterfaces.cpp.inc -gen-op-interface-defs)
22+
add_public_tablegen_target(TritonNvidiaGPUOpInterfacesIncGen)

include/triton/Dialect/TritonNvidiaGPU/IR/Dialect.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@
3737
#define GET_ATTRDEF_CLASSES
3838
#include "triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUAttrDefs.h.inc"
3939

40+
#include "triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOpInterfaces.h.inc"
41+
4042
#define GET_OP_CLASSES
4143
#include "triton/Dialect/TritonNvidiaGPU/IR/Ops.h.inc"
4244

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
#ifndef TRITON_NVIDIAGPU_OP_INTERFACES
2+
#define TRITON_NVIDIAGPU_OP_INTERFACES
3+
4+
include "mlir/IR/OpBase.td"
5+
6+
def MMAv5OpInterface : OpInterface<"MMAv5OpInterface"> {
7+
let description = [{
8+
This interface is implemented by MMAv5 dot and dot scaled ops.
9+
}];
10+
11+
let cppNamespace = "::mlir::triton::nvidia_gpu";
12+
13+
// We can add more methods as needed.
14+
let methods = [
15+
InterfaceMethod<"Return the accumulator init flag.",
16+
"::mlir::Value",
17+
"useAccumulator">,
18+
InterfaceMethod<"Set the accumulator init flag.",
19+
"void",
20+
"setUseAccumulator",
21+
(ins "::mlir::Value":$flag)>,
22+
InterfaceMethod<"Associate a new barrier to this MMAv5 op.",
23+
"void",
24+
"setBarrier",
25+
(ins "::mlir::Value":$barrier)>,
26+
InterfaceMethod<"Return the accumulator.",
27+
"::mlir::Value",
28+
"getAccumulator">,
29+
InterfaceMethod<"Set the accumulator.",
30+
"void",
31+
"setAccumulator",
32+
(ins "::mlir::Value":$accum)>,
33+
InterfaceMethod<"Return the predicate of this op.",
34+
"::mlir::Value",
35+
"getPredicate">,
36+
InterfaceMethod<"Set the predicate of this op.",
37+
"void",
38+
"setPredicate",
39+
(ins "::mlir::Value":$pred)>,
40+
];
41+
}
42+
#endif // TRITON_NVIDIAGPU_OP_INTERFACES

include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
#define TRITONNVIDIAGPU_OPS
2424

2525
include "triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUDialect.td"
26+
include "triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOpInterfaces.td"
2627
include "mlir/Dialect/Arith/IR/ArithBase.td"
2728
include "triton/Dialect/Triton/IR/TritonTypes.td"
2829
include "triton/Dialect/Triton/IR/TritonAttrDefs.td"
@@ -326,7 +327,7 @@ def TTNG_TMAStoreWaitOp : TTNG_Op<"async_tma_store_wait"> {
326327
let assemblyFormat = "attr-dict";
327328
}
328329

329-
def TTNG_TCGen5MMAOp : TTNG_Op<"tc_gen5_mma", [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>, DeclareOpInterfaceMethods<DotOpInterface>]> {
330+
def TTNG_TCGen5MMAOp : TTNG_Op<"tc_gen5_mma", [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>, DeclareOpInterfaceMethods<DotOpInterface>, DeclareOpInterfaceMethods<MMAv5OpInterface>]> {
330331
let summary = "block level op mapping to tensorcore gen5 mma";
331332

332333
let description = [{
@@ -349,7 +350,7 @@ def TTNG_TCGen5MMAOp : TTNG_Op<"tc_gen5_mma", [DeclareOpInterfaceMethods<MemoryE
349350
let assemblyFormat = "$a`,` $b`,` $d`,` $useD`,` $pred (`,` $barrier^)? attr-dict `:` functional-type(operands, results)";
350351
}
351352

352-
def TTNG_TCGen5MMAScaledOp : TTNG_Op<"tc_gen5_mma_scaled", [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>, DeclareOpInterfaceMethods<DotOpInterface>]> {
353+
def TTNG_TCGen5MMAScaledOp : TTNG_Op<"tc_gen5_mma_scaled", [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>, DeclareOpInterfaceMethods<DotOpInterface>, DeclareOpInterfaceMethods<MMAv5OpInterface>]> {
353354
let summary = "block level op mapping to tensorcore gen5 mma";
354355

355356
let description = [{

lib/Dialect/TritonGPU/Transforms/OptimizeAccumulatorInit.cpp

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -28,20 +28,21 @@ class TMEMAllocWithUnusedInit
2828
op.getResult().getUsers().end());
2929
if (users.size() > 2)
3030
return failure();
31-
triton::nvidia_gpu::TCGen5MMAOp mmaOp = nullptr;
31+
triton::nvidia_gpu::MMAv5OpInterface mmaOp = nullptr;
3232
triton::nvidia_gpu::TMEMLoadOp tmemLoad = nullptr;
3333
for (auto user : users) {
3434
if (auto load = dyn_cast<triton::nvidia_gpu::TMEMLoadOp>(user)) {
3535
tmemLoad = load;
36-
} else if (auto mma = dyn_cast<triton::nvidia_gpu::TCGen5MMAOp>(user)) {
36+
} else if (auto mma =
37+
dyn_cast<triton::nvidia_gpu::MMAv5OpInterface>(user)) {
3738
mmaOp = mma;
3839
}
3940
}
4041
if (!mmaOp)
4142
return failure();
4243
if (tmemLoad && !mmaOp->isBeforeInBlock(tmemLoad))
4344
return failure();
44-
Value useAccFlag = mmaOp.getUseD();
45+
Value useAccFlag = mmaOp.useAccumulator();
4546
if (!useAccFlag)
4647
return failure();
4748
auto flagConstOp = useAccFlag.getDefiningOp<arith::ConstantOp>();
@@ -63,7 +64,7 @@ bool dotSupportsAccInitFlag(Operation *op) {
6364
// initialization that would degrade the performance.
6465
return !wgDotOp.needsPartialAccumulator();
6566
}
66-
if (isa<triton::nvidia_gpu::TCGen5MMAOp>(op)) {
67+
if (isa<triton::nvidia_gpu::MMAv5OpInterface>(op)) {
6768
return true;
6869
}
6970
return false;
@@ -76,8 +77,8 @@ std::pair<Value, Operation *> getAccumulatorUseAndDef(Operation *op) {
7677
if (auto wgDotOp = dyn_cast<triton::nvidia_gpu::WarpGroupDotOp>(op)) {
7778
return std::make_pair(wgDotOp.getC(), wgDotOp);
7879
}
79-
if (auto tc05MmaOp = dyn_cast<triton::nvidia_gpu::TCGen5MMAOp>(op)) {
80-
auto accVal = tc05MmaOp.getD();
80+
if (auto tc05MmaOp = dyn_cast<triton::nvidia_gpu::MMAv5OpInterface>(op)) {
81+
auto accVal = tc05MmaOp.getAccumulator();
8182
auto tmemAlloc = accVal.getDefiningOp<triton::nvidia_gpu::TMEMAllocOp>();
8283
if (!tmemAlloc ||
8384
tmemAlloc->getParentRegion() != tc05MmaOp->getParentRegion())
@@ -104,8 +105,9 @@ void setUseAccFlag(Operation *op, Value useAcc) {
104105

105106
if (auto wgDotOp = dyn_cast<triton::nvidia_gpu::WarpGroupDotOp>(op)) {
106107
wgDotOp.getUseCMutable().assign(useAcc);
107-
} else if (auto tc05MmaOp = dyn_cast<triton::nvidia_gpu::TCGen5MMAOp>(op)) {
108-
tc05MmaOp.getUseDMutable().assign(useAcc);
108+
} else if (auto tc05MmaOp =
109+
dyn_cast<triton::nvidia_gpu::MMAv5OpInterface>(op)) {
110+
tc05MmaOp.setUseAccumulator(useAcc);
109111
} else {
110112
assert(false && "Unexpected op which implements a DotOpInterface");
111113
}

lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
1111
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
1212
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
13+
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
1314
#include <memory>
1415

1516
namespace mlir::triton::gpu {
@@ -90,8 +91,7 @@ class FuseTransMMAV3Plus : public OpRewritePattern<LocalAllocOp> {
9091
PatternRewriter &rewriter) const override {
9192
if (!allocOp.getSrc() || !allocOp->hasOneUse() ||
9293
!isa<triton::nvidia_gpu::WarpGroupDotOp,
93-
triton::nvidia_gpu::TCGen5MMAOp,
94-
triton::nvidia_gpu::TCGen5MMAScaledOp>(
94+
triton::nvidia_gpu::MMAv5OpInterface>(
9595
*allocOp->getUsers().begin()))
9696
return failure();
9797

lib/Dialect/TritonGPU/Transforms/Pipeliner/PipeliningUtility.cpp

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
88
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
99
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
10+
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
1011
#include "llvm/Support/Casting.h"
1112

1213
using namespace mlir;
@@ -97,18 +98,11 @@ Operation *mlir::triton::predicateOp(RewriterBase &rewriter, Operation *op,
9798
expectOp.getPredMutable().assign(mask);
9899
return op;
99100
}
100-
if (auto mmav5Op = dyn_cast<ttng::TCGen5MMAOp>(op)) {
101+
if (auto mmav5Op = dyn_cast<ttng::MMAv5OpInterface>(op)) {
101102
rewriter.setInsertionPoint(mmav5Op);
102-
Value mask = getPredMask(rewriter, mmav5Op.getPred().getType(),
103-
mmav5Op.getPred(), pred);
104-
mmav5Op.getPredMutable().assign(mask);
105-
return op;
106-
}
107-
if (auto mmav5Op = dyn_cast<ttng::TCGen5MMAScaledOp>(op)) {
108-
rewriter.setInsertionPoint(mmav5Op);
109-
Value mask = getPredMask(rewriter, mmav5Op.getPred().getType(),
110-
mmav5Op.getPred(), pred);
111-
mmav5Op.getPredMutable().assign(mask);
103+
auto currPred = mmav5Op.getPredicate();
104+
Value mask = getPredMask(rewriter, currPred.getType(), currPred, pred);
105+
mmav5Op.setPredicate(mask);
112106
return op;
113107
}
114108
if (auto tmemStoreOp = dyn_cast<ttng::TMEMStoreOp>(op)) {

lib/Dialect/TritonGPU/Transforms/Pipeliner/TC05MMAPipeline.cpp

Lines changed: 15 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ struct MMAInfo {
6262
// accumulator for the given MMA operation. The TMEMAllocOp and TMEMLoadOp must
6363
// be in the same region as the MMA operation.
6464
std::optional<std::pair<ttng::TMEMAllocOp, ttng::TMEMLoadOp>>
65-
getTMemAllocAndLoad(Operation *mmaOp) {
65+
getTMemAllocAndLoad(ttng::MMAv5OpInterface mmaOp) {
6666
auto acc = mmaOp->getOperand(2).getDefiningOp<ttng::TMEMAllocOp>();
6767
if (!acc || acc->getParentRegion() != mmaOp->getParentRegion()) {
6868
return std::nullopt;
@@ -230,20 +230,16 @@ getAccUseFlagFalseInLoop(scf::ForOp forOp, Value useAccFlagUse) {
230230
}
231231

232232
std::optional<MMAInfo::AccOverridePoint>
233-
getAccOverrideOrFlagFalseInLoop(scf::ForOp forOp, Operation *mmaOp) {
233+
getAccOverrideOrFlagFalseInLoop(scf::ForOp forOp,
234+
ttng::MMAv5OpInterface mmaOp) {
234235
auto tmemAllocAndLoad = getTMemAllocAndLoad(mmaOp);
235236
assert(tmemAllocAndLoad.has_value() && "Expected tmem alloc and load");
236237
auto [accAlloc, accLoad] = tmemAllocAndLoad.value();
237238
auto accOverridePoint = getAccOverridePointInLoop(forOp, accAlloc, accLoad);
238239

239240
if (!accOverridePoint.has_value()) {
240-
if (auto op = dyn_cast<ttng::TCGen5MMAOp>(mmaOp)) {
241-
auto useAccFlag = op.getUseD();
242-
accOverridePoint = getAccUseFlagFalseInLoop(forOp, useAccFlag);
243-
} else if (auto op = dyn_cast<ttng::TCGen5MMAScaledOp>(mmaOp)) {
244-
auto useAccFlag = op.getUseD();
245-
accOverridePoint = getAccUseFlagFalseInLoop(forOp, useAccFlag);
246-
}
241+
auto useAccFlag = mmaOp.useAccumulator();
242+
accOverridePoint = getAccUseFlagFalseInLoop(forOp, useAccFlag);
247243
}
248244

249245
return accOverridePoint;
@@ -281,7 +277,7 @@ Value createSingleBufferView(IRRewriter &builder, Value alloc, int idx) {
281277
builder.create<arith::ConstantIntOp>(alloc.getLoc(), idx, 32));
282278
}
283279

284-
Value createBarrierAlloc(scf::ForOp forOp, Operation *mmaOp, int numStages) {
280+
Value createBarrierAlloc(scf::ForOp forOp, int numStages) {
285281
IRRewriter rewriter(forOp->getContext());
286282
rewriter.setInsertionPoint(forOp);
287283
MLIRContext *ctx = forOp.getContext();
@@ -490,7 +486,8 @@ void updateAccDefsInLoop(IRRewriter &builder, scf::ForOp forOp, MMAInfo &info,
490486
// hoisted tmem allocs. Also, update the acc loads and stores to use the new
491487
// tmem allocs.
492488
void hoistAndUseTMemAlloc(IRRewriter &builder, scf::ForOp forOp,
493-
Operation *mmaOp, MMAInfo &info, int numStages) {
489+
ttng::MMAv5OpInterface mmaOp, MMAInfo &info,
490+
int numStages) {
494491
builder.setInsertionPoint(forOp);
495492
Value zero = builder.create<arith::ConstantIntOp>(forOp.getLoc(), 0, 32);
496493
Value one = builder.create<arith::ConstantIntOp>(forOp.getLoc(), 1, 32);
@@ -515,11 +512,7 @@ void hoistAndUseTMemAlloc(IRRewriter &builder, scf::ForOp forOp,
515512
createSingleBufferView(builder, insertSlice, info.accInsertIdx);
516513
}
517514

518-
if (auto op = dyn_cast<ttng::TCGen5MMAOp>(mmaOp)) {
519-
op.getDMutable().assign(insertSlice);
520-
} else if (auto op = dyn_cast<ttng::TCGen5MMAScaledOp>(mmaOp)) {
521-
op.getDMutable().assign(insertSlice);
522-
}
515+
mmaOp.setAccumulator(insertSlice);
523516

524517
updateAccUsesInLoop(builder, forOp, info, newAlloc, numStages);
525518
assert(isa<BlockArgument>(info.accExtractIdx));
@@ -545,26 +538,22 @@ void hoistAndUseTMemAlloc(IRRewriter &builder, scf::ForOp forOp,
545538

546539
// Create multi-buffered barrier allocs and lower the MMA to MMA + wait barrier
547540
void createBarrierAndWaitOps(IRRewriter &builder, scf::ForOp forOp,
548-
Operation *mmaOp, MMAInfo &info, int numStages) {
541+
ttng::MMAv5OpInterface mmaOp, MMAInfo &info,
542+
int numStages) {
549543
builder.setInsertionPoint(forOp);
550544
Value zero = builder.create<arith::ConstantIntOp>(forOp.getLoc(), 0, 32);
551545
Value one = builder.create<arith::ConstantIntOp>(forOp.getLoc(), 1, 32);
552546
Value numStagesVal =
553547
builder.create<arith::ConstantIntOp>(forOp.getLoc(), numStages, 32);
554548

555-
info.barrierAlloc = createBarrierAlloc(forOp, mmaOp, numStages);
549+
info.barrierAlloc = createBarrierAlloc(forOp, numStages);
556550

557551
Location loc = mmaOp->getLoc();
558552
builder.setInsertionPoint(mmaOp);
559553

560554
Value barrierSlice =
561555
createSingleBufferView(builder, info.barrierAlloc, info.barrierIdx);
562-
563-
if (auto op = dyn_cast<ttng::TCGen5MMAOp>(mmaOp)) {
564-
op.getBarrierMutable().assign(barrierSlice);
565-
} else if (auto op = dyn_cast<ttng::TCGen5MMAScaledOp>(mmaOp)) {
566-
op.getBarrierMutable().assign(barrierSlice);
567-
}
556+
mmaOp.setBarrier(barrierSlice);
568557

569558
builder.setInsertionPointAfter(mmaOp);
570559
auto waitOp =
@@ -653,10 +642,11 @@ FailureOr<scf::ForOp> preProcessLoopForTC05MMAPipelining(scf::ForOp forOp,
653642
}
654643

655644
IRRewriter builder(forOp->getContext());
656-
for (auto mmaOp : mmaOps) {
645+
for (auto op : mmaOps) {
657646
// Avoid pipelining if in the backward slice of the mmaOp there is an
658647
// operation that is already assigned a stage, as it would make the pipeline
659648
// deeper than we are prepared for.
649+
auto mmaOp = cast<ttng::MMAv5OpInterface>(op);
660650
SetVector<Operation *> backwardSlice;
661651
BackwardSliceOptions opt;
662652
opt.omitBlockArguments = true;

lib/Dialect/TritonNvidiaGPU/IR/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ add_triton_library(TritonNvidiaGPUIR
55
DEPENDS
66
TritonNvidiaGPUTableGen
77
TritonNvidiaGPUAttrDefsIncGen
8+
TritonNvidiaGPUOpInterfacesIncGen
89

910
LINK_LIBS PUBLIC
1011
TritonIR

lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525
#include "mlir/Support/LLVM.h"
2626
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
2727

28+
#include "triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOpInterfaces.cpp.inc"
29+
2830
#define GET_OP_CLASSES
2931
#include "triton/Dialect/TritonNvidiaGPU/IR/Ops.cpp.inc"
3032

@@ -273,6 +275,24 @@ bool TCGen5MMAOp::verifyDims() {
273275
return aShape[aShape.size() - 1] == bShape[aShape.size() - 2];
274276
}
275277

278+
Value TCGen5MMAOp::useAccumulator() { return getUseD(); }
279+
280+
void TCGen5MMAOp::setUseAccumulator(Value flag) {
281+
getUseDMutable().assign(flag);
282+
}
283+
284+
void TCGen5MMAOp::setBarrier(Value barrier) {
285+
getBarrierMutable().assign(barrier);
286+
}
287+
288+
Value TCGen5MMAOp::getAccumulator() { return getD(); }
289+
290+
void TCGen5MMAOp::setAccumulator(Value accum) { getDMutable().assign(accum); }
291+
292+
Value TCGen5MMAOp::getPredicate() { return getPred(); }
293+
294+
void TCGen5MMAOp::setPredicate(Value pred) { getPredMutable().assign(pred); }
295+
276296
// -- TMEMStoreOp --
277297
LogicalResult TMEMStoreOp::verify() {
278298
if (!isa<triton::nvidia_gpu::TensorMemorySpaceAttr>(
@@ -317,6 +337,28 @@ bool TCGen5MMAScaledOp::verifyDims() {
317337
return aKdim == bKdim;
318338
}
319339

340+
Value TCGen5MMAScaledOp::useAccumulator() { return getUseD(); }
341+
342+
void TCGen5MMAScaledOp::setUseAccumulator(Value flag) {
343+
getUseDMutable().assign(flag);
344+
}
345+
346+
void TCGen5MMAScaledOp::setBarrier(Value barrier) {
347+
getBarrierMutable().assign(barrier);
348+
}
349+
350+
Value TCGen5MMAScaledOp::getAccumulator() { return getD(); }
351+
352+
void TCGen5MMAScaledOp::setAccumulator(Value accum) {
353+
getDMutable().assign(accum);
354+
}
355+
356+
Value TCGen5MMAScaledOp::getPredicate() { return getPred(); }
357+
358+
void TCGen5MMAScaledOp::setPredicate(Value pred) {
359+
getPredMutable().assign(pred);
360+
}
361+
320362
// -- TMEMLoadOp --
321363
LogicalResult TMEMLoadOp::verify() {
322364
if (!isa<triton::nvidia_gpu::TensorMemorySpaceAttr>(

0 commit comments

Comments
 (0)