55#include " mlir/IR/Verifier.h"
66#include " mlir/Pass/Pass.h"
77#include " mlir/Pass/PassManager.h"
8+ #include " mlir/Transforms/RegionUtils.h"
89#include " triton/Dialect/TritonGPU/IR/Dialect.h"
910#include " triton/Dialect/TritonGPU/Transforms/Utility.h"
1011#include < deque>
@@ -24,6 +25,73 @@ static bool isLocalLoadOrDotLayoutConversion(Operation *op) {
2425 return false ;
2526}
2627
28+ // Copy of mlir::getBackwardSlice with changes to handle nested regions.
29+ // This is a temporary local fix until these changes are upstreamed to mlir.
30+ static void getDeepBackwardSlice (Operation *op,
31+ SetVector<Operation *> *backwardSlice,
32+ const BackwardSliceOptions &options) {
33+ if (!op || op->hasTrait <OpTrait::IsIsolatedFromAbove>())
34+ return ;
35+
36+ // Evaluate whether we should keep this def.
37+ // This is useful in particular to implement scoping; i.e. return the
38+ // transitive backwardSlice in the current scope.
39+ if (options.filter && !options.filter (op))
40+ return ;
41+
42+ SetVector<Value> usedValues;
43+ Block *opBlock = op->getBlock ();
44+ auto f = [&](OpOperand *nestedValue) {
45+ // Filter out values that are not defined in the block
46+ // that contains 'op'. This is to avoid including values
47+ // that are defined in the nested regions of 'op'.
48+ if (auto *nestedOp = nestedValue->get ().getDefiningOp ()) {
49+ if (opBlock == nestedOp->getBlock ()) {
50+ usedValues.insert (nestedValue->get ());
51+ }
52+ }
53+ };
54+
55+ // collect all the values used in the nested regions of this op
56+ // SetVector<Region*> nestedRegions;
57+ for (auto ®ion : op->getRegions ()) {
58+ region.walk ([&](Region *nestedRegion) {
59+ mlir::visitUsedValuesDefinedAbove (*nestedRegion, *nestedRegion, f);
60+ });
61+ }
62+
63+ // collect all the values used in the op
64+ for (const auto &en : llvm::enumerate (op->getOperands ())) {
65+ usedValues.insert (en.value ());
66+ }
67+
68+ for (const auto &en : llvm::enumerate (usedValues)) {
69+ auto operand = en.value ();
70+ if (auto *definingOp = operand.getDefiningOp ()) {
71+ if (backwardSlice->count (definingOp) == 0 )
72+ getDeepBackwardSlice (definingOp, backwardSlice, options);
73+ } else if (auto blockArg = dyn_cast<BlockArgument>(operand)) {
74+ if (options.omitBlockArguments )
75+ continue ;
76+
77+ Block *block = blockArg.getOwner ();
78+ Operation *parentOp = block->getParentOp ();
79+ // TODO: determine whether we want to recurse backward into the other
80+ // blocks of parentOp, which are not technically backward unless they flow
81+ // into us. For now, just bail.
82+ if (parentOp && backwardSlice->count (parentOp) == 0 ) {
83+ assert (parentOp->getNumRegions () == 1 &&
84+ parentOp->getRegion (0 ).getBlocks ().size () == 1 );
85+ getDeepBackwardSlice (parentOp, backwardSlice, options);
86+ }
87+ } else {
88+ llvm_unreachable (" No definingOp and not a block argument." );
89+ }
90+ }
91+
92+ backwardSlice->insert (op);
93+ }
94+
2795// Search through block to find earliest insertion point for move op. This can
2896// be either an atomic op or last usage of source pointer. Search ends when move
2997// op is encountered.
@@ -221,8 +289,7 @@ class TritonAMDGPUReorderInstructionsPass
221289 // Only move ops residing in the same block.
222290 return defBlock == block;
223291 };
224- mlir::getBackwardSlice (op, &backwardSet, options);
225- backwardSet.insert (op);
292+ getDeepBackwardSlice (op, &backwardSet, options);
226293
227294 // Don't move a local_store if its source is a load from
228295 // the same iteration.
0 commit comments