@@ -62,20 +62,23 @@ class LoopUnrollPass : public TritonLoopUnrollBase<LoopUnrollPass> {
6262 // check which one is the unrolled loop and which one is the prolog/epilog
6363 // loop. A simple heuristic is to check the number of instructions in the
6464 // loop. The unrolled main loop should have the most instructions.
65- assert (loops.size () == 2 && " only support unrolling one loop at a time" );
66- SmallVector<int , 2 > loopInstructionCounts;
67- for (auto loop : loops) {
68- loop->removeAttr (loopUnrollFactorAttrName);
69- loop->removeAttr (unrolledLoopIdAttrName);
70- int count = 0 ;
71- loop->walk ([&](Operation *op) { count++; });
72- loopInstructionCounts.push_back (count);
73- }
65+ assert (loops.size () <= 2 && " Expect at most 2 loops, one for the main loop "
66+ " and one for the prolog/epilog" );
67+ if (loops.size () == 2 ) {
68+ SmallVector<int , 2 > loopInstructionCounts;
69+ for (auto loop : loops) {
70+ loop->removeAttr (loopUnrollFactorAttrName);
71+ loop->removeAttr (unrolledLoopIdAttrName);
72+ int count = 0 ;
73+ loop->walk ([&](Operation *op) { count++; });
74+ loopInstructionCounts.push_back (count);
75+ }
7476
75- // sort the loops by the number of instructions. The unrolled main loop
76- // should go first.
77- if (loopInstructionCounts[0 ] < loopInstructionCounts[1 ])
78- std::swap (loops[0 ], loops[1 ]);
77+ // sort the loops by the number of instructions. The unrolled main loop
78+ // should go first.
79+ if (loopInstructionCounts[0 ] < loopInstructionCounts[1 ])
80+ std::swap (loops[0 ], loops[1 ]);
81+ }
7982
8083 return loops;
8184 }
0 commit comments