@@ -24,36 +24,6 @@ namespace {
2424// Roughly, whether op is elementwise and thus threads don't need
2525// to exchange elements. But some ops are not currently supported even though
2626// they meet that criterion.
27- bool canHoistDotOpEncV2 (Operation *op, DotOperandEncodingAttr &dotOpEnc) {
28- // Only consider custom conversions or arith ops.
29- // TODO(jlebar): Is this too restrictive?
30- if (!isa<FpToFpOp, BitcastOp>(op) && !isPureUnaryInlineAsm (op) &&
31- !isa<arith::ArithDialect>(op->getDialect ()))
32- return false ;
33-
34- // Quick handling to fix loading issues when computing the original
35- // bitwidth is unable to realize that there is a mixed-precision dot
36- // (hence kWidth = 1) but wants to hoist through the type conversion.
37- if (isa<arith::ExtFOp>(op) && dotOpEnc.getKWidth () == 1 )
38- return false ;
39-
40- // Currently, these instructions are not supported during lowering of
41- // shared -> dot_operand layout. Not all types and type conversions are
42- // supported.
43- if (isa<arith::TruncIOp, arith::TruncFOp, arith::SelectOp>(op))
44- return false ;
45-
46- // Don't hoist through u1 -> fp casts as they aren't supported in
47- // ElementwiseOpToLLVM::reorderValues().
48- if (isa<arith::UIToFPOp>(op)) {
49- Type opType = getElementTypeOrSelf (op->getOperand (0 ));
50- if (opType.isInteger (1 ))
51- return false ;
52- }
53-
54- return true ;
55- }
56-
5727// Analog of canHoistDotOpEncV2, but for MMAv3 (WGMMA where operand A
5828// is in registers).
5929bool canHoistDotOpEncV3 (Operation *op) {
@@ -195,116 +165,6 @@ class SwizzleShmemConvert : public OpRewritePattern<ConvertLayoutOp> {
195165 }
196166};
197167
198- // Move convert-to-dot-operand "up" past elementwise ops:
199- //
200- // convert(elementwise(x)) #dot_operand ->
201- // elementwise(convert(x, #dot_operand)).
202- //
203- // The goal is to put the convert right next to the originating load. If we can
204- // accomplish this, then we can save a shmem round-trip:
205- //
206- // Before:
207- //
208- // - Load from global into shmem using an async copy.
209- // - Load from shmem into a #blocked layout.
210- // - Do elementwise ops over #blocked layout.
211- // - Convert to #dot_operand (round-trip through shmem).
212- // - Do dot.
213- //
214- // After:
215- //
216- // - Load from global into shmem using an async copy (same as before).
217- // - Load from shmem into a #dot_operand layout.
218- // - Do elementwise ops over #dot_operand layout.
219- // - Do dot.
220- //
221- // This can also be propagated when we have a constant, instead of a load.
222- //
223- // Eliminating the shmem round-trip is such a big win, we're willing to do it
224- // even if this duplicates work because some of the elementwise ops have uses
225- // that don't flow into the dot. On the other hand, we only want to do this if
226- // we can in fact reduce shmem round-trips: For example, simply moving a convert
227- // up above e.g. an `add` now means we have *two* converts. That's worse,
228- // unless we can continue moving the converts upwards and eventually merge them.
229- // So we try to check that this will be beneficial before making any changes.
230- class HoistLayoutConversion : public OpRewritePattern <ConvertLayoutOp> {
231- public:
232- using OpRewritePattern::OpRewritePattern;
233-
234- LogicalResult matchAndRewrite (ConvertLayoutOp cvt,
235- PatternRewriter &rewriter) const override {
236- // Only consider conversions to dot operand.
237- auto cvtTy = cast<RankedTensorType>(cvt.getType ());
238- auto dotOpEnc = dyn_cast<DotOperandEncodingAttr>(cvtTy.getEncoding ());
239- if (!dotOpEnc)
240- return failure ();
241-
242- auto src = cvt.getSrc ().getDefiningOp ();
243- if (!src || src->getNumOperands () == 0 || src->getNumResults () != 1 )
244- return failure ();
245-
246- auto srcTy = dyn_cast<RankedTensorType>(src->getResult (0 ).getType ());
247- if (!srcTy)
248- return failure ();
249-
250- if (!all_of (src->getOperandTypes (),
251- [](Type ty) { return isa<RankedTensorType>(ty); }))
252- return failure ();
253-
254- if (!canHoistDotOpEncV2 (src, dotOpEnc))
255- return failure ();
256-
257- // Check that the conversion is transitively dependent on a load or a
258- // constant, and all operations between it and the convert are layout
259- // preserving.
260- //
261- // TODO(jlebar): This is accidentally quadratic; we iterate over the whole
262- // slice but then at the end we only modify one op!
263- SetVector<Operation *> slice;
264- BackwardSliceOptions opt;
265- opt.omitBlockArguments = true ;
266- getBackwardSlice (cvt.getOperation (), &slice, opt);
267-
268- // TODO(jlebar): This is too conservative when there are multiple loads in
269- // the chain. If one of the loads has a non-layout-preserving op and the
270- // other does not, then we may or may not accept the chain, depending on
271- // which load gets hit first by getBackwardSlice. For example:
272- // cvt(broadcast(load(x)) + load(y)) // accepted & load(y) will benefit.
273- // cvt(load(y) + broadcast(load(x))) // rejected & load(y) will not benefit.
274- bool foundInitializer = false ;
275- // Reverse the slice so that we start directly above the convert and check
276- // that every op allows hoisting until we find a load or a constant.
277- for (Operation *currOp : llvm::reverse (slice)) {
278- if (isa<LoadOp>(currOp) || isa<arith::ConstantOp>(currOp)) {
279- foundInitializer = true ;
280- break ;
281- }
282- if (!canHoistDotOpEncV2 (currOp, dotOpEnc))
283- return failure ();
284- }
285- if (!foundInitializer)
286- return failure ();
287-
288- SmallVector<ConvertLayoutOp> newOperands;
289- for (auto operand : src->getOperands ()) {
290- // We checked earlier that all operands are ranked tensors.
291- auto operandTy = cast<RankedTensorType>(operand.getType ());
292- Type newCvtTy = RankedTensorType::get (
293- srcTy.getShape (), operandTy.getElementType (), cvtTy.getEncoding ());
294- newOperands.push_back (
295- rewriter.create <ConvertLayoutOp>(cvt.getLoc (), newCvtTy, operand));
296- }
297- auto newRet = rewriter.clone (*src);
298- for (int i = 0 ; i < newOperands.size (); i++)
299- newRet->setOperand (i, newOperands[i]);
300- newRet->getResult (0 ).setType (RankedTensorType::get (
301- srcTy.getShape (), srcTy.getElementType (), cvtTy.getEncoding ()));
302-
303- rewriter.replaceOp (cvt, newRet->getResults ());
304- return success ();
305- }
306- };
307-
308168// Rewrite
309169//
310170// dot(alloc(trans() #shared1) ->
@@ -699,8 +559,6 @@ class TritonGPUOptimizeDotOperandsPass
699559 mlir::RewritePatternSet patterns (context);
700560 patterns.add <MMAV3HoistLayoutConversion>(context);
701561 patterns.add <SwizzleShmemConvert>(context);
702- if (this ->hoistLayoutConversion .getValue ())
703- patterns.add <HoistLayoutConversion>(context);
704562 patterns.add <FuseTransMMAV3Plus>(context);
705563 patterns.add <MMAV3UseRegOperand>(context);
706564 patterns.add <InjectTMemCopy>(context);
0 commit comments