Skip to content

Commit d958c81

Browse files
committed
local_load via bufferid
1 parent 81d4eaa commit d958c81

File tree

4 files changed

+61
-99
lines changed

4 files changed

+61
-99
lines changed

include/triton/Analysis/Membar.h

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -140,13 +140,14 @@ struct BlockInfo {
140140

141141
/// Returns true if intervals in two BlockInfo objects are intersected.
142142
std::optional<CTA_UFDS> isIntersected(const BlockInfo &other,
143-
MembarFilterFn filter) const {
144-
auto raw =
145-
isIntersected(syncWriteIntervals, other.syncReadIntervals, filter);
146-
auto war =
147-
isIntersected(syncReadIntervals, other.syncWriteIntervals, filter);
148-
auto waw =
149-
isIntersected(syncWriteIntervals, other.syncWriteIntervals, filter);
143+
MembarFilterFn filter,
144+
Allocation *allocation) const {
145+
auto raw = isIntersected(syncWriteIntervals, other.syncReadIntervals,
146+
filter, allocation);
147+
auto war = isIntersected(syncReadIntervals, other.syncWriteIntervals,
148+
filter, allocation);
149+
auto waw = isIntersected(syncWriteIntervals, other.syncWriteIntervals,
150+
filter, allocation);
150151
auto maybeJoin =
151152
[](const std::optional<CTA_UFDS> &lhs,
152153
const std::optional<CTA_UFDS> &rhs) -> std::optional<CTA_UFDS> {
@@ -189,11 +190,13 @@ struct BlockInfo {
189190
bool operator!=(const BlockInfo &other) const { return !(*this == other); }
190191

191192
private:
192-
static bool haveSameAlloc(Operation *lhs, Operation *rhs);
193+
static bool haveSameAlloc(Operation *lhs, Operation *rhs,
194+
Allocation *allocation);
193195

194196
std::optional<CTA_UFDS> isIntersected(const IntervalMapT &lhsIntervalSet,
195197
const IntervalMapT &rhsIntervalSet,
196-
MembarFilterFn filter) const {
198+
MembarFilterFn filter,
199+
Allocation *allocation) const {
197200
// They intersect whenever the intervals intersect. If they do, collect the
198201
// union of CTA sets for any op pair that is not filtered out and does not
199202
// share the exact same explicit shared value.
@@ -210,12 +213,12 @@ struct BlockInfo {
210213
llvm::all_of(lhsOps, [&, rhsOpsPtr = &rhsOps](const auto &lhsOp) {
211214
return llvm::all_of(*rhsOpsPtr, [&](const auto &rhsOp) {
212215
return (filter && filter(lhsOp, rhsOp)) ||
213-
(joined.isDistributed() && haveSameAlloc(lhsOp, rhsOp));
216+
(joined.isDistributed() &&
217+
haveSameAlloc(lhsOp, rhsOp, allocation));
214218
});
215219
});
216220
if (skipBarrier)
217221
continue;
218-
219222
if (!ret.has_value()) {
220223
ret = joined;
221224
} else {

lib/Analysis/Membar.cpp

Lines changed: 43 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -3,63 +3,46 @@
33
#include "triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h"
44
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
55

6-
#include "mlir/Analysis/SliceAnalysis.h"
76
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
87
#include "mlir/Interfaces/ControlFlowInterfaces.h"
8+
#include "mlir/Interfaces/SideEffectInterfaces.h"
9+
#include "llvm/ADT/DenseSet.h"
910
#include "llvm/ADT/STLExtras.h"
1011
#include "llvm/ADT/SetVector.h"
11-
#include "llvm/ADT/SmallPtrSet.h"
1212
#include <deque>
1313

1414
namespace mlir {
1515

16-
static SmallPtrSet<Operation *, 2> parentAllocs(Operation *op) {
17-
SmallPtrSet<Operation *, 2> owners;
16+
namespace {
17+
18+
llvm::SmallDenseSet<Allocation::BufferId, 2>
19+
getSharedBufferIds(Operation *op, Allocation *allocation) {
1820
auto opEffects = dyn_cast<MemoryEffectOpInterface>(op);
1921
if (!opEffects)
20-
return owners;
22+
return {};
2123

24+
llvm::SmallDenseSet<Allocation::BufferId, 2> bufferIds;
2225
SmallVector<SideEffects::EffectInstance<MemoryEffects::Effect>> effects;
2326
opEffects.getEffects(effects);
2427
for (auto &effect : effects) {
2528
if (effect.getResource() != triton::gpu::SharedMemory::get())
2629
continue;
27-
2830
Value value = effect.getValue();
29-
// Hacky way to skip barriers...
30-
if (cast<triton::gpu::MemDescType>(value.getType()).getNumElements() == 1)
31+
auto memDescTy = cast<triton::gpu::MemDescType>(value.getType());
32+
// Hacky way to skip barriers
33+
if (memDescTy.getNumElements() == 1)
3134
continue;
32-
33-
// Get a slice of all the operations that touch this shared memory
34-
// (subslice/index/memdesc views) all the way up to the local_alloc.
35-
BackwardSliceOptions options;
36-
options.omitUsesFromAbove = false;
37-
options.inclusive = true;
38-
auto isSharedMemDesc = [](Type ty) {
39-
auto memDescTy = dyn_cast<triton::gpu::MemDescType>(ty);
40-
if (!memDescTy)
41-
return false;
42-
return isa<triton::gpu::SharedMemorySpaceAttr>(
43-
memDescTy.getMemorySpace());
44-
};
45-
options.filter = [&](Operation *depOp) -> bool {
46-
return llvm::any_of(depOp->getOperandTypes(), isSharedMemDesc) ||
47-
// Add ops that have Memdesc in the result types to pick
48-
// local_alloc ops as well
49-
llvm::any_of(depOp->getResultTypes(), isSharedMemDesc);
50-
};
51-
llvm::SetVector<Operation *> slice;
52-
LogicalResult result = getBackwardSlice(value, &slice, options);
53-
assert(succeeded(result) && "backward slice must succeed");
54-
55-
for (Operation *depOp : slice) {
56-
if (auto alloc = dyn_cast<triton::gpu::LocalAllocOp>(depOp))
57-
owners.insert(alloc.getOperation());
35+
for (auto bufferId : allocation->getBufferIds(value)) {
36+
if (bufferId == Allocation::InvalidBufferId)
37+
continue;
38+
bufferIds.insert(bufferId);
5839
}
5940
}
60-
return owners;
41+
return bufferIds;
6142
}
6243

44+
} // namespace
45+
6346
static std::pair<BlockInfo::CTA_UFDS, BlockInfo::CTA_UFDS>
6447
getCTAEquivalenceSets(Operation *op) {
6548
auto numCTAs = triton::gpu::lookupNumCTAs(op);
@@ -68,46 +51,23 @@ getCTAEquivalenceSets(Operation *op) {
6851
}
6952
auto *ctx = op->getContext();
7053
auto kBlock = StringAttr::get(ctx, "block");
71-
if (auto cvt = dyn_cast<triton::gpu::ConvertLayoutOp>(op)) {
72-
auto srcTy = cast<RankedTensorType>(cvt.getSrc().getType());
73-
auto dstTy = cast<RankedTensorType>(cvt.getType());
74-
auto cvtLayout = minimalCvtLayout(srcTy, dstTy);
75-
if (llvm::is_contained(cvtLayout.getInDimNames(), kBlock)) {
76-
auto readsUFDS = BlockInfo::CTA_UFDS(numCTAs);
77-
auto blockLayout =
78-
cvtLayout.sublayout({kBlock}, to_vector(cvtLayout.getOutDimNames()));
79-
for (int i = 0; i < numCTAs; i++) {
80-
auto res = blockLayout.apply({{kBlock, i}});
81-
assert(res.size() == 4);
82-
assert(res.back().first == kBlock);
83-
readsUFDS.unite(i, res.back().second);
84-
}
85-
// The writes are just each writing to their own shmem
86-
auto writesUFDS = BlockInfo::CTA_UFDS(numCTAs);
87-
return {readsUFDS, writesUFDS};
88-
}
89-
} else if (auto reduce = dyn_cast<triton::ReduceOp>(op)) {
90-
auto srcTy = cast<RankedTensorType>(reduce.getInputTypes()[0]);
91-
auto inCTALayout = triton::gpu::getCTALayout(srcTy.getEncoding());
92-
auto axis = reduce.getAxis();
93-
auto ll = inCTALayout.getLinearLayout();
94-
auto outdims = to_vector(ll.getOutDimNames());
95-
if (ll.getOutDimSize(outdims[axis]) != 1) {
96-
auto outCTALayout = triton::gpu::getCTALayout(
97-
cast<RankedTensorType>(reduce.getType(0)).getEncoding());
98-
// Maps the reads necessary in the reduction
99-
auto ctaLl = outCTALayout.getLinearLayout().invertAndCompose(ll);
100-
auto readsUFDS = BlockInfo::CTA_UFDS(numCTAs);
101-
for (int i = 0; i < numCTAs; i++) {
102-
auto res = ctaLl.apply({{kBlock, i}});
103-
assert(res.size() == 1);
104-
assert(res.front().first == kBlock);
105-
readsUFDS.unite(i, res.front().second);
106-
}
107-
// The writes are just each writing to their own shmem
108-
auto writesUFDS = BlockInfo::CTA_UFDS(numCTAs);
109-
return {readsUFDS, writesUFDS};
54+
if (isa<triton::gpu::ConvertLayoutOp, triton::ReduceOp>(op)) {
55+
auto srcTy = cast<RankedTensorType>(op->getOperand(0).getType());
56+
auto dstTy = cast<RankedTensorType>(op->getResult(0).getType());
57+
auto srcCTALayout = triton::gpu::getCTALayout(srcTy.getEncoding());
58+
auto dstCTALayout = triton::gpu::getCTALayout(dstTy.getEncoding());
59+
auto ctaLl = dstCTALayout.getLinearLayout().invertAndCompose(
60+
srcCTALayout.getLinearLayout());
61+
auto readsUFDS = BlockInfo::CTA_UFDS(numCTAs);
62+
for (int i = 0; i < numCTAs; i++) {
63+
auto res = ctaLl.apply({{kBlock, i}});
64+
assert(res.size() == 1);
65+
assert(res.front().first == kBlock);
66+
readsUFDS.unite(i, res.front().second);
11067
}
68+
// The writes are just each writing to their own shmem
69+
auto writesUFDS = BlockInfo::CTA_UFDS(numCTAs);
70+
return {readsUFDS, writesUFDS};
11171
} else if (auto tma =
11272
dyn_cast<triton::nvidia_gpu::AsyncTMACopyGlobalToLocalOp>(
11373
op)) {
@@ -143,15 +103,12 @@ getCTAEquivalenceSets(Operation *op) {
143103
return {BlockInfo::CTA_UFDS(numCTAs), BlockInfo::CTA_UFDS(numCTAs)};
144104
}
145105

146-
bool BlockInfo::haveSameAlloc(Operation *lhs, Operation *rhs) {
147-
auto lhsAllocs = parentAllocs(lhs);
148-
auto rhsAllocs = parentAllocs(rhs);
149-
// They can be empty when the buffer is internal, e.g. a convert_layout.
150-
if (lhsAllocs.empty() || rhsAllocs.empty())
151-
return false;
152-
106+
bool BlockInfo::haveSameAlloc(Operation *lhs, Operation *rhs,
107+
Allocation *allocation) {
108+
auto lhsBuffers = getSharedBufferIds(lhs, allocation);
109+
auto rhsBuffers = getSharedBufferIds(rhs, allocation);
153110
return llvm::any_of(
154-
lhsAllocs, [&](Operation *alloc) { return rhsAllocs.contains(alloc); });
111+
lhsBuffers, [&](auto bufferId) { return rhsBuffers.contains(bufferId); });
155112
}
156113

157114
void MembarOrFenceAnalysis::run(FuncBlockInfoMapT &funcBlockInfoMap) {
@@ -405,7 +362,8 @@ void MembarAnalysis::update(Operation *op, BlockInfo *blockInfo,
405362
}
406363
auto interval = allocation->getAllocatedInterval(scratchBufferId);
407364
curBlockInfo.syncWriteIntervals[{interval, writeCTAs}].insert(op);
408-
auto insertCTABarrier = blockInfo->isIntersected(curBlockInfo, filter);
365+
auto insertCTABarrier =
366+
blockInfo->isIntersected(curBlockInfo, filter, allocation);
409367
if (insertCTABarrier.has_value()) {
410368
builder->setInsertionPoint(op);
411369
insertBarrier(op, builder, *insertCTABarrier);
@@ -416,7 +374,8 @@ void MembarAnalysis::update(Operation *op, BlockInfo *blockInfo,
416374
blockInfo->sync(false);
417375
}
418376
curBlockInfo.syncReadIntervals[{interval, readCTAs}].insert(op);
419-
} else if (auto ctas = blockInfo->isIntersected(curBlockInfo, filter)) {
377+
} else if (auto ctas =
378+
blockInfo->isIntersected(curBlockInfo, filter, allocation)) {
420379
builder->setInsertionPoint(op);
421380
insertBarrier(op, builder, *ctas);
422381
blockInfo->sync(ctas->isDistributed());

lib/Dialect/TritonNvidiaGPU/Transforms/ProxFenceInsertion.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ void ProxyFenceAnalysis::update(Operation *op, BlockInfo *blockInfo,
166166
curBlockInfo.syncReadIntervals[{interval, readCTAs}].insert(op);
167167
}
168168
if (isAsyncProxyWrite(op) || isAsyncProxyRead(op)) {
169-
if (proxyBlockInfo.isIntersected(*blockInfo, filter)) {
169+
if (proxyBlockInfo.isIntersected(*blockInfo, filter, allocation)) {
170170
builder->setInsertionPoint(op);
171171
insertFence(op, builder);
172172
blockInfo->sync(false);

test/Analysis/test-membar-cluster.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -99,10 +99,10 @@ module attributes {"ttg.num-ctas" = 2 : i32, "ttg.num-warps" = 1 : i32} {
9999
// -----
100100

101101
// Distributed convert alias (needs cluster barrier)
102-
#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 1], order = [1, 0], CGALayout = [[1, 0]]}>
103-
#blocked2 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 1], order = [1, 0], CGALayout = [[0, 1]]}>
102+
#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 1], order = [1, 0], CGALayout = [[1, 0], [0, 1]]}>
103+
#blocked2 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 1], order = [1, 0], CGALayout = [[0, 1], [1, 0]]}>
104104

105-
module attributes {"ttg.num-ctas" = 2 : i32, "ttg.num-warps" = 1 : i32} {
105+
module attributes {"ttg.num-ctas" = 4 : i32, "ttg.num-warps" = 1 : i32} {
106106
// CHECK-LABEL: convert_alias_same_offset
107107
tt.func @convert_alias_same_offset() -> (tensor<2x2xf16, #blocked2>, tensor<2x2xf16, #blocked2>) {
108108
%c0 = arith.constant 0.000000e+00 : f16

0 commit comments

Comments
 (0)