@@ -59,10 +59,8 @@ class LoopUnrollPass : public TritonLoopUnrollBase<LoopUnrollPass> {
5959 loops.push_back (forOp);
6060 });
6161
62- // check which one is the unrolled loop and which one is the prolog/epilog
63- // loop. A simple heuristic is to check the number of instructions in the
64- // loop. The unrolled main loop should have the most instructions.
65- assert (loops.size () == 2 && " only support unrolling one loop at a time" );
62+ assert (loops.size () <= 2 && " Expect at most 2 loops, one for the main loop "
63+ " and one for the prolog/epilog" );
6664 SmallVector<int , 2 > loopInstructionCounts;
6765 for (auto loop : loops) {
6866 loop->removeAttr (loopUnrollFactorAttrName);
@@ -71,11 +69,15 @@ class LoopUnrollPass : public TritonLoopUnrollBase<LoopUnrollPass> {
7169 loop->walk ([&](Operation *op) { count++; });
7270 loopInstructionCounts.push_back (count);
7371 }
74-
75- // sort the loops by the number of instructions. The unrolled main loop
76- // should go first.
77- if (loopInstructionCounts[0 ] < loopInstructionCounts[1 ])
78- std::swap (loops[0 ], loops[1 ]);
72+ if (loops.size () == 2 ) {
73+ // check which one is the unrolled loop and which one is the prolog/epilog
74+ // loop. A simple heuristic is to check the number of instructions in the
75+ // loop. The unrolled main loop should have the most instructions.
76+ // sort the loops by the number of instructions. The unrolled main loop
77+ // should go first.
78+ if (loopInstructionCounts[0 ] < loopInstructionCounts[1 ])
79+ std::swap (loops[0 ], loops[1 ]);
80+ }
7981
8082 return loops;
8183 }
0 commit comments