22#define TRITON_ANALYSIS_MEMBAR_H
33
44#include " Allocation.h"
5+ #include " mlir/Interfaces/SideEffectInterfaces.h"
6+ #include " triton/Analysis/Utility.h"
57
68#include < set>
79
@@ -15,7 +17,83 @@ class OpBuilder;
1517using MembarFilterFn = std::function<bool (Operation *, Operation *)>;
1618
1719struct BlockInfo {
18- using IntervalMapT = std::map<Interval<size_t >, std::set<Operation *>>;
20+ // UFDS to represent cross-CTA reads/writes
21+ struct UFDS {
22+ SmallVector<unsigned > parent;
23+ SmallVector<unsigned > rank;
24+ // Invariant: At the root of a class, minRep[i] is the smallest element in
25+ // the class
26+ SmallVector<unsigned > minRep;
27+
28+ UFDS () = default ;
29+ explicit UFDS (unsigned n) : rank(n, 0 ), minRep(n) {
30+ assert (llvm::isPowerOf2_32 (n) && n != 0 );
31+ parent = llvm::to_vector (llvm::seq (n));
32+ for (unsigned i = 0 ; i < n; ++i)
33+ minRep[i] = i;
34+ }
35+
36+ unsigned find (unsigned x) const {
37+ unsigned p = parent[x];
38+ while (p != parent[p])
39+ p = parent[p];
40+ return p;
41+ }
42+
43+ unsigned findMin (unsigned x) const { return minRep[find (x)]; }
44+
45+ void unite (unsigned x, unsigned y) {
46+ x = find (x);
47+ y = find (y);
48+ if (x == y)
49+ return ;
50+
51+ if (rank[x] < rank[y])
52+ std::swap (x, y);
53+
54+ parent[y] = x;
55+ minRep[x] = std::min (minRep[x], minRep[y]);
56+
57+ if (rank[x] == rank[y])
58+ ++rank[x];
59+ }
60+
61+ UFDS join (const UFDS &other) const {
62+ // Transitive closure of two UFDS
63+ UFDS result = *this ;
64+ for (unsigned i = 0 ; i < size (); ++i)
65+ result.unite (i, other.find (i));
66+ return result;
67+ }
68+
69+ SmallVector<unsigned > canonical () const {
70+ SmallVector<unsigned > reps (size ());
71+ for (unsigned i = 0 ; i < size (); ++i)
72+ reps[i] = findMin (i);
73+ return reps;
74+ }
75+
76+ bool isDistributed () const { return *this != UFDS (parent.size ()); }
77+
78+ bool operator <(const UFDS &other) const {
79+ return canonical () < other.canonical ();
80+ }
81+ bool operator ==(const UFDS &other) const {
82+ return canonical () == other.canonical ();
83+ }
84+ bool operator !=(const UFDS &other) const { return !(*this == other); }
85+
86+ void print (raw_ostream &os) const {
87+ os << " UFDS(" ;
88+ llvm::interleaveComma (canonical (), os, [&](unsigned x) { os << x; });
89+ os << " )" ;
90+ }
91+
92+ size_t size () const { return parent.size (); }
93+ };
94+
95+ using IntervalMapT =
96+ std::map<std::pair<Interval<size_t >, UFDS>, std::set<Operation *>>;
1997
2098 IntervalMapT syncReadIntervals;
2199 IntervalMapT syncWriteIntervals;
@@ -24,48 +102,82 @@ struct BlockInfo {
24102
25103 // / Unions two BlockInfo objects.
26104 BlockInfo &join (const BlockInfo &other) {
27- for (auto &interval : other.syncReadIntervals )
28- syncReadIntervals[interval.first ].insert (interval.second .begin (),
29- interval.second .end ());
30- for (auto &interval : other.syncWriteIntervals )
31- syncWriteIntervals[interval.first ].insert (interval.second .begin (),
32- interval.second .end ());
105+ // We don't fold the intervals (we could tho)
106+ for (auto &[key, ops] : other.syncReadIntervals )
107+ syncReadIntervals[key].insert (ops.begin (), ops.end ());
108+ for (auto &[key, ops] : other.syncWriteIntervals )
109+ syncWriteIntervals[key].insert (ops.begin (), ops.end ());
33110 return *this ;
34111 }
35112
36113 void dump () {
37114 auto &err = llvm::errs ();
115+
116+ auto printKey = [&](const std::pair<Interval<size_t >, UFDS> &key) {
117+ const auto &[interval, ufds] = key;
118+ err << " [" << interval.start () << " , " << interval.end () << " ] " ;
119+ if (ufds.isDistributed ())
120+ ufds.print (err);
121+ else if (ufds.size () == 1 )
122+ err << " (CTA local)" ;
123+ };
38124 err << " Block Interval:\n " ;
39125 err << " Read Intervals:\n " ;
40- for (auto &[interval , ops] : syncReadIntervals) {
41- err << " [ " << interval. start () << " , " << interval. end () << " ] " ;
126+ for (auto &[key , ops] : syncReadIntervals) {
127+ printKey (key) ;
42128 for (auto &op : ops)
43129 err << op->getName () << " " ;
44130 err << " \n " ;
45131 }
46132 err << " Write Intervals:\n " ;
47- for (auto &[interval , ops] : syncWriteIntervals) {
48- err << " [ " << interval. start () << " , " << interval. end () << " ] " ;
133+ for (auto &[key , ops] : syncWriteIntervals) {
134+ printKey (key) ;
49135 for (auto &op : ops)
50136 err << op->getName () << " " ;
51137 err << " \n " ;
52138 }
53139 }
54140
55141 // / Returns true if intervals in two BlockInfo objects are intersected.
56- bool isIntersected (const BlockInfo &other, MembarFilterFn filter) const {
57- return /* RAW*/ isIntersected (syncWriteIntervals, other.syncReadIntervals ,
58- filter) ||
59- /* WAR*/
60- isIntersected (syncReadIntervals, other.syncWriteIntervals , filter) ||
61- /* WAW*/
62- isIntersected (syncWriteIntervals, other.syncWriteIntervals , filter);
142+ std::optional<UFDS> isIntersected (const BlockInfo &other,
143+ MembarFilterFn filter) const {
144+ auto raw =
145+ isIntersected (syncWriteIntervals, other.syncReadIntervals , filter);
146+ auto war =
147+ isIntersected (syncReadIntervals, other.syncWriteIntervals , filter);
148+ auto waw =
149+ isIntersected (syncWriteIntervals, other.syncWriteIntervals , filter);
150+ auto maybeJoin = [](const std::optional<UFDS> &lhs,
151+ const std::optional<UFDS> &rhs) -> std::optional<UFDS> {
152+ if (!lhs.has_value ())
153+ return rhs;
154+ if (!rhs.has_value ())
155+ return lhs;
156+ return lhs.value ().join (rhs.value ());
157+ };
158+ return maybeJoin (raw, maybeJoin (war, waw));
63159 }
64160
65161 // / Clears the intervals because a barrier is inserted.
66- void sync () {
67- syncReadIntervals.clear ();
68- syncWriteIntervals.clear ();
162+ // / If `cluster` is true, the barrier synchronizes all CTAs in the cluster and
163+ // / we can drop every pending dependency. Otherwise only CTA-local
164+ // / dependencies are cleared; distributed ones remain until a cluster barrier
165+ // / is observed.
166+ void sync (bool cluster) {
167+ if (cluster) {
168+ syncReadIntervals.clear ();
169+ syncWriteIntervals.clear ();
170+ return ;
171+ } else {
172+ auto eraseNotDistributed = [](auto &map) {
173+ for (auto &[key, _] : llvm::make_early_inc_range (map)) {
174+ if (!key.second .isDistributed ())
175+ map.erase (key);
176+ }
177+ };
178+ eraseNotDistributed (syncReadIntervals);
179+ eraseNotDistributed (syncWriteIntervals);
180+ }
69181 }
70182
71183 // / Compares two BlockInfo objects.
@@ -77,18 +189,79 @@ struct BlockInfo {
77189 bool operator !=(const BlockInfo &other) const { return !(*this == other); }
78190
79191private:
80- bool isIntersected (const IntervalMapT &lhsIntervalSet,
81- const IntervalMapT &rhsIntervalSet,
82- MembarFilterFn filter) const {
83- for (auto &lhs : lhsIntervalSet)
84- for (auto &rhs : rhsIntervalSet)
85- if (lhs.first .intersects (rhs.first ))
86- for (auto lhsOp : lhs.second )
87- for (auto rhsOp : rhs.second )
88- if (!filter || !filter (lhsOp, rhsOp))
89- return true ;
90-
91- return false ;
192+ static bool haveSameAlloc (Operation *lhs, Operation *rhs) {
193+ auto collectAllocOwners = [](Operation *op) -> std::optional<Operation *> {
194+ SmallVector<Operation *> allocs;
195+ if (auto mei = dyn_cast<MemoryEffectOpInterface>(op)) {
196+ SmallVector<SideEffects::EffectInstance<MemoryEffects::Effect>> effs;
197+ mei.getEffects (effs);
198+ for (auto &eff : effs) {
199+ if (eff.getResource () != triton::gpu::SharedMemory::get ())
200+ continue ;
201+ if (auto v = eff.getValue ()) {
202+ // Hacky way to skip barriers...
203+ if (cast<triton::gpu::MemDescType>(v.getType ()).getNumElements () ==
204+ 1 )
205+ continue ;
206+ auto alloc = findShmemAlloc (v);
207+ assert (alloc && " Expected to find a shmem alloc" );
208+ allocs.push_back (alloc.getOperation ());
209+ }
210+ }
211+ }
212+ assert (allocs.size () <= 1 && " Expected to find exactly one shmem alloc" );
213+ if (allocs.empty ()) {
214+ return std::nullopt ;
215+ } else {
216+ return allocs[0 ];
217+ }
218+ };
219+
220+ auto lhsAllocs = collectAllocOwners (lhs);
221+ auto rhsAllocs = collectAllocOwners (rhs);
222+ return lhsAllocs.has_value () && rhsAllocs.has_value () &&
223+ lhsAllocs.value () == rhsAllocs.value ();
224+ }
225+
226+ std::optional<UFDS> isIntersected (const IntervalMapT &lhsIntervalSet,
227+ const IntervalMapT &rhsIntervalSet,
228+ MembarFilterFn filter) const {
229+ // They intersect whenever the intervals intersect. If they do, collect the
230+ // union of CTA sets for any op pair that is not filtered out and does not
231+ // share the exact same explicit shared value.
232+ std::optional<UFDS> ret = std::nullopt ;
233+ for (const auto &[lhsKey, lhsOps] : lhsIntervalSet) {
234+ const auto &[intervalLhs, ctasLhs] = lhsKey;
235+ for (const auto &[rhsKey, rhsOps] : rhsIntervalSet) {
236+ const auto &[intervalRhs, ctasRhs] = rhsKey;
237+ if (!intervalLhs.intersects (intervalRhs))
238+ continue ;
239+
240+ auto joined = ctasLhs.join (ctasRhs);
241+ bool needsBarrier = llvm::any_of (lhsOps, [&, rhsOpsPtr = &rhsOps](
242+ const auto &lhsOp) {
243+ return llvm::any_of (*rhsOpsPtr, [&](const auto &rhsOp) {
244+ // Skip if filtered or both ops touch the same explicit shared
245+ // allocation (same local_alloc).
246+ return !((filter && filter (lhsOp, rhsOp)) ||
247+ (joined.isDistributed () && haveSameAlloc (lhsOp, rhsOp)));
248+ });
249+ });
250+ if (!needsBarrier)
251+ continue ;
252+
253+ if (!ret.has_value ()) {
254+ ret = joined;
255+ } else {
256+ ret = ret->join (joined);
257+ }
258+ // Single CTA case, we can early exit
259+ if (ret->size () == 1 ) {
260+ return ret;
261+ }
262+ }
263+ }
264+ return ret;
92265 }
93266};
94267
@@ -170,7 +343,8 @@ class MembarAnalysis : public MembarOrFenceAnalysis {
170343 FuncBlockInfoMapT *funcBlockInfoMap,
171344 OpBuilder *builder) override ;
172345
173- void insertBarrier (Operation *operation, OpBuilder *builder);
346+ void insertBarrier (Operation *operation, OpBuilder *builder,
347+ const BlockInfo::UFDS &ctaClasses);
174348};
175349
176350// / Postorder traversal on the callgraph to insert membar instructions
0 commit comments