diff --git a/include/triton/Analysis/Membar.h b/include/triton/Analysis/Membar.h index c76289801762..05f7a0f84d50 100644 --- a/include/triton/Analysis/Membar.h +++ b/include/triton/Analysis/Membar.h @@ -15,7 +15,83 @@ class OpBuilder; using MembarFilterFn = std::function; struct BlockInfo { - using IntervalMapT = std::map, std::set>; + // Union-Find Disjoint Sets to represent cross-CTA reads/writes + struct CTA_UFDS { + SmallVector parent; + SmallVector rank; + // Invariant: At the root of a class, minRep[i] is the smallest element in + // the class + SmallVector minRep; + + CTA_UFDS() = default; + explicit CTA_UFDS(unsigned n) : rank(n, 0), minRep(n) { + assert(llvm::isPowerOf2_32(n) && n != 0); + parent = llvm::to_vector(llvm::seq(n)); + for (unsigned i = 0; i < n; ++i) + minRep[i] = i; + } + + unsigned find(unsigned x) const { + unsigned p = parent[x]; + while (p != parent[p]) + p = parent[p]; + return p; + } + + unsigned findMin(unsigned x) const { return minRep[find(x)]; } + + void unite(unsigned x, unsigned y) { + x = find(x); + y = find(y); + if (x == y) + return; + + if (rank[x] < rank[y]) + std::swap(x, y); + + parent[y] = x; + minRep[x] = std::min(minRep[x], minRep[y]); + + if (rank[x] == rank[y]) + ++rank[x]; + } + + CTA_UFDS join(const CTA_UFDS &other) const { + // Transitive closure of two UFDS + CTA_UFDS result = *this; + for (unsigned i = 0; i < size(); ++i) + result.unite(i, other.find(i)); + return result; + } + + SmallVector canonical() const { + SmallVector reps(size()); + for (unsigned i = 0; i < size(); ++i) + reps[i] = findMin(i); + return reps; + } + + bool isDistributed() const { return *this != CTA_UFDS(parent.size()); } + + bool operator<(const CTA_UFDS &other) const { + return canonical() < other.canonical(); + } + bool operator==(const CTA_UFDS &other) const { + return canonical() == other.canonical(); + } + bool operator!=(const CTA_UFDS &other) const { return !(*this == other); } + + void print(raw_ostream &os) const { + os << "UFDS("; + llvm::interleaveComma(canonical(), os, [&](unsigned x) { os << x; }); + os << ")"; + } + + size_t size() const { return parent.size(); } + }; + + using IntervalMapT = + std::map, CTA_UFDS>, std::set>; IntervalMapT syncReadIntervals; IntervalMapT syncWriteIntervals; @@ -24,28 +100,38 @@ struct BlockInfo { /// Unions two BlockInfo objects. BlockInfo &join(const BlockInfo &other) { - for (auto &interval : other.syncReadIntervals) - syncReadIntervals[interval.first].insert(interval.second.begin(), - interval.second.end()); - for (auto &interval : other.syncWriteIntervals) - syncWriteIntervals[interval.first].insert(interval.second.begin(), - interval.second.end()); + // We don't fold the intervals (we could tho) + for (auto &[key, ops] : other.syncReadIntervals) + syncReadIntervals[key].insert(ops.begin(), ops.end()); + for (auto &[key, ops] : other.syncWriteIntervals) + syncWriteIntervals[key].insert(ops.begin(), ops.end()); return *this; } void dump() { auto &err = llvm::errs(); + + auto printKey = [&](const std::pair, CTA_UFDS> &key) { + const auto &[interval, ufds] = key; + err << " [" << interval.start() << ", " << interval.end() << "] "; + if (ufds.isDistributed()) { + ufds.print(err); + err << " "; + } else if (ufds.size() == 1) { + err << " (CTA local) "; + } + }; err << "Block Interval:\n"; err << " Read Intervals:\n"; - for (auto &[interval, ops] : syncReadIntervals) { - err << " [" << interval.start() << ", " << interval.end() << "] "; + for (auto &[key, ops] : syncReadIntervals) { + printKey(key); for (auto &op : ops) err << op->getName() << " "; err << "\n"; } err << " Write Intervals:\n"; - for (auto &[interval, ops] : syncWriteIntervals) { - err << " [" << interval.start() << ", " << interval.end() << "] "; + for (auto &[key, ops] : syncWriteIntervals) { + printKey(key); for (auto &op : ops) err << op->getName() << " "; err << "\n"; @@ -53,19 +139,46 @@ struct BlockInfo { } /// Returns true if intervals in two BlockInfo objects are intersected. - bool isIntersected(const BlockInfo &other, MembarFilterFn filter) const { - return /*RAW*/ isIntersected(syncWriteIntervals, other.syncReadIntervals, - filter) || - /*WAR*/ - isIntersected(syncReadIntervals, other.syncWriteIntervals, filter) || - /*WAW*/ - isIntersected(syncWriteIntervals, other.syncWriteIntervals, filter); + std::optional isIntersected(const BlockInfo &other, + MembarFilterFn filter, + Allocation *allocation) const { + auto raw = isIntersected(syncWriteIntervals, other.syncReadIntervals, + filter, allocation); + auto war = isIntersected(syncReadIntervals, other.syncWriteIntervals, + filter, allocation); + auto waw = isIntersected(syncWriteIntervals, other.syncWriteIntervals, + filter, allocation); + auto maybeJoin = + [](const std::optional &lhs, + const std::optional &rhs) -> std::optional { + if (!lhs.has_value()) + return rhs; + if (!rhs.has_value()) + return lhs; + return lhs.value().join(rhs.value()); + }; + return maybeJoin(raw, maybeJoin(war, waw)); } /// Clears the intervals because a barrier is inserted. - void sync() { - syncReadIntervals.clear(); - syncWriteIntervals.clear(); + /// If `cluster` is true, the barrier synchronizes all CTAs in the cluster and + /// we can drop every pending dependency. Otherwise only CTA-local + /// dependencies are cleared; distributed ones remain until a cluster barrier + /// is observed. + void sync(bool cluster) { + if (cluster) { + syncReadIntervals.clear(); + syncWriteIntervals.clear(); + } else { + auto eraseNotDistributed = [](auto &map) { + for (auto &[key, _] : llvm::make_early_inc_range(map)) { + if (!key.second.isDistributed()) + map.erase(key); + } + }; + eraseNotDistributed(syncReadIntervals); + eraseNotDistributed(syncWriteIntervals); + } } /// Compares two BlockInfo objects. @@ -77,18 +190,47 @@ struct BlockInfo { bool operator!=(const BlockInfo &other) const { return !(*this == other); } private: - bool isIntersected(const IntervalMapT &lhsIntervalSet, - const IntervalMapT &rhsIntervalSet, - MembarFilterFn filter) const { - for (auto &lhs : lhsIntervalSet) - for (auto &rhs : rhsIntervalSet) - if (lhs.first.intersects(rhs.first)) - for (auto lhsOp : lhs.second) - for (auto rhsOp : rhs.second) - if (!filter || !filter(lhsOp, rhsOp)) - return true; - - return false; + static bool haveSameAlloc(Operation *lhs, Operation *rhs, + Allocation *allocation); + + std::optional isIntersected(const IntervalMapT &lhsIntervalSet, + const IntervalMapT &rhsIntervalSet, + MembarFilterFn filter, + Allocation *allocation) const { + // They intersect whenever the intervals intersect. If they do, collect the + // union of CTA sets for any op pair that is not filtered out and does not + // share the exact same explicit shared value. + std::optional ret = std::nullopt; + for (const auto &[lhsKey, lhsOps] : lhsIntervalSet) { + const auto &[intervalLhs, ctasLhs] = lhsKey; + for (const auto &[rhsKey, rhsOps] : rhsIntervalSet) { + const auto &[intervalRhs, ctasRhs] = rhsKey; + if (!intervalLhs.intersects(intervalRhs)) + continue; + + auto joined = ctasLhs.join(ctasRhs); + bool skipBarrier = + llvm::all_of(lhsOps, [&, rhsOpsPtr = &rhsOps](const auto &lhsOp) { + return llvm::all_of(*rhsOpsPtr, [&](const auto &rhsOp) { + return (filter && filter(lhsOp, rhsOp)) || + (joined.isDistributed() && + haveSameAlloc(lhsOp, rhsOp, allocation)); + }); + }); + if (skipBarrier) + continue; + if (!ret.has_value()) { + ret = joined; + } else { + ret = ret->join(joined); + } + // Single CTA case, we can early exit + if (ret->size() == 1) { + return ret; + } + } + } + return ret; } }; @@ -170,7 +312,8 @@ class MembarAnalysis : public MembarOrFenceAnalysis { FuncBlockInfoMapT *funcBlockInfoMap, OpBuilder *builder) override; - void insertBarrier(Operation *operation, OpBuilder *builder); + void insertBarrier(Operation *operation, OpBuilder *builder, + const BlockInfo::CTA_UFDS &ctaClasses); }; /// Postorder traversal on the callgraph to insert membar instructions diff --git a/include/triton/Analysis/Utility.h b/include/triton/Analysis/Utility.h index 6dc310a22b65..d24e76441170 100644 --- a/include/triton/Analysis/Utility.h +++ b/include/triton/Analysis/Utility.h @@ -414,7 +414,6 @@ std::unique_ptr createDataFlowSolver(); bool isCvtWarpSync(const triton::LinearLayout &srcLayout, const triton::LinearLayout &dstLayout); - } // namespace mlir #endif // TRITON_ANALYSIS_UTILITY_H diff --git a/include/triton/Dialect/TritonGPU/Transforms/Utility.h b/include/triton/Dialect/TritonGPU/Transforms/Utility.h index d9c950dd8158..bb005b9f34b4 100644 --- a/include/triton/Dialect/TritonGPU/Transforms/Utility.h +++ b/include/triton/Dialect/TritonGPU/Transforms/Utility.h @@ -218,14 +218,6 @@ getSharedEncIfAllUsersAreDotEnc(Value val, bool &incompatible); // Skips operands if they're in shared encoding. Operation *convertDistributedOpEncoding(Attribute encoding, Operation *op); -// Returns the original memory allocation for a memdesc value -triton::gpu::LocalAllocOp findShmemAlloc(Value operand); - -// Returns MMAs inside a for loop that are multi-buffered for pipeline analysis -SmallVector -getMMAsWithMultiBufferredOperands(scf::ForOp forOp, - SmallVector &mmaOps); - // Given a list of ops, find the naerest common dominator of all ops or return // null if one could not be found. The ops are allowed to be in different // regions. The result op is not necessarily one of the ops in the list. diff --git a/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td b/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td index fde58b6bfc78..319528b4abc5 100644 --- a/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td +++ b/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td @@ -293,7 +293,8 @@ def TTNG_AsyncTMACopyGlobalToLocalOp : TTNG_Op<"async_tma_copy_global_to_local"> I1:$pred, DefaultValuedAttr:$cache, DefaultValuedAttr:$evict, - DefaultValuedAttr:$isVolatile + DefaultValuedAttr:$isVolatile, + DefaultValuedAttr:$multicast ); let assemblyFormat = [{ diff --git a/lib/Analysis/Membar.cpp b/lib/Analysis/Membar.cpp index d06f3c2e99aa..0e9a5c4e4924 100644 --- a/lib/Analysis/Membar.cpp +++ b/lib/Analysis/Membar.cpp @@ -5,10 +5,112 @@ #include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" +#include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SetVector.h" #include namespace mlir { +namespace { + +llvm::SmallDenseSet +getSharedBufferIds(Operation *op, Allocation *allocation) { + auto opEffects = dyn_cast(op); + if (!opEffects) + return {}; + + llvm::SmallDenseSet bufferIds; + SmallVector> effects; + opEffects.getEffects(effects); + for (auto &effect : effects) { + if (effect.getResource() != triton::gpu::SharedMemory::get()) + continue; + Value value = effect.getValue(); + auto memDescTy = cast(value.getType()); + // Hacky way to skip barriers + if (memDescTy.getNumElements() == 1) + continue; + for (auto bufferId : allocation->getBufferIds(value)) { + if (bufferId == Allocation::InvalidBufferId) + continue; + bufferIds.insert(bufferId); + } + } + return bufferIds; +} + +} // namespace + +static std::pair +getCTAEquivalenceSets(Operation *op) { + auto numCTAs = triton::gpu::lookupNumCTAs(op); + if (numCTAs == 1) { + return {BlockInfo::CTA_UFDS(1), BlockInfo::CTA_UFDS(1)}; + } + auto *ctx = op->getContext(); + auto kBlock = StringAttr::get(ctx, "block"); + if (isa(op)) { + auto srcTy = cast(op->getOperand(0).getType()); + auto dstTy = cast(op->getResult(0).getType()); + auto srcCTALayout = triton::gpu::getCTALayout(srcTy.getEncoding()); + auto dstCTALayout = triton::gpu::getCTALayout(dstTy.getEncoding()); + auto ctaLl = dstCTALayout.getLinearLayout().invertAndCompose( + srcCTALayout.getLinearLayout()); + auto readsUFDS = BlockInfo::CTA_UFDS(numCTAs); + for (int i = 0; i < numCTAs; i++) { + auto res = ctaLl.apply({{kBlock, i}}); + assert(res.size() == 1); + assert(res.front().first == kBlock); + readsUFDS.unite(i, res.front().second); + } + // The writes are just each writing to their own shmem + auto writesUFDS = BlockInfo::CTA_UFDS(numCTAs); + return {readsUFDS, writesUFDS}; + } else if (auto tma = + dyn_cast( + op)) { + if (tma.getMulticast()) { + auto ctaLl = + triton::gpu::getCTALayout(tma.getResult().getType().getEncoding()) + .getLinearLayout() + .flattenOuts(); + // We build a map that's the identity on the non broadcasted blocks and + // zero in the broadcasted + auto ll = triton::LinearLayout::identity1D(numCTAs, kBlock, kBlock); + auto bases = ll.getBases(); + auto &basesBlock = bases[kBlock]; + auto outDim = *ctaLl.getOutDimNames().begin(); + for (int i = 0; i < llvm::Log2_32(numCTAs); i++) { + if (ctaLl.getBasis(kBlock, i, outDim) == 0) { + basesBlock[i] = {0}; + } + } + ll = triton::LinearLayout(bases, {{kBlock, numCTAs}}, false); + auto writesUFDS = BlockInfo::CTA_UFDS(numCTAs); + for (int i = 0; i < numCTAs; i++) { + auto res = ll.apply({{kBlock, i}}); + assert(res.size() == 1); + assert(res.front().first == kBlock); + writesUFDS.unite(i, res.front().second); + } + // It's not going to be used so it's fine + auto defaultUFDS = BlockInfo::CTA_UFDS(numCTAs); + return {defaultUFDS, writesUFDS}; + } + } + return {BlockInfo::CTA_UFDS(numCTAs), BlockInfo::CTA_UFDS(numCTAs)}; +} + +bool BlockInfo::haveSameAlloc(Operation *lhs, Operation *rhs, + Allocation *allocation) { + auto lhsBuffers = getSharedBufferIds(lhs, allocation); + auto rhsBuffers = getSharedBufferIds(rhs, allocation); + return llvm::any_of( + lhsBuffers, [&](auto bufferId) { return rhsBuffers.contains(bufferId); }); +} + void MembarOrFenceAnalysis::run(FuncBlockInfoMapT &funcBlockInfoMap) { FunctionOpInterface funcOp = dyn_cast(allocation->getOperation()); @@ -157,17 +259,27 @@ void MembarOrFenceAnalysis::visitTerminator( llvm_unreachable("Unknown terminator encountered in membar analysis"); } -void MembarAnalysis::insertBarrier(Operation *op, OpBuilder *builder) { +void MembarAnalysis::insertBarrier(Operation *op, OpBuilder *builder, + const BlockInfo::CTA_UFDS &ctaClasses) { OpBuilder::InsertionGuard g(*builder); - auto barrierOp = triton::gpu::LocalBarrierOp::create(*builder, op->getLoc()); + if (ctaClasses.isDistributed()) { + // TODO Insert an membar when there is more than one CTA class to avoid + // synchronising the whole cluster + triton::nvidia_gpu::ClusterArriveOp::create(*builder, op->getLoc(), + /*relaxed=*/false); + triton::nvidia_gpu::ClusterWaitOp::create(*builder, op->getLoc()); + } else { + triton::gpu::LocalBarrierOp::create(*builder, op->getLoc()); + } } void MembarAnalysis::update(Operation *op, BlockInfo *blockInfo, FuncBlockInfoMapT *funcBlockInfoMap, OpBuilder *builder) { - if (isa(op)) { + if (isa(op)) { // If the current op is a barrier, we sync previous reads and writes - blockInfo->sync(); + blockInfo->sync(isa(op)); return; } @@ -176,11 +288,14 @@ void MembarAnalysis::update(Operation *op, BlockInfo *blockInfo, // If the current op is an async wait and the next op is not a barrier we // insert a barrier op and sync builder->setInsertionPointAfter(op); - insertBarrier(op, builder); - blockInfo->sync(); + auto nCTAs = triton::gpu::lookupNumCTAs(op); + insertBarrier(op, builder, BlockInfo::CTA_UFDS(nCTAs)); + blockInfo->sync(false); return; } + auto [readCTAs, writeCTAs] = getCTAEquivalenceSets(op); + BlockInfo curBlockInfo; auto scratchBufferId = Allocation::InvalidBufferId; if (isa(op)) { @@ -200,27 +315,23 @@ void MembarAnalysis::update(Operation *op, BlockInfo *blockInfo, if (auto value = effectInstance.getValue()) { for (auto bufferId : allocation->getBufferIds(value)) { if (bufferId != Allocation::InvalidBufferId) { + auto interval = allocation->getAllocatedInterval(bufferId); if (isa(effectInstance.getEffect())) - curBlockInfo - .syncWriteIntervals[allocation->getAllocatedInterval( - bufferId)] - .insert(op); + curBlockInfo.syncWriteIntervals[{interval, writeCTAs}].insert( + op); else if (isa(effectInstance.getEffect())) - curBlockInfo - .syncReadIntervals[allocation->getAllocatedInterval( - bufferId)] - .insert(op); + curBlockInfo.syncReadIntervals[{interval, readCTAs}].insert(op); } } } } } - // If this op is may be signalling other threads asynchronously, make sure + // If this op may be signalling other threads asynchronously, make sure // all shared memory transactions are complete beforehand. if (isa(op)) { Interval allIntervals(0, std::numeric_limits::max()); - curBlockInfo.syncWriteIntervals[allIntervals].insert(op); - curBlockInfo.syncReadIntervals[allIntervals].insert(op); + curBlockInfo.syncWriteIntervals[{allIntervals, writeCTAs}].insert(op); + curBlockInfo.syncReadIntervals[{allIntervals, readCTAs}].insert(op); } scratchBufferId = allocation->getBufferId(op); } @@ -250,21 +361,24 @@ void MembarAnalysis::update(Operation *op, BlockInfo *blockInfo, "dependencies"); } auto interval = allocation->getAllocatedInterval(scratchBufferId); - curBlockInfo.syncWriteIntervals[interval].insert(op); - auto insertCTABarrier = blockInfo->isIntersected(curBlockInfo, filter); - if (insertCTABarrier) { + curBlockInfo.syncWriteIntervals[{interval, writeCTAs}].insert(op); + auto insertCTABarrier = + blockInfo->isIntersected(curBlockInfo, filter, allocation); + if (insertCTABarrier.has_value()) { builder->setInsertionPoint(op); - insertBarrier(op, builder); + insertBarrier(op, builder, *insertCTABarrier); + blockInfo->sync(insertCTABarrier->isDistributed()); + } else if (!isWarpSync) { + // Ops with a scratch buffer that don't use warp.sync internally sync + // read/write on shared memory at the CTA level. + blockInfo->sync(false); } - // Ops with a scratch buffer that don't use warp.sync internally sync - // read/write on shared memory - if (insertCTABarrier || !isWarpSync) - blockInfo->sync(); - curBlockInfo.syncReadIntervals[interval].insert(op); - } else if (blockInfo->isIntersected(curBlockInfo, filter)) { + curBlockInfo.syncReadIntervals[{interval, readCTAs}].insert(op); + } else if (auto ctas = + blockInfo->isIntersected(curBlockInfo, filter, allocation)) { builder->setInsertionPoint(op); - insertBarrier(op, builder); - blockInfo->sync(); + insertBarrier(op, builder, *ctas); + blockInfo->sync(ctas->isDistributed()); } // Update the region info, even if barrier is inserted, we have to maintain // the current op's read/write buffers. diff --git a/lib/Analysis/Utility.cpp b/lib/Analysis/Utility.cpp index 21155285dd82..62e5cad0df99 100644 --- a/lib/Analysis/Utility.cpp +++ b/lib/Analysis/Utility.cpp @@ -131,10 +131,9 @@ unsigned ReduceOpHelper::getScratchSizeInBytes() { } bool ReduceOpHelper::isReduceWithinCTA() { - // TODO: Support reduce across CTAS - // Layout optimization passes such as PlanCTAPass and - // RemoveLayoutConversionPass should avoid cross-CTA reduction - return getCTASplitNum(srcEncoding)[axis] == 1; + // TODO: Implement. + // We allow them to be able to test them in the MembarPass + return true; } bool ReduceOpHelper::isAssociative() { @@ -1133,5 +1132,4 @@ bool isCvtWarpSync(const triton::LinearLayout &srcLayout, srcLayout.getFreeVariableMasks()[kWarp] == 0 && dstLayout.getFreeVariableMasks()[kWarp] == 0; } - } // namespace mlir diff --git a/lib/Dialect/TritonGPU/Transforms/Utility.cpp b/lib/Dialect/TritonGPU/Transforms/Utility.cpp index 9ab16bb97f1b..c341eaf98ce0 100644 --- a/lib/Dialect/TritonGPU/Transforms/Utility.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Utility.cpp @@ -1377,56 +1377,6 @@ void populateForOpDeadArgumentElimination(RewritePatternSet &patterns) { patterns.add(patterns.getContext()); } -ttg::LocalAllocOp findShmemAlloc(Value operand) { - // If it's a shmem operand, it must either be defined outside the loop, or - // come from an MemDescIndex op. Only ConvertLayout and MemdescView ops are - // allowed in between. - Value transitiveOperand = operand; - while (isa_and_nonnull( - transitiveOperand.getDefiningOp()) || - isa(transitiveOperand)) { - if (auto blockArg = dyn_cast(transitiveOperand)) { - assert(isa(blockArg.getOwner()->getParentOp()) && - "Block argument must come from a for loop"); - transitiveOperand = - cast(blockArg.getOwner()->getTerminator()) - .getOperand(blockArg.getArgNumber() - 1); - } else { - transitiveOperand = transitiveOperand.getDefiningOp()->getOperand(0); - } - } - if (auto subView = dyn_cast_or_null( - transitiveOperand.getDefiningOp())) { - // Multi-buffered operand - return dyn_cast_or_null( - subView.getSrc().getDefiningOp()); - } else { - // Single bufferred operand that does not require a subview (not loaded in - // the loop) - return dyn_cast_or_null( - transitiveOperand.getDefiningOp()); - } - return nullptr; -} - -SmallVector -getMMAsWithMultiBufferredOperands(scf::ForOp forOp, - SmallVector &mmaOps) { - // The A and B operands of the mmaOp should be multi-buffered - SmallVector eligible; - for (auto mmaOp : mmaOps) { - auto a = findShmemAlloc(mmaOp->getOperand(0)); - auto b = findShmemAlloc(mmaOp->getOperand(1)); - if (a && forOp.isDefinedOutsideOfLoop(a) && b && - forOp.isDefinedOutsideOfLoop(b)) { - eligible.push_back(mmaOp); - } - } - - return eligible; -} - template static Operation *findNearestCommonDominatorImpl( ArrayRef ops, DomInfoT &domInfo, diff --git a/lib/Dialect/TritonNvidiaGPU/Transforms/ProxFenceInsertion.cpp b/lib/Dialect/TritonNvidiaGPU/Transforms/ProxFenceInsertion.cpp index 96b9a4a2ff11..aff1a920bfe0 100644 --- a/lib/Dialect/TritonNvidiaGPU/Transforms/ProxFenceInsertion.cpp +++ b/lib/Dialect/TritonNvidiaGPU/Transforms/ProxFenceInsertion.cpp @@ -98,9 +98,14 @@ void ProxyFenceAnalysis::update(Operation *op, BlockInfo *blockInfo, OpBuilder *builder) { if (isa(op)) { // If the current op is a fence, we clear previous reads and writes - blockInfo->sync(); + blockInfo->sync(false); return; } + // Proxy fences are CTA-local; distributed shmem is not relevant here. + const unsigned numCTAs = triton::gpu::lookupNumCTAs(op); + const BlockInfo::CTA_UFDS readCTAs(numCTAs); + const BlockInfo::CTA_UFDS writeCTAs(numCTAs); + BlockInfo curBlockInfo; BlockInfo proxyBlockInfo; @@ -128,20 +133,21 @@ void ProxyFenceAnalysis::update(Operation *op, BlockInfo *blockInfo, if (isAsyncProxyWrite(op)) { if (value == getSmemDest(op)) { proxyBlockInfo - .syncWriteIntervals[allocation->getAllocatedInterval( - bufferId)] + .syncWriteIntervals[{ + allocation->getAllocatedInterval(bufferId), + writeCTAs}] .insert(op); } } else if (isa( effectInstance.getEffect())) { curBlockInfo - .syncWriteIntervals[allocation->getAllocatedInterval( - bufferId)] + .syncWriteIntervals[{ + allocation->getAllocatedInterval(bufferId), writeCTAs}] .insert(op); } else if (isa(effectInstance.getEffect())) { curBlockInfo - .syncReadIntervals[allocation->getAllocatedInterval( - bufferId)] + .syncReadIntervals[{ + allocation->getAllocatedInterval(bufferId), readCTAs}] .insert(op); } } @@ -157,13 +163,13 @@ void ProxyFenceAnalysis::update(Operation *op, BlockInfo *blockInfo, // read/write operations, mark them as a read. if (scratchBufferId != Allocation::InvalidBufferId) { auto interval = allocation->getAllocatedInterval(scratchBufferId); - curBlockInfo.syncReadIntervals[interval].insert(op); + curBlockInfo.syncReadIntervals[{interval, readCTAs}].insert(op); } if (isAsyncProxyWrite(op) || isAsyncProxyRead(op)) { - if (proxyBlockInfo.isIntersected(*blockInfo, filter)) { + if (proxyBlockInfo.isIntersected(*blockInfo, filter, allocation)) { builder->setInsertionPoint(op); insertFence(op, builder); - blockInfo->sync(); + blockInfo->sync(false); } } diff --git a/test/Analysis/test-membar-cluster.mlir b/test/Analysis/test-membar-cluster.mlir new file mode 100644 index 000000000000..dd721e4c580d --- /dev/null +++ b/test/Analysis/test-membar-cluster.mlir @@ -0,0 +1,118 @@ +// RUN: triton-opt %s -split-input-file -test-print-membar | FileCheck %s + +#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], CGALayout = [[1]]}> +#shared1 = #ttg.nvmma_shared<{swizzlingByteWidth = 0, transposed = false, elementBitWidth = 16, CGALayout = [[0, 0]]}> +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 1], order = [1, 0], CGALayout = [[0, 1]]}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 2 : i32, "ttg.num-warps" = 1 : i32} { + // Two lifetimes alias at offset 0; second is multicast TMA => cluster barrier. + // CHECK-LABEL: alias_async_then_multicast + tt.func @alias_async_then_multicast(%desc: !tt.tensordesc>) { + %pred = arith.constant true + %c0 = arith.constant 0 : i32 + %bar = ttg.local_alloc {allocation.offset = 4096 : i32} : () -> !ttg.memdesc<1xi64, #shared, #smem, mutable> + %dst0 = ttg.local_alloc {allocation.offset = 0 : i32} : () -> !ttg.memdesc<32x32xf16, #shared1, #smem, mutable> + // First lifetime: single-CTA async copy and consume token. + ttng.barrier_expect %bar, 2048, %pred : !ttg.memdesc<1xi64, #shared, #smem, mutable> + ttng.async_tma_copy_global_to_local %desc[%c0, %c0] %dst0, %bar, %pred {multicast = false} : !tt.tensordesc>, !ttg.memdesc<1xi64, #shared, #smem, mutable> -> !ttg.memdesc<32x32xf16, #shared1, #smem, mutable> + ttng.wait_barrier %bar, %c0 : !ttg.memdesc<1xi64, #shared, #smem, mutable> + // CHECK: ttg.local_load + %t0 = ttg.local_load %dst0 : !ttg.memdesc<32x32xf16, #shared1, #smem, mutable> -> tensor<32x32xf16, #blocked> + // Second lifetime aliases offset 0 and writes multicast to all CTAs. + %dst1 = ttg.local_alloc {allocation.offset = 0 : i32} : () -> !ttg.memdesc<32x32xf16, #shared1, #smem, mutable> + // CHECK: ttng.barrier_expect + // CHECK: ttng.cluster_arrive + // CHECK: ttng.cluster_wait + ttng.barrier_expect %bar, 2048, %pred : !ttg.memdesc<1xi64, #shared, #smem, mutable> + // CHECK-NEXT: ttng.async_tma_copy_global_to_local {{.*}} {multicast = true} + ttng.async_tma_copy_global_to_local %desc[%c0, %c0] %dst1, %bar, %pred {multicast = true} : !tt.tensordesc>, !ttg.memdesc<1xi64, #shared, #smem, mutable> -> !ttg.memdesc<32x32xf16, #shared1, #smem, mutable> + tt.return + } +} + +// ----- + +// Async cp then multicast alias (needs cluster barrier) +#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], CGALayout = [[1]]}> +#shared1 = #ttg.nvmma_shared<{swizzlingByteWidth = 0, transposed = false, elementBitWidth = 16, CGALayout = [[0, 1]]}> +#shared2 = #ttg.nvmma_shared<{swizzlingByteWidth = 0, transposed = false, elementBitWidth = 16, CGALayout = [[0, 0]]}> +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 1], order = [1, 0], CGALayout = [[0, 1]]}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 2 : i32, "ttg.num-warps" = 1 : i32} { + // CHECK-LABEL: async_cp_then_multicast_alias + tt.func @async_cp_then_multicast_alias(%gptr: !tt.ptr, %desc: !tt.tensordesc>, %desc2: !tt.tensordesc>) { + %pred = arith.constant true + %c0 = arith.constant 0 : i32 + %gptr_tensor = tt.splat %gptr : !tt.ptr -> tensor<32x32x!tt.ptr, #blocked> + %bar = ttg.local_alloc {allocation.offset = 4096 : i32} : () -> !ttg.memdesc<1xi64, #shared, #smem, mutable> + %a = ttg.local_alloc {allocation.offset = 0 : i32} : () -> !ttg.memdesc<32x32xf16, #shared1, #smem, mutable> + %cp = ttg.async_copy_global_to_local %gptr_tensor, %a : tensor<32x32x!tt.ptr, #blocked> -> !ttg.memdesc<32x32xf16, #shared1, #smem, mutable> + %tok = ttg.async_commit_group tokens %cp + %tok2 = ttg.async_wait %tok {num = 0 : i32} + // CHECK: ttg.local_load + %ld = ttg.local_load %a token %tok2 : !ttg.memdesc<32x32xf16, #shared1, #smem, mutable> -> tensor<32x32xf16, #blocked> + %b = ttg.local_alloc {allocation.offset = 0 : i32} : () -> !ttg.memdesc<32x32xf16, #shared2, #smem, mutable> + // CHECK: ttng.barrier_expect + // CHECK: ttng.cluster_arrive + // CHECK: ttng.cluster_wait + ttng.barrier_expect %bar, 2048, %pred : !ttg.memdesc<1xi64, #shared, #smem, mutable> + // CHECK-NEXT: ttng.async_tma_copy_global_to_local {{.*}} {multicast = true} + ttng.async_tma_copy_global_to_local %desc2[%c0, %c0] %b, %bar, %pred {multicast = true} : !tt.tensordesc>, !ttg.memdesc<1xi64, #shared, #smem, mutable> -> !ttg.memdesc<32x32xf16, #shared2, #smem, mutable> + ttng.wait_barrier %bar, %c0 : !ttg.memdesc<1xi64, #shared, #smem, mutable> + tt.return + } +} + +// ----- + +// Async cp + reinterpet & multicast on same allocation (no cluster barrier) +#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], CGALayout = [[1]]}> +#shared1 = #ttg.nvmma_shared<{swizzlingByteWidth = 0, transposed = false, elementBitWidth = 16, CGALayout = [[0, 1]]}> +#shared2 = #ttg.nvmma_shared<{swizzlingByteWidth = 0, transposed = false, elementBitWidth = 16, CGALayout = [[0, 0]]}> +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 1], order = [1, 0], CGALayout = [[0, 1]]}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 2 : i32, "ttg.num-warps" = 1 : i32} { + // CHECK-LABEL: async_cp_and_multicast_same_alloc + tt.func @async_cp_and_multicast_same_alloc(%gptr: !tt.ptr, %desc: !tt.tensordesc>) { + %pred = arith.constant true + %c0 = arith.constant 0 : i32 + %gptr_tensor = tt.splat %gptr : !tt.ptr -> tensor<32x32x!tt.ptr, #blocked> + %bar = ttg.local_alloc {allocation.offset = 4096 : i32} : () -> !ttg.memdesc<1xi64, #shared, #smem, mutable> + %buf = ttg.local_alloc {allocation.offset = 0 : i32} : () -> !ttg.memdesc<32x32xf16, #shared1, #smem, mutable> + %cp = ttg.async_copy_global_to_local %gptr_tensor, %buf : tensor<32x32x!tt.ptr, #blocked> -> !ttg.memdesc<32x32xf16, #shared1, #smem, mutable> + %tok = ttg.async_commit_group tokens %cp + %tok2 = ttg.async_wait %tok {num = 0 : i32} + %ld = ttg.local_load %buf token %tok2 : !ttg.memdesc<32x32xf16, #shared1, #smem, mutable> -> tensor<32x32xf16, #blocked> + %buf_reint = ttg.memdesc_reinterpret %buf : !ttg.memdesc<32x32xf16, #shared1, #smem, mutable> -> !ttg.memdesc<32x32xf16, #shared2, #smem, mutable> + + ttng.barrier_expect %bar, 2048, %pred : !ttg.memdesc<1xi64, #shared, #smem, mutable> + // CHECK: ttng.barrier_expect + // CHECK-NOT: ttng.cluster_arrive + // CHECK-NOT: ttng.cluster_wait + // CHECK: ttng.async_tma_copy_global_to_local {{.*}} {multicast = true} + ttng.async_tma_copy_global_to_local %desc[%c0, %c0] %buf_reint, %bar, %pred {multicast = true} : !tt.tensordesc>, !ttg.memdesc<1xi64, #shared, #smem, mutable> -> !ttg.memdesc<32x32xf16, #shared2, #smem, mutable> + ttng.wait_barrier %bar, %c0 : !ttg.memdesc<1xi64, #shared, #smem, mutable> + tt.return + } +} + +// ----- + +// Distributed convert alias (needs cluster barrier) +#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 1], order = [1, 0], CGALayout = [[1, 0], [0, 1]]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 1], order = [1, 0], CGALayout = [[0, 1], [1, 0]]}> + +module attributes {"ttg.num-ctas" = 4 : i32, "ttg.num-warps" = 1 : i32} { + // CHECK-LABEL: convert_alias_same_offset + tt.func @convert_alias_same_offset() -> (tensor<2x2xf16, #blocked2>, tensor<2x2xf16, #blocked2>) { + %c0 = arith.constant 0.000000e+00 : f16 + %src = tt.splat %c0 : f16 -> tensor<2x2xf16, #blocked1> + // CHECK: ttg.convert_layout + %cvt0 = ttg.convert_layout %src {allocation.offset = 0 : i32} : tensor<2x2xf16, #blocked1> -> tensor<2x2xf16, #blocked2> + // CHECK: ttng.cluster_arrive + // CHECK: ttng.cluster_wait + // CHECK: ttg.convert_layout + %cvt1 = ttg.convert_layout %src {allocation.offset = 0 : i32} : tensor<2x2xf16, #blocked1> -> tensor<2x2xf16, #blocked2> + tt.return %cvt0, %cvt1 : tensor<2x2xf16, #blocked2>, tensor<2x2xf16, #blocked2> + } +} diff --git a/test/lib/Analysis/TestMembar.cpp b/test/lib/Analysis/TestMembar.cpp index e167722a6a73..057e21d3e77f 100644 --- a/test/lib/Analysis/TestMembar.cpp +++ b/test/lib/Analysis/TestMembar.cpp @@ -3,6 +3,8 @@ #include "mlir/Transforms/DialectConversion.h" #include "triton/Analysis/Allocation.h" #include "triton/Analysis/Membar.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" using namespace mlir; @@ -18,6 +20,11 @@ struct TestMembarPass return "print the result of the allocation pass"; } + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + void runOnOperation() override { Operation *operation = getOperation(); ModuleOp moduleOp = cast(operation);