Skip to content

Commit 6e16e01

Browse files
committed
Simplified based on MLIR-side changes.
1 parent 3eb537b commit 6e16e01

File tree

1 file changed

+8
-53
lines changed

1 file changed

+8
-53
lines changed

lib/Dialect/Triton/Transforms/LoopUnroll.cpp

Lines changed: 8 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
#include <memory>
22

33
#include "mlir/Dialect/SCF/Utils/Utils.h"
4-
#include "mlir/IR/Attributes.h"
54
#include "mlir/IR/BuiltinAttributes.h"
65
#include "mlir/IR/Matchers.h"
76
#include "mlir/IR/PatternMatch.h"
@@ -36,52 +35,12 @@ class LoopUnrollPass : public TritonLoopUnrollBase<LoopUnrollPass> {
3635
return 1;
3736
}
3837

39-
int getUnrollIdOrDefault(scf::ForOp forOp) {
40-
// Use the attribute attached to the loop if it exists otherwise set the
41-
// factor to 1 to suppress the unrolling.
42-
if (auto factor = forOp->getAttrOfType<IntegerAttr>(unrolledLoopIdAttrName))
43-
return factor.getInt();
44-
return 0;
45-
}
46-
4738
const char *loopUnrollFactorAttrName = "tt.loop_unroll_factor";
48-
const char *unrolledLoopIdAttrName = "tt.unrolled_loop_id";
4939
const char *pipelineStagesAttrName = "tt.num_stages";
5040

5141
public:
5242
LoopUnrollPass() = default;
5343
LoopUnrollPass(const LoopUnrollPass &) {}
54-
55-
SmallVector<scf::ForOp, 2> getUnrolledLoopsAndClearAttrs(unsigned loopId) {
56-
SmallVector<scf::ForOp, 2> loops;
57-
getOperation()->walk([&](scf::ForOp forOp) {
58-
if (getUnrollIdOrDefault(forOp) == loopId)
59-
loops.push_back(forOp);
60-
});
61-
62-
assert(loops.size() <= 2 && "Expect at most 2 loops, one for the main loop "
63-
"and one for the prolog/epilog");
64-
SmallVector<int, 2> loopInstructionCounts;
65-
for (auto loop : loops) {
66-
loop->removeAttr(loopUnrollFactorAttrName);
67-
loop->removeAttr(unrolledLoopIdAttrName);
68-
int count = 0;
69-
loop->walk([&](Operation *op) { count++; });
70-
loopInstructionCounts.push_back(count);
71-
}
72-
if (loops.size() == 2) {
73-
// check which one is the unrolled loop and which one is the prolog/epilog
74-
// loop. A simple heuristic is to check the number of instructions in the
75-
// loop. The unrolled main loop should have the most instructions.
76-
// sort the loops by the number of instructions. The unrolled main loop
77-
// should go first.
78-
if (loopInstructionCounts[0] < loopInstructionCounts[1])
79-
std::swap(loops[0], loops[1]);
80-
}
81-
82-
return loops;
83-
}
84-
8544
void runOnOperation() override {
8645
LDBG("Loop unroll pass");
8746
SmallVector<scf::ForOp, 4> loops;
@@ -92,20 +51,16 @@ class LoopUnrollPass : public TritonLoopUnrollBase<LoopUnrollPass> {
9251
});
9352

9453
auto ctx = getOperation()->getContext();
95-
for (unsigned i = 0; i < loops.size(); i++) {
96-
auto loop = loops[i];
54+
for (auto loop : loops) {
9755
auto unrollFactor = getUnrollFactorOrDefault(loop);
98-
loop->setAttr(unrolledLoopIdAttrName,
99-
mlir::IntegerAttr::get(IntegerType::get(ctx, 32), i + 1));
56+
loop->removeAttr(loopUnrollFactorAttrName);
10057
LDBG("Unrolling loop by " << unrollFactor << " times\n" << loop);
101-
(void)loopUnrollByFactor(loop, unrollFactor);
102-
auto unrolledLoops = getUnrolledLoopsAndClearAttrs(i + 1);
103-
// Do not pipeline the prolog/epilog loop.
104-
if (unrolledLoops.size() == 2) {
105-
auto prologEpilogLoop = unrolledLoops[1];
106-
prologEpilogLoop->setAttr(
107-
pipelineStagesAttrName,
108-
mlir::IntegerAttr::get(IntegerType::get(ctx, 32), 1));
58+
auto resultLoops = loopUnrollByFactor(loop, unrollFactor);
59+
// Do not pipeline the epilog loop.
60+
if (succeeded(resultLoops) && resultLoops->epilogueLoopOp) {
61+
(*resultLoops->epilogueLoopOp)
62+
->setAttr(pipelineStagesAttrName,
63+
mlir::IntegerAttr::get(IntegerType::get(ctx, 32), 1));
10964
}
11065
}
11166
}

0 commit comments

Comments
 (0)