Skip to content

Commit 81d4eaa

Browse files
committed
[WIP][BACKEND] Generalize the MemBar to consider cross-CTA ops
The semantics here are that it's the user's/compiler's responsability to add the relevant synchronisation if they reuse the same shmem buffer, but otherwise the compiler will do so.
1 parent 0f235ee commit 81d4eaa

File tree

10 files changed

+502
-136
lines changed

10 files changed

+502
-136
lines changed

include/triton/Analysis/Membar.h

Lines changed: 174 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,83 @@ class OpBuilder;
1515
using MembarFilterFn = std::function<bool(Operation *, Operation *)>;
1616

1717
struct BlockInfo {
18-
using IntervalMapT = std::map<Interval<size_t>, std::set<Operation *>>;
18+
// Union-Find Disjoint Sets to represent cross-CTA reads/writes
19+
struct CTA_UFDS {
20+
SmallVector<unsigned> parent;
21+
SmallVector<unsigned> rank;
22+
// Invariant: At the root of a class, minRep[i] is the smallest element in
23+
// the class
24+
SmallVector<unsigned> minRep;
25+
26+
CTA_UFDS() = default;
27+
explicit CTA_UFDS(unsigned n) : rank(n, 0), minRep(n) {
28+
assert(llvm::isPowerOf2_32(n) && n != 0);
29+
parent = llvm::to_vector(llvm::seq(n));
30+
for (unsigned i = 0; i < n; ++i)
31+
minRep[i] = i;
32+
}
33+
34+
unsigned find(unsigned x) const {
35+
unsigned p = parent[x];
36+
while (p != parent[p])
37+
p = parent[p];
38+
return p;
39+
}
40+
41+
unsigned findMin(unsigned x) const { return minRep[find(x)]; }
42+
43+
void unite(unsigned x, unsigned y) {
44+
x = find(x);
45+
y = find(y);
46+
if (x == y)
47+
return;
48+
49+
if (rank[x] < rank[y])
50+
std::swap(x, y);
51+
52+
parent[y] = x;
53+
minRep[x] = std::min(minRep[x], minRep[y]);
54+
55+
if (rank[x] == rank[y])
56+
++rank[x];
57+
}
58+
59+
CTA_UFDS join(const CTA_UFDS &other) const {
60+
// Transitive closure of two UFDS
61+
CTA_UFDS result = *this;
62+
for (unsigned i = 0; i < size(); ++i)
63+
result.unite(i, other.find(i));
64+
return result;
65+
}
66+
67+
SmallVector<unsigned> canonical() const {
68+
SmallVector<unsigned> reps(size());
69+
for (unsigned i = 0; i < size(); ++i)
70+
reps[i] = findMin(i);
71+
return reps;
72+
}
73+
74+
bool isDistributed() const { return *this != CTA_UFDS(parent.size()); }
75+
76+
bool operator<(const CTA_UFDS &other) const {
77+
return canonical() < other.canonical();
78+
}
79+
bool operator==(const CTA_UFDS &other) const {
80+
return canonical() == other.canonical();
81+
}
82+
bool operator!=(const CTA_UFDS &other) const { return !(*this == other); }
83+
84+
void print(raw_ostream &os) const {
85+
os << "UFDS(";
86+
llvm::interleaveComma(canonical(), os, [&](unsigned x) { os << x; });
87+
os << ")";
88+
}
89+
90+
size_t size() const { return parent.size(); }
91+
};
92+
93+
using IntervalMapT =
94+
std::map<std::pair<Interval<size_t>, CTA_UFDS>, std::set<Operation *>>;
1995

2096
IntervalMapT syncReadIntervals;
2197
IntervalMapT syncWriteIntervals;
@@ -24,48 +100,84 @@ struct BlockInfo {
24100

25101
/// Unions two BlockInfo objects.
26102
BlockInfo &join(const BlockInfo &other) {
27-
for (auto &interval : other.syncReadIntervals)
28-
syncReadIntervals[interval.first].insert(interval.second.begin(),
29-
interval.second.end());
30-
for (auto &interval : other.syncWriteIntervals)
31-
syncWriteIntervals[interval.first].insert(interval.second.begin(),
32-
interval.second.end());
103+
// We don't fold the intervals (we could tho)
104+
for (auto &[key, ops] : other.syncReadIntervals)
105+
syncReadIntervals[key].insert(ops.begin(), ops.end());
106+
for (auto &[key, ops] : other.syncWriteIntervals)
107+
syncWriteIntervals[key].insert(ops.begin(), ops.end());
33108
return *this;
34109
}
35110

36111
void dump() {
37112
auto &err = llvm::errs();
113+
114+
auto printKey = [&](const std::pair<Interval<size_t>, CTA_UFDS> &key) {
115+
const auto &[interval, ufds] = key;
116+
err << " [" << interval.start() << ", " << interval.end() << "] ";
117+
if (ufds.isDistributed()) {
118+
ufds.print(err);
119+
err << " ";
120+
} else if (ufds.size() == 1) {
121+
err << " (CTA local) ";
122+
}
123+
};
38124
err << "Block Interval:\n";
39125
err << " Read Intervals:\n";
40-
for (auto &[interval, ops] : syncReadIntervals) {
41-
err << " [" << interval.start() << ", " << interval.end() << "] ";
126+
for (auto &[key, ops] : syncReadIntervals) {
127+
printKey(key);
42128
for (auto &op : ops)
43129
err << op->getName() << " ";
44130
err << "\n";
45131
}
46132
err << " Write Intervals:\n";
47-
for (auto &[interval, ops] : syncWriteIntervals) {
48-
err << " [" << interval.start() << ", " << interval.end() << "] ";
133+
for (auto &[key, ops] : syncWriteIntervals) {
134+
printKey(key);
49135
for (auto &op : ops)
50136
err << op->getName() << " ";
51137
err << "\n";
52138
}
53139
}
54140

55141
/// Returns true if intervals in two BlockInfo objects are intersected.
56-
bool isIntersected(const BlockInfo &other, MembarFilterFn filter) const {
57-
return /*RAW*/ isIntersected(syncWriteIntervals, other.syncReadIntervals,
58-
filter) ||
59-
/*WAR*/
60-
isIntersected(syncReadIntervals, other.syncWriteIntervals, filter) ||
61-
/*WAW*/
62-
isIntersected(syncWriteIntervals, other.syncWriteIntervals, filter);
142+
std::optional<CTA_UFDS> isIntersected(const BlockInfo &other,
143+
MembarFilterFn filter) const {
144+
auto raw =
145+
isIntersected(syncWriteIntervals, other.syncReadIntervals, filter);
146+
auto war =
147+
isIntersected(syncReadIntervals, other.syncWriteIntervals, filter);
148+
auto waw =
149+
isIntersected(syncWriteIntervals, other.syncWriteIntervals, filter);
150+
auto maybeJoin =
151+
[](const std::optional<CTA_UFDS> &lhs,
152+
const std::optional<CTA_UFDS> &rhs) -> std::optional<CTA_UFDS> {
153+
if (!lhs.has_value())
154+
return rhs;
155+
if (!rhs.has_value())
156+
return lhs;
157+
return lhs.value().join(rhs.value());
158+
};
159+
return maybeJoin(raw, maybeJoin(war, waw));
63160
}
64161

65162
/// Clears the intervals because a barrier is inserted.
66-
void sync() {
67-
syncReadIntervals.clear();
68-
syncWriteIntervals.clear();
163+
/// If `cluster` is true, the barrier synchronizes all CTAs in the cluster and
164+
/// we can drop every pending dependency. Otherwise only CTA-local
165+
/// dependencies are cleared; distributed ones remain until a cluster barrier
166+
/// is observed.
167+
void sync(bool cluster) {
168+
if (cluster) {
169+
syncReadIntervals.clear();
170+
syncWriteIntervals.clear();
171+
} else {
172+
auto eraseNotDistributed = [](auto &map) {
173+
for (auto &[key, _] : llvm::make_early_inc_range(map)) {
174+
if (!key.second.isDistributed())
175+
map.erase(key);
176+
}
177+
};
178+
eraseNotDistributed(syncReadIntervals);
179+
eraseNotDistributed(syncWriteIntervals);
180+
}
69181
}
70182

71183
/// Compares two BlockInfo objects.
@@ -77,18 +189,45 @@ struct BlockInfo {
77189
bool operator!=(const BlockInfo &other) const { return !(*this == other); }
78190

79191
private:
80-
bool isIntersected(const IntervalMapT &lhsIntervalSet,
81-
const IntervalMapT &rhsIntervalSet,
82-
MembarFilterFn filter) const {
83-
for (auto &lhs : lhsIntervalSet)
84-
for (auto &rhs : rhsIntervalSet)
85-
if (lhs.first.intersects(rhs.first))
86-
for (auto lhsOp : lhs.second)
87-
for (auto rhsOp : rhs.second)
88-
if (!filter || !filter(lhsOp, rhsOp))
89-
return true;
90-
91-
return false;
192+
static bool haveSameAlloc(Operation *lhs, Operation *rhs);
193+
194+
std::optional<CTA_UFDS> isIntersected(const IntervalMapT &lhsIntervalSet,
195+
const IntervalMapT &rhsIntervalSet,
196+
MembarFilterFn filter) const {
197+
// They intersect whenever the intervals intersect. If they do, collect the
198+
// union of CTA sets for any op pair that is not filtered out and does not
199+
// share the exact same explicit shared value.
200+
std::optional<CTA_UFDS> ret = std::nullopt;
201+
for (const auto &[lhsKey, lhsOps] : lhsIntervalSet) {
202+
const auto &[intervalLhs, ctasLhs] = lhsKey;
203+
for (const auto &[rhsKey, rhsOps] : rhsIntervalSet) {
204+
const auto &[intervalRhs, ctasRhs] = rhsKey;
205+
if (!intervalLhs.intersects(intervalRhs))
206+
continue;
207+
208+
auto joined = ctasLhs.join(ctasRhs);
209+
bool skipBarrier =
210+
llvm::all_of(lhsOps, [&, rhsOpsPtr = &rhsOps](const auto &lhsOp) {
211+
return llvm::all_of(*rhsOpsPtr, [&](const auto &rhsOp) {
212+
return (filter && filter(lhsOp, rhsOp)) ||
213+
(joined.isDistributed() && haveSameAlloc(lhsOp, rhsOp));
214+
});
215+
});
216+
if (skipBarrier)
217+
continue;
218+
219+
if (!ret.has_value()) {
220+
ret = joined;
221+
} else {
222+
ret = ret->join(joined);
223+
}
224+
// Single CTA case, we can early exit
225+
if (ret->size() == 1) {
226+
return ret;
227+
}
228+
}
229+
}
230+
return ret;
92231
}
93232
};
94233

@@ -170,7 +309,8 @@ class MembarAnalysis : public MembarOrFenceAnalysis {
170309
FuncBlockInfoMapT *funcBlockInfoMap,
171310
OpBuilder *builder) override;
172311

173-
void insertBarrier(Operation *operation, OpBuilder *builder);
312+
void insertBarrier(Operation *operation, OpBuilder *builder,
313+
const BlockInfo::CTA_UFDS &ctaClasses);
174314
};
175315

176316
/// Postorder traversal on the callgraph to insert membar instructions

include/triton/Analysis/Utility.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -414,7 +414,6 @@ std::unique_ptr<DataFlowSolver> createDataFlowSolver();
414414

415415
bool isCvtWarpSync(const triton::LinearLayout &srcLayout,
416416
const triton::LinearLayout &dstLayout);
417-
418417
} // namespace mlir
419418

420419
#endif // TRITON_ANALYSIS_UTILITY_H

include/triton/Dialect/TritonGPU/Transforms/Utility.h

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -218,14 +218,6 @@ getSharedEncIfAllUsersAreDotEnc(Value val, bool &incompatible);
218218
// Skips operands if they're in shared encoding.
219219
Operation *convertDistributedOpEncoding(Attribute encoding, Operation *op);
220220

221-
// Returns the original memory allocation for a memdesc value
222-
triton::gpu::LocalAllocOp findShmemAlloc(Value operand);
223-
224-
// Returns MMAs inside a for loop that are multi-buffered for pipeline analysis
225-
SmallVector<Operation *>
226-
getMMAsWithMultiBufferredOperands(scf::ForOp forOp,
227-
SmallVector<Operation *> &mmaOps);
228-
229221
// Given a list of ops, find the naerest common dominator of all ops or return
230222
// null if one could not be found. The ops are allowed to be in different
231223
// regions. The result op is not necessarily one of the ops in the list.

include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -293,7 +293,8 @@ def TTNG_AsyncTMACopyGlobalToLocalOp : TTNG_Op<"async_tma_copy_global_to_local">
293293
I1:$pred,
294294
DefaultValuedAttr<TT_CacheModifierAttr, "triton::CacheModifier::NONE">:$cache,
295295
DefaultValuedAttr<TT_EvictionPolicyAttr, "triton::EvictionPolicy::NORMAL">:$evict,
296-
DefaultValuedAttr<BoolAttr, "false">:$isVolatile
296+
DefaultValuedAttr<BoolAttr, "false">:$isVolatile,
297+
DefaultValuedAttr<BoolAttr, "false">:$multicast
297298
);
298299

299300
let assemblyFormat = [{

0 commit comments

Comments
 (0)