Skip to content

Commit 5049304

Browse files
authored
[AMD] Add shared memory encoding to avoid transpose bank conflict (#5797)
This PR introduces a new swizzling pattern for AMD backend to reduce bank conflicts in cases where shared memory writes and reads are performed on layouts with different order. It's meant for hardware without native shared memory tranpose support.
1 parent f3bd7f7 commit 5049304

File tree

6 files changed

+314
-4
lines changed

6 files changed

+314
-4
lines changed

include/triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ enum class ScaleDotElemType : uint32_t;
1515
namespace mlir::triton::gpu {
1616
class SwizzledSharedEncodingAttr;
1717
class NVMMASharedEncodingAttr;
18+
class AMDRotatingSharedEncodingAttr;
1819
class AMDMfmaEncodingAttr;
1920

2021
// - BlockedEncodingAttrs have the following input dimensions.

include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -446,6 +446,110 @@ def NVMMASharedEncodingAttr :
446446
let hasCustomAssemblyFormat = 1;
447447
}
448448

449+
def AMDRotatingSharedEncodingAttr :
450+
TritonGPU_Attr<"AMDRotatingSharedEncoding", "amd_rotating_shared_encoding",
451+
[SharedEncodingTrait, LayoutEncodingTrait]> {
452+
let mnemonic = "amd_rotating_shared";
453+
454+
let description = [{
455+
This shared encoding is similar to SwizzledSharedEncodingAttr, but instead of
456+
repeating swizzling pattern every `maxPhase*perPhase` rows of the memory object,
457+
called a block, this layout changes swizzling pattern `maxPhase` times, then
458+
repeats the pattern. The name "rotating" comes from the fact that first tensor
459+
element of each block is swizzled with different phase, which is equal to
460+
current block number: 0, 1, 2.. maxPhase-1, 0, 1, 2 ...
461+
462+
This layout is used to reduce bank conflicts in cases where shared memory writes
463+
and reads are performed on layouts with different order. It's meant for hardware
464+
without native shared memory tranpose support.
465+
466+
Swizzling pattern affects only 2 fastest dimensions of a tensor.
467+
In the following text these two dimensions are called row and column:
468+
- row is a fastest dimension
469+
- column is a second fastest dimension
470+
471+
Elements in a row dimension are stored in memory contiguously.
472+
473+
If a matrix of size [128x64] is stored in this shared layout with order [1, 0],
474+
dim 1 (64) will be stored contiguously and called row, dim 0 (128) is will be
475+
called column. If order of shared layout is [0, 1], dim 0 (128) is stored
476+
contiguously becomes a row, dim 1 (64) becomes a column.
477+
478+
Swizzling pattern is following:
479+
480+
Let's consider an element with logical coordinates = (inRowId, inColId).
481+
For simplicity, we do not vectorize memory in examples,
482+
i.e. vec == 1 and layout swizzles inidividual elements.
483+
For vec != 1 example, take a look at SwizzledSharedEncodingAttr documentation.
484+
485+
Swizzled coordinates within memory object are (outRowId, outColId):
486+
487+
outRowId = inRowId
488+
phase = (inRowId / perPhase) % maxPhase
489+
blockNo = (inRowId / (perPhase * maxPhase)) % maxPhase
490+
combinedPhase = phase ^ blockNo
491+
outColId = inColId ^ combinedPhase
492+
493+
Actual offset in memory could be computed with following function:
494+
495+
memmory_offset = (outColId + outRowId * num_of_element_in_row) * sizeof(element)
496+
497+
498+
Swizzling examples (matrix is filled with numbers 0, 1, 2, .. columns*rows-1):
499+
500+
#shared<{vec=1, perPhase=1, maxPhase=2, order=[1,0]}>
501+
row elements
502+
0 [ 0, 1, 2, 3], // phase = 0 blockNo = 0 (xor with 0)
503+
1 [ 5, 4, 7, 6], // phase = 1 blockNo = 0 (xor with 1)
504+
2 [ 9, 8, 11, 10], // phase = 0 blockNo = 1 (xor with 1)
505+
3 [12, 13, 14, 15] // phase = 1 blockNo = 1 (xor with 0)
506+
4 [16, 17, 18, 19], // phase = 0 blockNo = 0 (xor with 0)
507+
5 [21, 20, 23, 22], // phase = 1 blockNo = 0 (xor with 1)
508+
6 [25, 24, 27, 26], // phase = 0 blockNo = 1 (xor with 1)
509+
7 [28, 29, 30, 31] // phase = 1 blockNo = 1 (xor with 0)
510+
511+
#shared<{vec=1, perPhase=2, maxPhase=2, order=[1,0]}>
512+
row elements
513+
0 [ 0, 1, 2, 3], // phase = 0 blockNo = 0 (xor with 0)
514+
1 [ 4, 5, 6, 7], // phase = 0 blockNo = 0 (xor with 0)
515+
2 [ 9, 8, 11, 10], // phase = 1 blockNo = 0 (xor with 1)
516+
3 [13, 12, 15, 14] // phase = 1 blockNo = 0 (xor with 1)
517+
4 [17, 16, 19, 18], // phase = 0 blockNo = 1 (xor with 1)
518+
5 [21, 20, 23, 22], // phase = 0 blockNo = 1 (xor with 1)
519+
6 [24, 25, 26, 27], // phase = 1 blockNo = 1 (xor with 0)
520+
7 [28, 29, 30, 31] // phase = 1 blockNo = 1 (xor with 0)
521+
522+
#shared<{vec=1, perPhase=1, maxPhase=4, order=[1,0]}>
523+
row elements
524+
0 [ 0, 1, 2, 3], // phase = 0 blockNo = 0 (xor with 0)
525+
1 [ 5, 4, 7, 6], // phase = 1 blockNo = 0 (xor with 1)
526+
2 [10, 11, 8, 9], // phase = 2 blockNo = 0 (xor with 2)
527+
3 [15, 14, 13, 12] // phase = 3 blockNo = 0 (xor with 3)
528+
4 [17, 16, 19, 18], // phase = 0 blockNo = 1 (xor with 1)
529+
5 [20, 21, 22, 23], // phase = 1 blockNo = 1 (xor with 0)
530+
6 [27, 26, 25, 24], // phase = 2 blockNo = 1 (xor with 3)
531+
7 [30, 31, 28, 29] // phase = 3 blockNo = 1 (xor with 2)
532+
}];
533+
534+
let parameters = (
535+
ins
536+
"unsigned":$vec,
537+
"unsigned":$perPhase,
538+
"unsigned":$maxPhase,
539+
ArrayRefParameter<"unsigned">:$order,
540+
"CTALayoutAttr":$CTALayout
541+
);
542+
543+
let extraClassDeclaration = extraBaseClassDeclaration # [{
544+
int32_t getAlignment() const;
545+
SmallVector<unsigned> getCTAsPerCGA() const;
546+
SmallVector<unsigned> getCTAOrder() const;
547+
SmallVector<unsigned> getCTASplitNum() const;
548+
}];
549+
let hasCustomAssemblyFormat = 1;
550+
}
551+
552+
449553
//===----------------------------------------------------------------------===//
450554
// Distributed Layout Encoding
451555
//===----------------------------------------------------------------------===//

lib/Dialect/TritonGPU/IR/Dialect.cpp

Lines changed: 48 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,10 @@ SmallVector<unsigned> getOrder(SharedEncodingTrait layout,
201201
if (auto sharedLayout = mlir::dyn_cast<NVMMASharedEncodingAttr>(layout)) {
202202
return sharedLayout.getOrder();
203203
}
204+
if (auto sharedLayout =
205+
mlir::dyn_cast<AMDRotatingSharedEncodingAttr>(layout)) {
206+
return llvm::to_vector(sharedLayout.getOrder());
207+
}
204208
llvm::report_fatal_error("Unimplemented usage of getOrder for MemDescType");
205209
return {};
206210
}
@@ -765,6 +769,18 @@ SmallVector<unsigned> NVMMASharedEncodingAttr::getCTASplitNum() const {
765769
return SmallVector<unsigned>(getCTALayout().getCTASplitNum());
766770
}
767771

772+
int32_t AMDRotatingSharedEncodingAttr::getAlignment() const { return 16; }
773+
774+
SmallVector<unsigned> AMDRotatingSharedEncodingAttr::getCTAsPerCGA() const {
775+
return SmallVector<unsigned>(getCTALayout().getCTAsPerCGA());
776+
}
777+
SmallVector<unsigned> AMDRotatingSharedEncodingAttr::getCTAOrder() const {
778+
return SmallVector<unsigned>(getCTALayout().getCTAOrder());
779+
}
780+
SmallVector<unsigned> AMDRotatingSharedEncodingAttr::getCTASplitNum() const {
781+
return SmallVector<unsigned>(getCTALayout().getCTASplitNum());
782+
}
783+
768784
SmallVector<unsigned> DotOperandEncodingAttr::getCTAsPerCGA() const {
769785
return ::getCTAsPerCGA(getParent());
770786
}
@@ -1637,10 +1653,11 @@ void SliceEncodingAttr::print(mlir::AsmPrinter &printer) const {
16371653
}
16381654

16391655
//===----------------------------------------------------------------------===//
1640-
// SwizzledShared encoding
1656+
// Helper shared encoding functions
16411657
//===----------------------------------------------------------------------===//
16421658

1643-
Attribute SwizzledSharedEncodingAttr::parse(AsmParser &parser, Type type) {
1659+
template <typename SpecificEncoding>
1660+
Attribute parseSwizzledEncoding(AsmParser &parser, Type type) {
16441661
if (parser.parseLess().failed())
16451662
return {};
16461663
// Parse the data as a dictionary
@@ -1694,8 +1711,16 @@ Attribute SwizzledSharedEncodingAttr::parse(AsmParser &parser, Type type) {
16941711
if (!CTALayout.has_value())
16951712
return {};
16961713

1697-
return parser.getChecked<SwizzledSharedEncodingAttr>(
1698-
parser.getContext(), vec, perPhase, maxPhase, order, *CTALayout);
1714+
return parser.getChecked<SpecificEncoding>(parser.getContext(), vec, perPhase,
1715+
maxPhase, order, *CTALayout);
1716+
}
1717+
1718+
//===----------------------------------------------------------------------===//
1719+
// SwizzledShared encoding
1720+
//===----------------------------------------------------------------------===//
1721+
1722+
Attribute SwizzledSharedEncodingAttr::parse(AsmParser &parser, Type type) {
1723+
return parseSwizzledEncoding<SwizzledSharedEncodingAttr>(parser, type);
16991724
}
17001725

17011726
void SwizzledSharedEncodingAttr::print(AsmPrinter &printer) const {
@@ -1787,6 +1812,25 @@ void NVMMASharedEncodingAttr::print(AsmPrinter &printer) const {
17871812
printer << "}>";
17881813
}
17891814

1815+
//===----------------------------------------------------------------------===//
1816+
// SwizzledBlocksShared encoding
1817+
//===----------------------------------------------------------------------===//
1818+
1819+
Attribute AMDRotatingSharedEncodingAttr::parse(AsmParser &parser, Type type) {
1820+
return parseSwizzledEncoding<AMDRotatingSharedEncodingAttr>(parser, type);
1821+
}
1822+
1823+
void AMDRotatingSharedEncodingAttr::print(AsmPrinter &printer) const {
1824+
printer << "<{"
1825+
<< "vec = " << getVec() //
1826+
<< ", perPhase = " << getPerPhase()
1827+
<< ", maxPhase = " << getMaxPhase() //
1828+
<< ", order = [" << getOrder() << "]";
1829+
maybePrintCTALayout(getContext(), printer, getCTALayout(),
1830+
/*rank=*/getOrder().size());
1831+
printer << "}>";
1832+
}
1833+
17901834
//===----------------------------------------------------------------------===//
17911835
// Mfma encoding
17921836
//===----------------------------------------------------------------------===//

lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,57 @@ sharedToLinearLayoutNoLeadingOffset(ArrayRef<int64_t> shape,
168168
return combineCtaCgaWithShape(ctaLayout, shared.getCTALayout(), shape);
169169
}
170170

171+
LinearLayout
172+
sharedToLinearLayoutAMDRotating(ArrayRef<int64_t> shape,
173+
AMDRotatingSharedEncodingAttr shared) {
174+
MLIRContext *ctx = shared.getContext();
175+
int rank = shape.size();
176+
if (rank == 1) {
177+
return combineCtaCgaWithShape(
178+
LinearLayout::identity1D(shape[0], S("offset"), S("dim0")),
179+
shared.getCTALayout(), shape);
180+
}
181+
182+
auto outDimNames = standardOutDimNames(ctx, rank);
183+
184+
// Construct bases for the 2 most minor dimensions of the layout. These are
185+
// the dims that get swizzled.
186+
assert(shape.size() >= 2);
187+
int colDim = shared.getOrder()[0];
188+
int rowDim = shared.getOrder()[1];
189+
int numCols = shape[colDim];
190+
int numRows = shape[rowDim];
191+
StringAttr colDimName = outDimNames[colDim];
192+
StringAttr rowDimName = outDimNames[rowDim];
193+
194+
std::vector<std::vector<int>> bases2D;
195+
for (int logCol = 0; logCol < llvm::Log2_32(numCols); logCol++) {
196+
bases2D.push_back({0, 1 << logCol});
197+
}
198+
for (int logRow = 0; logRow < llvm::Log2_32(numRows); logRow++) {
199+
int row = 1 << logRow;
200+
int vec = shared.getVec();
201+
int perPhase = shared.getPerPhase();
202+
int maxPhase = shared.getMaxPhase();
203+
204+
int phase = (row / perPhase) % maxPhase;
205+
int blockNo = row / maxPhase / perPhase % maxPhase;
206+
int combinedPhase = phase ^ blockNo;
207+
bases2D.push_back({row, (vec * combinedPhase) % numCols});
208+
}
209+
LinearLayout ctaLayout =
210+
LinearLayout({{S("offset"), bases2D}}, {rowDimName, colDimName});
211+
212+
// Add the remaining dimensions.
213+
for (int i = 2; i < rank; i++) {
214+
int dim = shared.getOrder()[i];
215+
ctaLayout *=
216+
LinearLayout::identity1D(shape[dim], S("offset"), outDimNames[dim]);
217+
}
218+
219+
return combineCtaCgaWithShape(ctaLayout, shared.getCTALayout(), shape);
220+
}
221+
171222
} // namespace
172223

173224
LinearLayout sharedToLinearLayoutLeadingOffset(ArrayRef<int64_t> shape,
@@ -1041,6 +1092,8 @@ LinearLayout TritonGPUDialect::toLinearLayout(ArrayRef<int64_t> shape,
10411092
result = sharedToLinearLayoutNoLeadingOffset(shape, shared);
10421093
} else if (auto shared = dyn_cast<NVMMASharedEncodingAttr>(layout)) {
10431094
result = sharedToLinearLayoutLeadingOffset(shape, shared);
1095+
} else if (auto sbl = dyn_cast<AMDRotatingSharedEncodingAttr>(layout)) {
1096+
result = sharedToLinearLayoutAMDRotating(shape, sbl);
10441097
} else {
10451098
assert(0 && "unknown layout");
10461099
}

test/Conversion/amd/tritongpu_to_llvm.mlir

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -341,3 +341,21 @@ module attributes {"ttg.target" = "hip:gfx942", "ttg.num-ctas" = 1 : i32, "ttg.n
341341
tt.return
342342
}
343343
}
344+
345+
// -----
346+
347+
// CHECK-LABEL: amd_rotating_shared_layout
348+
#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [8, 8], warpsPerCTA = [2, 2], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
349+
#shared = #ttg.amd_rotating_shared<{vec = 1, perPhase = 1, maxPhase = 4, order = [1, 0]}>
350+
#smem = #ttg.shared_memory
351+
module attributes {"ttg.target" = "hip:gfx942", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} {
352+
tt.func @amd_rotating_shared_layout(%arg0: tensor<64x64xf16, #blocked>) {
353+
// CHECK-COUNT-16: llvm.store {{.*}} : vector<1xf16>, !llvm.ptr<3>
354+
%0 = ttg.local_alloc %arg0 : (tensor<64x64xf16, #blocked>) -> !ttg.memdesc<64x64xf16, #shared, #smem, mutable>
355+
// CHECK-COUNT-16: llvm.load {{.*}} : !llvm.ptr<3> -> vector<1xf16>
356+
%1 = ttg.local_load %0 : !ttg.memdesc<64x64xf16, #shared, #smem, mutable> -> tensor<64x64xf16, #blocked>
357+
// CHECK-COUNT-16: llvm.store {{.*}} : vector<1xf16>, !llvm.ptr<3>
358+
ttg.local_store %1, %0 : tensor<64x64xf16, #blocked> -> !ttg.memdesc<64x64xf16, #shared, #smem, mutable>
359+
tt.return
360+
}
361+
}

0 commit comments

Comments
 (0)