Skip to content

Commit fb56ae3

Browse files
authored
Revert "[LAYOUTS] Generalise HoistLayoutConversion to work with arbit… (#5776)
This reverts PR #5673 This broke the tests on A100, even though CI was green. The CI issue will be resolved by #5775
1 parent 0ffc67d commit fb56ae3

File tree

8 files changed

+421
-438
lines changed

8 files changed

+421
-438
lines changed

include/triton/Dialect/Triton/IR/TritonOps.td

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -830,8 +830,7 @@ def TT_MakeRangeOp : TT_Op<"make_range", [Pure]> {
830830
def TT_ElementwiseInlineAsmOp : TT_Op<"elementwise_inline_asm", [
831831
Elementwise,
832832
SameOperandsAndResultEncoding,
833-
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
834-
DeclareOpInterfaceMethods<ConditionallySpeculatable>
833+
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>
835834
]> {
836835
let summary = "inline assembly applying an elementwise operation to a group of packed elements.";
837836
let description = [{

include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,7 @@ def TTG_MemDescSubviewOp : TTG_Op<"memdesc_subview", [Pure]> {
225225

226226
def TTG_MemDescTransOp : TTG_Op<"memdesc_trans", [Pure,
227227
TransposeOpInterface,
228-
InferTypeOpWithLayoutEquivalence,
228+
DeclareOpInterfaceMethods<InferTypeOpInterface>,
229229
SameOperandsAndResultElementType]> {
230230
let summary = "transpose the descriptor";
231231

lib/Dialect/Triton/IR/Ops.cpp

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1037,12 +1037,6 @@ void ElementwiseInlineAsmOp::getEffects(
10371037
SideEffects::DefaultResource::get());
10381038
}
10391039

1040-
Speculation::Speculatability ElementwiseInlineAsmOp::getSpeculatability() {
1041-
if (getPure())
1042-
return Speculation::Speculatable;
1043-
return Speculation::NotSpeculatable;
1044-
}
1045-
10461040
LogicalResult ElementwiseInlineAsmOp::verify() {
10471041
if (getNumOperands() >= 1) {
10481042
auto tensorType = dyn_cast<RankedTensorType>(getOperand(0).getType());

lib/Dialect/TritonGPU/IR/Ops.cpp

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -463,17 +463,15 @@ OpFoldResult MemDescTransOp::fold(FoldAdaptor adaptor) {
463463
return {};
464464
}
465465

466-
LogicalResult
467-
MemDescTransOp::inferReturnTypes(MLIRContext *context,
468-
std::optional<Location> location,
469-
MemDescTransOp::Adaptor adaptor,
470-
SmallVectorImpl<Type> &inferredReturnTypes) {
471-
466+
LogicalResult MemDescTransOp::inferReturnTypes(
467+
MLIRContext *context, std::optional<Location> location, ValueRange operands,
468+
DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
469+
SmallVectorImpl<Type> &inferredReturnTypes) {
472470
// type is the same as the input
473-
auto argTy = cast<MemDescType>(adaptor.getSrc().getType());
474-
auto shape = argTy.getShape();
475-
auto order = adaptor.getOrder();
476-
SmallVector<int64_t> retShape = applyPermutation(shape, order);
471+
auto argTy = cast<MemDescType>(operands[0].getType());
472+
auto argShape = argTy.getShape();
473+
auto order = properties.as<Properties *>()->order.asArrayRef();
474+
SmallVector<int64_t> retShape = applyPermutation(argTy.getShape(), order);
477475

478476
auto retEltTy = argTy.getElementType();
479477
Attribute argEncoding = argTy.getEncoding();
@@ -482,17 +480,17 @@ MemDescTransOp::inferReturnTypes(MLIRContext *context,
482480
Dialect &dialect = argEncoding.getDialect();
483481
auto inferLayoutInterface = cast<DialectInferLayoutInterface>(&dialect);
484482
if (inferLayoutInterface
485-
->inferTransOpEncoding(argEncoding, shape, order, retEncoding)
483+
->inferTransOpEncoding(argEncoding, argShape, order, retEncoding)
486484
.failed()) {
487485
return failure();
488486
}
489487
}
490-
inferredReturnTypes.push_back(
491-
MemDescType::get(retShape, retEltTy, retEncoding, argTy.getMemorySpace(),
492-
argTy.getMutableMemory()));
488+
auto memDescTy = cast<MemDescType>(argTy);
489+
inferredReturnTypes.push_back(MemDescType::get(
490+
retShape, retEltTy, retEncoding, memDescTy.getMemorySpace(),
491+
memDescTy.getMutableMemory()));
493492
return success();
494493
}
495-
496494
// LocalAllocOp
497495
void LocalAllocOp::getEffects(
498496
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>

lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp

Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,36 @@ namespace {
2424
// Roughly, whether op is elementwise and thus threads don't need
2525
// to exchange elements. But some ops are not currently supported even though
2626
// they meet that criterion.
27+
bool canHoistDotOpEncV2(Operation *op, DotOperandEncodingAttr &dotOpEnc) {
28+
// Only consider custom conversions or arith ops.
29+
// TODO(jlebar): Is this too restrictive?
30+
if (!isa<FpToFpOp, BitcastOp>(op) && !isPureUnaryInlineAsm(op) &&
31+
!isa<arith::ArithDialect>(op->getDialect()))
32+
return false;
33+
34+
// Quick handling to fix loading issues when computing the original
35+
// bitwidth is unable to realize that there is a mixed-precision dot
36+
// (hence kWidth = 1) but wants to hoist through the type conversion.
37+
if (isa<arith::ExtFOp>(op) && dotOpEnc.getKWidth() == 1)
38+
return false;
39+
40+
// Currently, these instructions are not supported during lowering of
41+
// shared -> dot_operand layout. Not all types and type conversions are
42+
// supported.
43+
if (isa<arith::TruncIOp, arith::TruncFOp, arith::SelectOp>(op))
44+
return false;
45+
46+
// Don't hoist through u1 -> fp casts as they aren't supported in
47+
// ElementwiseOpToLLVM::reorderValues().
48+
if (isa<arith::UIToFPOp>(op)) {
49+
Type opType = getElementTypeOrSelf(op->getOperand(0));
50+
if (opType.isInteger(1))
51+
return false;
52+
}
53+
54+
return true;
55+
}
56+
2757
// Analog of canHoistDotOpEncV2, but for MMAv3 (WGMMA where operand A
2858
// is in registers).
2959
bool canHoistDotOpEncV3(Operation *op) {
@@ -165,6 +195,116 @@ class SwizzleShmemConvert : public OpRewritePattern<ConvertLayoutOp> {
165195
}
166196
};
167197

198+
// Move convert-to-dot-operand "up" past elementwise ops:
199+
//
200+
// convert(elementwise(x)) #dot_operand ->
201+
// elementwise(convert(x, #dot_operand)).
202+
//
203+
// The goal is to put the convert right next to the originating load. If we can
204+
// accomplish this, then we can save a shmem round-trip:
205+
//
206+
// Before:
207+
//
208+
// - Load from global into shmem using an async copy.
209+
// - Load from shmem into a #blocked layout.
210+
// - Do elementwise ops over #blocked layout.
211+
// - Convert to #dot_operand (round-trip through shmem).
212+
// - Do dot.
213+
//
214+
// After:
215+
//
216+
// - Load from global into shmem using an async copy (same as before).
217+
// - Load from shmem into a #dot_operand layout.
218+
// - Do elementwise ops over #dot_operand layout.
219+
// - Do dot.
220+
//
221+
// This can also be propagated when we have a constant, instead of a load.
222+
//
223+
// Eliminating the shmem round-trip is such a big win, we're willing to do it
224+
// even if this duplicates work because some of the elementwise ops have uses
225+
// that don't flow into the dot. On the other hand, we only want to do this if
226+
// we can in fact reduce shmem round-trips: For example, simply moving a convert
227+
// up above e.g. an `add` now means we have *two* converts. That's worse,
228+
// unless we can continue moving the converts upwards and eventually merge them.
229+
// So we try to check that this will be beneficial before making any changes.
230+
class HoistLayoutConversion : public OpRewritePattern<ConvertLayoutOp> {
231+
public:
232+
using OpRewritePattern::OpRewritePattern;
233+
234+
LogicalResult matchAndRewrite(ConvertLayoutOp cvt,
235+
PatternRewriter &rewriter) const override {
236+
// Only consider conversions to dot operand.
237+
auto cvtTy = cast<RankedTensorType>(cvt.getType());
238+
auto dotOpEnc = dyn_cast<DotOperandEncodingAttr>(cvtTy.getEncoding());
239+
if (!dotOpEnc)
240+
return failure();
241+
242+
auto src = cvt.getSrc().getDefiningOp();
243+
if (!src || src->getNumOperands() == 0 || src->getNumResults() != 1)
244+
return failure();
245+
246+
auto srcTy = dyn_cast<RankedTensorType>(src->getResult(0).getType());
247+
if (!srcTy)
248+
return failure();
249+
250+
if (!all_of(src->getOperandTypes(),
251+
[](Type ty) { return isa<RankedTensorType>(ty); }))
252+
return failure();
253+
254+
if (!canHoistDotOpEncV2(src, dotOpEnc))
255+
return failure();
256+
257+
// Check that the conversion is transitively dependent on a load or a
258+
// constant, and all operations between it and the convert are layout
259+
// preserving.
260+
//
261+
// TODO(jlebar): This is accidentally quadratic; we iterate over the whole
262+
// slice but then at the end we only modify one op!
263+
SetVector<Operation *> slice;
264+
BackwardSliceOptions opt;
265+
opt.omitBlockArguments = true;
266+
getBackwardSlice(cvt.getOperation(), &slice, opt);
267+
268+
// TODO(jlebar): This is too conservative when there are multiple loads in
269+
// the chain. If one of the loads has a non-layout-preserving op and the
270+
// other does not, then we may or may not accept the chain, depending on
271+
// which load gets hit first by getBackwardSlice. For example:
272+
// cvt(broadcast(load(x)) + load(y)) // accepted & load(y) will benefit.
273+
// cvt(load(y) + broadcast(load(x))) // rejected & load(y) will not benefit.
274+
bool foundInitializer = false;
275+
// Reverse the slice so that we start directly above the convert and check
276+
// that every op allows hoisting until we find a load or a constant.
277+
for (Operation *currOp : llvm::reverse(slice)) {
278+
if (isa<LoadOp>(currOp) || isa<arith::ConstantOp>(currOp)) {
279+
foundInitializer = true;
280+
break;
281+
}
282+
if (!canHoistDotOpEncV2(currOp, dotOpEnc))
283+
return failure();
284+
}
285+
if (!foundInitializer)
286+
return failure();
287+
288+
SmallVector<ConvertLayoutOp> newOperands;
289+
for (auto operand : src->getOperands()) {
290+
// We checked earlier that all operands are ranked tensors.
291+
auto operandTy = cast<RankedTensorType>(operand.getType());
292+
Type newCvtTy = RankedTensorType::get(
293+
srcTy.getShape(), operandTy.getElementType(), cvtTy.getEncoding());
294+
newOperands.push_back(
295+
rewriter.create<ConvertLayoutOp>(cvt.getLoc(), newCvtTy, operand));
296+
}
297+
auto newRet = rewriter.clone(*src);
298+
for (int i = 0; i < newOperands.size(); i++)
299+
newRet->setOperand(i, newOperands[i]);
300+
newRet->getResult(0).setType(RankedTensorType::get(
301+
srcTy.getShape(), srcTy.getElementType(), cvtTy.getEncoding()));
302+
303+
rewriter.replaceOp(cvt, newRet->getResults());
304+
return success();
305+
}
306+
};
307+
168308
// Rewrite
169309
//
170310
// dot(alloc(trans() #shared1) ->
@@ -559,6 +699,8 @@ class TritonGPUOptimizeDotOperandsPass
559699
mlir::RewritePatternSet patterns(context);
560700
patterns.add<MMAV3HoistLayoutConversion>(context);
561701
patterns.add<SwizzleShmemConvert>(context);
702+
if (this->hoistLayoutConversion.getValue())
703+
patterns.add<HoistLayoutConversion>(context);
562704
patterns.add<FuseTransMMAV3Plus>(context);
563705
patterns.add<MMAV3UseRegOperand>(context);
564706
patterns.add<InjectTMemCopy>(context);

0 commit comments

Comments
 (0)