Skip to content

Commit 4499262

Browse files
Jokerenjataylo
authored andcommitted
[BACKEND] Replace isMmaToDotShortcut with linear layout based logic (#4951)
This PR removes the legacy `isMmaToDotShortcut` and its associated shortcut conversion. (cherry picked from commit 1d5fdfe)
1 parent 7c0257d commit 4499262

File tree

15 files changed

+242
-200
lines changed

15 files changed

+242
-200
lines changed

include/triton/Analysis/Utility.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -216,8 +216,6 @@ bool isBlockedToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy);
216216

217217
bool isMfmaToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy);
218218

219-
bool isMmaToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy);
220-
221219
// Return true if the src and dst layout match.
222220
bool matchMmaV3AndDotOperandLayout(RankedTensorType srcTy,
223221
RankedTensorType dstTy);

include/triton/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVMBase.h

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,6 @@ namespace gpu {
1818
SmallVector<Value> reorderValues(const SmallVector<Value> &values, Type inType,
1919
Type ouType);
2020

21-
SmallVector<Value> unpackI32(const SmallVector<Value> &inValues, Type srcTy,
22-
ConversionPatternRewriter &rewriter, Location loc,
23-
const LLVMTypeConverter *typeConverter);
24-
25-
SmallVector<Value> packI32(const SmallVector<Value> &inValues, Type srcTy,
26-
ConversionPatternRewriter &rewriter, Location loc,
27-
const LLVMTypeConverter *typeConverter);
28-
2921
Type getElementType(Value value);
3022

3123
class MultipleOperandsRange
@@ -187,8 +179,8 @@ class ElementwiseOpConversionBase : public ConvertOpToLLVMPattern<SourceOp> {
187179
for (auto operand : adaptor.getOperands()) {
188180
auto argTy = op->getOperand(0).getType();
189181
auto subOperands = unpackLLElements(loc, operand, rewriter);
190-
subOperands = unpackI32(subOperands, argTy, rewriter, loc,
191-
this->getTypeConverter());
182+
subOperands = unpackI32s(subOperands, argTy, rewriter, loc,
183+
this->getTypeConverter());
192184
allOperands.resize(subOperands.size());
193185
for (auto v : llvm::enumerate(subOperands))
194186
allOperands[v.index()].push_back(v.value());
@@ -215,7 +207,7 @@ class ElementwiseOpConversionBase : public ConvertOpToLLVMPattern<SourceOp> {
215207
}
216208
resultVals = maybeDeduplicate(op, resultVals);
217209
resultVals =
218-
packI32(resultVals, resultTy, rewriter, loc, this->getTypeConverter());
210+
packI32s(resultVals, resultTy, rewriter, loc, this->getTypeConverter());
219211
Value view = packLLElements(loc, this->getTypeConverter(), resultVals,
220212
rewriter, resultTy);
221213
rewriter.replaceOp(op, view);

include/triton/Conversion/TritonGPUToLLVM/Utility.h

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1388,6 +1388,67 @@ inline Value getStructFromSharedMemoryObject(Location loc,
13881388
return llvmStruct;
13891389
}
13901390

1391+
// For some reasons, LLVM's NVPTX backend inserts unnecessary (?) integer
1392+
// instructions to pack & unpack sub-word integers. A workaround is to
1393+
// store the results of tensors with dot operand encodings in i32 to
1394+
// facilitate instructions such as `ldmatrix`.
1395+
//
1396+
// TODO: Confirm if the problem is still there.
1397+
inline bool requiresI32Conversion(Type type) {
1398+
auto tensorTy = dyn_cast<RankedTensorType>(type);
1399+
if (!tensorTy)
1400+
return false;
1401+
auto dotOpEnc = dyn_cast<DotOperandEncodingAttr>(tensorTy.getEncoding());
1402+
if (!dotOpEnc)
1403+
return false;
1404+
auto parent = dyn_cast<NvidiaMmaEncodingAttr>(dotOpEnc.getParent());
1405+
if (!(parent && parent.getVersionMajor() < 3))
1406+
return false;
1407+
return true;
1408+
}
1409+
1410+
inline SmallVector<Value> packI32s(const SmallVector<Value> &inValues,
1411+
Type type, RewriterBase &rewriter,
1412+
Location loc,
1413+
const LLVMTypeConverter *typeConverter) {
1414+
if (!requiresI32Conversion(type))
1415+
return inValues;
1416+
Type eltTy =
1417+
typeConverter->convertType(cast<RankedTensorType>(type).getElementType());
1418+
1419+
SmallVector<Value> outValues;
1420+
int vecWidth = 32 / eltTy.getIntOrFloatBitWidth();
1421+
auto vecTy = vec_ty(eltTy, vecWidth);
1422+
for (int i = 0; i < inValues.size(); i += vecWidth) {
1423+
Value vec = undef(vecTy);
1424+
for (int j = 0; j < vecWidth; j++) {
1425+
vec = insert_element(vec, inValues[i + j], i32_val(j));
1426+
}
1427+
outValues.push_back(bitcast(vec, i32_ty));
1428+
}
1429+
return outValues;
1430+
}
1431+
1432+
inline SmallVector<Value> unpackI32s(const SmallVector<Value> &inValues,
1433+
Type type, RewriterBase &rewriter,
1434+
Location loc,
1435+
const LLVMTypeConverter *typeConverter) {
1436+
if (!requiresI32Conversion(type))
1437+
return inValues;
1438+
Type eltTy =
1439+
typeConverter->convertType(cast<RankedTensorType>(type).getElementType());
1440+
1441+
SmallVector<Value> outValues;
1442+
for (auto v : inValues) {
1443+
auto vecTy = vec_ty(eltTy, 32 / eltTy.getIntOrFloatBitWidth());
1444+
auto vec = bitcast(v, vecTy);
1445+
for (int i = 0; i < 32 / eltTy.getIntOrFloatBitWidth(); i++) {
1446+
outValues.push_back(extract_element(vec, i32_val(i)));
1447+
}
1448+
}
1449+
return outValues;
1450+
}
1451+
13911452
inline SmallVector<Value> unpackLLElements(Location loc, Value llvmStruct,
13921453
RewriterBase &rewriter) {
13931454
assert(bool(llvmStruct) && "can not unpack null values");

lib/Analysis/Utility.cpp

Lines changed: 4 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -731,14 +731,14 @@ bool cvtNeedsWarpShuffle(RankedTensorType srcTy, RankedTensorType dstTy) {
731731
}
732732

733733
bool cvtNeedsSharedMemory(RankedTensorType srcTy, RankedTensorType dstTy) {
734-
// TODO(jlebar): Remove these special cases (`isMmaToDotShortcut`,
735-
// `isBlockedToDotShortcut` and `isMfmaToDotShortcut`) once they're fully
736-
// subsumed by the linear-layout checks.
734+
// TODO(jlebar): Remove these special cases (`isBlockedToDotShortcut` and
735+
// `isMfmaToDotShortcut`) once they're fully subsumed by the linear-layout
736+
// checks.
737737
// TODO(Keren): We didn't check `cvtNeedsWarpShuffle` here because it's not
738738
// supported yet in Triton's backend.
739739
return !cvtReordersRegisters(srcTy, dstTy) &&
740740
!isBlockedToDotShortcut(srcTy, dstTy) &&
741-
!isMmaToDotShortcut(srcTy, dstTy) &&
741+
!matchMmaV3AndDotOperandLayout(srcTy, dstTy) &&
742742
!isMfmaToDotShortcut(srcTy, dstTy);
743743
}
744744

@@ -749,20 +749,6 @@ bool atomicNeedsSharedMemory(Value value) {
749749
return true;
750750
}
751751

752-
bool isMmaToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy) {
753-
if (matchMmaV3AndDotOperandLayout(srcTy, dstTy))
754-
return true;
755-
// dot_op<opIdx=0, parent=#mma> = #mma
756-
// when #mma = MmaEncoding<version=2, warpsPerCTA=[..., 1]>
757-
auto mmaLayout = dyn_cast<NvidiaMmaEncodingAttr>(srcTy.getEncoding());
758-
auto dotOperandLayout = dyn_cast<DotOperandEncodingAttr>(dstTy.getEncoding());
759-
return mmaLayout && dotOperandLayout && mmaLayout.getVersionMajor() == 2 &&
760-
mmaLayout.getWarpsPerCTA()[1] == 1 &&
761-
dotOperandLayout.getOpIdx() == 0 &&
762-
dotOperandLayout.getParent() == mmaLayout &&
763-
!srcTy.getElementType().isF32();
764-
}
765-
766752
namespace {
767753

768754
/// A data structure similar to SetVector but maintains

lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp

Lines changed: 22 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -328,7 +328,20 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
328328
} else {
329329
// Cast 5. The two layouts are equivalent. We should probably remove
330330
// these in RemoveLayoutConversion.
331-
rewriter.replaceOp(op, adaptor.getSrc());
331+
auto dstCvt = requiresI32Conversion(dstTy);
332+
auto srcCvt = requiresI32Conversion(srcTy);
333+
if (dstCvt || srcCvt) {
334+
auto inVals = unpackLLElements(op.getLoc(), adaptor.getSrc(), rewriter);
335+
inVals = unpackI32s(inVals, srcTy, rewriter, op.getLoc(),
336+
getTypeConverter());
337+
inVals =
338+
packI32s(inVals, dstTy, rewriter, op.getLoc(), getTypeConverter());
339+
auto res = packLLElements(op.getLoc(), getTypeConverter(), inVals,
340+
rewriter, op.getType());
341+
rewriter.replaceOp(op, res);
342+
} else {
343+
rewriter.replaceOp(op, adaptor.getSrc());
344+
}
332345
return success();
333346
}
334347
}
@@ -342,9 +355,12 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
342355
StringAttr kRegister = str_attr("register");
343356
assert(!cvtNeedsSharedMemory(op.getSrc().getType(), op.getType()));
344357

358+
auto srcTy = op.getSrc().getType();
359+
auto dstTy = op.getType();
345360
auto inVals = unpackLLElements(loc, adaptor.getSrc(), rewriter);
361+
inVals = unpackI32s(inVals, srcTy, rewriter, loc, getTypeConverter());
346362
SmallVector<Value> outVals(numRegs);
347-
for (int i = 0; i < outVals.size(); i++) {
363+
for (int i = 0; i < numRegs; i++) {
348364
// Remove free masks from the register index
349365
// For example, if idx = 0b00111, and masks = 0b00100, then we get
350366
// 0b00011. It means that register 7 (0b111) has the same value as
@@ -355,6 +371,7 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
355371
: idx;
356372
outVals[i] = inVals[srcIdx];
357373
}
374+
outVals = packI32s(outVals, dstTy, rewriter, loc, getTypeConverter());
358375
Value result = packLLElements(loc, getTypeConverter(), outVals, rewriter,
359376
op.getType());
360377
rewriter.replaceOp(op, result);
@@ -386,9 +403,6 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
386403
if (auto dotOperand = dyn_cast<DotOperandEncodingAttr>(layout)) {
387404
if (auto nvidiaMma =
388405
dyn_cast<NvidiaMmaEncodingAttr>(dotOperand.getParent())) {
389-
if (product(getCTAsPerCGA(nvidiaMma)) > 1) {
390-
return false;
391-
}
392406
if (useLegacyMMAConversion) {
393407
return false;
394408
}
@@ -398,6 +412,7 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
398412
dotOperand.getKWidth() * dstTy.getElementTypeBitWidth() > 64;
399413
return largeKWidth && nvidiaMma.isAmpere();
400414
}
415+
return false;
401416
}
402417
if (isa<BlockedEncodingAttr>(layout)) {
403418
return true;
@@ -439,6 +454,7 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
439454
inVals[it.index()] = ptrtoint(llvmElemTy, it.value());
440455
}
441456
}
457+
inVals = unpackI32s(inVals, srcTy, rewriter, loc, getTypeConverter());
442458

443459
// Pretty sure this is the identity function ATM
444460
// It'd be better to simply call `quotient({kBlock})` and
@@ -458,22 +474,7 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
458474
}
459475
}
460476

461-
// FIXME [Dot LL]
462-
// We know it's just for largeKWidth case in Ampere
463-
// In this case, we need to pack the outputs into i32
464-
if (isa<DotOperandEncodingAttr>(dstTy.getEncoding())) {
465-
auto concat = [&](Value a, Value b) {
466-
return or_(zext(i32_ty, bitcast(a, i16_ty)),
467-
shl(zext(i32_ty, bitcast(b, i16_ty)), i32_val(16)));
468-
};
469-
470-
SmallVector<Value> outVals32(outVals.size() / 2);
471-
for (int i = 0; i < outVals32.size(); ++i) {
472-
outVals32[i] = concat(outVals[2 * i], outVals[2 * i + 1]);
473-
}
474-
outVals = outVals32;
475-
}
476-
477+
outVals = packI32s(outVals, dstTy, rewriter, loc, getTypeConverter());
477478
Value result = packLLElements(loc, getTypeConverter(), outVals, rewriter,
478479
op.getType());
479480
rewriter.replaceOp(op, result);

lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp

Lines changed: 6 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -103,51 +103,6 @@ SmallVector<Value> reorderValues(const SmallVector<Value> &values, Type inType,
103103
llvm_unreachable("unimplemented code path");
104104
}
105105

106-
SmallVector<Value> unpackI32(const SmallVector<Value> &inValues, Type srcTy,
107-
ConversionPatternRewriter &rewriter, Location loc,
108-
const LLVMTypeConverter *typeConverter) {
109-
auto tensorTy = dyn_cast<RankedTensorType>(srcTy);
110-
if (!tensorTy)
111-
return inValues;
112-
auto encoding = dyn_cast<DotOperandEncodingAttr>(tensorTy.getEncoding());
113-
if (!(encoding && isa<NvidiaMmaEncodingAttr>(encoding.getParent())))
114-
return inValues;
115-
SmallVector<Value> outValues;
116-
for (auto v : inValues) {
117-
// cast i32 to appropriate eltType vector and extract elements
118-
auto eltType = typeConverter->convertType(tensorTy.getElementType());
119-
auto vecType = vec_ty(eltType, 32 / eltType.getIntOrFloatBitWidth());
120-
auto vec = bitcast(v, vecType);
121-
for (int i = 0; i < 32 / eltType.getIntOrFloatBitWidth(); i++) {
122-
outValues.push_back(extract_element(vec, i32_val(i)));
123-
}
124-
}
125-
return outValues;
126-
}
127-
128-
SmallVector<Value> packI32(const SmallVector<Value> &inValues, Type srcTy,
129-
ConversionPatternRewriter &rewriter, Location loc,
130-
const LLVMTypeConverter *typeConverter) {
131-
auto tensorTy = dyn_cast<RankedTensorType>(srcTy);
132-
if (!tensorTy)
133-
return inValues;
134-
auto encoding = dyn_cast<DotOperandEncodingAttr>(tensorTy.getEncoding());
135-
if (!(encoding && isa<NvidiaMmaEncodingAttr>(encoding.getParent())))
136-
return inValues;
137-
SmallVector<Value> outValues;
138-
auto eltType = typeConverter->convertType(tensorTy.getElementType());
139-
int vecWidth = 32 / eltType.getIntOrFloatBitWidth();
140-
auto vecType = vec_ty(eltType, vecWidth);
141-
for (int i = 0; i < inValues.size(); i += vecWidth) {
142-
Value vec = undef(vecType);
143-
for (int j = 0; j < vecWidth; j++) {
144-
vec = insert_element(vec, inValues[i + j], i32_val(j));
145-
}
146-
outValues.push_back(bitcast(vec, i32_ty));
147-
}
148-
return outValues;
149-
}
150-
151106
int getNumElementsPerThreads(Type type,
152107
const LLVMTypeConverter *typeConverter) {
153108
int numElemsPerThread = 1;
@@ -500,7 +455,7 @@ struct ElementwiseInlineAsmOpConversion
500455
auto argTy = op->getOperand(0).getType();
501456
auto subOperands = unpackLLElements(loc, operand, rewriter);
502457
unpackedOperands.push_back(
503-
unpackI32(subOperands, argTy, rewriter, loc, getTypeConverter()));
458+
unpackI32s(subOperands, argTy, rewriter, loc, getTypeConverter()));
504459
}
505460

506461
int numElemsPerThread = getNumElementsPerThreads(op->getResult(0).getType(),
@@ -560,10 +515,11 @@ struct ElementwiseInlineAsmOpConversion
560515
unpackedResults[i], /*inType=*/op->getOperand(0).getType(),
561516
/*ouType=*/op->getResult(i).getType());
562517
}
563-
auto packed = packI32(unpackedResults[i], op->getResult(i).getType(),
564-
rewriter, loc, getTypeConverter());
565-
outs.push_back(packLLElements(loc, getTypeConverter(), packed, rewriter,
566-
op->getResult(i).getType()));
518+
auto dstTy = op->getResult(i).getType();
519+
unpackedResults[i] = packI32s(unpackedResults[i], dstTy, rewriter, loc,
520+
getTypeConverter());
521+
outs.push_back(packLLElements(loc, getTypeConverter(), unpackedResults[i],
522+
rewriter, op->getResult(i).getType()));
567523
}
568524

569525
rewriter.replaceOp(op, outs);

lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp

Lines changed: 1 addition & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -184,42 +184,7 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern<LocalLoadOp> {
184184
SmallVector<Value> outVals = loadSharedToDistributed(
185185
dstTy, srcTy, elemLlvmTy, smemObj, loc, rewriter, targetInfo);
186186

187-
// FIXME [Dot LL]
188-
// Ampere case
189-
// In this case, we need to pack the outputs into i32
190-
if (auto dotOp = dyn_cast<DotOperandEncodingAttr>(dstTy.getEncoding())) {
191-
if (auto parent = dyn_cast<NvidiaMmaEncodingAttr>(dotOp.getParent())) {
192-
if (parent.isAmpere()) {
193-
if (elemLlvmTy.isInteger(8)) {
194-
auto concat = [&](Value a1, Value a2, Value a3, Value a4) {
195-
return or_(
196-
or_(zext(i32_ty, a1), shl(zext(i32_ty, a2), i32_val(8))),
197-
or_(shl(zext(i32_ty, a3), i32_val(16)),
198-
shl(zext(i32_ty, a4), i32_val(24))));
199-
};
200-
SmallVector<Value> outVals32(outVals.size() / 4);
201-
for (int i = 0; i < outVals32.size(); ++i) {
202-
outVals32[i] = concat(outVals[4 * i], outVals[4 * i + 1],
203-
outVals[4 * i + 2], outVals[4 * i + 3]);
204-
}
205-
outVals = outVals32;
206-
} else {
207-
assert(elemLlvmTy.isBF16() && "Unexpected element type");
208-
auto concat = [&](Value a, Value b) {
209-
return or_(zext(i32_ty, bitcast(a, i16_ty)),
210-
shl(zext(i32_ty, bitcast(b, i16_ty)), i32_val(16)));
211-
};
212-
213-
SmallVector<Value> outVals32(outVals.size() / 2);
214-
for (int i = 0; i < outVals32.size(); ++i) {
215-
outVals32[i] = concat(outVals[2 * i], outVals[2 * i + 1]);
216-
}
217-
outVals = outVals32;
218-
}
219-
}
220-
}
221-
}
222-
187+
outVals = packI32s(outVals, dstTy, rewriter, loc, typeConverter);
223188
Value result = packLLElements(loc, typeConverter, outVals, rewriter, dstTy);
224189
rewriter.replaceOp(op, result);
225190

0 commit comments

Comments
 (0)