@@ -52,36 +52,6 @@ class LoopUnrollPass : public TritonLoopUnrollBase<LoopUnrollPass> {
5252 LoopUnrollPass () = default ;
5353 LoopUnrollPass (const LoopUnrollPass &) {}
5454
55- SmallVector<scf::ForOp, 2 > getUnrolledLoopsAndClearAttrs (unsigned loopId) {
56- SmallVector<scf::ForOp, 2 > loops;
57- getOperation ()->walk ([&](scf::ForOp forOp) {
58- if (getUnrollIdOrDefault (forOp) == loopId)
59- loops.push_back (forOp);
60- });
61-
62- assert (loops.size () <= 2 && " Expect at most 2 loops, one for the main loop "
63- " and one for the prolog/epilog" );
64- SmallVector<int , 2 > loopInstructionCounts;
65- for (auto loop : loops) {
66- loop->removeAttr (loopUnrollFactorAttrName);
67- loop->removeAttr (unrolledLoopIdAttrName);
68- int count = 0 ;
69- loop->walk ([&](Operation *op) { count++; });
70- loopInstructionCounts.push_back (count);
71- }
72- if (loops.size () == 2 ) {
73- // check which one is the unrolled loop and which one is the prolog/epilog
74- // loop. A simple heuristic is to check the number of instructions in the
75- // loop. The unrolled main loop should have the most instructions.
76- // sort the loops by the number of instructions. The unrolled main loop
77- // should go first.
78- if (loopInstructionCounts[0 ] < loopInstructionCounts[1 ])
79- std::swap (loops[0 ], loops[1 ]);
80- }
81-
82- return loops;
83- }
84-
8555 void runOnOperation () override {
8656 LDBG (" Loop unroll pass" );
8757 SmallVector<scf::ForOp, 4 > loops;
@@ -95,17 +65,14 @@ class LoopUnrollPass : public TritonLoopUnrollBase<LoopUnrollPass> {
9565 for (unsigned i = 0 ; i < loops.size (); i++) {
9666 auto loop = loops[i];
9767 auto unrollFactor = getUnrollFactorOrDefault (loop);
98- loop->setAttr (unrolledLoopIdAttrName,
99- mlir::IntegerAttr::get (IntegerType::get (ctx, 32 ), i + 1 ));
68+ loop->removeAttr (loopUnrollFactorAttrName);
10069 LDBG (" Unrolling loop by " << unrollFactor << " times\n " << loop);
101- (void )loopUnrollByFactor (loop, unrollFactor);
102- auto unrolledLoops = getUnrolledLoopsAndClearAttrs (i + 1 );
103- // Do not pipeline the prolog/epilog loop.
104- if (unrolledLoops.size () == 2 ) {
105- auto prologEpilogLoop = unrolledLoops[1 ];
106- prologEpilogLoop->setAttr (
107- pipelineStagesAttrName,
108- mlir::IntegerAttr::get (IntegerType::get (ctx, 32 ), 1 ));
70+ auto resultLoops = loopUnrollByFactor (loop, unrollFactor);
71+ // Do not pipeline the epilog loop.
72+ if (succeeded (resultLoops) && resultLoops->epilogueLoopOp ) {
73+ (*resultLoops->epilogueLoopOp )
74+ ->setAttr (pipelineStagesAttrName,
75+ mlir::IntegerAttr::get (IntegerType::get (ctx, 32 ), 1 ));
10976 }
11077 }
11178 }
0 commit comments