11#include < memory>
22
33#include " mlir/Dialect/SCF/Utils/Utils.h"
4- #include " mlir/IR/Attributes.h"
54#include " mlir/IR/BuiltinAttributes.h"
65#include " mlir/IR/Matchers.h"
76#include " mlir/IR/PatternMatch.h"
@@ -36,52 +35,12 @@ class LoopUnrollPass : public TritonLoopUnrollBase<LoopUnrollPass> {
3635 return 1 ;
3736 }
3837
39- int getUnrollIdOrDefault (scf::ForOp forOp) {
40- // Use the attribute attached to the loop if it exists otherwise set the
41- // factor to 1 to suppress the unrolling.
42- if (auto factor = forOp->getAttrOfType <IntegerAttr>(unrolledLoopIdAttrName))
43- return factor.getInt ();
44- return 0 ;
45- }
46-
4738 const char *loopUnrollFactorAttrName = " tt.loop_unroll_factor" ;
48- const char *unrolledLoopIdAttrName = " tt.unrolled_loop_id" ;
4939 const char *pipelineStagesAttrName = " tt.num_stages" ;
5040
5141public:
5242 LoopUnrollPass () = default ;
5343 LoopUnrollPass (const LoopUnrollPass &) {}
54-
55- SmallVector<scf::ForOp, 2 > getUnrolledLoopsAndClearAttrs (unsigned loopId) {
56- SmallVector<scf::ForOp, 2 > loops;
57- getOperation ()->walk ([&](scf::ForOp forOp) {
58- if (getUnrollIdOrDefault (forOp) == loopId)
59- loops.push_back (forOp);
60- });
61-
62- assert (loops.size () <= 2 && " Expect at most 2 loops, one for the main loop "
63- " and one for the prolog/epilog" );
64- SmallVector<int , 2 > loopInstructionCounts;
65- for (auto loop : loops) {
66- loop->removeAttr (loopUnrollFactorAttrName);
67- loop->removeAttr (unrolledLoopIdAttrName);
68- int count = 0 ;
69- loop->walk ([&](Operation *op) { count++; });
70- loopInstructionCounts.push_back (count);
71- }
72- if (loops.size () == 2 ) {
73- // check which one is the unrolled loop and which one is the prolog/epilog
74- // loop. A simple heuristic is to check the number of instructions in the
75- // loop. The unrolled main loop should have the most instructions.
76- // sort the loops by the number of instructions. The unrolled main loop
77- // should go first.
78- if (loopInstructionCounts[0 ] < loopInstructionCounts[1 ])
79- std::swap (loops[0 ], loops[1 ]);
80- }
81-
82- return loops;
83- }
84-
8544 void runOnOperation () override {
8645 LDBG (" Loop unroll pass" );
8746 SmallVector<scf::ForOp, 4 > loops;
@@ -95,17 +54,14 @@ class LoopUnrollPass : public TritonLoopUnrollBase<LoopUnrollPass> {
9554 for (unsigned i = 0 ; i < loops.size (); i++) {
9655 auto loop = loops[i];
9756 auto unrollFactor = getUnrollFactorOrDefault (loop);
98- loop->setAttr (unrolledLoopIdAttrName,
99- mlir::IntegerAttr::get (IntegerType::get (ctx, 32 ), i + 1 ));
57+ loop->removeAttr (loopUnrollFactorAttrName);
10058 LDBG (" Unrolling loop by " << unrollFactor << " times\n " << loop);
101- (void )loopUnrollByFactor (loop, unrollFactor);
102- auto unrolledLoops = getUnrolledLoopsAndClearAttrs (i + 1 );
103- // Do not pipeline the prolog/epilog loop.
104- if (unrolledLoops.size () == 2 ) {
105- auto prologEpilogLoop = unrolledLoops[1 ];
106- prologEpilogLoop->setAttr (
107- pipelineStagesAttrName,
108- mlir::IntegerAttr::get (IntegerType::get (ctx, 32 ), 1 ));
59+ auto resultLoops = loopUnrollByFactor (loop, unrollFactor);
60+ // Do not pipeline the epilog loop.
61+ if (succeeded (resultLoops) && resultLoops->epilogueLoopOp ) {
62+ (*resultLoops->epilogueLoopOp )
63+ ->setAttr (pipelineStagesAttrName,
64+ mlir::IntegerAttr::get (IntegerType::get (ctx, 32 ), 1 ));
10965 }
11066 }
11167 }
0 commit comments