@@ -15,7 +15,83 @@ class OpBuilder;
1515using MembarFilterFn = std::function<bool (Operation *, Operation *)>;
1616
1717struct BlockInfo {
18- using IntervalMapT = std::map<Interval<size_t >, std::set<Operation *>>;
18+ // Union-Find Disjoint Sets to represent cross-CTA reads/writes
19+ struct CTA_UFDS {
20+ SmallVector<unsigned > parent;
21+ SmallVector<unsigned > rank;
22+ // Invariant: At the root of a class, minRep[i] is the smallest element in
23+ // the class
24+ SmallVector<unsigned > minRep;
25+
26+ CTA_UFDS () = default ;
27+ explicit CTA_UFDS (unsigned n) : rank(n, 0 ), minRep(n) {
28+ assert (llvm::isPowerOf2_32 (n) && n != 0 );
29+ parent = llvm::to_vector (llvm::seq (n));
30+ for (unsigned i = 0 ; i < n; ++i)
31+ minRep[i] = i;
32+ }
33+
34+ unsigned find (unsigned x) const {
35+ unsigned p = parent[x];
36+ while (p != parent[p])
37+ p = parent[p];
38+ return p;
39+ }
40+
41+ unsigned findMin (unsigned x) const { return minRep[find (x)]; }
42+
43+ void unite (unsigned x, unsigned y) {
44+ x = find (x);
45+ y = find (y);
46+ if (x == y)
47+ return ;
48+
49+ if (rank[x] < rank[y])
50+ std::swap (x, y);
51+
52+ parent[y] = x;
53+ minRep[x] = std::min (minRep[x], minRep[y]);
54+
55+ if (rank[x] == rank[y])
56+ ++rank[x];
57+ }
58+
59+ CTA_UFDS join (const CTA_UFDS &other) const {
60+ // Transitive closure of two UFDS
61+ CTA_UFDS result = *this ;
62+ for (unsigned i = 0 ; i < size (); ++i)
63+ result.unite (i, other.find (i));
64+ return result;
65+ }
66+
67+ SmallVector<unsigned > canonical () const {
68+ SmallVector<unsigned > reps (size ());
69+ for (unsigned i = 0 ; i < size (); ++i)
70+ reps[i] = findMin (i);
71+ return reps;
72+ }
73+
74+ bool isDistributed () const { return *this != CTA_UFDS (parent.size ()); }
75+
76+ bool operator <(const CTA_UFDS &other) const {
77+ return canonical () < other.canonical ();
78+ }
79+ bool operator ==(const CTA_UFDS &other) const {
80+ return canonical () == other.canonical ();
81+ }
82+ bool operator !=(const CTA_UFDS &other) const { return !(*this == other); }
83+
84+ void print (raw_ostream &os) const {
85+ os << " UFDS(" ;
86+ llvm::interleaveComma (canonical (), os, [&](unsigned x) { os << x; });
87+ os << " )" ;
88+ }
89+
90+ size_t size () const { return parent.size (); }
91+ };
92+
93+ using IntervalMapT =
94+ std::map<std::pair<Interval<size_t >, CTA_UFDS>, std::set<Operation *>>;
1995
2096 IntervalMapT syncReadIntervals;
2197 IntervalMapT syncWriteIntervals;
@@ -24,48 +100,84 @@ struct BlockInfo {
24100
25101 // / Unions two BlockInfo objects.
26102 BlockInfo &join (const BlockInfo &other) {
27- for (auto &interval : other.syncReadIntervals )
28- syncReadIntervals[interval.first ].insert (interval.second .begin (),
29- interval.second .end ());
30- for (auto &interval : other.syncWriteIntervals )
31- syncWriteIntervals[interval.first ].insert (interval.second .begin (),
32- interval.second .end ());
103+ // We don't fold the intervals (we could tho)
104+ for (auto &[key, ops] : other.syncReadIntervals )
105+ syncReadIntervals[key].insert (ops.begin (), ops.end ());
106+ for (auto &[key, ops] : other.syncWriteIntervals )
107+ syncWriteIntervals[key].insert (ops.begin (), ops.end ());
33108 return *this ;
34109 }
35110
36111 void dump () {
37112 auto &err = llvm::errs ();
113+
114+ auto printKey = [&](const std::pair<Interval<size_t >, CTA_UFDS> &key) {
115+ const auto &[interval, ufds] = key;
116+ err << " [" << interval.start () << " , " << interval.end () << " ] " ;
117+ if (ufds.isDistributed ()) {
118+ ufds.print (err);
119+ err << " " ;
120+ } else if (ufds.size () == 1 ) {
121+ err << " (CTA local) " ;
122+ }
123+ };
38124 err << " Block Interval:\n " ;
39125 err << " Read Intervals:\n " ;
40- for (auto &[interval , ops] : syncReadIntervals) {
41- err << " [ " << interval. start () << " , " << interval. end () << " ] " ;
126+ for (auto &[key , ops] : syncReadIntervals) {
127+ printKey (key) ;
42128 for (auto &op : ops)
43129 err << op->getName () << " " ;
44130 err << " \n " ;
45131 }
46132 err << " Write Intervals:\n " ;
47- for (auto &[interval , ops] : syncWriteIntervals) {
48- err << " [ " << interval. start () << " , " << interval. end () << " ] " ;
133+ for (auto &[key , ops] : syncWriteIntervals) {
134+ printKey (key) ;
49135 for (auto &op : ops)
50136 err << op->getName () << " " ;
51137 err << " \n " ;
52138 }
53139 }
54140
55141 // / Returns true if intervals in two BlockInfo objects are intersected.
56- bool isIntersected (const BlockInfo &other, MembarFilterFn filter) const {
57- return /* RAW*/ isIntersected (syncWriteIntervals, other.syncReadIntervals ,
58- filter) ||
59- /* WAR*/
60- isIntersected (syncReadIntervals, other.syncWriteIntervals , filter) ||
61- /* WAW*/
62- isIntersected (syncWriteIntervals, other.syncWriteIntervals , filter);
142+ std::optional<CTA_UFDS> isIntersected (const BlockInfo &other,
143+ MembarFilterFn filter) const {
144+ auto raw =
145+ isIntersected (syncWriteIntervals, other.syncReadIntervals , filter);
146+ auto war =
147+ isIntersected (syncReadIntervals, other.syncWriteIntervals , filter);
148+ auto waw =
149+ isIntersected (syncWriteIntervals, other.syncWriteIntervals , filter);
150+ auto maybeJoin =
151+ [](const std::optional<CTA_UFDS> &lhs,
152+ const std::optional<CTA_UFDS> &rhs) -> std::optional<CTA_UFDS> {
153+ if (!lhs.has_value ())
154+ return rhs;
155+ if (!rhs.has_value ())
156+ return lhs;
157+ return lhs.value ().join (rhs.value ());
158+ };
159+ return maybeJoin (raw, maybeJoin (war, waw));
63160 }
64161
65162 // / Clears the intervals because a barrier is inserted.
66- void sync () {
67- syncReadIntervals.clear ();
68- syncWriteIntervals.clear ();
163+ // / If `cluster` is true, the barrier synchronizes all CTAs in the cluster and
164+ // / we can drop every pending dependency. Otherwise only CTA-local
165+ // / dependencies are cleared; distributed ones remain until a cluster barrier
166+ // / is observed.
167+ void sync (bool cluster) {
168+ if (cluster) {
169+ syncReadIntervals.clear ();
170+ syncWriteIntervals.clear ();
171+ } else {
172+ auto eraseNotDistributed = [](auto &map) {
173+ for (auto &[key, _] : llvm::make_early_inc_range (map)) {
174+ if (!key.second .isDistributed ())
175+ map.erase (key);
176+ }
177+ };
178+ eraseNotDistributed (syncReadIntervals);
179+ eraseNotDistributed (syncWriteIntervals);
180+ }
69181 }
70182
71183 // / Compares two BlockInfo objects.
@@ -77,18 +189,45 @@ struct BlockInfo {
77189 bool operator !=(const BlockInfo &other) const { return !(*this == other); }
78190
79191private:
80- bool isIntersected (const IntervalMapT &lhsIntervalSet,
81- const IntervalMapT &rhsIntervalSet,
82- MembarFilterFn filter) const {
83- for (auto &lhs : lhsIntervalSet)
84- for (auto &rhs : rhsIntervalSet)
85- if (lhs.first .intersects (rhs.first ))
86- for (auto lhsOp : lhs.second )
87- for (auto rhsOp : rhs.second )
88- if (!filter || !filter (lhsOp, rhsOp))
89- return true ;
90-
91- return false ;
192+ static bool haveSameAlloc (Operation *lhs, Operation *rhs);
193+
194+ std::optional<CTA_UFDS> isIntersected (const IntervalMapT &lhsIntervalSet,
195+ const IntervalMapT &rhsIntervalSet,
196+ MembarFilterFn filter) const {
197+ // They intersect whenever the intervals intersect. If they do, collect the
198+ // union of CTA sets for any op pair that is not filtered out and does not
199+ // share the exact same explicit shared value.
200+ std::optional<CTA_UFDS> ret = std::nullopt ;
201+ for (const auto &[lhsKey, lhsOps] : lhsIntervalSet) {
202+ const auto &[intervalLhs, ctasLhs] = lhsKey;
203+ for (const auto &[rhsKey, rhsOps] : rhsIntervalSet) {
204+ const auto &[intervalRhs, ctasRhs] = rhsKey;
205+ if (!intervalLhs.intersects (intervalRhs))
206+ continue ;
207+
208+ auto joined = ctasLhs.join (ctasRhs);
209+ bool skipBarrier =
210+ llvm::all_of (lhsOps, [&, rhsOpsPtr = &rhsOps](const auto &lhsOp) {
211+ return llvm::all_of (*rhsOpsPtr, [&](const auto &rhsOp) {
212+ return (filter && filter (lhsOp, rhsOp)) ||
213+ (joined.isDistributed () && haveSameAlloc (lhsOp, rhsOp));
214+ });
215+ });
216+ if (skipBarrier)
217+ continue ;
218+
219+ if (!ret.has_value ()) {
220+ ret = joined;
221+ } else {
222+ ret = ret->join (joined);
223+ }
224+ // Single CTA case, we can early exit
225+ if (ret->size () == 1 ) {
226+ return ret;
227+ }
228+ }
229+ }
230+ return ret;
92231 }
93232};
94233
@@ -170,7 +309,8 @@ class MembarAnalysis : public MembarOrFenceAnalysis {
170309 FuncBlockInfoMapT *funcBlockInfoMap,
171310 OpBuilder *builder) override ;
172311
173- void insertBarrier (Operation *operation, OpBuilder *builder);
312+ void insertBarrier (Operation *operation, OpBuilder *builder,
313+ const BlockInfo::CTA_UFDS &ctaClasses);
174314};
175315
176316// / Postorder traversal on the callgraph to insert membar instructions
0 commit comments