@@ -101,35 +101,20 @@ struct AdvanceBasePointer : public OpRewritePattern<scf::ForOp> {
101101 return Value ();
102102 }
103103
104- // load is a target buffer load
105- // offsetIncrement is a tensor added to offsets on each iteration
106- // baseIncrement is a scalar which will be added to base pointer after
107- // optimization offsetInitialized is a value of offset on first loop iteration
108- // incrementOp is an operation that advances offset tensor
109- struct LoadData {
110- Operation *load;
104+ // Description of struct fields:
105+ // - op is a target BufferOp, for example BufferLoadOp or BufferLoadToLocalOp
106+ // - offsetIncrement is a tensor added to offsets on each iteration
107+ // - baseIncrement is a scalar which will be added to base pointer
108+ // - offsetInitializer is a value of offset on first loop iteration
109+ // - incrementOp is an operation that advances offset tensor
110+ struct BufferOpInfo {
111+ amdttg::BufferOpAddressinInterface op;
111112 Value offsetIncrement;
112113 Value baseIncrement;
113114 Value offsetInitializer;
114115 Operation *incrementOp;
115116 };
116117
117- static Value getOffset (Operation *load) {
118- if (auto specific = dyn_cast<amdttg::BufferLoadOp>(load))
119- return specific.getOffsets ();
120- if (auto specific = dyn_cast<amdttg::BufferLoadToLocalOp>(load))
121- return specific.getOffsets ();
122- assert (false && " unsupported operation type" );
123- }
124-
125- static Value getBasePtr (Operation *load) {
126- if (auto specific = dyn_cast<amdttg::BufferLoadOp>(load))
127- return specific.getPtr ();
128- if (auto specific = dyn_cast<amdttg::BufferLoadToLocalOp>(load))
129- return specific.getPtr ();
130- assert (false && " unsupported operation type" );
131- }
132-
133118 static void setOffset (Operation *load, Value newOffset) {
134119 assert ((isa<amdttg::BufferLoadOp, amdttg::BufferLoadToLocalOp>(load)));
135120 const int offsetIdx = isa<amdttg::BufferLoadOp>(load) ? 1 : 2 ;
@@ -143,11 +128,11 @@ struct AdvanceBasePointer : public OpRewritePattern<scf::ForOp> {
143128 }
144129
145130 // Perform series of checks to decide if given operation could be optimized.
146- // If optimization is possible, return filled LoadData
147- static std::optional<LoadData> analyzeLoad (Operation *loadOp,
148- scf::ForOp targetFor) {
149- LDBG (" Analyzing: " << *loadOp );
150- Value maybeOffsetsBlockArg = getOffset (loadOp );
131+ // If optimization is possible, return filled BufferOpInfo
132+ static std::optional<BufferOpInfo>
133+ analyzeBufferOp (amdttg::BufferOpAddressinInterface op, scf::ForOp targetFor) {
134+ LDBG (" Analyzing: " << *op );
135+ Value maybeOffsetsBlockArg = op. getOffsets ( );
151136 auto maybeOffsetDefOp = maybeOffsetsBlockArg.getDefiningOp ();
152137 if (maybeOffsetDefOp && isa<arith::AddIOp>(maybeOffsetDefOp)) {
153138 for (auto &use : maybeOffsetDefOp->getUses ()) {
@@ -174,7 +159,7 @@ struct AdvanceBasePointer : public OpRewritePattern<scf::ForOp> {
174159 LDBG (" Rejected: expect load offset to be a target loop argument" );
175160 return {};
176161 }
177- auto basePtr = getBasePtr (loadOp );
162+ auto basePtr = op. getPtr ( );
178163 auto defOpBlock = basePtr.getParentBlock ();
179164 if (!defOpBlock->getParentOp ()->isProperAncestor (targetFor)) {
180165 LDBG (" Rejected: expect load base Ptr to be invariant to the loop" );
@@ -207,33 +192,33 @@ struct AdvanceBasePointer : public OpRewritePattern<scf::ForOp> {
207192 return {};
208193 }
209194 Value offsetInitializer = forOp.getInitArgs ()[offsetOperandNo];
210- LoadData data = {loadOp , advanceStep, Value (), offsetInitializer,
211- incrementOp};
195+ BufferOpInfo data = {op , advanceStep, Value (), offsetInitializer,
196+ incrementOp};
212197 LDBG (" Load is suitable for offset pointer optimization" );
213198 return data;
214199 }
215200
216201 // Create scalar values which will increment load base ptr
217- // Fills appropriate fields in given LoadData structures
202+ // Fills appropriate fields in given BufferOpInfo structures
218203 static void createScalarIncrements (PatternRewriter &rewriter,
219- SmallVector<LoadData > &loads) {
220- for (auto &loadData : loads) {
221- auto scalarStep = scalarizeValue (rewriter, loadData .offsetIncrement );
222- loadData .baseIncrement = scalarStep;
204+ SmallVector<BufferOpInfo > &loads) {
205+ for (auto &BufferOpInfo : loads) {
206+ auto scalarStep = scalarizeValue (rewriter, BufferOpInfo .offsetIncrement );
207+ BufferOpInfo .baseIncrement = scalarStep;
223208 }
224209 }
225210
226- static bool isAddFirst (LoadData &ld ) {
227- return getOffset (ld. load ).getDefiningOp () == ld .incrementOp ;
211+ static bool isAddFirst (BufferOpInfo &info ) {
212+ return info. op . getOffsets ( ).getDefiningOp () == info .incrementOp ;
228213 }
229214
230215 static scf::ForOp
231216 cloneLoopWithBasePtrIncrements (PatternRewriter &rewriter, scf::ForOp forOp,
232- SmallVector<LoadData > &loads ) {
217+ SmallVector<BufferOpInfo > &infoList ) {
233218 // Create new loop with additional arguments
234219 llvm::SmallVector<Value> newLoopArgs (forOp.getInitArgs ());
235- for (auto loadData : loads ) {
236- newLoopArgs.push_back (getBasePtr (loadData. load ));
220+ for (auto info : infoList ) {
221+ newLoopArgs.push_back (info. op . getPtr ( ));
237222 }
238223 rewriter.setInsertionPoint (forOp);
239224 auto newForOp = rewriter.create <scf::ForOp>(
@@ -251,19 +236,19 @@ struct AdvanceBasePointer : public OpRewritePattern<scf::ForOp> {
251236 rewriter.clone (op, mapping);
252237 }
253238 // Create base pointer increment operations
254- auto basePtrs = newBlock->getArguments ().take_back (loads .size ());
239+ auto basePtrs = newBlock->getArguments ().take_back (infoList .size ());
255240 llvm::SmallVector<Value> nextIterBasePtrs;
256- for (auto [loadData , basePtr] : llvm::zip (loads , basePtrs)) {
257- if (isAddFirst (loadData )) {
241+ for (auto [info , basePtr] : llvm::zip (infoList , basePtrs)) {
242+ if (isAddFirst (info )) {
258243 rewriter.setInsertionPoint (newBlock, newBlock->begin ());
259244 } else {
260245 rewriter.setInsertionPoint (newBlock, newBlock->end ());
261246 }
262- Value step = loadData .baseIncrement ;
263- if (mapping.contains (loadData .baseIncrement )) {
264- step = mapping.lookup (loadData .baseIncrement );
247+ Value step = info .baseIncrement ;
248+ if (mapping.contains (info .baseIncrement )) {
249+ step = mapping.lookup (info .baseIncrement );
265250 }
266- auto loc = loadData .incrementOp ->getLoc ();
251+ auto loc = info .incrementOp ->getLoc ();
267252 auto ptrType = basePtr.getType ();
268253 auto nextIterBasePtr =
269254 rewriter.create <triton::AddPtrOp>(loc, ptrType, basePtr, step);
@@ -282,46 +267,46 @@ struct AdvanceBasePointer : public OpRewritePattern<scf::ForOp> {
282267 newYieldOperands);
283268 // Replace dynamic load offsets with invariant value
284269 // Replace base ptr with incrementing value
285- for (auto [loadData , basePtr, nextBasePtr] :
286- llvm::zip (loads , basePtrs, nextIterBasePtrs)) {
287- auto newLoad = mapping.lookup <Operation *>(loadData. load );
288- setOffset (newLoad, loadData .offsetInitializer );
270+ for (auto [info , basePtr, nextBasePtr] :
271+ llvm::zip (infoList , basePtrs, nextIterBasePtrs)) {
272+ auto newLoad = mapping.lookup <Operation *>(info. op . getOperation () );
273+ setOffset (newLoad, info .offsetInitializer );
289274 // two cases:
290275 // 1. first advance pointer, then load
291276 // 2. load uses pointers from loop arguments, advanced pointer used on
292277 // next iteration
293- Value advancingBasePtr = isAddFirst (loadData ) ? nextBasePtr : basePtr;
278+ Value advancingBasePtr = isAddFirst (info ) ? nextBasePtr : basePtr;
294279 setBasePtr (newLoad, advancingBasePtr);
295280 }
296281 return newForOp;
297282 }
298283
299- template < typename OpType>
300- static void collectLoads ( SmallVector<LoadData> &loads, scf::ForOp forOp) {
301- forOp.walk ([&loads , forOp](OpType loadOp ) {
302- auto loadData = analyzeLoad (loadOp , forOp);
303- if (loadData .has_value ()) {
304- loads .push_back (loadData .value ());
284+ static SmallVector<BufferOpInfo> collectBufferOps (scf::ForOp forOp) {
285+ SmallVector<BufferOpInfo> list;
286+ forOp.walk ([&list , forOp](amdttg::BufferOpAddressinInterface op ) {
287+ auto info = analyzeBufferOp (op , forOp);
288+ if (info .has_value ()) {
289+ list .push_back (info .value ());
305290 }
306291 });
292+ return list;
307293 }
308294
309295 LogicalResult matchAndRewrite (scf::ForOp forOp,
310296 PatternRewriter &rewriter) const override {
311- LDBG (" Analyzing ForOp for for offset pointer optimization: " << forOp);
312- // Gather buffer loads which could be optimized
313- SmallVector<LoadData> loads;
314- collectLoads<triton::amdgpu::BufferLoadOp>(loads, forOp);
315- collectLoads<triton::amdgpu::BufferLoadToLocalOp>(loads, forOp);
297+ LDBG (" Analyzing ForOp for offset pointer optimization: " << forOp);
298+ // Gather buffer buffer operations which could be optimized
299+ SmallVector<BufferOpInfo> infoList = collectBufferOps (forOp);
316300
317- if (loads.empty ())
318- return rewriter.notifyMatchFailure (forOp, " no suitable buffer loads" );
301+ if (infoList.empty ())
302+ return rewriter.notifyMatchFailure (forOp,
303+ " no suitable buffer operations" );
319304
320305 // Perform IR transformation
321- createScalarIncrements (rewriter, loads );
322- auto newForOp = cloneLoopWithBasePtrIncrements (rewriter, forOp, loads );
323- rewriter.replaceAllUsesWith (forOp. getResults (),
324- newForOp.getResults ().drop_back (loads .size ()));
306+ createScalarIncrements (rewriter, infoList );
307+ auto newForOp = cloneLoopWithBasePtrIncrements (rewriter, forOp, infoList );
308+ rewriter.replaceAllUsesWith (
309+ forOp. getResults (), newForOp.getResults ().drop_back (infoList .size ()));
325310 rewriter.eraseOp (forOp);
326311 return success ();
327312 }
0 commit comments