Skip to content

Commit f0552a9

Browse files
committed
optimize conditionals for compute logic in epilogue that is wholly dependent on pipelined loads
1 parent 3eb8501 commit f0552a9

File tree

8 files changed

+139
-17
lines changed

8 files changed

+139
-17
lines changed

include/triton/Dialect/TritonGPU/Transforms/PipelineExpander.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,8 @@ struct PipeliningOption {
5151
/// lambda to generate the predicated version of operations.
5252
bool peelEpilogue = true;
5353

54+
bool guardEpilogue = true;
55+
5456
/// Control whether the transformation checks that the number of iterations is
5557
/// greater or equal to the number of stages and skip the transformation if
5658
/// this is not the case. If the loop is dynamic and this is set to true the

lib/Dialect/TritonGPU/Transforms/Pipeliner/PipelineExpander.cpp

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,9 +64,10 @@ struct LoopPipelinerInternal {
6464
Value lb;
6565
Value step;
6666
bool dynamicLoop;
67-
triton::PipeliningOption::AnnotationlFnType annotateFn = nullptr;
6867
bool peelEpilogue;
68+
bool guardEpilogue;
6969
triton::PipeliningOption::PredicateOpFnType predicateFn = nullptr;
70+
triton::PipeliningOption::AnnotationlFnType annotateFn = nullptr;
7071

7172
// When peeling the kernel we generate several version of each value for
7273
// different stage of the prologue. This map tracks the mapping between
@@ -156,6 +157,8 @@ bool LoopPipelinerInternal::initializeLoopInfo(
156157
}
157158
peelEpilogue = options.peelEpilogue;
158159
predicateFn = options.predicateFn;
160+
guardEpilogue = options.guardEpilogue;
161+
159162
if ((!peelEpilogue || dynamicLoop) && predicateFn == nullptr) {
160163
LDBG("--no epilogue or predicate set -> BAIL");
161164
return false;
@@ -775,6 +778,13 @@ LoopPipelinerInternal::emitEpilogue(RewriterBase &rewriter,
775778
Value prevValue = valueMapping[mapVal][currentVersion];
776779
auto selOp = rewriter.create<arith::SelectOp>(loc, pred, pair.value(),
777780
prevValue);
781+
782+
if (guardEpilogue) {
783+
Value nextValue = pair.value();
784+
selOp = rewriter.create<arith::SelectOp>(loc, pred, nextValue,
785+
prevValue);
786+
}
787+
778788
returnValues[ri] = selOp;
779789
if (nextVersion <= maxStage)
780790
setValueMapping(mapVal, selOp, nextVersion);

python/src/passes.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,11 @@
1919
m.def(name, [](mlir::PassManager &pm, ty0 val0, ty1 val1, ty2 val2, \
2020
ty3 val3) { pm.addPass(builder(val0, val1, val2, val3)); })
2121

22+
#define ADD_PASS_WRAPPER_5(name, builder, ty0, ty1, ty2, ty3, ty4) \
23+
m.def(name, \
24+
[](mlir::PassManager &pm, ty0 val0, ty1 val1, ty2 val2, ty3 val3, \
25+
ty4 val4) { pm.addPass(builder(val0, val1, val2, val3, val4)); })
26+
2227
#define ADD_PASS_OPTION_WRAPPER_1(name, builder, ty0) \
2328
m.def(name, \
2429
[](mlir::PassManager &pm, ty0 val0) { pm.addPass(builder({val0})); })

third_party/amd/backend/compiler.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -246,12 +246,13 @@ def make_ttgir(mod, metadata, options):
246246
global_prefetch = int(os.getenv("TRITON_HIP_GLOBAL_PREFETCH", "0"))
247247
local_prefetch = int(os.getenv("TRITON_HIP_LOCAL_PREFETCH", "0"))
248248
use_async_copy = int(os.getenv("TRITON_HIP_USE_ASYNC_COPY", "0")) == 1
249+
must_guard_epilogue = true;
249250

250251
# The `local-prefetch` scheduling variant requires turning on buffer ops.
251252
if options.schedule_hint == "local-prefetch":
252253
global_prefetch = local_prefetch = 1
253254

254-
amd.passes.ttgpuir.add_stream_pipeline(pm, options.num_stages, global_prefetch, local_prefetch, use_async_copy)
255+
amd.passes.ttgpuir.add_stream_pipeline(pm, options.num_stages, global_prefetch, local_prefetch, use_async_copy, must_guard_epilogue)
255256
if use_async_copy:
256257
amd.passes.ttgpuir.add_coalesce_async_copy(pm, options.arch)
257258
passes.common.add_canonicalizer(pm)

third_party/amd/include/TritonAMDGPUTransforms/Passes.h

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,9 @@
88

99
namespace mlir {
1010

11-
std::unique_ptr<Pass>
12-
createTritonAMDGPUStreamPipelinePass(int numStages = 2, int globalPrefetch = 0,
13-
int localPrefetch = 0,
14-
bool useAsyncCopy = false);
11+
std::unique_ptr<Pass> createTritonAMDGPUStreamPipelinePass(
12+
int numStages = 2, int globalPrefetch = 0, int localPrefetch = 0,
13+
bool useAsyncCopy = false, bool mustGuardEpilogue = true);
1514

1615
std::unique_ptr<Pass>
1716
createTritonAMDGPUAccelerateMatmulPass(std::string archGenName = std::string(),

third_party/amd/include/TritonAMDGPUTransforms/Passes.td

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,9 @@ def TritonAMDGPUStreamPipeline : Pass<"tritonamdgpu-stream-pipeline", "mlir::Mod
2828
Option<"useAsyncCopy", "use_async_copy",
2929
"bool", /*default*/"false",
3030
"Use AsyncCopyGlobalToLocal to directly load to shared memory">,
31+
Option<"mustGuardEpilogue", "must_guard_epilogue",
32+
"bool", /*default*/"true",
33+
"Require conditionalized loads in epiloge of pipelined loop">,
3134
];
3235
}
3336

third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp

Lines changed: 111 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -116,9 +116,11 @@ class StreamPipeliner {
116116

117117
public:
118118
StreamPipeliner(scf::ForOp _forOp, int _numStages, int _globalPrefetch,
119-
int _localPrefetch, bool _useAsyncCopy)
119+
int _localPrefetch, bool _useAsyncCopy,
120+
bool _mustGuardEpilogue)
120121
: forOp(_forOp), numStages(_numStages), numBuffers(1),
121-
useAsyncCopy(_useAsyncCopy), schedule(numStages),
122+
useAsyncCopy(_useAsyncCopy), mustGuardEpilogue(_mustGuardEpilogue),
123+
schedule(numStages),
122124
axisInfoAnalysis(forOp->getParentOfType<ModuleOp>()) {
123125
int lastStage = numStages - 1;
124126
stages[SCHED_GLOBAL_LOAD] = 0;
@@ -128,7 +130,11 @@ class StreamPipeliner {
128130

129131
options.supportDynamicLoops = true;
130132
options.peelEpilogue = true;
131-
options.predicateFn = streamPredication;
133+
options.guardEpilogue = _mustGuardEpilogue;
134+
if (_mustGuardEpilogue)
135+
options.predicateFn = streamPredication;
136+
else
137+
options.predicateFn = tt::predicateOp;
132138
}
133139

134140
LogicalResult pipelineLoop();
@@ -151,6 +157,10 @@ class StreamPipeliner {
151157
void createStreamCopy(tt::LoadOp loadOp, Value alloc, Value extractIdx);
152158
void createStreamOps();
153159

160+
// Unguard epilogue
161+
bool safeDAG(Value v, int index);
162+
void checkResultResilience();
163+
154164
void scheduleOp(Operation *op, SchedType type, int stage = -1) {
155165
if (stage < 0)
156166
stage = stages[type];
@@ -170,6 +180,8 @@ class StreamPipeliner {
170180
// Directly store to shared memory with AsyncCopy when pipelining tt.loads
171181
bool useAsyncCopy;
172182

183+
bool mustGuardEpilogue;
184+
173185
// Stage for each SchedType Op
174186
int stages[SCHED_SIZE];
175187
// Cluster for each SchedType Op
@@ -205,6 +217,80 @@ class StreamPipeliner {
205217

206218
} // namespace
207219

220+
bool StreamPipeliner::safeDAG(Value v, int index) {
221+
if (Operation *defOp = v.getDefiningOp()) {
222+
if (auto loadOp = dyn_cast<tt::LoadOp>(defOp)) {
223+
// Loads in the loop will be guarded
224+
return loadOp->getParentOfType<scf::ForOp>() == forOp;
225+
} else if (auto cvtOp = dyn_cast<ttg::ConvertLayoutOp>(defOp)) {
226+
return safeDAG(cvtOp.getSrc(), index);
227+
} else if (auto dotOp = dyn_cast<tt::DotOpInterface>(defOp)) {
228+
// 1 input must be safe
229+
if (!safeDAG(dotOp.getA(), -1) && !safeDAG(dotOp.getB(), -1))
230+
return false;
231+
auto C = dotOp->getOperand(2);
232+
return safeDAG(C, index);
233+
} else if (auto splatOp = dyn_cast<tt::SplatOp>(defOp)) {
234+
// both inputs must be safe
235+
return safeDAG(splatOp.getOperand(), -1);
236+
} else if (auto bcastOp = dyn_cast<tt::BroadcastOp>(defOp)) {
237+
// both inputs must be safe
238+
return safeDAG(bcastOp.getOperand(), -1);
239+
} else if (auto expandOp = dyn_cast<tt::ExpandDimsOp>(defOp)) {
240+
// both inputs must be safe
241+
return safeDAG(expandOp.getOperand(), -1);
242+
} else if (auto addOp = dyn_cast<arith::AddFOp>(defOp)) {
243+
// both inputs must be safe
244+
return safeDAG(addOp.getLhs(), -1) && safeDAG(addOp.getRhs(), -1);
245+
} else if (auto subOp = dyn_cast<arith::SubFOp>(defOp)) {
246+
// both inputs must be safe
247+
return safeDAG(subOp.getLhs(), -1) && safeDAG(subOp.getRhs(), -1);
248+
} else if (auto mulOp = dyn_cast<arith::MulFOp>(defOp)) {
249+
// either input must be safe
250+
return safeDAG(mulOp.getLhs(), -1) || safeDAG(mulOp.getRhs(), -1);
251+
} else if (auto truncOp = dyn_cast<arith::TruncFOp>(defOp)) {
252+
// either input must be safe
253+
return safeDAG(truncOp.getOperand(), -1);
254+
} else if (auto constOp = dyn_cast<arith::ConstantOp>(defOp)) {
255+
// check for constant zero
256+
if (auto attr = dyn_cast<FloatAttr>(constOp.getValue()))
257+
return attr.getValue().isZero();
258+
} else if (auto exp2Op = dyn_cast<math::Exp2Op>(defOp)) {
259+
// either input must be safe
260+
return safeDAG(exp2Op.getOperand(), -1);
261+
} else if (auto selectOp = dyn_cast<arith::SelectOp>(defOp)) {
262+
// both inputs must be safe
263+
return safeDAG(selectOp.getTrueValue(), -1) &&
264+
safeDAG(selectOp.getFalseValue(), -1);
265+
} else if (auto transposeOp = dyn_cast<tt::TransposeOpInterface>(defOp)) {
266+
// input must be safe
267+
return safeDAG(transposeOp.getSrc(), -1);
268+
} else {
269+
// Unknown op default to false
270+
LDBG("Unknown op for unguard epilogue in stream pipeliner");
271+
return false;
272+
}
273+
} else {
274+
// check block arg
275+
auto arg = cast<BlockArgument>(v);
276+
return arg.getArgNumber() == index;
277+
}
278+
return false;
279+
}
280+
281+
void StreamPipeliner::checkResultResilience() {
282+
auto yieldVals = forOp.getYieldedValuesMutable().value();
283+
for (auto [index, res] : llvm::enumerate(forOp.getResults())) {
284+
if (!res.use_empty()) {
285+
// Check init value == 0
286+
// Backtrack yield value
287+
Value yieldVal = yieldVals[index].get();
288+
if (!safeDAG(yieldVal, index + 1)) // + induction
289+
mustGuardEpilogue = true;
290+
}
291+
}
292+
}
293+
208294
// Init Schedule Config based on settings and loop characteristics.
209295
// Create clusters in order of ops in loop. This can interleave ops
210296
// from different stages in the same cluster to achieve better backend
@@ -213,6 +299,10 @@ class StreamPipeliner {
213299
// can cause invalid schedules to be produced.
214300
LogicalResult StreamPipeliner::initSchedule(int maxIndirectionLevel) {
215301

302+
// Check to see if we can unconditionalize epilogue
303+
if (!mustGuardEpilogue)
304+
checkResultResilience();
305+
216306
bool pairedGlobalLoadLocalStore = stages[SCHED_LOCAL_STORE] == 0;
217307
stages[SCHED_LOCAL_STORE] += maxIndirectionLevel;
218308

@@ -787,11 +877,16 @@ void StreamPipeliner::scheduleRemainingToLastStage() {
787877
// Assign the rest of the ops to the last stage.
788878
// Take care of the ordering of the ops - uses cannot be scheduled to the
789879
// cluster before the definition.
880+
int count = 0;
790881
auto cluster = clusters[SCHED_COMPUTE];
791882
DenseMap<Operation *, tt::CoarseSchedule::Cluster> opToCluster;
792883
for (auto &op : forOp.getBody()->without_terminator()) {
793-
if (schedule.count(&op) == 0)
884+
if (schedule.count(&op) == 0) {
794885
opToCluster[&op] = cluster;
886+
}
887+
}
888+
if (count != 0) {
889+
mustGuardEpilogue = true;
795890
}
796891
SmallVector<Operation *> queue;
797892
for (auto [op, stage, cluster] : schedule.getOpsInOrder(forOp)) {
@@ -1011,13 +1106,16 @@ void labelLoadOpsForTritonDot(scf::ForOp forOp) {
10111106
struct PipelinePass : public TritonAMDGPUStreamPipelineBase<PipelinePass> {
10121107
PipelinePass() = default;
10131108
PipelinePass(int32_t _numStages, int32_t _globalPrefetch,
1014-
int32_t _localPrefetch, bool _useAsyncCopy) {
1109+
int32_t _localPrefetch, bool _useAsyncCopy,
1110+
bool _mustGuardEpilogue) {
10151111
this->numStages = _numStages;
10161112

10171113
this->globalPrefetch = _globalPrefetch;
10181114
this->localPrefetch = _localPrefetch;
10191115

10201116
this->useAsyncCopy = _useAsyncCopy;
1117+
1118+
this->mustGuardEpilogue = _mustGuardEpilogue;
10211119
}
10221120

10231121
void runOnOperation() override {
@@ -1047,15 +1145,19 @@ struct PipelinePass : public TritonAMDGPUStreamPipelineBase<PipelinePass> {
10471145
if (!checkPrecondition(forOp))
10481146
continue;
10491147
StreamPipeliner sp(forOp, tt::getNumStagesOrDefault(forOp, numStages),
1050-
globalPrefetch, localPrefetch, useAsyncCopy);
1148+
globalPrefetch, localPrefetch, useAsyncCopy,
1149+
mustGuardEpilogue);
10511150
(void)sp.pipelineLoop();
10521151
}
10531152
}
10541153
};
10551154
} // namespace
10561155

1057-
std::unique_ptr<Pass> mlir::createTritonAMDGPUStreamPipelinePass(
1058-
int numStages, int globalPrefetch, int localPrefetch, bool useAsyncCopy) {
1156+
std::unique_ptr<Pass>
1157+
mlir::createTritonAMDGPUStreamPipelinePass(int numStages, int globalPrefetch,
1158+
int localPrefetch, bool useAsyncCopy,
1159+
bool mustGuardEpilogue) {
10591160
return std::make_unique<PipelinePass>(numStages, globalPrefetch,
1060-
localPrefetch, useAsyncCopy);
1161+
localPrefetch, useAsyncCopy,
1162+
mustGuardEpilogue);
10611163
}

third_party/amd/python/triton_amd.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,9 +78,9 @@ void init_triton_amd_passes_ttgpuir(py::module &&m) {
7878
mlir::createTritonAMDGPUFoldTrueCmpIPass);
7979
ADD_PASS_WRAPPER_1("add_block_pingpong",
8080
mlir::createTritonAMDGPUBlockPingpongPass, int32_t);
81-
ADD_PASS_WRAPPER_4("add_stream_pipeline",
81+
ADD_PASS_WRAPPER_5("add_stream_pipeline",
8282
mlir::createTritonAMDGPUStreamPipelinePass, int, int, int,
83-
bool);
83+
bool, bool);
8484
ADD_PASS_WRAPPER_1("add_coalesce_async_copy",
8585
mlir::createTritonAMDGPUCoalesceAsyncCopyPass,
8686
std::string);

0 commit comments

Comments
 (0)