Skip to content

Commit b3dcc32

Browse files
authored
[LAYOUTS] Generalise HoistLayoutConversion to work with arbitrary layouts and chains of ops (#5673)
We generalise `HoistLayoutConversion` to lift a given `convert_layout dot_operand` above any chain of operations that do not require data movement. We could totally generalise this in the future to lift it over other ops. We do this as a first step to keep the code somewhat similar to the previous one. Regarding the previous limitations of `canHoistDotOpEncV2` I did a bit of archeology: - The "don't hoist past select" was added in this issue #2857. I run the repro and with the recent layout fixes, it now passes. - The TruncOps being skipped comes from #2181. I think this is related with the hack that was removed in #5044, so now it should work - Same same for the `UIToFpOp`, this is now supported after #5044 - Mixed dtype hack is not necessary either as now everything works as expected with the `convert_layout` rework. We also add proper support for `isPure` for `elementwise_inline_asm` ops On the location of the code, we just leave it in `RemoveLayoutConversion.cpp` to take advantage of the rather generic implementation of `rewriteSlice`. We could totally move this pass outside of `remove-layout-conversion`, as it's probably enough to run it once. This code will go through further changes in the near future, so we'll assess this then.
1 parent 4ce54b5 commit b3dcc32

File tree

8 files changed

+438
-421
lines changed

8 files changed

+438
-421
lines changed

include/triton/Dialect/Triton/IR/TritonOps.td

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -830,7 +830,8 @@ def TT_MakeRangeOp : TT_Op<"make_range", [Pure]> {
830830
def TT_ElementwiseInlineAsmOp : TT_Op<"elementwise_inline_asm", [
831831
Elementwise,
832832
SameOperandsAndResultEncoding,
833-
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>
833+
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
834+
DeclareOpInterfaceMethods<ConditionallySpeculatable>
834835
]> {
835836
let summary = "inline assembly applying an elementwise operation to a group of packed elements.";
836837
let description = [{

include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,7 @@ def TTG_MemDescSubviewOp : TTG_Op<"memdesc_subview", [Pure]> {
225225

226226
def TTG_MemDescTransOp : TTG_Op<"memdesc_trans", [Pure,
227227
TransposeOpInterface,
228-
DeclareOpInterfaceMethods<InferTypeOpInterface>,
228+
InferTypeOpWithLayoutEquivalence,
229229
SameOperandsAndResultElementType]> {
230230
let summary = "transpose the descriptor";
231231

lib/Dialect/Triton/IR/Ops.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1037,6 +1037,12 @@ void ElementwiseInlineAsmOp::getEffects(
10371037
SideEffects::DefaultResource::get());
10381038
}
10391039

1040+
Speculation::Speculatability ElementwiseInlineAsmOp::getSpeculatability() {
1041+
if (getPure())
1042+
return Speculation::Speculatable;
1043+
return Speculation::NotSpeculatable;
1044+
}
1045+
10401046
LogicalResult ElementwiseInlineAsmOp::verify() {
10411047
if (getNumOperands() >= 1) {
10421048
auto tensorType = dyn_cast<RankedTensorType>(getOperand(0).getType());

lib/Dialect/TritonGPU/IR/Ops.cpp

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -463,15 +463,17 @@ OpFoldResult MemDescTransOp::fold(FoldAdaptor adaptor) {
463463
return {};
464464
}
465465

466-
LogicalResult MemDescTransOp::inferReturnTypes(
467-
MLIRContext *context, std::optional<Location> location, ValueRange operands,
468-
DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
469-
SmallVectorImpl<Type> &inferredReturnTypes) {
466+
LogicalResult
467+
MemDescTransOp::inferReturnTypes(MLIRContext *context,
468+
std::optional<Location> location,
469+
MemDescTransOp::Adaptor adaptor,
470+
SmallVectorImpl<Type> &inferredReturnTypes) {
471+
470472
// type is the same as the input
471-
auto argTy = cast<MemDescType>(operands[0].getType());
472-
auto argShape = argTy.getShape();
473-
auto order = properties.as<Properties *>()->order.asArrayRef();
474-
SmallVector<int64_t> retShape = applyPermutation(argTy.getShape(), order);
473+
auto argTy = cast<MemDescType>(adaptor.getSrc().getType());
474+
auto shape = argTy.getShape();
475+
auto order = adaptor.getOrder();
476+
SmallVector<int64_t> retShape = applyPermutation(shape, order);
475477

476478
auto retEltTy = argTy.getElementType();
477479
Attribute argEncoding = argTy.getEncoding();
@@ -480,17 +482,17 @@ LogicalResult MemDescTransOp::inferReturnTypes(
480482
Dialect &dialect = argEncoding.getDialect();
481483
auto inferLayoutInterface = cast<DialectInferLayoutInterface>(&dialect);
482484
if (inferLayoutInterface
483-
->inferTransOpEncoding(argEncoding, argShape, order, retEncoding)
485+
->inferTransOpEncoding(argEncoding, shape, order, retEncoding)
484486
.failed()) {
485487
return failure();
486488
}
487489
}
488-
auto memDescTy = cast<MemDescType>(argTy);
489-
inferredReturnTypes.push_back(MemDescType::get(
490-
retShape, retEltTy, retEncoding, memDescTy.getMemorySpace(),
491-
memDescTy.getMutableMemory()));
490+
inferredReturnTypes.push_back(
491+
MemDescType::get(retShape, retEltTy, retEncoding, argTy.getMemorySpace(),
492+
argTy.getMutableMemory()));
492493
return success();
493494
}
495+
494496
// LocalAllocOp
495497
void LocalAllocOp::getEffects(
496498
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>

lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp

Lines changed: 0 additions & 142 deletions
Original file line numberDiff line numberDiff line change
@@ -24,36 +24,6 @@ namespace {
2424
// Roughly, whether op is elementwise and thus threads don't need
2525
// to exchange elements. But some ops are not currently supported even though
2626
// they meet that criterion.
27-
bool canHoistDotOpEncV2(Operation *op, DotOperandEncodingAttr &dotOpEnc) {
28-
// Only consider custom conversions or arith ops.
29-
// TODO(jlebar): Is this too restrictive?
30-
if (!isa<FpToFpOp, BitcastOp>(op) && !isPureUnaryInlineAsm(op) &&
31-
!isa<arith::ArithDialect>(op->getDialect()))
32-
return false;
33-
34-
// Quick handling to fix loading issues when computing the original
35-
// bitwidth is unable to realize that there is a mixed-precision dot
36-
// (hence kWidth = 1) but wants to hoist through the type conversion.
37-
if (isa<arith::ExtFOp>(op) && dotOpEnc.getKWidth() == 1)
38-
return false;
39-
40-
// Currently, these instructions are not supported during lowering of
41-
// shared -> dot_operand layout. Not all types and type conversions are
42-
// supported.
43-
if (isa<arith::TruncIOp, arith::TruncFOp, arith::SelectOp>(op))
44-
return false;
45-
46-
// Don't hoist through u1 -> fp casts as they aren't supported in
47-
// ElementwiseOpToLLVM::reorderValues().
48-
if (isa<arith::UIToFPOp>(op)) {
49-
Type opType = getElementTypeOrSelf(op->getOperand(0));
50-
if (opType.isInteger(1))
51-
return false;
52-
}
53-
54-
return true;
55-
}
56-
5727
// Analog of canHoistDotOpEncV2, but for MMAv3 (WGMMA where operand A
5828
// is in registers).
5929
bool canHoistDotOpEncV3(Operation *op) {
@@ -195,116 +165,6 @@ class SwizzleShmemConvert : public OpRewritePattern<ConvertLayoutOp> {
195165
}
196166
};
197167

198-
// Move convert-to-dot-operand "up" past elementwise ops:
199-
//
200-
// convert(elementwise(x)) #dot_operand ->
201-
// elementwise(convert(x, #dot_operand)).
202-
//
203-
// The goal is to put the convert right next to the originating load. If we can
204-
// accomplish this, then we can save a shmem round-trip:
205-
//
206-
// Before:
207-
//
208-
// - Load from global into shmem using an async copy.
209-
// - Load from shmem into a #blocked layout.
210-
// - Do elementwise ops over #blocked layout.
211-
// - Convert to #dot_operand (round-trip through shmem).
212-
// - Do dot.
213-
//
214-
// After:
215-
//
216-
// - Load from global into shmem using an async copy (same as before).
217-
// - Load from shmem into a #dot_operand layout.
218-
// - Do elementwise ops over #dot_operand layout.
219-
// - Do dot.
220-
//
221-
// This can also be propagated when we have a constant, instead of a load.
222-
//
223-
// Eliminating the shmem round-trip is such a big win, we're willing to do it
224-
// even if this duplicates work because some of the elementwise ops have uses
225-
// that don't flow into the dot. On the other hand, we only want to do this if
226-
// we can in fact reduce shmem round-trips: For example, simply moving a convert
227-
// up above e.g. an `add` now means we have *two* converts. That's worse,
228-
// unless we can continue moving the converts upwards and eventually merge them.
229-
// So we try to check that this will be beneficial before making any changes.
230-
class HoistLayoutConversion : public OpRewritePattern<ConvertLayoutOp> {
231-
public:
232-
using OpRewritePattern::OpRewritePattern;
233-
234-
LogicalResult matchAndRewrite(ConvertLayoutOp cvt,
235-
PatternRewriter &rewriter) const override {
236-
// Only consider conversions to dot operand.
237-
auto cvtTy = cast<RankedTensorType>(cvt.getType());
238-
auto dotOpEnc = dyn_cast<DotOperandEncodingAttr>(cvtTy.getEncoding());
239-
if (!dotOpEnc)
240-
return failure();
241-
242-
auto src = cvt.getSrc().getDefiningOp();
243-
if (!src || src->getNumOperands() == 0 || src->getNumResults() != 1)
244-
return failure();
245-
246-
auto srcTy = dyn_cast<RankedTensorType>(src->getResult(0).getType());
247-
if (!srcTy)
248-
return failure();
249-
250-
if (!all_of(src->getOperandTypes(),
251-
[](Type ty) { return isa<RankedTensorType>(ty); }))
252-
return failure();
253-
254-
if (!canHoistDotOpEncV2(src, dotOpEnc))
255-
return failure();
256-
257-
// Check that the conversion is transitively dependent on a load or a
258-
// constant, and all operations between it and the convert are layout
259-
// preserving.
260-
//
261-
// TODO(jlebar): This is accidentally quadratic; we iterate over the whole
262-
// slice but then at the end we only modify one op!
263-
SetVector<Operation *> slice;
264-
BackwardSliceOptions opt;
265-
opt.omitBlockArguments = true;
266-
getBackwardSlice(cvt.getOperation(), &slice, opt);
267-
268-
// TODO(jlebar): This is too conservative when there are multiple loads in
269-
// the chain. If one of the loads has a non-layout-preserving op and the
270-
// other does not, then we may or may not accept the chain, depending on
271-
// which load gets hit first by getBackwardSlice. For example:
272-
// cvt(broadcast(load(x)) + load(y)) // accepted & load(y) will benefit.
273-
// cvt(load(y) + broadcast(load(x))) // rejected & load(y) will not benefit.
274-
bool foundInitializer = false;
275-
// Reverse the slice so that we start directly above the convert and check
276-
// that every op allows hoisting until we find a load or a constant.
277-
for (Operation *currOp : llvm::reverse(slice)) {
278-
if (isa<LoadOp>(currOp) || isa<arith::ConstantOp>(currOp)) {
279-
foundInitializer = true;
280-
break;
281-
}
282-
if (!canHoistDotOpEncV2(currOp, dotOpEnc))
283-
return failure();
284-
}
285-
if (!foundInitializer)
286-
return failure();
287-
288-
SmallVector<ConvertLayoutOp> newOperands;
289-
for (auto operand : src->getOperands()) {
290-
// We checked earlier that all operands are ranked tensors.
291-
auto operandTy = cast<RankedTensorType>(operand.getType());
292-
Type newCvtTy = RankedTensorType::get(
293-
srcTy.getShape(), operandTy.getElementType(), cvtTy.getEncoding());
294-
newOperands.push_back(
295-
rewriter.create<ConvertLayoutOp>(cvt.getLoc(), newCvtTy, operand));
296-
}
297-
auto newRet = rewriter.clone(*src);
298-
for (int i = 0; i < newOperands.size(); i++)
299-
newRet->setOperand(i, newOperands[i]);
300-
newRet->getResult(0).setType(RankedTensorType::get(
301-
srcTy.getShape(), srcTy.getElementType(), cvtTy.getEncoding()));
302-
303-
rewriter.replaceOp(cvt, newRet->getResults());
304-
return success();
305-
}
306-
};
307-
308168
// Rewrite
309169
//
310170
// dot(alloc(trans() #shared1) ->
@@ -699,8 +559,6 @@ class TritonGPUOptimizeDotOperandsPass
699559
mlir::RewritePatternSet patterns(context);
700560
patterns.add<MMAV3HoistLayoutConversion>(context);
701561
patterns.add<SwizzleShmemConvert>(context);
702-
if (this->hoistLayoutConversion.getValue())
703-
patterns.add<HoistLayoutConversion>(context);
704562
patterns.add<FuseTransMMAV3Plus>(context);
705563
patterns.add<MMAV3UseRegOperand>(context);
706564
patterns.add<InjectTMemCopy>(context);

0 commit comments

Comments
 (0)