11#include < memory>
22
33#include " mlir/Dialect/SCF/Utils/Utils.h"
4+ #include " mlir/IR/Attributes.h"
45#include " mlir/IR/BuiltinAttributes.h"
56#include " mlir/IR/Matchers.h"
67#include " mlir/IR/PatternMatch.h"
2223
2324namespace mlir ::triton {
2425
25- static const char *loopUnrollFactorAttrName = " tt.loop_unroll_factor" ;
26-
2726namespace {
2827
2928class LoopUnrollPass : public TritonLoopUnrollBase <LoopUnrollPass> {
3029
3130 int getUnrollFactorOrDefault (scf::ForOp forOp) {
3231 // Use the attribute attached to the loop if it exists otherwise set the
3332 // factor to 1 to suppress the unrolling.
34- if (auto factor = forOp-> getAttrOfType <IntegerAttr>(
35- mlir::triton:: loopUnrollFactorAttrName))
33+ if (auto factor =
34+ forOp-> getAttrOfType <IntegerAttr>( loopUnrollFactorAttrName))
3635 return factor.getInt ();
3736 return 1 ;
3837 }
3938
39+ int getUnrollIdOrDefault (scf::ForOp forOp) {
40+ // Use the attribute attached to the loop if it exists otherwise set the
41+ // factor to 1 to suppress the unrolling.
42+ if (auto factor = forOp->getAttrOfType <IntegerAttr>(unrolledLoopIdAttrName))
43+ return factor.getInt ();
44+ return 0 ;
45+ }
46+
47+ const char *loopUnrollFactorAttrName = " tt.loop_unroll_factor" ;
48+ const char *unrolledLoopIdAttrName = " tt.unrolled_loop_id" ;
49+ const char *pipelineStagesAttrName = " tt.num_stages" ;
50+
4051public:
4152 LoopUnrollPass () = default ;
4253 LoopUnrollPass (const LoopUnrollPass &) {}
54+
55+ SmallVector<scf::ForOp, 2 > getUnrolledLoopsAndClearAttrs (unsigned loopId) {
56+ SmallVector<scf::ForOp, 2 > loops;
57+ getOperation ()->walk ([&](scf::ForOp forOp) {
58+ if (getUnrollIdOrDefault (forOp) == loopId)
59+ loops.push_back (forOp);
60+ });
61+
62+ // check which one is the unrolled loop and which one is the prolog/epilog
63+ // loop. A simple heuristic is to check the number of instructions in the
64+ // loop. The unrolled main loop should have the most instructions.
65+ assert (loops.size () == 2 && " only support unrolling one loop at a time" );
66+ SmallVector<int , 2 > loopInstructionCounts;
67+ for (auto loop : loops) {
68+ loop->removeAttr (loopUnrollFactorAttrName);
69+ loop->removeAttr (unrolledLoopIdAttrName);
70+ int count = 0 ;
71+ loop->walk ([&](Operation *op) { count++; });
72+ loopInstructionCounts.push_back (count);
73+ }
74+
75+ // sort the loops by the number of instructions. The unrolled main loop
76+ // should go first.
77+ if (loopInstructionCounts[0 ] < loopInstructionCounts[1 ])
78+ std::swap (loops[0 ], loops[1 ]);
79+
80+ return loops;
81+ }
82+
4383 void runOnOperation () override {
4484 LDBG (" Loop unroll pass" );
4585 SmallVector<scf::ForOp, 4 > loops;
@@ -49,11 +89,22 @@ class LoopUnrollPass : public TritonLoopUnrollBase<LoopUnrollPass> {
4989 loops.push_back (forOp);
5090 });
5191
52- for (auto loop : loops) {
92+ auto ctx = getOperation ()->getContext ();
93+ for (unsigned i = 0 ; i < loops.size (); i++) {
94+ auto loop = loops[i];
5395 auto unrollFactor = getUnrollFactorOrDefault (loop);
54- loop->removeAttr (mlir::triton::loopUnrollFactorAttrName);
96+ loop->setAttr (unrolledLoopIdAttrName,
97+ mlir::IntegerAttr::get (IntegerType::get (ctx, 32 ), i + 1 ));
5598 LDBG (" Unrolling loop by " << unrollFactor << " times\n " << loop);
5699 (void )loopUnrollByFactor (loop, unrollFactor);
100+ auto unrolledLoops = getUnrolledLoopsAndClearAttrs (i + 1 );
101+ // Do not pipeline the prolog/epilog loop.
102+ if (unrolledLoops.size () == 2 ) {
103+ auto prologEpilogLoop = unrolledLoops[1 ];
104+ prologEpilogLoop->setAttr (
105+ pipelineStagesAttrName,
106+ mlir::IntegerAttr::get (IntegerType::get (ctx, 32 ), 1 ));
107+ }
57108 }
58109 }
59110};
0 commit comments