@@ -24,6 +24,36 @@ namespace {
2424// Roughly, whether op is elementwise and thus threads don't need
2525// to exchange elements. But some ops are not currently supported even though
2626// they meet that criterion.
27+ bool canHoistDotOpEncV2 (Operation *op, DotOperandEncodingAttr &dotOpEnc) {
28+ // Only consider custom conversions or arith ops.
29+ // TODO(jlebar): Is this too restrictive?
30+ if (!isa<FpToFpOp, BitcastOp>(op) && !isPureUnaryInlineAsm (op) &&
31+ !isa<arith::ArithDialect>(op->getDialect ()))
32+ return false ;
33+
34+ // Quick handling to fix loading issues when computing the original
35+ // bitwidth is unable to realize that there is a mixed-precision dot
36+ // (hence kWidth = 1) but wants to hoist through the type conversion.
37+ if (isa<arith::ExtFOp>(op) && dotOpEnc.getKWidth () == 1 )
38+ return false ;
39+
40+ // Currently, these instructions are not supported during lowering of
41+ // shared -> dot_operand layout. Not all types and type conversions are
42+ // supported.
43+ if (isa<arith::TruncIOp, arith::TruncFOp, arith::SelectOp>(op))
44+ return false ;
45+
46+ // Don't hoist through u1 -> fp casts as they aren't supported in
47+ // ElementwiseOpToLLVM::reorderValues().
48+ if (isa<arith::UIToFPOp>(op)) {
49+ Type opType = getElementTypeOrSelf (op->getOperand (0 ));
50+ if (opType.isInteger (1 ))
51+ return false ;
52+ }
53+
54+ return true ;
55+ }
56+
2757// Analog of canHoistDotOpEncV2, but for MMAv3 (WGMMA where operand A
2858// is in registers).
2959bool canHoistDotOpEncV3 (Operation *op) {
@@ -165,6 +195,116 @@ class SwizzleShmemConvert : public OpRewritePattern<ConvertLayoutOp> {
165195 }
166196};
167197
198+ // Move convert-to-dot-operand "up" past elementwise ops:
199+ //
200+ // convert(elementwise(x)) #dot_operand ->
201+ // elementwise(convert(x, #dot_operand)).
202+ //
203+ // The goal is to put the convert right next to the originating load. If we can
204+ // accomplish this, then we can save a shmem round-trip:
205+ //
206+ // Before:
207+ //
208+ // - Load from global into shmem using an async copy.
209+ // - Load from shmem into a #blocked layout.
210+ // - Do elementwise ops over #blocked layout.
211+ // - Convert to #dot_operand (round-trip through shmem).
212+ // - Do dot.
213+ //
214+ // After:
215+ //
216+ // - Load from global into shmem using an async copy (same as before).
217+ // - Load from shmem into a #dot_operand layout.
218+ // - Do elementwise ops over #dot_operand layout.
219+ // - Do dot.
220+ //
221+ // This can also be propagated when we have a constant, instead of a load.
222+ //
223+ // Eliminating the shmem round-trip is such a big win, we're willing to do it
224+ // even if this duplicates work because some of the elementwise ops have uses
225+ // that don't flow into the dot. On the other hand, we only want to do this if
226+ // we can in fact reduce shmem round-trips: For example, simply moving a convert
227+ // up above e.g. an `add` now means we have *two* converts. That's worse,
228+ // unless we can continue moving the converts upwards and eventually merge them.
229+ // So we try to check that this will be beneficial before making any changes.
230+ class HoistLayoutConversion : public OpRewritePattern <ConvertLayoutOp> {
231+ public:
232+ using OpRewritePattern::OpRewritePattern;
233+
234+ LogicalResult matchAndRewrite (ConvertLayoutOp cvt,
235+ PatternRewriter &rewriter) const override {
236+ // Only consider conversions to dot operand.
237+ auto cvtTy = cast<RankedTensorType>(cvt.getType ());
238+ auto dotOpEnc = dyn_cast<DotOperandEncodingAttr>(cvtTy.getEncoding ());
239+ if (!dotOpEnc)
240+ return failure ();
241+
242+ auto src = cvt.getSrc ().getDefiningOp ();
243+ if (!src || src->getNumOperands () == 0 || src->getNumResults () != 1 )
244+ return failure ();
245+
246+ auto srcTy = dyn_cast<RankedTensorType>(src->getResult (0 ).getType ());
247+ if (!srcTy)
248+ return failure ();
249+
250+ if (!all_of (src->getOperandTypes (),
251+ [](Type ty) { return isa<RankedTensorType>(ty); }))
252+ return failure ();
253+
254+ if (!canHoistDotOpEncV2 (src, dotOpEnc))
255+ return failure ();
256+
257+ // Check that the conversion is transitively dependent on a load or a
258+ // constant, and all operations between it and the convert are layout
259+ // preserving.
260+ //
261+ // TODO(jlebar): This is accidentally quadratic; we iterate over the whole
262+ // slice but then at the end we only modify one op!
263+ SetVector<Operation *> slice;
264+ BackwardSliceOptions opt;
265+ opt.omitBlockArguments = true ;
266+ getBackwardSlice (cvt.getOperation (), &slice, opt);
267+
268+ // TODO(jlebar): This is too conservative when there are multiple loads in
269+ // the chain. If one of the loads has a non-layout-preserving op and the
270+ // other does not, then we may or may not accept the chain, depending on
271+ // which load gets hit first by getBackwardSlice. For example:
272+ // cvt(broadcast(load(x)) + load(y)) // accepted & load(y) will benefit.
273+ // cvt(load(y) + broadcast(load(x))) // rejected & load(y) will not benefit.
274+ bool foundInitializer = false ;
275+ // Reverse the slice so that we start directly above the convert and check
276+ // that every op allows hoisting until we find a load or a constant.
277+ for (Operation *currOp : llvm::reverse (slice)) {
278+ if (isa<LoadOp>(currOp) || isa<arith::ConstantOp>(currOp)) {
279+ foundInitializer = true ;
280+ break ;
281+ }
282+ if (!canHoistDotOpEncV2 (currOp, dotOpEnc))
283+ return failure ();
284+ }
285+ if (!foundInitializer)
286+ return failure ();
287+
288+ SmallVector<ConvertLayoutOp> newOperands;
289+ for (auto operand : src->getOperands ()) {
290+ // We checked earlier that all operands are ranked tensors.
291+ auto operandTy = cast<RankedTensorType>(operand.getType ());
292+ Type newCvtTy = RankedTensorType::get (
293+ srcTy.getShape (), operandTy.getElementType (), cvtTy.getEncoding ());
294+ newOperands.push_back (
295+ rewriter.create <ConvertLayoutOp>(cvt.getLoc (), newCvtTy, operand));
296+ }
297+ auto newRet = rewriter.clone (*src);
298+ for (int i = 0 ; i < newOperands.size (); i++)
299+ newRet->setOperand (i, newOperands[i]);
300+ newRet->getResult (0 ).setType (RankedTensorType::get (
301+ srcTy.getShape (), srcTy.getElementType (), cvtTy.getEncoding ()));
302+
303+ rewriter.replaceOp (cvt, newRet->getResults ());
304+ return success ();
305+ }
306+ };
307+
168308// Rewrite
169309//
170310// dot(alloc(trans() #shared1) ->
@@ -559,6 +699,8 @@ class TritonGPUOptimizeDotOperandsPass
559699 mlir::RewritePatternSet patterns (context);
560700 patterns.add <MMAV3HoistLayoutConversion>(context);
561701 patterns.add <SwizzleShmemConvert>(context);
702+ if (this ->hoistLayoutConversion .getValue ())
703+ patterns.add <HoistLayoutConversion>(context);
562704 patterns.add <FuseTransMMAV3Plus>(context);
563705 patterns.add <MMAV3UseRegOperand>(context);
564706 patterns.add <InjectTMemCopy>(context);
0 commit comments