Skip to content

Commit 5d2a7a9

Browse files
authored
[BACKEND] Prevent reordering local_load across side-effecting op (#8423)
1 parent b5fea1e commit 5d2a7a9

File tree

2 files changed

+64
-0
lines changed

2 files changed

+64
-0
lines changed

lib/Dialect/TritonGPU/Transforms/ReorderInstructions.cpp

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,36 @@ static bool willIncreaseRegisterPressure(Operation *op) {
4040
return false;
4141
}
4242

43+
// Return true if it has side effects that are either unknown or writes.
44+
static bool hasWriteSideEffect(Operation *op) {
45+
auto effects = getEffectsRecursively(op);
46+
if (!effects)
47+
return false;
48+
return llvm::any_of(*effects, [](MemoryEffects::EffectInstance effect) {
49+
return !isa<MemoryEffects::Read, MemoryEffects::Allocate,
50+
MemoryEffects::Free>(effect.getEffect());
51+
});
52+
}
53+
54+
// Return true if there is a write side effect on any path between start and end
55+
// ops. This assumes start dominates end.
56+
static bool crossWriteSideEffectingOp(Operation *start, Operation *end) {
57+
auto ancestor = start->getBlock()->findAncestorOpInBlock(*end);
58+
// Couldn't find an ancestor in the same block, conservatively assume true.
59+
if (!ancestor)
60+
return true;
61+
Operation *nextOp = start->getNextNode();
62+
while (nextOp) {
63+
if ((hasWriteSideEffect(nextOp)))
64+
return true;
65+
if (nextOp == ancestor)
66+
return false;
67+
nextOp = nextOp->getNextNode();
68+
}
69+
assert(false && "op doesn't dominate other");
70+
return true;
71+
}
72+
4373
class TritonGPUReorderInstructionsPass
4474
: public impl::TritonGPUReorderInstructionsBase<
4575
TritonGPUReorderInstructionsPass> {
@@ -135,6 +165,8 @@ class TritonGPUReorderInstructionsPass
135165
// after the conversion to OpIdx=0.
136166
if (!dom.dominates(op.getOperation(), AOp.getOperation()))
137167
return;
168+
if (crossWriteSideEffectingOp(op, AOp))
169+
return;
138170
moveAfter(op, AOp);
139171
});
140172
return;

test/TritonGPU/reorder-instructions.mlir

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,38 @@ module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32}
9696

9797
// -----
9898

99+
// CHECK-LABEL: sink_convert_idx_1_negative
100+
// CHECK: ttg.local_load %{{.*}} : !ttg.memdesc<32x32xf32, #{{.*}}, #smem> -> tensor<32x32xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>>
101+
// CHECK: ttng.arrive_barrier
102+
// CHECK: ttg.local_load %{{.*}} : !ttg.memdesc<32x32xf32, #{{.*}}, #smem> -> tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>>
103+
// CHECK: tt.dot
104+
#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [0, 1]}>
105+
#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 2], instrShape = [16, 8]}>
106+
#shared = #ttg.swizzled_shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [0, 1]}>
107+
#shared1 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0]}>
108+
#smem = #ttg.shared_memory
109+
module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} {
110+
tt.func public @sink_convert_idx_1_negative(%arg0: tensor<32x32x!tt.ptr<f32>, #blocked>) {
111+
%bar = ttg.local_alloc : () -> !ttg.memdesc<1xi64, #shared1, #smem, mutable>
112+
%true = arith.constant true
113+
%cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #mma>
114+
%B = tt.load %arg0 : tensor<32x32x!tt.ptr<f32>, #blocked>
115+
%BS = ttg.local_alloc %B : (tensor<32x32xf32, #blocked>) -> !ttg.memdesc<32x32xf32, #shared, #smem>
116+
%BD = ttg.local_load %BS : !ttg.memdesc<32x32xf32, #shared, #smem> -> tensor<32x32xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>>
117+
%cst_0 = arith.constant dense<1.230000e+02> : tensor<32x32xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>>
118+
%A = tt.load %arg0 : tensor<32x32x!tt.ptr<f32>, #blocked>
119+
%AS = ttg.local_alloc %A : (tensor<32x32xf32, #blocked>) -> !ttg.memdesc<32x32xf32, #shared, #smem>
120+
ttng.arrive_barrier %bar, 2, %true : !ttg.memdesc<1xi64, #shared1, #smem, mutable>
121+
%AD = ttg.local_load %AS : !ttg.memdesc<32x32xf32, #shared, #smem> -> tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>>
122+
%12 = tt.dot %AD, %BD, %cst : tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<32x32xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -> tensor<32x32xf32, #mma>
123+
%13 = ttg.convert_layout %12 : tensor<32x32xf32, #mma> -> tensor<32x32xf32, #blocked>
124+
tt.store %arg0, %13 : tensor<32x32x!tt.ptr<f32>, #blocked>
125+
tt.return
126+
}
127+
}
128+
129+
// -----
130+
99131
// check that we don't sink convert_layout if it has multi users
100132
// CHECK-LABEL: convert_cannot_sink
101133
// CHECK: ttg.local_load %{{.*}} : !ttg.memdesc<32x32xf32, #shared, #smem> -> tensor<32x32xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>>

0 commit comments

Comments
 (0)