Skip to content

Commit 6d2ca1c

Browse files
binarmangiuseros
andauthored
[AMD] WMMA dot operand conversion to Linear Layout (#5299)
This PR implements conversion of WMMA dot operand layout to linear layout and adds related tests. --------- Co-authored-by: Giuseppe Rossini <[email protected]>
1 parent b5af392 commit 6d2ca1c

File tree

6 files changed

+536
-65
lines changed

6 files changed

+536
-65
lines changed

lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -369,8 +369,7 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
369369
return !useLegacyMMAConversion;
370370
}
371371
if (auto dotOperand = dyn_cast<DotOperandEncodingAttr>(layout)) {
372-
if (isa<NvidiaMmaEncodingAttr, AMDMfmaEncodingAttr>(
373-
dotOperand.getParent())) {
372+
if (isa<MmaEncodingTrait>(dotOperand.getParent())) {
374373
return !useLegacyMMAConversion;
375374
}
376375
return false;

lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern<LocalLoadOp> {
159159
srcTy.getShape()[1] >= 4 * kWidth & dstTy.getRank() <= 2;
160160
return !canUseLdmatrix;
161161
}
162-
if (isa<AMDMfmaEncodingAttr>(dot.getParent()))
162+
if (isa<AMDMfmaEncodingAttr, AMDWmmaEncodingAttr>(dot.getParent()))
163163
return true;
164164
}
165165
return false;

lib/Dialect/TritonGPU/IR/Dialect.cpp

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1120,11 +1120,9 @@ SmallVector<unsigned> DotOperandEncodingAttr::getWarpsPerCTA() const {
11201120
}
11211121
SmallVector<unsigned> DotOperandEncodingAttr::getWarpOrder() const {
11221122
// FIXME(Lezcano): Preexisting. Do we want to have this path at all?
1123-
if (mlir::isa<AMDMfmaEncodingAttr>(getParent())) {
1123+
if (mlir::isa<AMDMfmaEncodingAttr, AMDWmmaEncodingAttr>(getParent())) {
11241124
return ::getWarpOrder(getParent());
11251125
}
1126-
// It's quite weird to talk about warp order when that the warps
1127-
// are broadcasted along the K dimension
11281126
llvm::report_fatal_error("DotOperandEncoding::getWarpOrder not implemented");
11291127
return {};
11301128
}
@@ -1160,9 +1158,9 @@ LogicalResult DotOperandEncodingAttr::verify(
11601158

11611159
if (auto parentAttr = mlir::dyn_cast<AMDWmmaEncodingAttr>(parent)) {
11621160
if (kWidth != 16 && parentAttr.getVersion() == 1 ||
1163-
kWidth != 8 && parentAttr.getVersion() == 2)
1161+
kWidth != 8 && kWidth != 16 && parentAttr.getVersion() == 2)
11641162
return emitError() << "ttg.dot_op kWidth parameter must be 16 for "
1165-
"gfx11 and 8 for gfx12";
1163+
"gfx11 and 8/16 for gfx12";
11661164
return success();
11671165
}
11681166

lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp

Lines changed: 127 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,42 @@ LinearLayout sharedToLinearLayoutLeadingOffset(ArrayRef<int64_t> shape,
239239
return combineCtaCgaWithShape(tileLayout, shared.getCTALayout(), shape);
240240
}
241241

242+
LinearLayout warpsDotOperand(MLIRContext *ctx, ArrayRef<unsigned> warpShape,
243+
ArrayRef<unsigned> warpOrder, unsigned inner) {
244+
// Let warpsPerCTAMma = {2, 2}, then
245+
// warpsPerCTA = {2, 1} for opA and warpsPerCTA = {1, 2} for opB
246+
// assume warpOrder = {1, 0}
247+
// Assume that C is tiled by 2x2 tiles. Since warpOrder={1, 0}, we have that
248+
// the C is owned as per the following layout:
249+
// C: 0 | 1
250+
// - | -
251+
// 2 | 3
252+
// In order to be able to compute C, we need the following warp tiling of
253+
// A and B:
254+
// A: 0 1 | 0 1 B: 0 2 | 1 3
255+
// - - | - - - - | - -
256+
// 2 3 | 2 3 0 2 | 1 3
257+
// In other words, we need to broadcast along K
258+
auto rank = warpShape.size();
259+
auto dimNames = standardOutDimNames(ctx, rank);
260+
LinearLayout warpLayout = LinearLayout::empty();
261+
262+
// We have to broadcast along the inner dimension
263+
// For A, when moving along M we go from 0 to 2.
264+
// For B, when moving along N we go from 0 to 1.
265+
// As such, choosing the order of A {1, 0}, gives us the correct broadcasting
266+
// Same happens if the warpOrder is {0, 1}, like in Hopper
267+
for (auto d : warpOrder) {
268+
if (d == inner) {
269+
warpLayout *= LinearLayout::zeros1D(warpShape[d], S("warp"), dimNames[d]);
270+
} else {
271+
warpLayout *=
272+
LinearLayout::identity1D(warpShape[d], S("warp"), dimNames[d]);
273+
}
274+
}
275+
return warpLayout;
276+
}
277+
242278
} // anonymous namespace
243279

244280
std::optional<LinearLayout>
@@ -470,7 +506,9 @@ AMDWmmaEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
470506

471507
// We use the order from fastest varying to slowest varying. So each base
472508
// vector is a tuple of values mapping to matrix C's (N, M[, B]) indices.
473-
SmallVector<unsigned> order = triton::gpu::getOrder(*this);
509+
SmallVector<unsigned> threadOrder = getThreadOrder();
510+
assert(threadOrder[0] == mIndex || threadOrder[0] == nIndex);
511+
assert(threadOrder[1] == mIndex || threadOrder[1] == nIndex);
474512

475513
// For wmma with 16x16 output, each of the 32 threads holds 8 elements.
476514
//
@@ -498,29 +536,106 @@ AMDWmmaEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
498536
? LinearLayout(
499537
{{kRegister, {/*gap*/ {0, 2}, {0, 4}, {0, 8}}},
500538
{kLane, {{1, 0}, {2, 0}, {4, 0}, {8, 0}, /*gap*/ {0, 1}}}},
501-
{outDimNames[order[0]], outDimNames[order[1]]})
539+
{outDimNames[threadOrder[0]], outDimNames[threadOrder[1]]})
502540
: LinearLayout(
503541
{{kRegister, {{0, 1}, {0, 2}, {0, 4}}},
504542
{kLane, {{1, 0}, {2, 0}, {4, 0}, {8, 0}, /*gap*/ {0, 8}}}},
505-
{outDimNames[order[0]], outDimNames[order[1]]});
543+
{outDimNames[threadOrder[0]], outDimNames[threadOrder[1]]});
506544

507545
if (hasBatchDim) {
508-
assert(order[2] == 0);
546+
int batchIndex = 0;
509547
// Extend the base vector with one value to accomodate for the batch
510548
// dimension, which appears at the last.
511-
tileLayout *= LinearLayout::identity1D(1, kRegister, outDimNames[order[2]]);
512-
tileLayout *= LinearLayout::identity1D(1, kLane, outDimNames[order[2]]);
549+
tileLayout *=
550+
LinearLayout::identity1D(1, kRegister, outDimNames[batchIndex]);
551+
tileLayout *= LinearLayout::identity1D(1, kLane, outDimNames[batchIndex]);
513552
}
514553

515554
// And each warp takes the same register and lane sub-layout. So mulitply with
516555
// an identity layout for the warp.
556+
auto warpOrder = getWarpOrder();
517557
LinearLayout warpLayout =
518-
identityStandardND(S("warp"), getWarpsPerCTA(), order);
519-
LinearLayout ctaLayout = tileLayout * warpLayout;
558+
identityStandardND(S("warp"), getWarpsPerCTA(), warpOrder);
559+
// reorder dim names in rep order, so combineCtaCgaWithShape generate proper
560+
// extension of layout
561+
auto repOrder = getRepOrder();
562+
SmallVector<StringAttr> repDimNames;
563+
for (auto dim : repOrder)
564+
repDimNames.push_back(outDimNames[dim]);
565+
LinearLayout ctaLayout = tileLayout.transposeOuts(repDimNames) *
566+
warpLayout.transposeOuts(repDimNames);
520567

521568
return combineCtaCgaWithShape(ctaLayout, getCTALayout(), shape);
522569
}
523570

571+
std::optional<LinearLayout>
572+
wmmaDotOperandToLinearLayout(DotOperandEncodingAttr dotWmmaLayout,
573+
ArrayRef<int64_t> shape) {
574+
auto wmmaLayout = llvm::cast<AMDWmmaEncodingAttr>(dotWmmaLayout.getParent());
575+
auto rank = shape.size();
576+
bool hasBatchDim = rank == 3;
577+
auto kDim = dotWmmaLayout.getOpIdx() == 0 ? rank - 1 : rank - 2;
578+
int32_t kSize = shape[kDim];
579+
MLIRContext *ctx = dotWmmaLayout.getContext();
580+
SmallVector<StringAttr> outDimNames = standardOutDimNames(ctx, rank);
581+
StringAttr kRegister = S("register");
582+
StringAttr kLane = S("lane");
583+
StringAttr kWarp = S("warp");
584+
// lane order
585+
// operand A: [1, 0] / [2, 1, 0]
586+
// operand B: [0, 1] / [1, 2, 0]
587+
// for both cases it is [k, nonk]/[k, nonk, batch]
588+
SmallVector<unsigned> laneOrder = triton::gpu::getOrder(dotWmmaLayout);
589+
// generate continuous part of register bases(i.e. kWidth)
590+
std::vector<std::vector<int32_t>> registerBase;
591+
const int32_t kWidth = dotWmmaLayout.getKWidth();
592+
for (int i = 1; i < kWidth; i *= 2)
593+
registerBase.push_back(std::vector<int32_t>{i, 0});
594+
std::vector<std::vector<int32_t>> laneBase = {{0, 1}, {0, 2}, {0, 4}, {0, 8}};
595+
switch (wmmaLayout.getVersion()) {
596+
case 1:
597+
// WMMA version 1 duplicates values in lanes 0-15 and 16-31
598+
laneBase.push_back({0, 0});
599+
break;
600+
case 2:
601+
// WMMA version 2 offset values in lanes 0-15 and 16-31 across k dimensions
602+
laneBase.push_back({kWidth, 0});
603+
break;
604+
default:
605+
assert(false && "unexpected version");
606+
}
607+
// Generate layout for one wmma instruction
608+
LinearLayout tileLayout(
609+
{{kRegister, registerBase}, {kLane, laneBase}},
610+
{outDimNames[laneOrder[0]], outDimNames[laneOrder[1]]});
611+
if (hasBatchDim) {
612+
assert(laneOrder[2] == 0);
613+
// Extend the base vector with one value to accomodate for the batch
614+
// dimension, which appears at the last.
615+
tileLayout *=
616+
LinearLayout::identity1D(1, kRegister, outDimNames[laneOrder[2]]);
617+
tileLayout *= LinearLayout::identity1D(1, kLane, outDimNames[laneOrder[2]]);
618+
}
619+
620+
// Generate warp layout
621+
auto warpsPerCTA = wmmaLayout.getWarpsPerCTA();
622+
auto warpOrder = triton::gpu::getWarpOrder(dotWmmaLayout);
623+
LinearLayout warpLayout = warpsDotOperand(ctx, warpsPerCTA, warpOrder, kDim);
624+
625+
// reorder dim names in rep order, so combineCtaCgaWithShape generate proper
626+
// extension of layout
627+
auto repOrder = wmmaLayout.getRepOrderForOperand(dotWmmaLayout.getOpIdx());
628+
SmallVector<StringAttr> repDimNames;
629+
for (auto dim : repOrder)
630+
repDimNames.push_back(outDimNames[dim]);
631+
632+
// join instruction layout and warps using repetition order of dimensions
633+
LinearLayout ctaLayout = tileLayout.transposeOuts(repDimNames) *
634+
warpLayout.transposeOuts(repDimNames);
635+
636+
return combineCtaCgaWithShape(ctaLayout, wmmaLayout.getCTALayout(), shape);
637+
}
638+
524639
std::optional<LinearLayout>
525640
BlockedEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
526641
assert(shape.size() == getOrder().size());
@@ -604,44 +719,6 @@ NvidiaMmaEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
604719
return combineCtaCgaWithShape(ctaLayout, getCTALayout(), shape);
605720
}
606721

607-
LinearLayout warpsNvidiaDot(MLIRContext *ctx, ArrayRef<unsigned> mmaWarpShape,
608-
ArrayRef<unsigned> mmaWarpOrder, bool isA) {
609-
// Let warpsPerCTAMma = {2, 2}, then
610-
// warpsPerCTA = {2, 1} for opA and warpsPerCTA = {1, 2} for opB
611-
// assume warpOrder = {1, 0}
612-
// Assume that C is tiled by 2x2 tiles. Since warpOrder={1, 0}, we have that
613-
// the C is owned as per the following layout:
614-
// C: 0 | 1
615-
// - | -
616-
// 2 | 3
617-
// In order to be able to compute C, we need the following warp tiling of
618-
// A and B:
619-
// A: 0 1 | 0 1 B: 0 2 | 1 3
620-
// - - | - - - - | - -
621-
// 2 3 | 2 3 0 2 | 1 3
622-
// In other words, we need to broadcast along K
623-
auto rank = mmaWarpOrder.size();
624-
auto inner = isA ? rank - 1 : rank - 2;
625-
auto dimNames = standardOutDimNames(ctx, rank);
626-
LinearLayout warpLayout = LinearLayout::empty();
627-
628-
// We have to broadcast along the inner dimension
629-
// For A, when moving along M we go from 0 to 2.
630-
// For B, when moving along N we go from 0 to 1.
631-
// As such, choosing the order of A {1, 0}, gives us the correct broadcasting
632-
// Same happens if the mmaWarpOrder is {0, 1}, like in Hopper
633-
for (auto d : mmaWarpOrder) {
634-
if (d == inner) {
635-
warpLayout *=
636-
LinearLayout::zeros1D(mmaWarpShape[d], S("warp"), dimNames[d]);
637-
} else {
638-
warpLayout *=
639-
LinearLayout::identity1D(mmaWarpShape[d], S("warp"), dimNames[d]);
640-
}
641-
}
642-
return warpLayout;
643-
}
644-
645722
LinearLayout nvidiaDotToLinearLayout(ArrayRef<int64_t> shape,
646723
DotOperandEncodingAttr dot) {
647724
int rank = shape.size();
@@ -662,8 +739,9 @@ LinearLayout nvidiaDotToLinearLayout(ArrayRef<int64_t> shape,
662739
}
663740
auto ctaLayout =
664741
nvidiaMmaTile(ctx, tileShape, kWidth, getOrder(dot), dot.getRepOrder());
742+
auto kDim = isA ? rank - 1 : rank - 2;
665743
ctaLayout *=
666-
warpsNvidiaDot(ctx, mma.getWarpsPerCTA(), mma.getWarpOrder(), isA)
744+
warpsDotOperand(ctx, mma.getWarpsPerCTA(), mma.getWarpOrder(), kDim)
667745
.transposeOuts(llvm::to_vector(ctaLayout.getOutDimNames()));
668746

669747
return combineCtaCgaWithShape(ctaLayout, getCTALayout(dot), shape);
@@ -674,6 +752,8 @@ DotOperandEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
674752
auto parent = getParent();
675753
if (auto mfmaLayout = llvm::dyn_cast<AMDMfmaEncodingAttr>(parent)) {
676754
return mfmaDotToLinearLayout(*this, shape);
755+
} else if (auto wmmaLayout = llvm::dyn_cast<AMDWmmaEncodingAttr>(parent)) {
756+
return wmmaDotOperandToLinearLayout(*this, shape);
677757
} else if (auto mma = mlir::dyn_cast<NvidiaMmaEncodingAttr>(parent)) {
678758
return nvidiaDotToLinearLayout(shape, *this);
679759
}

test/TritonGPU/invalid-attributes.mlir

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -42,23 +42,23 @@
4242

4343
// -----
4444

45-
// expected-error@+2 {{ttg.dot_op kWidth parameter must be 16 for gfx11 and 8 for gfx12}}
45+
// expected-error@+2 {{ttg.dot_op kWidth parameter must be 16 for gfx11 and 8/16 for gfx12}}
4646
#wmma = #ttg.amd_wmma<{version = 1, warpsPerCTA = [1, 4]}>
4747
#dot_op = #ttg.dot_op<{opIdx = 1, parent = #wmma}>
4848

4949
// -----
5050

51-
// expected-error@+2 {{ttg.dot_op kWidth parameter must be 16 for gfx11 and 8 for gfx12}}
51+
// expected-error@+2 {{ttg.dot_op kWidth parameter must be 16 for gfx11 and 8/16 for gfx12}}
5252
#wmma = #ttg.amd_wmma<{version = 1, warpsPerCTA = [1, 4]}>
5353
#dot_op = #ttg.dot_op<{opIdx = 1, parent = #wmma, kWidth = 8}>
5454

5555
// -----
56-
// expected-error@+2 {{ttg.dot_op kWidth parameter must be 16 for gfx11 and 8 for gfx12}}
56+
// expected-error@+2 {{ttg.dot_op kWidth parameter must be 16 for gfx11 and 8/16 for gfx12}}
5757
#wmma = #ttg.amd_wmma<{version = 2, warpsPerCTA = [1, 4]}>
58-
#dot_op = #ttg.dot_op<{opIdx = 1, parent = #wmma, kWidth = 16}>
58+
#dot_op = #ttg.dot_op<{opIdx = 1, parent = #wmma, kWidth = 32}>
5959

6060
// -----
61-
// expected-error@+2 {{ttg.dot_op kWidth parameter must be 16 for gfx11 and 8 for gfx12}}
61+
// expected-error@+2 {{ttg.dot_op kWidth parameter must be 16 for gfx11 and 8/16 for gfx12}}
6262
#wmma = #ttg.amd_wmma<{version = 2, warpsPerCTA = [1, 4]}>
6363
#dot_op = #ttg.dot_op<{opIdx = 1, parent = #wmma, kWidth = 4}>
6464

0 commit comments

Comments
 (0)