@@ -36,7 +36,7 @@ namespace {
3636// /*-----------------Base pointer increment optimization-------------------*/
3737
3838// Optimization tries to transfer increments from offsets to base pointer in
39- // buffer loads :
39+ // buffer operations :
4040//
4141// for ... (offsets = offsets_init):
4242// val = buffer_load basePtr [ offsets ]
@@ -101,53 +101,26 @@ struct AdvanceBasePointer : public OpRewritePattern<scf::ForOp> {
101101 return Value ();
102102 }
103103
104- // load is a target buffer load
105- // offsetIncrement is a tensor added to offsets on each iteration
106- // baseIncrement is a scalar which will be added to base pointer after
107- // optimization offsetInitialized is a value of offset on first loop iteration
108- // incrementOp is an operation that advances offset tensor
109- struct LoadData {
110- Operation *load;
104+ // Description of struct fields:
105+ // - op is a target BufferOp, for example BufferLoadOp or BufferLoadToLocalOp
106+ // - offsetIncrement is a tensor added to offsets on each iteration
107+ // - baseIncrement is a scalar which will be added to base pointer
108+ // - offsetInitializer is a value of offset on first loop iteration
109+ // - incrementOp is an operation that advances offset tensor
110+ struct BufferOpInfo {
111+ amdttg::BufferOpAddressinInterface op;
111112 Value offsetIncrement;
112113 Value baseIncrement;
113114 Value offsetInitializer;
114115 Operation *incrementOp;
115116 };
116117
117- static Value getOffset (Operation *load) {
118- if (auto specific = dyn_cast<amdttg::BufferLoadOp>(load))
119- return specific.getOffsets ();
120- if (auto specific = dyn_cast<amdttg::BufferLoadToLocalOp>(load))
121- return specific.getOffsets ();
122- assert (false && " unsupported operation type" );
123- }
124-
125- static Value getBasePtr (Operation *load) {
126- if (auto specific = dyn_cast<amdttg::BufferLoadOp>(load))
127- return specific.getPtr ();
128- if (auto specific = dyn_cast<amdttg::BufferLoadToLocalOp>(load))
129- return specific.getPtr ();
130- assert (false && " unsupported operation type" );
131- }
132-
133- static void setOffset (Operation *load, Value newOffset) {
134- assert ((isa<amdttg::BufferLoadOp, amdttg::BufferLoadToLocalOp>(load)));
135- const int offsetIdx = isa<amdttg::BufferLoadOp>(load) ? 1 : 2 ;
136- load->setOperand (offsetIdx, newOffset);
137- }
138-
139- static void setBasePtr (Operation *load, Value newBasePtr) {
140- assert ((isa<amdttg::BufferLoadOp, amdttg::BufferLoadToLocalOp>(load)));
141- const int ptrIdx = isa<amdttg::BufferLoadOp>(load) ? 0 : 1 ;
142- load->setOperand (ptrIdx, newBasePtr);
143- }
144-
145118 // Perform series of checks to decide if given operation could be optimized.
146- // If optimization is possible, return filled LoadData
147- static std::optional<LoadData> analyzeLoad (Operation *loadOp,
148- scf::ForOp targetFor) {
149- LDBG (" Analyzing: " << *loadOp );
150- Value maybeOffsetsBlockArg = getOffset (loadOp );
119+ // If optimization is possible, return filled BufferOpInfo
120+ static std::optional<BufferOpInfo>
121+ analyzeBufferOp (amdttg::BufferOpAddressinInterface op, scf::ForOp targetFor) {
122+ LDBG (" Analyzing: " << *op );
123+ Value maybeOffsetsBlockArg = op. getOffsets ( );
151124 auto maybeOffsetDefOp = maybeOffsetsBlockArg.getDefiningOp ();
152125 if (maybeOffsetDefOp && isa<arith::AddIOp>(maybeOffsetDefOp)) {
153126 for (auto &use : maybeOffsetDefOp->getUses ()) {
@@ -164,20 +137,20 @@ struct AdvanceBasePointer : public OpRewritePattern<scf::ForOp> {
164137 }
165138 }
166139 if (!isa<BlockArgument>(maybeOffsetsBlockArg)) {
167- LDBG (" Rejected: expect load offset to be a loop argument" );
140+ LDBG (" Rejected: expect buffer op offset to be a loop argument" );
168141 return {};
169142 }
170143 auto blockArg = dyn_cast<BlockArgument>(maybeOffsetsBlockArg);
171144 auto loopBlock = blockArg.getOwner ();
172145 auto forOp = dyn_cast<scf::ForOp>(loopBlock->getParentOp ());
173146 if (!forOp || forOp != targetFor) {
174- LDBG (" Rejected: expect load offset to be a target loop argument" );
147+ LDBG (" Rejected: expect buffer op offset to be a target loop argument" );
175148 return {};
176149 }
177- auto basePtr = getBasePtr (loadOp );
150+ auto basePtr = op. getPtr ( );
178151 auto defOpBlock = basePtr.getParentBlock ();
179152 if (!defOpBlock->getParentOp ()->isProperAncestor (targetFor)) {
180- LDBG (" Rejected: expect load base Ptr to be invariant to the loop" );
153+ LDBG (" Rejected: expect buffer op base Ptr to be invariant to the loop" );
181154 return {};
182155 }
183156 auto yield = dyn_cast<scf::YieldOp>(loopBlock->getTerminator ());
@@ -199,41 +172,41 @@ struct AdvanceBasePointer : public OpRewritePattern<scf::ForOp> {
199172 }
200173 if (!advanceStep) {
201174 LDBG (" Rejected: expect arith::addi to advance same block argument as "
202- " used in load " );
175+ " used in buffer op " );
203176 return {};
204177 }
205178 if (!isScalarizableValue (advanceStep)) {
206179 LDBG (" Rejected: ptr increment step is not supported" );
207180 return {};
208181 }
209182 Value offsetInitializer = forOp.getInitArgs ()[offsetOperandNo];
210- LoadData data = {loadOp , advanceStep, Value (), offsetInitializer,
211- incrementOp};
212- LDBG (" Load is suitable for offset pointer optimization" );
183+ BufferOpInfo data = {op , advanceStep, Value (), offsetInitializer,
184+ incrementOp};
185+ LDBG (" Buffer op is suitable for offset pointer optimization" );
213186 return data;
214187 }
215188
216- // Create scalar values which will increment load base ptr
217- // Fills appropriate fields in given LoadData structures
189+ // Create scalar values which will increment buffer op base ptr
190+ // Fills appropriate fields in given BufferOpInfo structures
218191 static void createScalarIncrements (PatternRewriter &rewriter,
219- SmallVector<LoadData > &loads ) {
220- for (auto &loadData : loads ) {
221- auto scalarStep = scalarizeValue (rewriter, loadData .offsetIncrement );
222- loadData .baseIncrement = scalarStep;
192+ SmallVector<BufferOpInfo > &infoList ) {
193+ for (auto &BufferOpInfo : infoList ) {
194+ auto scalarStep = scalarizeValue (rewriter, BufferOpInfo .offsetIncrement );
195+ BufferOpInfo .baseIncrement = scalarStep;
223196 }
224197 }
225198
226- static bool isAddFirst (LoadData &ld ) {
227- return getOffset (ld. load ).getDefiningOp () == ld .incrementOp ;
199+ static bool isAddFirst (BufferOpInfo &info ) {
200+ return info. op . getOffsets ( ).getDefiningOp () == info .incrementOp ;
228201 }
229202
230203 static scf::ForOp
231204 cloneLoopWithBasePtrIncrements (PatternRewriter &rewriter, scf::ForOp forOp,
232- SmallVector<LoadData > &loads ) {
205+ SmallVector<BufferOpInfo > &infoList ) {
233206 // Create new loop with additional arguments
234207 llvm::SmallVector<Value> newLoopArgs (forOp.getInitArgs ());
235- for (auto loadData : loads ) {
236- newLoopArgs.push_back (getBasePtr (loadData. load ));
208+ for (auto info : infoList ) {
209+ newLoopArgs.push_back (info. op . getPtr ( ));
237210 }
238211 rewriter.setInsertionPoint (forOp);
239212 auto newForOp = rewriter.create <scf::ForOp>(
@@ -251,19 +224,19 @@ struct AdvanceBasePointer : public OpRewritePattern<scf::ForOp> {
251224 rewriter.clone (op, mapping);
252225 }
253226 // Create base pointer increment operations
254- auto basePtrs = newBlock->getArguments ().take_back (loads .size ());
227+ auto basePtrs = newBlock->getArguments ().take_back (infoList .size ());
255228 llvm::SmallVector<Value> nextIterBasePtrs;
256- for (auto [loadData , basePtr] : llvm::zip (loads , basePtrs)) {
257- if (isAddFirst (loadData )) {
229+ for (auto [info , basePtr] : llvm::zip (infoList , basePtrs)) {
230+ if (isAddFirst (info )) {
258231 rewriter.setInsertionPoint (newBlock, newBlock->begin ());
259232 } else {
260233 rewriter.setInsertionPoint (newBlock, newBlock->end ());
261234 }
262- Value step = loadData .baseIncrement ;
263- if (mapping.contains (loadData .baseIncrement )) {
264- step = mapping.lookup (loadData .baseIncrement );
235+ Value step = info .baseIncrement ;
236+ if (mapping.contains (info .baseIncrement )) {
237+ step = mapping.lookup (info .baseIncrement );
265238 }
266- auto loc = loadData .incrementOp ->getLoc ();
239+ auto loc = info .incrementOp ->getLoc ();
267240 auto ptrType = basePtr.getType ();
268241 auto nextIterBasePtr =
269242 rewriter.create <triton::AddPtrOp>(loc, ptrType, basePtr, step);
@@ -280,48 +253,49 @@ struct AdvanceBasePointer : public OpRewritePattern<scf::ForOp> {
280253 rewriter.setInsertionPoint (newBlock, newBlock->end ());
281254 rewriter.create <scf::YieldOp>(oldBlock->getTerminator ()->getLoc (),
282255 newYieldOperands);
283- // Replace dynamic load offsets with invariant value
256+ // Replace dynamic buffer op offsets with invariant value
284257 // Replace base ptr with incrementing value
285- for (auto [loadData, basePtr, nextBasePtr] :
286- llvm::zip (loads, basePtrs, nextIterBasePtrs)) {
287- auto newLoad = mapping.lookup <Operation *>(loadData.load );
288- setOffset (newLoad, loadData.offsetInitializer );
258+ for (auto [info, basePtr, nextBasePtr] :
259+ llvm::zip (infoList, basePtrs, nextIterBasePtrs)) {
260+ auto newBufferOp = cast<amdttg::BufferOpAddressinInterface>(
261+ mapping.lookup <Operation *>(info.op .getOperation ()));
262+ newBufferOp.getOffsetsMutable ().assign (info.offsetInitializer );
289263 // two cases:
290- // 1. first advance pointer, then load
291- // 2. load uses pointers from loop arguments, advanced pointer used on
292- // next iteration
293- Value advancingBasePtr = isAddFirst (loadData ) ? nextBasePtr : basePtr;
294- setBasePtr (newLoad, advancingBasePtr);
264+ // 1. buffer op uses pointer after increment
265+ // 2. buffer op uses pointers from loop arguments,
266+ // incremented pointer is used on next iteration
267+ Value advancingBasePtr = isAddFirst (info ) ? nextBasePtr : basePtr;
268+ newBufferOp. getPtrMutable (). assign ( advancingBasePtr);
295269 }
296270 return newForOp;
297271 }
298272
299- template < typename OpType>
300- static void collectLoads ( SmallVector<LoadData> &loads, scf::ForOp forOp) {
301- forOp.walk ([&loads , forOp](OpType loadOp ) {
302- auto loadData = analyzeLoad (loadOp , forOp);
303- if (loadData .has_value ()) {
304- loads .push_back (loadData .value ());
273+ static SmallVector<BufferOpInfo> collectBufferOps (scf::ForOp forOp) {
274+ SmallVector<BufferOpInfo> list;
275+ forOp.walk ([&list , forOp](amdttg::BufferOpAddressinInterface op ) {
276+ auto info = analyzeBufferOp (op , forOp);
277+ if (info .has_value ()) {
278+ list .push_back (info .value ());
305279 }
306280 });
281+ return list;
307282 }
308283
309284 LogicalResult matchAndRewrite (scf::ForOp forOp,
310285 PatternRewriter &rewriter) const override {
311- LDBG (" Analyzing ForOp for for offset pointer optimization: " << forOp);
312- // Gather buffer loads which could be optimized
313- SmallVector<LoadData> loads;
314- collectLoads<triton::amdgpu::BufferLoadOp>(loads, forOp);
315- collectLoads<triton::amdgpu::BufferLoadToLocalOp>(loads, forOp);
286+ LDBG (" Analyzing ForOp for offset pointer optimization: " << forOp);
287+ // Gather buffer buffer operations which could be optimized
288+ SmallVector<BufferOpInfo> infoList = collectBufferOps (forOp);
316289
317- if (loads.empty ())
318- return rewriter.notifyMatchFailure (forOp, " no suitable buffer loads" );
290+ if (infoList.empty ())
291+ return rewriter.notifyMatchFailure (forOp,
292+ " no suitable buffer operations" );
319293
320294 // Perform IR transformation
321- createScalarIncrements (rewriter, loads );
322- auto newForOp = cloneLoopWithBasePtrIncrements (rewriter, forOp, loads );
323- rewriter.replaceAllUsesWith (forOp. getResults (),
324- newForOp.getResults ().drop_back (loads .size ()));
295+ createScalarIncrements (rewriter, infoList );
296+ auto newForOp = cloneLoopWithBasePtrIncrements (rewriter, forOp, infoList );
297+ rewriter.replaceAllUsesWith (
298+ forOp. getResults (), newForOp.getResults ().drop_back (infoList .size ()));
325299 rewriter.eraseOp (forOp);
326300 return success ();
327301 }
0 commit comments