forked from facebookexperimental/triton
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathMembar.cpp
More file actions
320 lines (296 loc) · 13.6 KB
/
Membar.cpp
File metadata and controls
320 lines (296 loc) · 13.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
#include "triton/Analysis/Membar.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h"
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/IR/Matchers.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include <deque>
namespace mlir {
/// Given a value that may be produced by a chain of memdesc_index operations,
/// narrow the parent buffer's interval to the sub-range actually accessed.
/// memdesc_index selects a contiguous slice along the leading dimension, so if
/// the index is a compile-time constant we can compute the exact byte range.
/// This avoids false hazards when different indices of the same buffer are
/// accessed (e.g. initializing elements of a barrier array).
static Interval<size_t> narrowIntervalForSubview(Value value,
Interval<size_t> interval) {
while (auto indexOp = value.getDefiningOp<triton::gpu::MemDescIndexOp>()) {
auto parentType =
cast<triton::gpu::MemDescType>(indexOp.getSrc().getType());
// Only narrow when the index is a compile-time constant.
APInt indexVal;
if (!matchPattern(indexOp.getIndex(), m_ConstantInt(&indexVal)))
break;
int64_t idx = indexVal.getSExtValue();
int64_t dim0 = parentType.getShape()[0];
size_t totalSize = interval.end() - interval.start();
// Ensure the stride divides evenly (should always hold for well-formed IR).
if (dim0 <= 0 || totalSize % dim0 != 0)
break;
size_t stride = totalSize / dim0;
size_t newStart = interval.start() + idx * stride;
size_t newEnd = newStart + stride;
interval = Interval<size_t>(newStart, newEnd);
// Continue tracing through the parent in case of nested indexing.
value = indexOp.getSrc();
}
return interval;
}
void MembarOrFenceAnalysis::run(FuncBlockInfoMapT &funcBlockInfoMap) {
FunctionOpInterface funcOp =
dyn_cast<FunctionOpInterface>(allocation->getOperation());
OpBuilder builder(funcOp.getContext());
resolve(funcOp, &funcBlockInfoMap, &builder);
}
void MembarOrFenceAnalysis::resolve(FunctionOpInterface funcOp,
FuncBlockInfoMapT *funcBlockInfoMap,
OpBuilder *builder) {
// Initialize the blockList. Operations are organized into "virtual blocks",
// which represent segments of straight-line code analyzed by each iteration
// of the dataflow analysis. Virtual blocks abstract over both control flow
// represented by basic blocks and block successors (i.e. `BranchOpInterface`)
// and control flow represented by regions (i.e. `RegionBranchOpInterface`).
//
// A virtual block consists of a parent block and a starting iterator, where
// the virtual block starts on the operation *after* the starting iterator. A
// null iterator is used to represent the beginning of the block. The virtual
// block ends at any region branch operation or the basic block terminator.
// Thus, basic blocks are broken up into multiple virtual blocks at each
// region operation.
//
// Entry virtual blocks are represented by a null iterator. Populate the
// blockList with the entry virtual blocks in the function. Then, each
// iteration scans until a terminator or region branch operation is found.
DenseMap<VirtualBlock, BlockInfo> inputBlockInfoMap;
DenseMap<VirtualBlock, BlockInfo> outputBlockInfoMap;
std::deque<VirtualBlock> blockList;
funcOp.walk<WalkOrder::PreOrder>([&](Block *block) {
// Start the analysis from the entry blocks of any nested isolated from
// above regions.
if (block->isEntryBlock() &&
!isa<RegionBranchOpInterface>(block->getParentOp()))
blockList.emplace_back(block, Block::iterator());
});
// A fixed point algorithm
while (!blockList.empty()) {
VirtualBlock block = blockList.front();
blockList.pop_front();
// Make a copy of the inputblockInfo but not update
auto inputBlockInfo = inputBlockInfoMap[block];
SmallVector<VirtualBlock> successors;
Block::iterator startIt =
block.second.isValid() ? std::next(block.second) : block.first->begin();
for (Operation &op : llvm::make_range(startIt, block.first->end())) {
if (op.hasTrait<OpTrait::IsTerminator>() ||
isa<RegionBranchOpInterface>(op)) {
visitTerminator(&op, successors);
break;
}
update(&op, &inputBlockInfo, funcBlockInfoMap, builder);
}
// Get the reference because we want to update if it changed
if (outputBlockInfoMap.count(block) &&
inputBlockInfo == outputBlockInfoMap[block]) {
// If we have seen the block before and the inputBlockInfo is the same as
// the outputBlockInfo, we skip the successors
continue;
}
// Update the current block. The block transfer function is not monotonic,
// so overwrite the output state entirely.
outputBlockInfoMap[block] = inputBlockInfo;
// Update the successors
for (VirtualBlock successor : successors) {
inputBlockInfoMap[successor].join(outputBlockInfoMap[block]);
blockList.emplace_back(successor);
}
}
// Update the final dangling buffers that haven't been synced
BlockInfo &funcBlockInfo = (*funcBlockInfoMap)[funcOp];
funcOp.walk<WalkOrder::PreOrder>([&](triton::ReturnOp returnOp) {
// A basic block can be broken into several virtual blocks. Find all virtual
// blocks that belong to the basic block containing the return.
SmallVector<std::pair<VirtualBlock, BlockInfo>> virtualBlocks;
for (auto &[block, blockInfo] : outputBlockInfoMap) {
if (block.first == returnOp->getBlock())
virtualBlocks.emplace_back(block, blockInfo);
}
// The return is a terminator, so the virtual block that contains this
// return starts after all other ones. Find it by comparing the start
// iterators of the virtual blocks.
auto maxIt = llvm::max_element(virtualBlocks, [&](auto &lhs, auto &rhs) {
assert(lhs.first.first == rhs.first.first);
Block::iterator lhsIt = lhs.first.second, rhsIt = rhs.first.second;
return !lhsIt.isValid() ||
(rhsIt.isValid() && lhsIt->isBeforeInBlock(&*rhsIt));
});
funcBlockInfo.join(maxIt->second);
});
}
void MembarOrFenceAnalysis::visitTerminator(
Operation *op, SmallVector<VirtualBlock> &successors) {
if (isa<BranchOpInterface>(op)) {
// Collect the block successors of the branch.
for (Block *successor : op->getSuccessors())
successors.emplace_back(successor, Block::iterator());
return;
}
if (auto br = dyn_cast<RegionBranchOpInterface>(op)) {
// The successors of an operation with regions can be queried via an
// interface. The operation branches to the entry blocks of its region
// successors. It can also branch to after itself.
SmallVector<RegionSuccessor> regions;
br.getSuccessorRegions(RegionBranchPoint::parent(), regions);
for (RegionSuccessor ®ion : regions) {
if (region.isParent()) {
successors.emplace_back(br->getBlock(), br->getIterator());
} else {
Block &block = region.getSuccessor()->front();
successors.emplace_back(&block, Block::iterator());
}
}
return;
}
// FIXME: `ReturnLike` adds `RegionBranchTerminatorOpInterface` for some
// reason. Check that the parent is actually a `RegionBranchOpInterface`.
auto br = dyn_cast<RegionBranchTerminatorOpInterface>(op);
if (br && isa<RegionBranchOpInterface>(br->getParentOp())) {
// Check the successors of a region branch terminator. It can branch to
// another region of its parent operation or to after the parent op.
SmallVector<Attribute> operands(br->getNumOperands());
SmallVector<RegionSuccessor> regions;
br.getSuccessorRegions(operands, regions);
for (RegionSuccessor ®ion : regions) {
if (region.isParent()) {
Operation *parent = br->getParentOp();
successors.emplace_back(parent->getBlock(), parent->getIterator());
} else {
Block &block = region.getSuccessor()->front();
successors.emplace_back(&block, Block::iterator());
}
}
return;
}
// Otherwise, it could be a return op
if (op->hasTrait<OpTrait::ReturnLike>())
return;
llvm_unreachable("Unknown terminator encountered in membar analysis");
}
void MembarAnalysis::insertBarrier(Operation *op, OpBuilder *builder) {
OpBuilder::InsertionGuard g(*builder);
auto barrierOp = builder->create<gpu::BarrierOp>(op->getLoc());
}
void MembarAnalysis::update(Operation *op, BlockInfo *blockInfo,
FuncBlockInfoMapT *funcBlockInfoMap,
OpBuilder *builder) {
if (isa<gpu::BarrierOp>(op)) {
// If the current op is a barrier, we sync previous reads and writes
blockInfo->sync();
return;
}
if (isa<triton::gpu::AsyncWaitOp, triton::nvidia_gpu::TMAStoreWaitOp>(op) &&
!isa<gpu::BarrierOp>(op->getNextNode())) {
// If the current op is an async wait and the next op is not a barrier we
// insert a barrier op and sync
builder->setInsertionPointAfter(op);
insertBarrier(op, builder);
blockInfo->sync();
return;
}
BlockInfo curBlockInfo;
auto scratchBufferId = Allocation::InvalidBufferId;
if (isa<triton::CallOp>(op)) {
// Inter-function dependencies
auto callOpInterface = dyn_cast<CallOpInterface>(op);
if (auto callee =
dyn_cast<FunctionOpInterface>(callOpInterface.resolveCallable()))
curBlockInfo = funcBlockInfoMap->lookup(callee);
} else {
// Intra-function dependencies
//
// For perThread ArriveBarrierOp, skip all SMEM hazard tracking.
// mbarrier.arrive has release semantics and mbarrier.wait has acquire
// semantics, so no CTA-wide bar.sync is needed before a perThread arrive.
// Each thread's program order guarantees its own SMEM ops are visible
// before its arrive, and the mbarrier accumulates all arrivals before
// releasing the waiter.
bool isPerThreadArrive = false;
if (auto arriveOp = dyn_cast<triton::nvidia_gpu::ArriveBarrierOp>(op))
isPerThreadArrive = arriveOp.getPerThread();
if (!isPerThreadArrive) {
if (auto memoryEffectOpInterface =
dyn_cast<MemoryEffectOpInterface>(op)) {
// Explicit buffer
SmallVector<SideEffects::EffectInstance<MemoryEffects::Effect>>
effectInstances;
memoryEffectOpInterface.getEffects(effectInstances);
for (auto effectInstance : effectInstances) {
if (auto value = effectInstance.getValue()) {
for (auto bufferId : allocation->getBufferIds(value)) {
if (bufferId != Allocation::InvalidBufferId) {
auto interval = allocation->getAllocatedInterval(bufferId);
interval = narrowIntervalForSubview(value, interval);
if (isa<MemoryEffects::Write>(effectInstance.getEffect()))
curBlockInfo.syncWriteIntervals[interval].insert(op);
else if (isa<MemoryEffects::Read>(effectInstance.getEffect()))
curBlockInfo.syncReadIntervals[interval].insert(op);
}
}
}
}
}
// If this op may be signalling other threads asynchronously, make sure
// all shared memory transactions are complete beforehand.
if (isa<triton::nvidia_gpu::ArriveBarrierOp>(op)) {
Interval<size_t> allIntervals(0, std::numeric_limits<size_t>::max());
curBlockInfo.syncWriteIntervals[allIntervals].insert(op);
curBlockInfo.syncReadIntervals[allIntervals].insert(op);
}
}
scratchBufferId = allocation->getBufferId(op);
}
// Scratch buffer operations consist of a series of shared memory operations
// starting from a shared memory write, followed by a series of shared memory
// read/write operations, and ending with a shared memory read, i.e., shared
// memory write -> ... -> shared memory read.
if (scratchBufferId != Allocation::InvalidBufferId) {
// Detect warp-synchronous convert-layout operations. These emit a
// warp-level barrier (warp.sync) rather than a CTA-wide barrier between
// the internal shared-memory write and read phases. For these ops, we must
// not globally clear pending dependencies.
bool isWarpSync = false;
if (auto cvt = dyn_cast<triton::gpu::ConvertLayoutOp>(op)) {
auto srcTy = cast<RankedTensorType>(cvt.getSrc().getType());
auto dstTy = cast<RankedTensorType>(cvt.getType());
auto srcLayout = triton::gpu::toLinearLayout(srcTy);
auto dstLayout = triton::gpu::toLinearLayout(dstTy);
isWarpSync = mlir::isCvtWarpSync(srcLayout, dstLayout);
}
if (!curBlockInfo.syncReadIntervals.empty() ||
!curBlockInfo.syncWriteIntervals.empty()) {
llvm::report_fatal_error(
"scratch buffer operations should not have any shared memory "
"dependencies");
}
auto interval = allocation->getAllocatedInterval(scratchBufferId);
curBlockInfo.syncWriteIntervals[interval].insert(op);
auto insertCTABarrier = blockInfo->isIntersected(curBlockInfo, filter);
if (insertCTABarrier) {
builder->setInsertionPoint(op);
insertBarrier(op, builder);
}
// Ops with a scratch buffer that don't use warp.sync internally sync
// read/write on shared memory
if (insertCTABarrier || !isWarpSync)
blockInfo->sync();
curBlockInfo.syncReadIntervals[interval].insert(op);
} else if (blockInfo->isIntersected(curBlockInfo, filter)) {
builder->setInsertionPoint(op);
insertBarrier(op, builder);
blockInfo->sync();
}
// Update the region info, even if barrier is inserted, we have to maintain
// the current op's read/write buffers.
blockInfo->join(curBlockInfo);
}
} // namespace mlir