Skip to content

Commit a47ae7b

Browse files
committed
update and clean up
1 parent 17dbc36 commit a47ae7b

File tree

3 files changed

+29
-36
lines changed

3 files changed

+29
-36
lines changed

lib/Dialect/TritonGPU/Transforms/Pipeliner/PipelineExpander.cpp

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,8 @@ struct LoopPipelinerInternal {
6666
bool dynamicLoop;
6767
bool peelEpilogue;
6868
bool guardEpilogue;
69-
triton::PipeliningOption::PredicateOpFnType predicateFn = nullptr;
7069
triton::PipeliningOption::AnnotationlFnType annotateFn = nullptr;
70+
triton::PipeliningOption::PredicateOpFnType predicateFn = nullptr;
7171

7272
// When peeling the kernel we generate several version of each value for
7373
// different stage of the prologue. This map tracks the mapping between
@@ -158,7 +158,6 @@ bool LoopPipelinerInternal::initializeLoopInfo(
158158
peelEpilogue = options.peelEpilogue;
159159
predicateFn = options.predicateFn;
160160
guardEpilogue = options.guardEpilogue;
161-
162161
if ((!peelEpilogue || dynamicLoop) && predicateFn == nullptr) {
163162
LDBG("--no epilogue or predicate set -> BAIL");
164163
return false;
@@ -776,16 +775,11 @@ LoopPipelinerInternal::emitEpilogue(RewriterBase &rewriter,
776775
unsigned nextVersion = currentVersion + 1;
777776
Value pred = predicates[currentVersion];
778777
Value prevValue = valueMapping[mapVal][currentVersion];
779-
Value nextValue = pair.value();
780-
781-
if (guardEpilogue) {
782-
nextValue = rewriter.create<arith::SelectOp>(loc, pred, nextValue,
783-
prevValue);
784-
}
785-
786-
returnValues[ri] = nextValue;
778+
auto selOp = rewriter.create<arith::SelectOp>(loc, pred, pair.value(),
779+
prevValue);
780+
returnValues[ri] = selOp;
787781
if (nextVersion <= maxStage)
788-
setValueMapping(mapVal, nextValue, nextVersion);
782+
setValueMapping(mapVal, selOp, nextVersion);
789783
}
790784
}
791785
}

third_party/amd/backend/compiler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,7 @@ def make_ttgir(mod, metadata, options):
246246
global_prefetch = int(os.getenv("TRITON_HIP_GLOBAL_PREFETCH", "0"))
247247
local_prefetch = int(os.getenv("TRITON_HIP_LOCAL_PREFETCH", "0"))
248248
use_async_copy = int(os.getenv("TRITON_HIP_USE_ASYNC_COPY", "0")) == 1
249-
must_guard_epilogue = 1
249+
must_guard_epilogue = int(os.getenv("TRITON_HIP_GUARD_EPILOGUE", "1"))
250250

251251
# The `local-prefetch` scheduling variant requires turning on buffer ops.
252252
if options.schedule_hint == "local-prefetch":

third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp

Lines changed: 23 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -127,14 +127,6 @@ class StreamPipeliner {
127127
stages[SCHED_LOCAL_STORE] = _globalPrefetch;
128128
stages[SCHED_LOCAL_LOAD] = lastStage - _localPrefetch;
129129
stages[SCHED_COMPUTE] = lastStage;
130-
131-
options.supportDynamicLoops = true;
132-
options.peelEpilogue = true;
133-
options.guardEpilogue = _mustGuardEpilogue;
134-
if (_mustGuardEpilogue)
135-
options.predicateFn = streamPredication;
136-
else
137-
options.predicateFn = tt::predicateOp;
138130
}
139131

140132
LogicalResult pipelineLoop();
@@ -210,9 +202,6 @@ class StreamPipeliner {
210202

211203
// Capture list of new shared memory buffers.
212204
SmallVector<Value> sharedMemAllocs;
213-
214-
// Pipelining options for the PipelineExpander
215-
tt::PipeliningOption options;
216205
};
217206

218207
} // namespace
@@ -277,7 +266,8 @@ bool StreamPipeliner::safeDAG(Value v, int index) {
277266
}
278267
return false;
279268
}
280-
269+
// TODO(crobeck): is this valid if we have loop-carried
270+
// results needed for the next epilogue stage?
281271
void StreamPipeliner::checkResultResilience() {
282272
auto yieldVals = forOp.getYieldedValuesMutable().value();
283273
for (auto [index, res] : llvm::enumerate(forOp.getResults())) {
@@ -1016,18 +1006,6 @@ LogicalResult StreamPipeliner::preprocessLoopAndBuildSchedule() {
10161006
schedule.dump();
10171007
});
10181008

1019-
// Create the final schedule for the kernel loop. This will dictate the
1020-
// stages and order of operations to the pipeline expander.
1021-
std::vector<std::pair<Operation *, unsigned>> coarseSchedule =
1022-
schedule.createFinalSchedule(forOp);
1023-
1024-
// Fill out the pipeline options.
1025-
options.getScheduleFn =
1026-
[coarseSchedule](scf::ForOp,
1027-
std::vector<std::pair<Operation *, unsigned>> &s) {
1028-
s = std::move(coarseSchedule);
1029-
};
1030-
10311009
OpBuilder builder(forOp);
10321010
builder.setInsertionPointAfter(forOp);
10331011
// Explicitly deallocate created allocations.
@@ -1041,6 +1019,27 @@ LogicalResult StreamPipeliner::pipelineLoop() {
10411019
if (failed(preprocessLoopAndBuildSchedule()))
10421020
return failure();
10431021
LDBG("Loop before sending to expander:\n" << *forOp);
1022+
// Create the final schedule for the kernel loop. This will dictate the
1023+
// stages and order of operations to the pipeline expander.
1024+
std::vector<std::pair<Operation *, unsigned>> coarseSchedule =
1025+
schedule.createFinalSchedule(forOp);
1026+
1027+
// Pipelining options for the PipelineExpander
1028+
tt::PipeliningOption options;
1029+
options.supportDynamicLoops = true;
1030+
options.peelEpilogue = true;
1031+
options.guardEpilogue = mustGuardEpilogue;
1032+
1033+
if (mustGuardEpilogue)
1034+
options.predicateFn = streamPredication;
1035+
else
1036+
options.predicateFn = tt::predicateOp;
1037+
1038+
options.getScheduleFn =
1039+
[&coarseSchedule](scf::ForOp,
1040+
std::vector<std::pair<Operation *, unsigned>> &s) {
1041+
s = std::move(coarseSchedule);
1042+
};
10441043

10451044
IRRewriter rewriter(forOp->getContext());
10461045
rewriter.setInsertionPoint(forOp);

0 commit comments

Comments
 (0)