Skip to content

Commit 7afcbd1

Browse files
committed
Update
1 parent be81f0a commit 7afcbd1

File tree

1 file changed

+16
-13
lines changed

1 file changed

+16
-13
lines changed

lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -138,18 +138,21 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern<LocalLoadOp> {
138138

139139
// FIXME [Dot LL]
140140
// Do for all DotOperandEncodingAttr once we have LLs for all of them
141-
static bool isSupportedDotOpLayout(RankedTensorType type) {
142-
auto layout = type.getEncoding();
143-
auto bitwidth = type.getElementType().getIntOrFloatBitWidth();
144-
if (auto dot = dyn_cast<DotOperandEncodingAttr>(layout)) {
141+
static bool isSupportedDotOpLayout(RankedTensorType srcTy,
142+
RankedTensorType dstTy) {
143+
auto srcLayout = cast<SharedEncodingAttr>(srcTy.getEncoding());
144+
auto dstLayout = dstTy.getEncoding();
145+
auto bitwidth = dstTy.getElementType().getIntOrFloatBitWidth();
146+
auto rank = dstTy.getRank();
147+
if (auto dot = dyn_cast<DotOperandEncodingAttr>(dstLayout)) {
148+
auto vecWidth = 32 / bitwidth;
145149
auto kWidth = dot.getKWidth();
146-
// Use when the SharedToDotOperandMMAv2OrV3 is known to be buggy:
147-
// - kWidth == 8
148-
// - kWidth == 4, bitwidth = 32
150+
auto kOrder = dot.getOpIdx() == 0 ? rank - 1 : rank - 2;
149151
if (auto mma = dyn_cast<NvidiaMmaEncodingAttr>(dot.getParent())) {
150-
bool legacyLoweringIsBuggy =
151-
kWidth >= 8 || (kWidth == 4 && bitwidth == 32);
152-
return legacyLoweringIsBuggy && mma.isAmpere();
152+
auto needTrans = kOrder != srcLayout.getOrder()[0];
153+
auto canUseLdmatrix =
154+
(bitwidth == 16 || (!needTrans)) && (kWidth == vecWidth);
155+
return !canUseLdmatrix && mma.isAmpere();
153156
}
154157
if (isa<AMDMfmaEncodingAttr>(dot.getParent()))
155158
return true;
@@ -164,10 +167,10 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern<LocalLoadOp> {
164167
RankedTensorType dstTy = op.getType();
165168
Attribute srcLayout = srcTy.getEncoding();
166169
Attribute dstLayout = dstTy.getEncoding();
167-
if (isa<SharedEncodingAttr>(srcLayout) &&
168-
(isa<BlockedEncodingAttr, MmaEncodingTrait, SliceEncodingAttr>(
170+
assert(isa<SharedEncodingAttr>(srcLayout) && "Unexpected src layout");
171+
if ((isa<BlockedEncodingAttr, MmaEncodingTrait, SliceEncodingAttr>(
169172
dstLayout) ||
170-
isSupportedDotOpLayout(dstTy))) {
173+
isSupportedDotOpLayout(srcTy, dstTy))) {
171174
return lowerSharedToDistributed(op, adaptor, getTypeConverter(),
172175
rewriter);
173176
}

0 commit comments

Comments
 (0)