Skip to content

Commit 1e4b9ba

Browse files
committed
[BACKEND] Do not pipeline epilog loops generated by loop unrolling
1 parent 4f6f768 commit 1e4b9ba

File tree

2 files changed

+58
-6
lines changed

2 files changed

+58
-6
lines changed

lib/Dialect/Triton/Transforms/LoopUnroll.cpp

Lines changed: 57 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#include <memory>
22

33
#include "mlir/Dialect/SCF/Utils/Utils.h"
4+
#include "mlir/IR/Attributes.h"
45
#include "mlir/IR/BuiltinAttributes.h"
56
#include "mlir/IR/Matchers.h"
67
#include "mlir/IR/PatternMatch.h"
@@ -22,24 +23,63 @@
2223

2324
namespace mlir::triton {
2425

25-
static const char *loopUnrollFactorAttrName = "tt.loop_unroll_factor";
26-
2726
namespace {
2827

2928
class LoopUnrollPass : public TritonLoopUnrollBase<LoopUnrollPass> {
3029

3130
int getUnrollFactorOrDefault(scf::ForOp forOp) {
3231
// Use the attribute attached to the loop if it exists otherwise set the
3332
// factor to 1 to suppress the unrolling.
34-
if (auto factor = forOp->getAttrOfType<IntegerAttr>(
35-
mlir::triton::loopUnrollFactorAttrName))
33+
if (auto factor =
34+
forOp->getAttrOfType<IntegerAttr>(loopUnrollFactorAttrName))
3635
return factor.getInt();
3736
return 1;
3837
}
3938

39+
int getUnrollIdOrDefault(scf::ForOp forOp) {
40+
// Use the attribute attached to the loop if it exists otherwise set the
41+
// factor to 1 to suppress the unrolling.
42+
if (auto factor = forOp->getAttrOfType<IntegerAttr>(unrolledLoopIdAttrName))
43+
return factor.getInt();
44+
return 0;
45+
}
46+
47+
const char *loopUnrollFactorAttrName = "tt.loop_unroll_factor";
48+
const char *unrolledLoopIdAttrName = "tt.unrolled_loop_id";
49+
const char *pipelineStagesAttrName = "tt.num_stages";
50+
4051
public:
4152
LoopUnrollPass() = default;
4253
LoopUnrollPass(const LoopUnrollPass &) {}
54+
55+
SmallVector<scf::ForOp, 2> getUnrolledLoopsAndClearAttrs(unsigned loopId) {
56+
SmallVector<scf::ForOp, 2> loops;
57+
getOperation()->walk([&](scf::ForOp forOp) {
58+
if (getUnrollIdOrDefault(forOp) == loopId)
59+
loops.push_back(forOp);
60+
});
61+
62+
// check which one is the unrolled loop and which one is the prolog/epilog
63+
// loop. A simple heuristic is to check the number of instructions in the
64+
// loop. The unrolled main loop should have the most instructions.
65+
assert(loops.size() == 2 && "only support unrolling one loop at a time");
66+
SmallVector<int, 2> loopInstructionCounts;
67+
for (auto loop : loops) {
68+
loop->removeAttr(loopUnrollFactorAttrName);
69+
loop->removeAttr(unrolledLoopIdAttrName);
70+
int count = 0;
71+
loop->walk([&](Operation *op) { count++; });
72+
loopInstructionCounts.push_back(count);
73+
}
74+
75+
// sort the loops by the number of instructions. The unrolled main loop
76+
// should go first.
77+
if (loopInstructionCounts[0] < loopInstructionCounts[1])
78+
std::swap(loops[0], loops[1]);
79+
80+
return loops;
81+
}
82+
4383
void runOnOperation() override {
4484
LDBG("Loop unroll pass");
4585
SmallVector<scf::ForOp, 4> loops;
@@ -49,11 +89,22 @@ class LoopUnrollPass : public TritonLoopUnrollBase<LoopUnrollPass> {
4989
loops.push_back(forOp);
5090
});
5191

52-
for (auto loop : loops) {
92+
auto ctx = getOperation()->getContext();
93+
for (unsigned i = 0; i < loops.size(); i++) {
94+
auto loop = loops[i];
5395
auto unrollFactor = getUnrollFactorOrDefault(loop);
54-
loop->removeAttr(mlir::triton::loopUnrollFactorAttrName);
96+
loop->setAttr(unrolledLoopIdAttrName,
97+
mlir::IntegerAttr::get(IntegerType::get(ctx, 32), i + 1));
5598
LDBG("Unrolling loop by " << unrollFactor << " times\n" << loop);
5699
(void)loopUnrollByFactor(loop, unrollFactor);
100+
auto unrolledLoops = getUnrolledLoopsAndClearAttrs(i + 1);
101+
// Do not pipeline the prolog/epilog loop.
102+
if (unrolledLoops.size() == 2) {
103+
auto prologEpilogLoop = unrolledLoops[1];
104+
prologEpilogLoop->setAttr(
105+
pipelineStagesAttrName,
106+
mlir::IntegerAttr::get(IntegerType::get(ctx, 32), 1));
107+
}
57108
}
58109
}
59110
};

test/Triton/loop-unroll.mlir

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ tt.func @add_kernel_unroll(%arg0: tensor<256x!tt.ptr<f32>>, %arg1: i32) {
1313
// CHECK: scf.for
1414
// CHECK: tt.load
1515
// CHECK-NOT: tt.load
16+
// CHECK: tt.num_stages = 1 : i32
1617
%2:2 = scf.for %arg3 = %c1_i32 to %arg1 step %c1_i32 iter_args(%arg4 = %1, %arg5 = %arg0) -> (tensor<256xf32>, tensor<256x!tt.ptr<f32>>) : i32 {
1718
%3 = tt.load %arg5 : tensor<256x!tt.ptr<f32>>
1819
%4 = arith.addf %arg4, %3 : tensor<256xf32>

0 commit comments

Comments
 (0)