@@ -99,45 +99,6 @@ class Allocation {
9999 using BufferIdSetT = DenseSet<BufferId>;
100100 using FuncAllocMapT = CallGraph<Allocation>::FuncDataMapT;
101101
102- // / A class that represents a shared memory buffer
103- struct BufferT {
104- // / Explicit: triton_gpu.local_alloc
105- // / Scratch: triton_gpu.convert_layout
106- // / Virtual: triton.call
107- enum class BufferKind { Explicit, Scratch, Virtual };
108-
109- // / MT: thread-safe
110- inline static std::atomic<BufferId> nextId = 0 ;
111-
112- BufferKind kind;
113- BufferId id;
114- size_t size;
115- size_t alignment;
116- size_t offset;
117-
118- bool operator ==(const BufferT &other) const { return id == other.id ; }
119- bool operator <(const BufferT &other) const { return id < other.id ; }
120-
121- BufferT () : BufferT(BufferKind::Explicit, 0 ) {}
122- BufferT (BufferKind kind, size_t size, size_t alignment = 4 ,
123- size_t offset = 0 )
124- : kind(kind), id(nextId++), size(size), alignment(alignment),
125- offset (offset) {}
126-
127- size_t setOffsetAligned (size_t newOffset) {
128- return offset = llvm::alignTo (newOffset, alignment);
129- }
130- };
131-
132- // / Op -> Scratch Buffer
133- using OpScratchMapT = DenseMap<Operation *, BufferT *>;
134- // / Value -> Explicit Buffer
135- using ValueBufferMapT = llvm::MapVector<Value, BufferT *>;
136- // / Value -> Alias Buffer
137- using AliasBufferMapT = llvm::MapVector<Value, llvm::SetVector<BufferT *>>;
138- // / BufferId -> Buffer
139- using BufferSetT = std::map<BufferId, BufferT>;
140-
141102 static constexpr BufferId InvalidBufferId =
142103 std::numeric_limits<BufferId>::max();
143104
@@ -153,12 +114,6 @@ class Allocation {
153114 // / Returns the operation this analysis was constructed from.
154115 Operation *getOperation () const { return operation; }
155116
156- const OpScratchMapT &getOpScratch () const { return opScratch; }
157- const OpScratchMapT &getOpVirtual () const { return opVirtual; }
158- const ValueBufferMapT &getValueBuffer () const { return valueBuffer; }
159- const AliasBufferMapT &getAliasBuffer () const { return aliasBuffer; }
160- void setSharedMemorySize (size_t size) { sharedMemorySize = size; }
161-
162117 // / Returns the offset of the given buffer in the shared memory.
163118 size_t getOffset (BufferId bufferId) const {
164119 return bufferSet.at (bufferId).offset ;
@@ -222,6 +177,47 @@ class Allocation {
222177 // / Returns mapping from operation to list of live LDS buffers
223178 std::map<Operation *, SmallVector<BufferId>> getLiveBuffers ();
224179
180+ private:
181+ // / A class that represents a shared memory buffer
182+ struct BufferT {
183+ // / Explicit: triton_gpu.local_alloc
184+ // / Scratch: triton_gpu.convert_layout
185+ // / Virtual: triton.call
186+ enum class BufferKind { Explicit, Scratch, Virtual };
187+
188+ // / MT: thread-safe
189+ inline static std::atomic<BufferId> nextId = 0 ;
190+
191+ BufferKind kind;
192+ BufferId id;
193+ size_t size;
194+ size_t alignment;
195+ size_t offset;
196+
197+ bool operator ==(const BufferT &other) const { return id == other.id ; }
198+ bool operator <(const BufferT &other) const { return id < other.id ; }
199+
200+ BufferT () : BufferT(BufferKind::Explicit, 0 ) {}
201+ BufferT (BufferKind kind, size_t size, size_t alignment = 4 ,
202+ size_t offset = 0 )
203+ : kind(kind), id(nextId++), size(size), alignment(alignment),
204+ offset (offset) {}
205+
206+ size_t setOffsetAligned (size_t newOffset) {
207+ return offset = llvm::alignTo (newOffset, alignment);
208+ }
209+ };
210+
211+ // / Op -> Scratch Buffer
212+ using OpScratchMapT = DenseMap<Operation *, BufferT *>;
213+ // / Value -> Explicit Buffer
214+ using ValueBufferMapT = llvm::MapVector<Value, BufferT *>;
215+ // / Value -> Alias Buffer
216+ using AliasBufferMapT = llvm::MapVector<Value, llvm::SetVector<BufferT *>>;
217+ // / BufferId -> Buffer
218+ using BufferSetT = std::map<BufferId, BufferT>;
219+
220+ private:
225221 template <BufferT::BufferKind Kind, typename KeyType, typename ... Args>
226222 void addBuffer (KeyType &key, Args &&...args) {
227223 auto buffer = BufferT (Kind, std::forward<Args>(args)...);
@@ -247,6 +243,8 @@ class Allocation {
247243 AliasBufferMapT aliasBuffer;
248244 BufferSetT bufferSet;
249245 size_t sharedMemorySize = 0 ;
246+
247+ friend class triton ::AllocationAnalysis;
250248};
251249
252250// / Static analysis that computes the allocation of shared memory buffers
0 commit comments