33#include " triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h"
44#include " triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
55
6- #include " mlir/Analysis/SliceAnalysis.h"
76#include " mlir/Dialect/GPU/IR/GPUDialect.h"
87#include " mlir/Interfaces/ControlFlowInterfaces.h"
8+ #include " mlir/Interfaces/SideEffectInterfaces.h"
9+ #include " llvm/ADT/DenseSet.h"
910#include " llvm/ADT/STLExtras.h"
1011#include " llvm/ADT/SetVector.h"
11- #include " llvm/ADT/SmallPtrSet.h"
1212#include < deque>
1313
1414namespace mlir {
1515
16- static SmallPtrSet<Operation *, 2 > parentAllocs (Operation *op) {
17- SmallPtrSet<Operation *, 2 > owners;
16+ namespace {
17+
18+ llvm::SmallDenseSet<Allocation::BufferId, 2 >
19+ getSharedBufferIds (Operation *op, Allocation *allocation) {
1820 auto opEffects = dyn_cast<MemoryEffectOpInterface>(op);
1921 if (!opEffects)
20- return owners ;
22+ return {} ;
2123
24+ llvm::SmallDenseSet<Allocation::BufferId, 2 > bufferIds;
2225 SmallVector<SideEffects::EffectInstance<MemoryEffects::Effect>> effects;
2326 opEffects.getEffects (effects);
2427 for (auto &effect : effects) {
2528 if (effect.getResource () != triton::gpu::SharedMemory::get ())
2629 continue ;
27-
2830 Value value = effect.getValue ();
29- // Hacky way to skip barriers...
30- if (cast<triton::gpu::MemDescType>(value.getType ()).getNumElements () == 1 )
31+ auto memDescTy = cast<triton::gpu::MemDescType>(value.getType ());
32+ // Hacky way to skip barriers
33+ if (memDescTy.getNumElements () == 1 )
3134 continue ;
32-
33- // Get a slice of all the operations that touch this shared memory
34- // (subslice/index/memdesc views) all the way up to the local_alloc.
35- BackwardSliceOptions options;
36- options.omitUsesFromAbove = false ;
37- options.inclusive = true ;
38- auto isSharedMemDesc = [](Type ty) {
39- auto memDescTy = dyn_cast<triton::gpu::MemDescType>(ty);
40- if (!memDescTy)
41- return false ;
42- return isa<triton::gpu::SharedMemorySpaceAttr>(
43- memDescTy.getMemorySpace ());
44- };
45- options.filter = [&](Operation *depOp) -> bool {
46- return llvm::any_of (depOp->getOperandTypes (), isSharedMemDesc) ||
47- // Add ops that have Memdesc in the result types to pick
48- // local_alloc ops as well
49- llvm::any_of (depOp->getResultTypes (), isSharedMemDesc);
50- };
51- llvm::SetVector<Operation *> slice;
52- LogicalResult result = getBackwardSlice (value, &slice, options);
53- assert (succeeded (result) && " backward slice must succeed" );
54-
55- for (Operation *depOp : slice) {
56- if (auto alloc = dyn_cast<triton::gpu::LocalAllocOp>(depOp))
57- owners.insert (alloc.getOperation ());
35+ for (auto bufferId : allocation->getBufferIds (value)) {
36+ if (bufferId == Allocation::InvalidBufferId)
37+ continue ;
38+ bufferIds.insert (bufferId);
5839 }
5940 }
60- return owners ;
41+ return bufferIds ;
6142}
6243
44+ } // namespace
45+
6346static std::pair<BlockInfo::CTA_UFDS, BlockInfo::CTA_UFDS>
6447getCTAEquivalenceSets (Operation *op) {
6548 auto numCTAs = triton::gpu::lookupNumCTAs (op);
@@ -68,46 +51,23 @@ getCTAEquivalenceSets(Operation *op) {
6851 }
6952 auto *ctx = op->getContext ();
7053 auto kBlock = StringAttr::get (ctx, " block" );
71- if (auto cvt = dyn_cast<triton::gpu::ConvertLayoutOp>(op)) {
72- auto srcTy = cast<RankedTensorType>(cvt.getSrc ().getType ());
73- auto dstTy = cast<RankedTensorType>(cvt.getType ());
74- auto cvtLayout = minimalCvtLayout (srcTy, dstTy);
75- if (llvm::is_contained (cvtLayout.getInDimNames (), kBlock )) {
76- auto readsUFDS = BlockInfo::CTA_UFDS (numCTAs);
77- auto blockLayout =
78- cvtLayout.sublayout ({kBlock }, to_vector (cvtLayout.getOutDimNames ()));
79- for (int i = 0 ; i < numCTAs; i++) {
80- auto res = blockLayout.apply ({{kBlock , i}});
81- assert (res.size () == 4 );
82- assert (res.back ().first == kBlock );
83- readsUFDS.unite (i, res.back ().second );
84- }
85- // The writes are just each writing to their own shmem
86- auto writesUFDS = BlockInfo::CTA_UFDS (numCTAs);
87- return {readsUFDS, writesUFDS};
88- }
89- } else if (auto reduce = dyn_cast<triton::ReduceOp>(op)) {
90- auto srcTy = cast<RankedTensorType>(reduce.getInputTypes ()[0 ]);
91- auto inCTALayout = triton::gpu::getCTALayout (srcTy.getEncoding ());
92- auto axis = reduce.getAxis ();
93- auto ll = inCTALayout.getLinearLayout ();
94- auto outdims = to_vector (ll.getOutDimNames ());
95- if (ll.getOutDimSize (outdims[axis]) != 1 ) {
96- auto outCTALayout = triton::gpu::getCTALayout (
97- cast<RankedTensorType>(reduce.getType (0 )).getEncoding ());
98- // Maps the reads necessary in the reduction
99- auto ctaLl = outCTALayout.getLinearLayout ().invertAndCompose (ll);
100- auto readsUFDS = BlockInfo::CTA_UFDS (numCTAs);
101- for (int i = 0 ; i < numCTAs; i++) {
102- auto res = ctaLl.apply ({{kBlock , i}});
103- assert (res.size () == 1 );
104- assert (res.front ().first == kBlock );
105- readsUFDS.unite (i, res.front ().second );
106- }
107- // The writes are just each writing to their own shmem
108- auto writesUFDS = BlockInfo::CTA_UFDS (numCTAs);
109- return {readsUFDS, writesUFDS};
54+ if (isa<triton::gpu::ConvertLayoutOp, triton::ReduceOp>(op)) {
55+ auto srcTy = cast<RankedTensorType>(op->getOperand (0 ).getType ());
56+ auto dstTy = cast<RankedTensorType>(op->getResult (0 ).getType ());
57+ auto srcCTALayout = triton::gpu::getCTALayout (srcTy.getEncoding ());
58+ auto dstCTALayout = triton::gpu::getCTALayout (dstTy.getEncoding ());
59+ auto ctaLl = dstCTALayout.getLinearLayout ().invertAndCompose (
60+ srcCTALayout.getLinearLayout ());
61+ auto readsUFDS = BlockInfo::CTA_UFDS (numCTAs);
62+ for (int i = 0 ; i < numCTAs; i++) {
63+ auto res = ctaLl.apply ({{kBlock , i}});
64+ assert (res.size () == 1 );
65+ assert (res.front ().first == kBlock );
66+ readsUFDS.unite (i, res.front ().second );
11067 }
68+ // The writes are just each writing to their own shmem
69+ auto writesUFDS = BlockInfo::CTA_UFDS (numCTAs);
70+ return {readsUFDS, writesUFDS};
11171 } else if (auto tma =
11272 dyn_cast<triton::nvidia_gpu::AsyncTMACopyGlobalToLocalOp>(
11373 op)) {
@@ -143,15 +103,12 @@ getCTAEquivalenceSets(Operation *op) {
143103 return {BlockInfo::CTA_UFDS (numCTAs), BlockInfo::CTA_UFDS (numCTAs)};
144104}
145105
146- bool BlockInfo::haveSameAlloc (Operation *lhs, Operation *rhs) {
147- auto lhsAllocs = parentAllocs (lhs);
148- auto rhsAllocs = parentAllocs (rhs);
149- // They can be empty when the buffer is internal, e.g. a convert_layout.
150- if (lhsAllocs.empty () || rhsAllocs.empty ())
151- return false ;
152-
106+ bool BlockInfo::haveSameAlloc (Operation *lhs, Operation *rhs,
107+ Allocation *allocation) {
108+ auto lhsBuffers = getSharedBufferIds (lhs, allocation);
109+ auto rhsBuffers = getSharedBufferIds (rhs, allocation);
153110 return llvm::any_of (
154- lhsAllocs , [&](Operation *alloc ) { return rhsAllocs .contains (alloc ); });
111+ lhsBuffers , [&](auto bufferId ) { return rhsBuffers .contains (bufferId ); });
155112}
156113
157114void MembarOrFenceAnalysis::run (FuncBlockInfoMapT &funcBlockInfoMap) {
@@ -405,7 +362,8 @@ void MembarAnalysis::update(Operation *op, BlockInfo *blockInfo,
405362 }
406363 auto interval = allocation->getAllocatedInterval (scratchBufferId);
407364 curBlockInfo.syncWriteIntervals [{interval, writeCTAs}].insert (op);
408- auto insertCTABarrier = blockInfo->isIntersected (curBlockInfo, filter);
365+ auto insertCTABarrier =
366+ blockInfo->isIntersected (curBlockInfo, filter, allocation);
409367 if (insertCTABarrier.has_value ()) {
410368 builder->setInsertionPoint (op);
411369 insertBarrier (op, builder, *insertCTABarrier);
@@ -416,7 +374,8 @@ void MembarAnalysis::update(Operation *op, BlockInfo *blockInfo,
416374 blockInfo->sync (false );
417375 }
418376 curBlockInfo.syncReadIntervals [{interval, readCTAs}].insert (op);
419- } else if (auto ctas = blockInfo->isIntersected (curBlockInfo, filter)) {
377+ } else if (auto ctas =
378+ blockInfo->isIntersected (curBlockInfo, filter, allocation)) {
420379 builder->setInsertionPoint (op);
421380 insertBarrier (op, builder, *ctas);
422381 blockInfo->sync (ctas->isDistributed ());
0 commit comments