@@ -116,9 +116,11 @@ class StreamPipeliner {
116116
117117public:
118118 StreamPipeliner (scf::ForOp _forOp, int _numStages, int _globalPrefetch,
119- int _localPrefetch, bool _useAsyncCopy)
119+ int _localPrefetch, bool _useAsyncCopy,
120+ bool _mustGuardEpilogue)
120121 : forOp(_forOp), numStages(_numStages), numBuffers(1 ),
121- useAsyncCopy (_useAsyncCopy), schedule(numStages),
122+ useAsyncCopy (_useAsyncCopy), mustGuardEpilogue(_mustGuardEpilogue),
123+ schedule(numStages),
122124 axisInfoAnalysis(forOp->getParentOfType<ModuleOp>()) {
123125 int lastStage = numStages - 1 ;
124126 stages[SCHED_GLOBAL_LOAD] = 0 ;
@@ -128,7 +130,11 @@ class StreamPipeliner {
128130
129131 options.supportDynamicLoops = true ;
130132 options.peelEpilogue = true ;
131- options.predicateFn = streamPredication;
133+ options.guardEpilogue = _mustGuardEpilogue;
134+ if (_mustGuardEpilogue)
135+ options.predicateFn = streamPredication;
136+ else
137+ options.predicateFn = tt::predicateOp;
132138 }
133139
134140 LogicalResult pipelineLoop ();
@@ -151,6 +157,10 @@ class StreamPipeliner {
151157 void createStreamCopy (tt::LoadOp loadOp, Value alloc, Value extractIdx);
152158 void createStreamOps ();
153159
160+ // Unguard epilogue
161+ bool safeDAG (Value v, int index);
162+ void checkResultResilience ();
163+
154164 void scheduleOp (Operation *op, SchedType type, int stage = -1 ) {
155165 if (stage < 0 )
156166 stage = stages[type];
@@ -170,6 +180,8 @@ class StreamPipeliner {
170180 // Directly store to shared memory with AsyncCopy when pipelining tt.loads
171181 bool useAsyncCopy;
172182
183+ bool mustGuardEpilogue;
184+
173185 // Stage for each SchedType Op
174186 int stages[SCHED_SIZE];
175187 // Cluster for each SchedType Op
@@ -205,6 +217,80 @@ class StreamPipeliner {
205217
206218} // namespace
207219
220+ bool StreamPipeliner::safeDAG (Value v, int index) {
221+ if (Operation *defOp = v.getDefiningOp ()) {
222+ if (auto loadOp = dyn_cast<tt::LoadOp>(defOp)) {
223+ // Loads in the loop will be guarded
224+ return loadOp->getParentOfType <scf::ForOp>() == forOp;
225+ } else if (auto cvtOp = dyn_cast<ttg::ConvertLayoutOp>(defOp)) {
226+ return safeDAG (cvtOp.getSrc (), index);
227+ } else if (auto dotOp = dyn_cast<tt::DotOpInterface>(defOp)) {
228+ // 1 input must be safe
229+ if (!safeDAG (dotOp.getA (), -1 ) && !safeDAG (dotOp.getB (), -1 ))
230+ return false ;
231+ auto C = dotOp->getOperand (2 );
232+ return safeDAG (C, index);
233+ } else if (auto splatOp = dyn_cast<tt::SplatOp>(defOp)) {
234+ // both inputs must be safe
235+ return safeDAG (splatOp.getOperand (), -1 );
236+ } else if (auto bcastOp = dyn_cast<tt::BroadcastOp>(defOp)) {
237+ // both inputs must be safe
238+ return safeDAG (bcastOp.getOperand (), -1 );
239+ } else if (auto expandOp = dyn_cast<tt::ExpandDimsOp>(defOp)) {
240+ // both inputs must be safe
241+ return safeDAG (expandOp.getOperand (), -1 );
242+ } else if (auto addOp = dyn_cast<arith::AddFOp>(defOp)) {
243+ // both inputs must be safe
244+ return safeDAG (addOp.getLhs (), -1 ) && safeDAG (addOp.getRhs (), -1 );
245+ } else if (auto subOp = dyn_cast<arith::SubFOp>(defOp)) {
246+ // both inputs must be safe
247+ return safeDAG (subOp.getLhs (), -1 ) && safeDAG (subOp.getRhs (), -1 );
248+ } else if (auto mulOp = dyn_cast<arith::MulFOp>(defOp)) {
249+ // either input must be safe
250+ return safeDAG (mulOp.getLhs (), -1 ) || safeDAG (mulOp.getRhs (), -1 );
251+ } else if (auto truncOp = dyn_cast<arith::TruncFOp>(defOp)) {
252+ // either input must be safe
253+ return safeDAG (truncOp.getOperand (), -1 );
254+ } else if (auto constOp = dyn_cast<arith::ConstantOp>(defOp)) {
255+ // check for constant zero
256+ if (auto attr = dyn_cast<FloatAttr>(constOp.getValue ()))
257+ return attr.getValue ().isZero ();
258+ } else if (auto exp2Op = dyn_cast<math::Exp2Op>(defOp)) {
259+ // either input must be safe
260+ return safeDAG (exp2Op.getOperand (), -1 );
261+ } else if (auto selectOp = dyn_cast<arith::SelectOp>(defOp)) {
262+ // both inputs must be safe
263+ return safeDAG (selectOp.getTrueValue (), -1 ) &&
264+ safeDAG (selectOp.getFalseValue (), -1 );
265+ } else if (auto transposeOp = dyn_cast<tt::TransposeOpInterface>(defOp)) {
266+ // input must be safe
267+ return safeDAG (transposeOp.getSrc (), -1 );
268+ } else {
269+ // Unknown op default to false
270+ LDBG (" Unknown op for unguard epilogue in stream pipeliner" );
271+ return false ;
272+ }
273+ } else {
274+ // check block arg
275+ auto arg = cast<BlockArgument>(v);
276+ return arg.getArgNumber () == index;
277+ }
278+ return false ;
279+ }
280+
281+ void StreamPipeliner::checkResultResilience () {
282+ auto yieldVals = forOp.getYieldedValuesMutable ().value ();
283+ for (auto [index, res] : llvm::enumerate (forOp.getResults ())) {
284+ if (!res.use_empty ()) {
285+ // Check init value == 0
286+ // Backtrack yield value
287+ Value yieldVal = yieldVals[index].get ();
288+ if (!safeDAG (yieldVal, index + 1 )) // + induction
289+ mustGuardEpilogue = true ;
290+ }
291+ }
292+ }
293+
208294// Init Schedule Config based on settings and loop characteristics.
209295// Create clusters in order of ops in loop. This can interleave ops
210296// from different stages in the same cluster to achieve better backend
@@ -213,6 +299,10 @@ class StreamPipeliner {
213299// can cause invalid schedules to be produced.
214300LogicalResult StreamPipeliner::initSchedule (int maxIndirectionLevel) {
215301
302+ // Check to see if we can unconditionalize epilogue
303+ if (!mustGuardEpilogue)
304+ checkResultResilience ();
305+
216306 bool pairedGlobalLoadLocalStore = stages[SCHED_LOCAL_STORE] == 0 ;
217307 stages[SCHED_LOCAL_STORE] += maxIndirectionLevel;
218308
@@ -787,11 +877,16 @@ void StreamPipeliner::scheduleRemainingToLastStage() {
787877 // Assign the rest of the ops to the last stage.
788878 // Take care of the ordering of the ops - uses cannot be scheduled to the
789879 // cluster before the definition.
880+ int count = 0 ;
790881 auto cluster = clusters[SCHED_COMPUTE];
791882 DenseMap<Operation *, tt::CoarseSchedule::Cluster> opToCluster;
792883 for (auto &op : forOp.getBody ()->without_terminator ()) {
793- if (schedule.count (&op) == 0 )
884+ if (schedule.count (&op) == 0 ) {
794885 opToCluster[&op] = cluster;
886+ }
887+ }
888+ if (count != 0 ) {
889+ mustGuardEpilogue = true ;
795890 }
796891 SmallVector<Operation *> queue;
797892 for (auto [op, stage, cluster] : schedule.getOpsInOrder (forOp)) {
@@ -1011,13 +1106,16 @@ void labelLoadOpsForTritonDot(scf::ForOp forOp) {
10111106struct PipelinePass : public TritonAMDGPUStreamPipelineBase <PipelinePass> {
10121107 PipelinePass () = default ;
10131108 PipelinePass (int32_t _numStages, int32_t _globalPrefetch,
1014- int32_t _localPrefetch, bool _useAsyncCopy) {
1109+ int32_t _localPrefetch, bool _useAsyncCopy,
1110+ bool _mustGuardEpilogue) {
10151111 this ->numStages = _numStages;
10161112
10171113 this ->globalPrefetch = _globalPrefetch;
10181114 this ->localPrefetch = _localPrefetch;
10191115
10201116 this ->useAsyncCopy = _useAsyncCopy;
1117+
1118+ this ->mustGuardEpilogue = _mustGuardEpilogue;
10211119 }
10221120
10231121 void runOnOperation () override {
@@ -1047,15 +1145,19 @@ struct PipelinePass : public TritonAMDGPUStreamPipelineBase<PipelinePass> {
10471145 if (!checkPrecondition (forOp))
10481146 continue ;
10491147 StreamPipeliner sp (forOp, tt::getNumStagesOrDefault (forOp, numStages),
1050- globalPrefetch, localPrefetch, useAsyncCopy);
1148+ globalPrefetch, localPrefetch, useAsyncCopy,
1149+ mustGuardEpilogue);
10511150 (void )sp.pipelineLoop ();
10521151 }
10531152 }
10541153};
10551154} // namespace
10561155
1057- std::unique_ptr<Pass> mlir::createTritonAMDGPUStreamPipelinePass (
1058- int numStages, int globalPrefetch, int localPrefetch, bool useAsyncCopy) {
1156+ std::unique_ptr<Pass>
1157+ mlir::createTritonAMDGPUStreamPipelinePass (int numStages, int globalPrefetch,
1158+ int localPrefetch, bool useAsyncCopy,
1159+ bool mustGuardEpilogue) {
10591160 return std::make_unique<PipelinePass>(numStages, globalPrefetch,
1060- localPrefetch, useAsyncCopy);
1161+ localPrefetch, useAsyncCopy,
1162+ mustGuardEpilogue);
10611163}
0 commit comments