Skip to content

Commit 307d809

Browse files
committed
[AMD] getBackwardSlice variant with handling for op regions
mlir::getBackwardSlice does not handle op regions. This can cause the backward slice to not be in topological order and this can result in the reordering pass moving a value's use before its def. This is a temporary local fix until these changes are upstreamed to mlir.
1 parent d31ccfe commit 307d809

File tree

2 files changed

+98
-2
lines changed

2 files changed

+98
-2
lines changed

test/TritonGPU/amd/amd-reorder-instructions.mlir

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -922,3 +922,32 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-war
922922
tt.return
923923
}
924924
}
925+
926+
// Check that reordering preserves def-before-use for values used inside control flow regions
927+
// For example, %12 should not be moved below the scf.if op %22
928+
// CHECK: %{{.+}} = tt.make_range
929+
// CHECK: %{{.+}} = scf.if %{{.+}}
930+
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "hip:gfx942", "triton_gpu.threads-per-warp" = 64 : i32} {
931+
tt.func public @reoder_across_nested(%arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg6: i32, %arg9: i64, %arg10: i64) attributes {noinline = false} {
932+
%12 = tt.make_range {end = 512 : i32, start = 0 : i32} : tensor<512xi32, #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}>>
933+
%21 = arith.cmpi slt, %arg9, %arg10 : i64
934+
%22 = scf.if %21 -> (tensor<512x!tt.ptr<f32>, #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}>>) {
935+
%30 = tt.splat %arg1 : !tt.ptr<f32> -> tensor<512x!tt.ptr<f32>, #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}>>
936+
%100 = scf.if %21 -> (tensor<512x!tt.ptr<f32>, #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}>>) {
937+
%31 = tt.addptr %30, %12 : tensor<512x!tt.ptr<f32>, #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}>>, tensor<512xi32, #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}>>
938+
scf.yield %31 : tensor<512x!tt.ptr<f32>, #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}>>
939+
} else {
940+
%31 = tt.addptr %30, %12 : tensor<512x!tt.ptr<f32>, #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}>>, tensor<512xi32, #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}>>
941+
scf.yield %31 : tensor<512x!tt.ptr<f32>, #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}>>
942+
}
943+
scf.yield %100 : tensor<512x!tt.ptr<f32>, #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}>>
944+
} else {
945+
%32 = tt.splat %arg1 : !tt.ptr<f32> -> tensor<512x!tt.ptr<f32>, #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}>>
946+
scf.yield %32 : tensor<512x!tt.ptr<f32>, #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}>>
947+
}
948+
%23 = tt.splat %arg6 : i32 -> tensor<512xi32, #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}>>
949+
%24 = arith.cmpi slt, %12, %23 : tensor<512xi32, #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}>>
950+
%25 = tt.load %22, %24 : tensor<512x!tt.ptr<f32>, #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}>>
951+
tt.return
952+
}
953+
}

third_party/amd/lib/TritonAMDGPUTransforms/ReorderInstructions.cpp

Lines changed: 69 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include "mlir/IR/Verifier.h"
66
#include "mlir/Pass/Pass.h"
77
#include "mlir/Pass/PassManager.h"
8+
#include "mlir/Transforms/RegionUtils.h"
89
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
910
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
1011
#include <deque>
@@ -24,6 +25,73 @@ static bool isLocalLoadOrDotLayoutConversion(Operation *op) {
2425
return false;
2526
}
2627

28+
// Copy of mlir::getBackwardSlice with changes to handle nested regions.
29+
// This is a temporary local fix until these changes are upstreamed to mlir.
30+
static void getDeepBackwardSlice(Operation *op,
31+
SetVector<Operation *> *backwardSlice,
32+
const BackwardSliceOptions &options) {
33+
if (!op || op->hasTrait<OpTrait::IsIsolatedFromAbove>())
34+
return;
35+
36+
// Evaluate whether we should keep this def.
37+
// This is useful in particular to implement scoping; i.e. return the
38+
// transitive backwardSlice in the current scope.
39+
if (options.filter && !options.filter(op))
40+
return;
41+
42+
SetVector<Value> usedValues;
43+
Block *opBlock = op->getBlock();
44+
auto f = [&](OpOperand *nestedValue) {
45+
// Filter out values that are not defined in the block
46+
// that contains 'op'. This is to avoid including values
47+
// that are defined in the nested regions of 'op'.
48+
if (auto *nestedOp = nestedValue->get().getDefiningOp()) {
49+
if (opBlock == nestedOp->getBlock()) {
50+
usedValues.insert(nestedValue->get());
51+
}
52+
}
53+
};
54+
55+
// collect all the values used in the nested regions of this op
56+
// SetVector<Region*> nestedRegions;
57+
for (auto &region : op->getRegions()) {
58+
region.walk([&](Region *nestedRegion) {
59+
mlir::visitUsedValuesDefinedAbove(*nestedRegion, *nestedRegion, f);
60+
});
61+
}
62+
63+
// collect all the values used in the op
64+
for (const auto &en : llvm::enumerate(op->getOperands())) {
65+
usedValues.insert(en.value());
66+
}
67+
68+
for (const auto &en : llvm::enumerate(usedValues)) {
69+
auto operand = en.value();
70+
if (auto *definingOp = operand.getDefiningOp()) {
71+
if (backwardSlice->count(definingOp) == 0)
72+
getDeepBackwardSlice(definingOp, backwardSlice, options);
73+
} else if (auto blockArg = dyn_cast<BlockArgument>(operand)) {
74+
if (options.omitBlockArguments)
75+
continue;
76+
77+
Block *block = blockArg.getOwner();
78+
Operation *parentOp = block->getParentOp();
79+
// TODO: determine whether we want to recurse backward into the other
80+
// blocks of parentOp, which are not technically backward unless they flow
81+
// into us. For now, just bail.
82+
if (parentOp && backwardSlice->count(parentOp) == 0) {
83+
assert(parentOp->getNumRegions() == 1 &&
84+
parentOp->getRegion(0).getBlocks().size() == 1);
85+
getDeepBackwardSlice(parentOp, backwardSlice, options);
86+
}
87+
} else {
88+
llvm_unreachable("No definingOp and not a block argument.");
89+
}
90+
}
91+
92+
backwardSlice->insert(op);
93+
}
94+
2795
// Search through block to find earliest insertion point for move op. This can
2896
// be either an atomic op or last usage of source pointer. Search ends when move
2997
// op is encountered.
@@ -221,8 +289,7 @@ class TritonAMDGPUReorderInstructionsPass
221289
// Only move ops residing in the same block.
222290
return defBlock == block;
223291
};
224-
mlir::getBackwardSlice(op, &backwardSet, options);
225-
backwardSet.insert(op);
292+
getDeepBackwardSlice(op, &backwardSet, options);
226293

227294
// Don't move a local_store if its source is a load from
228295
// the same iteration.

0 commit comments

Comments
 (0)