Skip to content

Commit f6c38b2

Browse files
committed
[WIP][BACKEND] Generalize the MemBar to consider cross-CTA ops
The semantics here are that it's the user's/compiler's responsability to add the relevant synchronisation if they reuse the same shmem buffer, but otherwise the compiler will do so.
1 parent 0f235ee commit f6c38b2

File tree

10 files changed

+502
-131
lines changed

10 files changed

+502
-131
lines changed

include/triton/Analysis/Membar.h

Lines changed: 208 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
#define TRITON_ANALYSIS_MEMBAR_H
33

44
#include "Allocation.h"
5+
#include "mlir/Interfaces/SideEffectInterfaces.h"
6+
#include "triton/Analysis/Utility.h"
57

68
#include <set>
79

@@ -15,7 +17,83 @@ class OpBuilder;
1517
using MembarFilterFn = std::function<bool(Operation *, Operation *)>;
1618

1719
struct BlockInfo {
18-
using IntervalMapT = std::map<Interval<size_t>, std::set<Operation *>>;
20+
// UFDS to represent cross-CTA reads/writes
21+
struct UFDS {
22+
SmallVector<unsigned> parent;
23+
SmallVector<unsigned> rank;
24+
// Invariant: At the root of a class, minRep[i] is the smallest element in
25+
// the class
26+
SmallVector<unsigned> minRep;
27+
28+
UFDS() = default;
29+
explicit UFDS(unsigned n) : rank(n, 0), minRep(n) {
30+
assert(llvm::isPowerOf2_32(n) && n != 0);
31+
parent = llvm::to_vector(llvm::seq(n));
32+
for (unsigned i = 0; i < n; ++i)
33+
minRep[i] = i;
34+
}
35+
36+
unsigned find(unsigned x) const {
37+
unsigned p = parent[x];
38+
while (p != parent[p])
39+
p = parent[p];
40+
return p;
41+
}
42+
43+
unsigned findMin(unsigned x) const { return minRep[find(x)]; }
44+
45+
void unite(unsigned x, unsigned y) {
46+
x = find(x);
47+
y = find(y);
48+
if (x == y)
49+
return;
50+
51+
if (rank[x] < rank[y])
52+
std::swap(x, y);
53+
54+
parent[y] = x;
55+
minRep[x] = std::min(minRep[x], minRep[y]);
56+
57+
if (rank[x] == rank[y])
58+
++rank[x];
59+
}
60+
61+
UFDS join(const UFDS &other) const {
62+
// Transitive closure of two UFDS
63+
UFDS result = *this;
64+
for (unsigned i = 0; i < size(); ++i)
65+
result.unite(i, other.find(i));
66+
return result;
67+
}
68+
69+
SmallVector<unsigned> canonical() const {
70+
SmallVector<unsigned> reps(size());
71+
for (unsigned i = 0; i < size(); ++i)
72+
reps[i] = findMin(i);
73+
return reps;
74+
}
75+
76+
bool isDistributed() const { return *this != UFDS(parent.size()); }
77+
78+
bool operator<(const UFDS &other) const {
79+
return canonical() < other.canonical();
80+
}
81+
bool operator==(const UFDS &other) const {
82+
return canonical() == other.canonical();
83+
}
84+
bool operator!=(const UFDS &other) const { return !(*this == other); }
85+
86+
void print(raw_ostream &os) const {
87+
os << "UFDS(";
88+
llvm::interleaveComma(canonical(), os, [&](unsigned x) { os << x; });
89+
os << ")";
90+
}
91+
92+
size_t size() const { return parent.size(); }
93+
};
94+
95+
using IntervalMapT =
96+
std::map<std::pair<Interval<size_t>, UFDS>, std::set<Operation *>>;
1997

2098
IntervalMapT syncReadIntervals;
2199
IntervalMapT syncWriteIntervals;
@@ -24,48 +102,82 @@ struct BlockInfo {
24102

25103
/// Unions two BlockInfo objects.
26104
BlockInfo &join(const BlockInfo &other) {
27-
for (auto &interval : other.syncReadIntervals)
28-
syncReadIntervals[interval.first].insert(interval.second.begin(),
29-
interval.second.end());
30-
for (auto &interval : other.syncWriteIntervals)
31-
syncWriteIntervals[interval.first].insert(interval.second.begin(),
32-
interval.second.end());
105+
// We don't fold the intervals (we could tho)
106+
for (auto &[key, ops] : other.syncReadIntervals)
107+
syncReadIntervals[key].insert(ops.begin(), ops.end());
108+
for (auto &[key, ops] : other.syncWriteIntervals)
109+
syncWriteIntervals[key].insert(ops.begin(), ops.end());
33110
return *this;
34111
}
35112

36113
void dump() {
37114
auto &err = llvm::errs();
115+
116+
auto printKey = [&](const std::pair<Interval<size_t>, UFDS> &key) {
117+
const auto &[interval, ufds] = key;
118+
err << " [" << interval.start() << ", " << interval.end() << "] ";
119+
if (ufds.isDistributed())
120+
ufds.print(err);
121+
else if (ufds.size() == 1)
122+
err << " (CTA local)";
123+
};
38124
err << "Block Interval:\n";
39125
err << " Read Intervals:\n";
40-
for (auto &[interval, ops] : syncReadIntervals) {
41-
err << " [" << interval.start() << ", " << interval.end() << "] ";
126+
for (auto &[key, ops] : syncReadIntervals) {
127+
printKey(key);
42128
for (auto &op : ops)
43129
err << op->getName() << " ";
44130
err << "\n";
45131
}
46132
err << " Write Intervals:\n";
47-
for (auto &[interval, ops] : syncWriteIntervals) {
48-
err << " [" << interval.start() << ", " << interval.end() << "] ";
133+
for (auto &[key, ops] : syncWriteIntervals) {
134+
printKey(key);
49135
for (auto &op : ops)
50136
err << op->getName() << " ";
51137
err << "\n";
52138
}
53139
}
54140

55141
/// Returns true if intervals in two BlockInfo objects are intersected.
56-
bool isIntersected(const BlockInfo &other, MembarFilterFn filter) const {
57-
return /*RAW*/ isIntersected(syncWriteIntervals, other.syncReadIntervals,
58-
filter) ||
59-
/*WAR*/
60-
isIntersected(syncReadIntervals, other.syncWriteIntervals, filter) ||
61-
/*WAW*/
62-
isIntersected(syncWriteIntervals, other.syncWriteIntervals, filter);
142+
std::optional<UFDS> isIntersected(const BlockInfo &other,
143+
MembarFilterFn filter) const {
144+
auto raw =
145+
isIntersected(syncWriteIntervals, other.syncReadIntervals, filter);
146+
auto war =
147+
isIntersected(syncReadIntervals, other.syncWriteIntervals, filter);
148+
auto waw =
149+
isIntersected(syncWriteIntervals, other.syncWriteIntervals, filter);
150+
auto maybeJoin = [](const std::optional<UFDS> &lhs,
151+
const std::optional<UFDS> &rhs) -> std::optional<UFDS> {
152+
if (!lhs.has_value())
153+
return rhs;
154+
if (!rhs.has_value())
155+
return lhs;
156+
return lhs.value().join(rhs.value());
157+
};
158+
return maybeJoin(raw, maybeJoin(war, waw));
63159
}
64160

65161
/// Clears the intervals because a barrier is inserted.
66-
void sync() {
67-
syncReadIntervals.clear();
68-
syncWriteIntervals.clear();
162+
/// If `cluster` is true, the barrier synchronizes all CTAs in the cluster and
163+
/// we can drop every pending dependency. Otherwise only CTA-local
164+
/// dependencies are cleared; distributed ones remain until a cluster barrier
165+
/// is observed.
166+
void sync(bool cluster) {
167+
if (cluster) {
168+
syncReadIntervals.clear();
169+
syncWriteIntervals.clear();
170+
return;
171+
} else {
172+
auto eraseNotDistributed = [](auto &map) {
173+
for (auto &[key, _] : llvm::make_early_inc_range(map)) {
174+
if (!key.second.isDistributed())
175+
map.erase(key);
176+
}
177+
};
178+
eraseNotDistributed(syncReadIntervals);
179+
eraseNotDistributed(syncWriteIntervals);
180+
}
69181
}
70182

71183
/// Compares two BlockInfo objects.
@@ -77,18 +189,79 @@ struct BlockInfo {
77189
bool operator!=(const BlockInfo &other) const { return !(*this == other); }
78190

79191
private:
80-
bool isIntersected(const IntervalMapT &lhsIntervalSet,
81-
const IntervalMapT &rhsIntervalSet,
82-
MembarFilterFn filter) const {
83-
for (auto &lhs : lhsIntervalSet)
84-
for (auto &rhs : rhsIntervalSet)
85-
if (lhs.first.intersects(rhs.first))
86-
for (auto lhsOp : lhs.second)
87-
for (auto rhsOp : rhs.second)
88-
if (!filter || !filter(lhsOp, rhsOp))
89-
return true;
90-
91-
return false;
192+
static bool haveSameAlloc(Operation *lhs, Operation *rhs) {
193+
auto collectAllocOwners = [](Operation *op) -> std::optional<Operation *> {
194+
SmallVector<Operation *> allocs;
195+
if (auto mei = dyn_cast<MemoryEffectOpInterface>(op)) {
196+
SmallVector<SideEffects::EffectInstance<MemoryEffects::Effect>> effs;
197+
mei.getEffects(effs);
198+
for (auto &eff : effs) {
199+
if (eff.getResource() != triton::gpu::SharedMemory::get())
200+
continue;
201+
if (auto v = eff.getValue()) {
202+
// Hacky way to skip barriers...
203+
if (cast<triton::gpu::MemDescType>(v.getType()).getNumElements() ==
204+
1)
205+
continue;
206+
auto alloc = findShmemAlloc(v);
207+
assert(alloc && "Expected to find a shmem alloc");
208+
allocs.push_back(alloc.getOperation());
209+
}
210+
}
211+
}
212+
assert(allocs.size() <= 1 && "Expected to find exactly one shmem alloc");
213+
if (allocs.empty()) {
214+
return std::nullopt;
215+
} else {
216+
return allocs[0];
217+
}
218+
};
219+
220+
auto lhsAllocs = collectAllocOwners(lhs);
221+
auto rhsAllocs = collectAllocOwners(rhs);
222+
return lhsAllocs.has_value() && rhsAllocs.has_value() &&
223+
lhsAllocs.value() == rhsAllocs.value();
224+
}
225+
226+
std::optional<UFDS> isIntersected(const IntervalMapT &lhsIntervalSet,
227+
const IntervalMapT &rhsIntervalSet,
228+
MembarFilterFn filter) const {
229+
// They intersect whenever the intervals intersect. If they do, collect the
230+
// union of CTA sets for any op pair that is not filtered out and does not
231+
// share the exact same explicit shared value.
232+
std::optional<UFDS> ret = std::nullopt;
233+
for (const auto &[lhsKey, lhsOps] : lhsIntervalSet) {
234+
const auto &[intervalLhs, ctasLhs] = lhsKey;
235+
for (const auto &[rhsKey, rhsOps] : rhsIntervalSet) {
236+
const auto &[intervalRhs, ctasRhs] = rhsKey;
237+
if (!intervalLhs.intersects(intervalRhs))
238+
continue;
239+
240+
auto joined = ctasLhs.join(ctasRhs);
241+
bool needsBarrier = llvm::any_of(lhsOps, [&, rhsOpsPtr = &rhsOps](
242+
const auto &lhsOp) {
243+
return llvm::any_of(*rhsOpsPtr, [&](const auto &rhsOp) {
244+
// Skip if filtered or both ops touch the same explicit shared
245+
// allocation (same local_alloc).
246+
return !((filter && filter(lhsOp, rhsOp)) ||
247+
(joined.isDistributed() && haveSameAlloc(lhsOp, rhsOp)));
248+
});
249+
});
250+
if (!needsBarrier)
251+
continue;
252+
253+
if (!ret.has_value()) {
254+
ret = joined;
255+
} else {
256+
ret = ret->join(joined);
257+
}
258+
// Single CTA case, we can early exit
259+
if (ret->size() == 1) {
260+
return ret;
261+
}
262+
}
263+
}
264+
return ret;
92265
}
93266
};
94267

@@ -170,7 +343,8 @@ class MembarAnalysis : public MembarOrFenceAnalysis {
170343
FuncBlockInfoMapT *funcBlockInfoMap,
171344
OpBuilder *builder) override;
172345

173-
void insertBarrier(Operation *operation, OpBuilder *builder);
346+
void insertBarrier(Operation *operation, OpBuilder *builder,
347+
const BlockInfo::UFDS &ctaClasses);
174348
};
175349

176350
/// Postorder traversal on the callgraph to insert membar instructions

include/triton/Analysis/Utility.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -415,6 +415,11 @@ std::unique_ptr<DataFlowSolver> createDataFlowSolver();
415415
bool isCvtWarpSync(const triton::LinearLayout &srcLayout,
416416
const triton::LinearLayout &dstLayout);
417417

418+
/// Returns the defining ttg.local_alloc (shared-memory allocation) for a
419+
/// memdesc-derived value by walking through memdesc/convert layout ops and
420+
/// loop-carried block arguments. Returns null if none is found.
421+
triton::gpu::LocalAllocOp findShmemAlloc(Value operand);
422+
418423
} // namespace mlir
419424

420425
#endif // TRITON_ANALYSIS_UTILITY_H

include/triton/Dialect/TritonGPU/Transforms/Utility.h

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -221,11 +221,6 @@ Operation *convertDistributedOpEncoding(Attribute encoding, Operation *op);
221221
// Returns the original memory allocation for a memdesc value
222222
triton::gpu::LocalAllocOp findShmemAlloc(Value operand);
223223

224-
// Returns MMAs inside a for loop that are multi-buffered for pipeline analysis
225-
SmallVector<Operation *>
226-
getMMAsWithMultiBufferredOperands(scf::ForOp forOp,
227-
SmallVector<Operation *> &mmaOps);
228-
229224
// Given a list of ops, find the naerest common dominator of all ops or return
230225
// null if one could not be found. The ops are allowed to be in different
231226
// regions. The result op is not necessarily one of the ops in the list.

include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -293,7 +293,8 @@ def TTNG_AsyncTMACopyGlobalToLocalOp : TTNG_Op<"async_tma_copy_global_to_local">
293293
I1:$pred,
294294
DefaultValuedAttr<TT_CacheModifierAttr, "triton::CacheModifier::NONE">:$cache,
295295
DefaultValuedAttr<TT_EvictionPolicyAttr, "triton::EvictionPolicy::NORMAL">:$evict,
296-
DefaultValuedAttr<BoolAttr, "false">:$isVolatile
296+
DefaultValuedAttr<BoolAttr, "false">:$isVolatile,
297+
DefaultValuedAttr<BoolAttr, "false">:$multicast
297298
);
298299

299300
let assemblyFormat = [{

0 commit comments

Comments
 (0)