Skip to content

Commit f3fb638

Browse files
Mogballzwu-2025
authored andcommitted
[Pipeliner] Merge warp specialization and pipeliner scheduling (triton-lang#6887)
This PR refactors warp specialization to share the same scheduling as software pipelining. What this means is that the pipeliner's loop scheduler is used to set the stages and clusters of the ops, then on top of that, warp specialization will perform partition assignment and split the loop, introducing synchronization, into multiple loops that are then individually software pipelined.
1 parent 1aba27a commit f3fb638

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

42 files changed

+359
-366
lines changed

include/triton/Dialect/TritonGPU/Transforms/Partition.h

Lines changed: 26 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,6 @@
44
#include "mlir/Support/LLVM.h"
55
#include "llvm/ADT/ArrayRef.h"
66
#include "llvm/ADT/DenseMap.h"
7-
#include "llvm/ADT/GraphTraits.h"
8-
#include "llvm/ADT/MapVector.h"
97
#include "llvm/ADT/SmallVector.h"
108

119
namespace mlir {
@@ -26,39 +24,39 @@ static constexpr char kPartitionStagesAttrName[] = "ttg.partition.stages";
2624
//===----------------------------------------------------------------------===//
2725

2826
namespace mlir::triton::gpu {
27+
// A partition has a stage and contains some operation. The stage of a
28+
// partition determines how many cycles the partition's outputs are buffered
29+
// relative to its consumers.
30+
class Partition {
31+
public:
32+
Partition(int idx, int stage) : idx(idx), stage(stage) {}
33+
34+
int getIndex() const { return idx; }
35+
int getStage() const { return stage; }
36+
ArrayRef<Operation *> getOps() const { return ops; }
37+
38+
void insert(Operation *op) { ops.push_back(op); }
39+
void remove(Operation *op) { ops.erase(llvm::find(ops, op)); }
40+
41+
private:
42+
void setIndex(int idx) { this->idx = idx; }
43+
friend class WarpSchedule;
44+
45+
// The partition number.
46+
int idx;
47+
// The stage of the partition.
48+
int stage;
49+
// The ops in the partition.
50+
SmallVector<Operation *> ops;
51+
};
52+
2953
// A warp schedule divides a loop into multiple partitions. Ops in a loop are
3054
// assigned at most one partition. A warp schedule represents asynchronous
3155
// execution of the loop body, where partitions may execute simultaneously.
3256
class WarpSchedule {
3357
static constexpr int kSentinel = -1;
3458

3559
public:
36-
// A partition has a stage and contains some operation. The stage of a
37-
// partition determines how many cycles the partition's outputs are buffered
38-
// relative to its consumers.
39-
class Partition {
40-
public:
41-
Partition(int idx, int stage) : idx(idx), stage(stage) {}
42-
43-
int getIndex() const { return idx; }
44-
int getStage() const { return stage; }
45-
ArrayRef<Operation *> getOps() const { return ops; }
46-
47-
void insert(Operation *op) { ops.push_back(op); }
48-
void remove(Operation *op) { ops.erase(llvm::find(ops, op)); }
49-
50-
private:
51-
void setIndex(int idx) { this->idx = idx; }
52-
friend class WarpSchedule;
53-
54-
// The partition number.
55-
int idx;
56-
// The stage of the partition.
57-
int stage;
58-
// The ops in the partition.
59-
SmallVector<Operation *> ops;
60-
};
61-
6260
// Create a new partition with a stage.
6361
Partition *addPartition(unsigned stage);
6462
// Update the op to partition mapping.

include/triton/Dialect/TritonGPU/Transforms/Passes.td

Lines changed: 9 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -26,36 +26,27 @@ def TritonGPUPipeline : Pass<"tritongpu-pipeline", "mlir::ModuleOp"> {
2626
];
2727
}
2828

29-
def TritonGPUTestPipelineAssignLatencies : Pass<"tritongpu-test-pipeline-assign-latencies", "mlir::ModuleOp"> {
30-
let summary = "test assigning latencies to interesting ops ahead of pipelining";
29+
def TritonGPUAssignLatencies : Pass<"tritongpu-assign-latencies", "mlir::ModuleOp"> {
30+
let summary = "assign latencies to interesting ops ahead of pipelining";
3131

3232
let description = [{
33-
This is a test pass that tests `assignLatencies` method of `TritonGPUPipeline`.
33+
The `tritongpu-assign-latencies` pass assigns latencies to latency ops based
34+
on the number of stages.
3435
}];
3536

36-
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect",
37-
"mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect",
38-
"mlir::scf::SCFDialect",
39-
"mlir::arith::ArithDialect"];
40-
4137
let options = [
42-
Option<"numStages", "num-stages",
43-
"int32_t", /*default*/"3",
38+
Option<"numStages", "num-stages", "int32_t", /*default*/"3",
4439
"number of pipeline stages">
4540
];
4641
}
4742

48-
def TritonGPUTestPipelineScheduleLoop : Pass<"tritongpu-test-pipeline-schedule-loop", "mlir::ModuleOp"> {
49-
let summary = "test scheduling a loop for software pipelining";
43+
def TritonGPUScheduleLoops : Pass<"tritongpu-schedule-loops", "mlir::ModuleOp"> {
44+
let summary = "software pipeline loop scheduling";
5045

5146
let description = [{
52-
This is a test pass that tests `scheduleLoop` method of `TritonGPUPipeline`.
47+
The `tritongpu-schedule-loops` pass performs scheduling for loop pipelining
48+
for loops with latency ops.
5349
}];
54-
55-
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect",
56-
"mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect",
57-
"mlir::scf::SCFDialect",
58-
"mlir::arith::ArithDialect"];
5950
}
6051

6152
def TritonGPUHoistTMEMAlloc : Pass<"tritongpu-hoist-tmem-alloc", "mlir::ModuleOp"> {

include/triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,6 @@ static const char *kWarpSpecializeAttrName = "tt.warp_specialize";
1919
static const char *kLoopStageAttrName = "loop.stage";
2020
static const char *kLoopClusterAttrName = "loop.cluster";
2121
static const char *kScheduledMaxStageAttrName = "tt.scheduled_max_stage";
22-
static const char *kAssignedStageAttrName = "ttg.assigned_stage";
23-
static const char *kAssignedClusterAttrName = "ttg.assigned_cluster";
2422

2523
//===----------------------------------------------------------------------===//
2624
// Hoisting Utilities

include/triton/Dialect/TritonGPU/Transforms/Schedule.h

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,6 @@ namespace triton {
1313

1414
namespace gpu {
1515

16-
/// Discover operations that should become async and assign latencies to them
17-
/// based on the numStages value provided by the user.
18-
void assignLatencies(ModuleOp moduleOp, int numStages);
19-
20-
/// Schedule the loops based on the latencies assigned to the operations.
21-
void scheduleLoops(ModuleOp moduleOp);
22-
2316
/// Lower the loops to prepare them for pipeline expansion.
2417
void lowerLoops(ModuleOp moduleOp);
2518

@@ -115,6 +108,10 @@ class CoarseSchedule {
115108
bool insertDepsOfOp(Operation *op, int stage, CoarseSchedule::Cluster cluster,
116109
bool includeArg, bool insertIfEarlier = false);
117110

111+
// Remove empty stages and clusters from the schedule, adjusting the maximum
112+
// number of stages as appropriate.
113+
void shrinkToFit();
114+
118115
void erase(Operation *op) { opToStageAndCluster.erase(op); }
119116

120117
int count(Operation *op) { return opToStageAndCluster.count(op); }

lib/Dialect/TritonGPU/Transforms/CMakeLists.txt

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,6 @@ add_triton_library(TritonGPUTransforms
1616
Pipeliner/ScheduleLoops.cpp
1717
Pipeliner/WGMMAPipeline.cpp
1818
Pipeliner/PipelineExpander.cpp
19-
Pipeliner/TestPipelineAssignLatencies.cpp
20-
Pipeliner/TestPipelineScheduleLoop.cpp
2119
Pipeliner/TestPipelineLowerLoop.cpp
2220
Pipeliner/SoftwarePipeliner.cpp
2321
Pipeliner/TMAStoresPipeline.cpp
@@ -33,6 +31,7 @@ add_triton_library(TritonGPUTransforms
3331
WarpSpecialization/LoadMMASpecialization.cpp
3432
WarpSpecialization/Partition.cpp
3533
WarpSpecialization/OptimizePartitionWarps.cpp
34+
WarpSpecialization/PartitionBuilder.cpp
3635
WarpSpecialization/PartitionLoops.cpp
3736
WarpSpecialization/PartitionScheduling.cpp
3837
WarpSpecialization/RewritePartitionDependencies.cpp

lib/Dialect/TritonGPU/Transforms/Pipeliner/AssignLatencies.cpp

Lines changed: 39 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,13 @@ namespace tt = mlir::triton;
1818
namespace ttg = mlir::triton::gpu;
1919
namespace ttng = mlir::triton::nvidia_gpu;
2020

21-
namespace mlir {
22-
namespace triton {
23-
namespace gpu {
24-
21+
namespace mlir::triton::gpu {
2522
namespace {
2623

24+
//===----------------------------------------------------------------------===//
25+
// assignLatencies
26+
//===----------------------------------------------------------------------===//
27+
2728
// Return true if the preconditions for pipelining the loop are met.
2829
bool preCondition(scf::ForOp forOp) {
2930
// Skip loop with distance > 1 for now.
@@ -293,6 +294,15 @@ class AssignMMALatencies {
293294
// MMA's users can be pushed to the next stage
294295
opLatency[&op] = 1;
295296
}
297+
// HACK: A pipelined MMA's latency should equal the number of buffers
298+
// for the accumulator, but when the user is in an `scf.if` in SWP,
299+
// the `scf.if` is pushed to the end of the loop rather than peeled
300+
// before the MMA op, requiring an extra buffer due to liverange
301+
// overlap. WS does not have this problem because the MMA is placed in
302+
// a different partition than the MMA, so we can correctly set the
303+
// latency.
304+
if (forOp->hasAttr(kWarpSpecializeAttrName))
305+
opLatency[&op] += 1;
296306
}
297307
}
298308
}
@@ -312,12 +322,13 @@ class AssignMMALatencies {
312322
}
313323
};
314324

315-
} // namespace
316-
317-
// Look for load ops that directly or indirectly feed into dot ops. Based
318-
// on the requested number of stages assign the latencies in a way that
319-
// cover all the stages with the sum of latencies in the chain from the first
320-
// load to the final dot op.
325+
// Discover operations that should become async and assign latencies to them
326+
// based on the numStages value provided by the user.
327+
//
328+
// Look for load ops that directly or indirectly feed into dot ops. Based on the
329+
// requested number of stages assign the latencies in a way that cover all the
330+
// stages with the sum of latencies in the chain from the first load to the
331+
// final dot op.
321332
void assignLatencies(ModuleOp moduleOp, int defaultNumStages) {
322333
SmallVector<scf::ForOp> loops;
323334
moduleOp->walk([&](scf::ForOp forOp) {
@@ -341,6 +352,21 @@ void assignLatencies(ModuleOp moduleOp, int defaultNumStages) {
341352
}
342353
serializeLatencies(moduleOp, opLatency);
343354
}
344-
} // namespace gpu
345-
} // namespace triton
346-
} // namespace mlir
355+
356+
} // namespace
357+
358+
//===----------------------------------------------------------------------===//
359+
// Pass Definition
360+
//===----------------------------------------------------------------------===//
361+
362+
#define GEN_PASS_DEF_TRITONGPUASSIGNLATENCIES
363+
#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc"
364+
365+
struct AssignLatencies
366+
: public impl::TritonGPUAssignLatenciesBase<AssignLatencies> {
367+
using TritonGPUAssignLatenciesBase::TritonGPUAssignLatenciesBase;
368+
369+
void runOnOperation() override { assignLatencies(getOperation(), numStages); }
370+
};
371+
372+
} // namespace mlir::triton::gpu

lib/Dialect/TritonGPU/Transforms/Pipeliner/PipeliningUtility.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ Operation *mlir::triton::predicateOp(RewriterBase &rewriter, Operation *op,
161161
OpBuilder::InsertionGuard guard(rewriter);
162162
if (mlir::isMemoryEffectFree(op))
163163
return op;
164-
if (isa<LLVM::AssumeOp>(op))
164+
if (isa<LLVM::AssumeOp, ttng::FenceAsyncSharedOp>(op))
165165
return op;
166166
if (isa<ttg::AsyncCommitGroupOp, ttg::AsyncWaitOp>(op))
167167
return op;
@@ -264,7 +264,7 @@ Operation *mlir::triton::predicateOp(RewriterBase &rewriter, Operation *op,
264264
return op;
265265
}
266266

267-
op->emitError("pipeliner doesn't know how to predicate this op.");
267+
op->emitOpError("pipeliner doesn't know how to predicate this op.");
268268
llvm::report_fatal_error("Fatal pipeliner error");
269269
return op;
270270
}

lib/Dialect/TritonGPU/Transforms/Pipeliner/Schedule.cpp

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,19 @@ bool tt::CoarseSchedule::insertDepsOfOp(Operation *op, int stage,
8787
return inserted;
8888
}
8989

90+
void tt::CoarseSchedule::shrinkToFit() {
91+
int minStage = std::numeric_limits<int>::max();
92+
int maxStage = std::numeric_limits<int>::min();
93+
for (auto &[op, stageAndCluster] : opToStageAndCluster) {
94+
auto [stage, cluster] = stageAndCluster;
95+
minStage = std::min(minStage, stage);
96+
maxStage = std::max(maxStage, stage);
97+
}
98+
for (auto &[op, stageAndCluster] : opToStageAndCluster)
99+
stageAndCluster.first -= minStage;
100+
numStages = maxStage - minStage + 1;
101+
}
102+
90103
// Split the cluster containing op into two clusters, one containing all
91104
// operations before the op and one containing op and all operations after the
92105
// op. Return the cluster containing op and all operations after the op. Do not
@@ -282,7 +295,8 @@ void tt::scheduleDependencies(scf::ForOp forOp, tt::CoarseSchedule &schedule) {
282295
for (auto [op, stage_, cluster] : opsInOrder) {
283296
if (stage_ != stage)
284297
continue;
285-
schedule.insertDepsOfOp(op, stage, cluster, false);
298+
schedule.insertDepsOfOp(op, stage, cluster, /*includeArg=*/false,
299+
/*insertIfEarlier=*/true);
286300
}
287301
}
288302
}

0 commit comments

Comments
 (0)