Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
211 changes: 177 additions & 34 deletions include/triton/Analysis/Membar.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,83 @@ class OpBuilder;
using MembarFilterFn = std::function<bool(Operation *, Operation *)>;

struct BlockInfo {
using IntervalMapT = std::map<Interval<size_t>, std::set<Operation *>>;
// Union-Find Disjoint Sets to represent cross-CTA reads/writes
struct CTA_UFDS {
SmallVector<unsigned> parent;
SmallVector<unsigned> rank;
// Invariant: At the root of a class, minRep[i] is the smallest element in
// the class
SmallVector<unsigned> minRep;

CTA_UFDS() = default;
explicit CTA_UFDS(unsigned n) : rank(n, 0), minRep(n) {
assert(llvm::isPowerOf2_32(n) && n != 0);
parent = llvm::to_vector(llvm::seq(n));
for (unsigned i = 0; i < n; ++i)
minRep[i] = i;
}

unsigned find(unsigned x) const {
unsigned p = parent[x];
while (p != parent[p])
p = parent[p];
return p;
}

unsigned findMin(unsigned x) const { return minRep[find(x)]; }

void unite(unsigned x, unsigned y) {
x = find(x);
y = find(y);
if (x == y)
return;

if (rank[x] < rank[y])
std::swap(x, y);

parent[y] = x;
minRep[x] = std::min(minRep[x], minRep[y]);

if (rank[x] == rank[y])
++rank[x];
}

CTA_UFDS join(const CTA_UFDS &other) const {
// Transitive closure of two UFDS
CTA_UFDS result = *this;
for (unsigned i = 0; i < size(); ++i)
result.unite(i, other.find(i));
return result;
}

SmallVector<unsigned> canonical() const {
SmallVector<unsigned> reps(size());
for (unsigned i = 0; i < size(); ++i)
reps[i] = findMin(i);
return reps;
}

bool isDistributed() const { return *this != CTA_UFDS(parent.size()); }

bool operator<(const CTA_UFDS &other) const {
return canonical() < other.canonical();
}
bool operator==(const CTA_UFDS &other) const {
return canonical() == other.canonical();
}
bool operator!=(const CTA_UFDS &other) const { return !(*this == other); }

void print(raw_ostream &os) const {
os << "UFDS(";
llvm::interleaveComma(canonical(), os, [&](unsigned x) { os << x; });
os << ")";
}

size_t size() const { return parent.size(); }
};

using IntervalMapT =
std::map<std::pair<Interval<size_t>, CTA_UFDS>, std::set<Operation *>>;

IntervalMapT syncReadIntervals;
IntervalMapT syncWriteIntervals;
Expand All @@ -24,48 +100,85 @@ struct BlockInfo {

/// Unions two BlockInfo objects.
BlockInfo &join(const BlockInfo &other) {
for (auto &interval : other.syncReadIntervals)
syncReadIntervals[interval.first].insert(interval.second.begin(),
interval.second.end());
for (auto &interval : other.syncWriteIntervals)
syncWriteIntervals[interval.first].insert(interval.second.begin(),
interval.second.end());
// We don't fold the intervals (we could tho)
for (auto &[key, ops] : other.syncReadIntervals)
syncReadIntervals[key].insert(ops.begin(), ops.end());
for (auto &[key, ops] : other.syncWriteIntervals)
syncWriteIntervals[key].insert(ops.begin(), ops.end());
return *this;
}

void dump() {
auto &err = llvm::errs();

auto printKey = [&](const std::pair<Interval<size_t>, CTA_UFDS> &key) {
const auto &[interval, ufds] = key;
err << " [" << interval.start() << ", " << interval.end() << "] ";
if (ufds.isDistributed()) {
ufds.print(err);
err << " ";
} else if (ufds.size() == 1) {
err << " (CTA local) ";
}
};
err << "Block Interval:\n";
err << " Read Intervals:\n";
for (auto &[interval, ops] : syncReadIntervals) {
err << " [" << interval.start() << ", " << interval.end() << "] ";
for (auto &[key, ops] : syncReadIntervals) {
printKey(key);
for (auto &op : ops)
err << op->getName() << " ";
err << "\n";
}
err << " Write Intervals:\n";
for (auto &[interval, ops] : syncWriteIntervals) {
err << " [" << interval.start() << ", " << interval.end() << "] ";
for (auto &[key, ops] : syncWriteIntervals) {
printKey(key);
for (auto &op : ops)
err << op->getName() << " ";
err << "\n";
}
}

/// Returns true if intervals in two BlockInfo objects are intersected.
bool isIntersected(const BlockInfo &other, MembarFilterFn filter) const {
return /*RAW*/ isIntersected(syncWriteIntervals, other.syncReadIntervals,
filter) ||
/*WAR*/
isIntersected(syncReadIntervals, other.syncWriteIntervals, filter) ||
/*WAW*/
isIntersected(syncWriteIntervals, other.syncWriteIntervals, filter);
std::optional<CTA_UFDS> isIntersected(const BlockInfo &other,
MembarFilterFn filter,
Allocation *allocation) const {
auto raw = isIntersected(syncWriteIntervals, other.syncReadIntervals,
filter, allocation);
auto war = isIntersected(syncReadIntervals, other.syncWriteIntervals,
filter, allocation);
auto waw = isIntersected(syncWriteIntervals, other.syncWriteIntervals,
filter, allocation);
auto maybeJoin =
[](const std::optional<CTA_UFDS> &lhs,
const std::optional<CTA_UFDS> &rhs) -> std::optional<CTA_UFDS> {
if (!lhs.has_value())
return rhs;
if (!rhs.has_value())
return lhs;
return lhs.value().join(rhs.value());
};
return maybeJoin(raw, maybeJoin(war, waw));
}

/// Clears the intervals because a barrier is inserted.
void sync() {
syncReadIntervals.clear();
syncWriteIntervals.clear();
/// If `cluster` is true, the barrier synchronizes all CTAs in the cluster and
/// we can drop every pending dependency. Otherwise only CTA-local
/// dependencies are cleared; distributed ones remain until a cluster barrier
/// is observed.
void sync(bool cluster) {
if (cluster) {
syncReadIntervals.clear();
syncWriteIntervals.clear();
} else {
auto eraseNotDistributed = [](auto &map) {
for (auto &[key, _] : llvm::make_early_inc_range(map)) {
if (!key.second.isDistributed())
map.erase(key);
}
};
eraseNotDistributed(syncReadIntervals);
eraseNotDistributed(syncWriteIntervals);
}
}

/// Compares two BlockInfo objects.
Expand All @@ -77,18 +190,47 @@ struct BlockInfo {
bool operator!=(const BlockInfo &other) const { return !(*this == other); }

private:
bool isIntersected(const IntervalMapT &lhsIntervalSet,
const IntervalMapT &rhsIntervalSet,
MembarFilterFn filter) const {
for (auto &lhs : lhsIntervalSet)
for (auto &rhs : rhsIntervalSet)
if (lhs.first.intersects(rhs.first))
for (auto lhsOp : lhs.second)
for (auto rhsOp : rhs.second)
if (!filter || !filter(lhsOp, rhsOp))
return true;

return false;
static bool haveSameAlloc(Operation *lhs, Operation *rhs,
Allocation *allocation);

std::optional<CTA_UFDS> isIntersected(const IntervalMapT &lhsIntervalSet,
const IntervalMapT &rhsIntervalSet,
MembarFilterFn filter,
Allocation *allocation) const {
// They intersect whenever the intervals intersect. If they do, collect the
// union of CTA sets for any op pair that is not filtered out and does not
// share the exact same explicit shared value.
std::optional<CTA_UFDS> ret = std::nullopt;
for (const auto &[lhsKey, lhsOps] : lhsIntervalSet) {
const auto &[intervalLhs, ctasLhs] = lhsKey;
for (const auto &[rhsKey, rhsOps] : rhsIntervalSet) {
const auto &[intervalRhs, ctasRhs] = rhsKey;
if (!intervalLhs.intersects(intervalRhs))
continue;

auto joined = ctasLhs.join(ctasRhs);
bool skipBarrier =
llvm::all_of(lhsOps, [&, rhsOpsPtr = &rhsOps](const auto &lhsOp) {
return llvm::all_of(*rhsOpsPtr, [&](const auto &rhsOp) {
return (filter && filter(lhsOp, rhsOp)) ||
(joined.isDistributed() &&
haveSameAlloc(lhsOp, rhsOp, allocation));
});
});
if (skipBarrier)
continue;
if (!ret.has_value()) {
ret = joined;
} else {
ret = ret->join(joined);
}
// Single CTA case, we can early exit
if (ret->size() == 1) {
return ret;
}
}
}
return ret;
}
};

Expand Down Expand Up @@ -170,7 +312,8 @@ class MembarAnalysis : public MembarOrFenceAnalysis {
FuncBlockInfoMapT *funcBlockInfoMap,
OpBuilder *builder) override;

void insertBarrier(Operation *operation, OpBuilder *builder);
void insertBarrier(Operation *operation, OpBuilder *builder,
const BlockInfo::CTA_UFDS &ctaClasses);
};

/// Postorder traversal on the callgraph to insert membar instructions
Expand Down
1 change: 0 additions & 1 deletion include/triton/Analysis/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -414,7 +414,6 @@ std::unique_ptr<DataFlowSolver> createDataFlowSolver();

bool isCvtWarpSync(const triton::LinearLayout &srcLayout,
const triton::LinearLayout &dstLayout);

} // namespace mlir

#endif // TRITON_ANALYSIS_UTILITY_H
8 changes: 0 additions & 8 deletions include/triton/Dialect/TritonGPU/Transforms/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -218,14 +218,6 @@ getSharedEncIfAllUsersAreDotEnc(Value val, bool &incompatible);
// Skips operands if they're in shared encoding.
Operation *convertDistributedOpEncoding(Attribute encoding, Operation *op);

// Returns the original memory allocation for a memdesc value
triton::gpu::LocalAllocOp findShmemAlloc(Value operand);

// Returns MMAs inside a for loop that are multi-buffered for pipeline analysis
SmallVector<Operation *>
getMMAsWithMultiBufferredOperands(scf::ForOp forOp,
SmallVector<Operation *> &mmaOps);

// Given a list of ops, find the naerest common dominator of all ops or return
// null if one could not be found. The ops are allowed to be in different
// regions. The result op is not necessarily one of the ops in the list.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,8 @@ def TTNG_AsyncTMACopyGlobalToLocalOp : TTNG_Op<"async_tma_copy_global_to_local">
I1:$pred,
DefaultValuedAttr<TT_CacheModifierAttr, "triton::CacheModifier::NONE">:$cache,
DefaultValuedAttr<TT_EvictionPolicyAttr, "triton::EvictionPolicy::NORMAL">:$evict,
DefaultValuedAttr<BoolAttr, "false">:$isVolatile
DefaultValuedAttr<BoolAttr, "false">:$isVolatile,
DefaultValuedAttr<BoolAttr, "false">:$multicast
);

let assemblyFormat = [{
Expand Down
Loading