Skip to content

Commit c2eacf4

Browse files
committed
address review comments
1 parent c5a189a commit c2eacf4

File tree

1 file changed

+68
-94
lines changed

1 file changed

+68
-94
lines changed

third_party/amd/lib/TritonAMDGPUTransforms/OptimizeBufferOpPtr.cpp

Lines changed: 68 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ namespace {
3636
// /*-----------------Base pointer increment optimization-------------------*/
3737

3838
// Optimization tries to transfer increments from offsets to base pointer in
39-
// buffer loads:
39+
// buffer operations:
4040
//
4141
// for ... (offsets = offsets_init):
4242
// val = buffer_load basePtr [ offsets ]
@@ -101,53 +101,26 @@ struct AdvanceBasePointer : public OpRewritePattern<scf::ForOp> {
101101
return Value();
102102
}
103103

104-
// load is a target buffer load
105-
// offsetIncrement is a tensor added to offsets on each iteration
106-
// baseIncrement is a scalar which will be added to base pointer after
107-
// optimization offsetInitialized is a value of offset on first loop iteration
108-
// incrementOp is an operation that advances offset tensor
109-
struct LoadData {
110-
Operation *load;
104+
// Description of struct fields:
105+
// - op is a target BufferOp, for example BufferLoadOp or BufferLoadToLocalOp
106+
// - offsetIncrement is a tensor added to offsets on each iteration
107+
// - baseIncrement is a scalar which will be added to base pointer
108+
// - offsetInitializer is a value of offset on first loop iteration
109+
// - incrementOp is an operation that advances offset tensor
110+
struct BufferOpInfo {
111+
amdttg::BufferOpAddressinInterface op;
111112
Value offsetIncrement;
112113
Value baseIncrement;
113114
Value offsetInitializer;
114115
Operation *incrementOp;
115116
};
116117

117-
static Value getOffset(Operation *load) {
118-
if (auto specific = dyn_cast<amdttg::BufferLoadOp>(load))
119-
return specific.getOffsets();
120-
if (auto specific = dyn_cast<amdttg::BufferLoadToLocalOp>(load))
121-
return specific.getOffsets();
122-
assert(false && "unsupported operation type");
123-
}
124-
125-
static Value getBasePtr(Operation *load) {
126-
if (auto specific = dyn_cast<amdttg::BufferLoadOp>(load))
127-
return specific.getPtr();
128-
if (auto specific = dyn_cast<amdttg::BufferLoadToLocalOp>(load))
129-
return specific.getPtr();
130-
assert(false && "unsupported operation type");
131-
}
132-
133-
static void setOffset(Operation *load, Value newOffset) {
134-
assert((isa<amdttg::BufferLoadOp, amdttg::BufferLoadToLocalOp>(load)));
135-
const int offsetIdx = isa<amdttg::BufferLoadOp>(load) ? 1 : 2;
136-
load->setOperand(offsetIdx, newOffset);
137-
}
138-
139-
static void setBasePtr(Operation *load, Value newBasePtr) {
140-
assert((isa<amdttg::BufferLoadOp, amdttg::BufferLoadToLocalOp>(load)));
141-
const int ptrIdx = isa<amdttg::BufferLoadOp>(load) ? 0 : 1;
142-
load->setOperand(ptrIdx, newBasePtr);
143-
}
144-
145118
// Perform series of checks to decide if given operation could be optimized.
146-
// If optimization is possible, return filled LoadData
147-
static std::optional<LoadData> analyzeLoad(Operation *loadOp,
148-
scf::ForOp targetFor) {
149-
LDBG("Analyzing: " << *loadOp);
150-
Value maybeOffsetsBlockArg = getOffset(loadOp);
119+
// If optimization is possible, return filled BufferOpInfo
120+
static std::optional<BufferOpInfo>
121+
analyzeBufferOp(amdttg::BufferOpAddressinInterface op, scf::ForOp targetFor) {
122+
LDBG("Analyzing: " << *op);
123+
Value maybeOffsetsBlockArg = op.getOffsets();
151124
auto maybeOffsetDefOp = maybeOffsetsBlockArg.getDefiningOp();
152125
if (maybeOffsetDefOp && isa<arith::AddIOp>(maybeOffsetDefOp)) {
153126
for (auto &use : maybeOffsetDefOp->getUses()) {
@@ -164,20 +137,20 @@ struct AdvanceBasePointer : public OpRewritePattern<scf::ForOp> {
164137
}
165138
}
166139
if (!isa<BlockArgument>(maybeOffsetsBlockArg)) {
167-
LDBG("Rejected: expect load offset to be a loop argument");
140+
LDBG("Rejected: expect buffer op offset to be a loop argument");
168141
return {};
169142
}
170143
auto blockArg = dyn_cast<BlockArgument>(maybeOffsetsBlockArg);
171144
auto loopBlock = blockArg.getOwner();
172145
auto forOp = dyn_cast<scf::ForOp>(loopBlock->getParentOp());
173146
if (!forOp || forOp != targetFor) {
174-
LDBG("Rejected: expect load offset to be a target loop argument");
147+
LDBG("Rejected: expect buffer op offset to be a target loop argument");
175148
return {};
176149
}
177-
auto basePtr = getBasePtr(loadOp);
150+
auto basePtr = op.getPtr();
178151
auto defOpBlock = basePtr.getParentBlock();
179152
if (!defOpBlock->getParentOp()->isProperAncestor(targetFor)) {
180-
LDBG("Rejected: expect load base Ptr to be invariant to the loop");
153+
LDBG("Rejected: expect buffer op base Ptr to be invariant to the loop");
181154
return {};
182155
}
183156
auto yield = dyn_cast<scf::YieldOp>(loopBlock->getTerminator());
@@ -199,41 +172,41 @@ struct AdvanceBasePointer : public OpRewritePattern<scf::ForOp> {
199172
}
200173
if (!advanceStep) {
201174
LDBG("Rejected: expect arith::addi to advance same block argument as "
202-
"used in load");
175+
"used in buffer op");
203176
return {};
204177
}
205178
if (!isScalarizableValue(advanceStep)) {
206179
LDBG("Rejected: ptr increment step is not supported");
207180
return {};
208181
}
209182
Value offsetInitializer = forOp.getInitArgs()[offsetOperandNo];
210-
LoadData data = {loadOp, advanceStep, Value(), offsetInitializer,
211-
incrementOp};
212-
LDBG("Load is suitable for offset pointer optimization");
183+
BufferOpInfo data = {op, advanceStep, Value(), offsetInitializer,
184+
incrementOp};
185+
LDBG("Buffer op is suitable for offset pointer optimization");
213186
return data;
214187
}
215188

216-
// Create scalar values which will increment load base ptr
217-
// Fills appropriate fields in given LoadData structures
189+
// Create scalar values which will increment buffer op base ptr
190+
// Fills appropriate fields in given BufferOpInfo structures
218191
static void createScalarIncrements(PatternRewriter &rewriter,
219-
SmallVector<LoadData> &loads) {
220-
for (auto &loadData : loads) {
221-
auto scalarStep = scalarizeValue(rewriter, loadData.offsetIncrement);
222-
loadData.baseIncrement = scalarStep;
192+
SmallVector<BufferOpInfo> &infoList) {
193+
for (auto &BufferOpInfo : infoList) {
194+
auto scalarStep = scalarizeValue(rewriter, BufferOpInfo.offsetIncrement);
195+
BufferOpInfo.baseIncrement = scalarStep;
223196
}
224197
}
225198

226-
static bool isAddFirst(LoadData &ld) {
227-
return getOffset(ld.load).getDefiningOp() == ld.incrementOp;
199+
static bool isAddFirst(BufferOpInfo &info) {
200+
return info.op.getOffsets().getDefiningOp() == info.incrementOp;
228201
}
229202

230203
static scf::ForOp
231204
cloneLoopWithBasePtrIncrements(PatternRewriter &rewriter, scf::ForOp forOp,
232-
SmallVector<LoadData> &loads) {
205+
SmallVector<BufferOpInfo> &infoList) {
233206
// Create new loop with additional arguments
234207
llvm::SmallVector<Value> newLoopArgs(forOp.getInitArgs());
235-
for (auto loadData : loads) {
236-
newLoopArgs.push_back(getBasePtr(loadData.load));
208+
for (auto info : infoList) {
209+
newLoopArgs.push_back(info.op.getPtr());
237210
}
238211
rewriter.setInsertionPoint(forOp);
239212
auto newForOp = rewriter.create<scf::ForOp>(
@@ -251,19 +224,19 @@ struct AdvanceBasePointer : public OpRewritePattern<scf::ForOp> {
251224
rewriter.clone(op, mapping);
252225
}
253226
// Create base pointer increment operations
254-
auto basePtrs = newBlock->getArguments().take_back(loads.size());
227+
auto basePtrs = newBlock->getArguments().take_back(infoList.size());
255228
llvm::SmallVector<Value> nextIterBasePtrs;
256-
for (auto [loadData, basePtr] : llvm::zip(loads, basePtrs)) {
257-
if (isAddFirst(loadData)) {
229+
for (auto [info, basePtr] : llvm::zip(infoList, basePtrs)) {
230+
if (isAddFirst(info)) {
258231
rewriter.setInsertionPoint(newBlock, newBlock->begin());
259232
} else {
260233
rewriter.setInsertionPoint(newBlock, newBlock->end());
261234
}
262-
Value step = loadData.baseIncrement;
263-
if (mapping.contains(loadData.baseIncrement)) {
264-
step = mapping.lookup(loadData.baseIncrement);
235+
Value step = info.baseIncrement;
236+
if (mapping.contains(info.baseIncrement)) {
237+
step = mapping.lookup(info.baseIncrement);
265238
}
266-
auto loc = loadData.incrementOp->getLoc();
239+
auto loc = info.incrementOp->getLoc();
267240
auto ptrType = basePtr.getType();
268241
auto nextIterBasePtr =
269242
rewriter.create<triton::AddPtrOp>(loc, ptrType, basePtr, step);
@@ -280,48 +253,49 @@ struct AdvanceBasePointer : public OpRewritePattern<scf::ForOp> {
280253
rewriter.setInsertionPoint(newBlock, newBlock->end());
281254
rewriter.create<scf::YieldOp>(oldBlock->getTerminator()->getLoc(),
282255
newYieldOperands);
283-
// Replace dynamic load offsets with invariant value
256+
// Replace dynamic buffer op offsets with invariant value
284257
// Replace base ptr with incrementing value
285-
for (auto [loadData, basePtr, nextBasePtr] :
286-
llvm::zip(loads, basePtrs, nextIterBasePtrs)) {
287-
auto newLoad = mapping.lookup<Operation *>(loadData.load);
288-
setOffset(newLoad, loadData.offsetInitializer);
258+
for (auto [info, basePtr, nextBasePtr] :
259+
llvm::zip(infoList, basePtrs, nextIterBasePtrs)) {
260+
auto newBufferOp = cast<amdttg::BufferOpAddressinInterface>(
261+
mapping.lookup<Operation *>(info.op.getOperation()));
262+
newBufferOp.getOffsetsMutable().assign(info.offsetInitializer);
289263
// two cases:
290-
// 1. first advance pointer, then load
291-
// 2. load uses pointers from loop arguments, advanced pointer used on
292-
// next iteration
293-
Value advancingBasePtr = isAddFirst(loadData) ? nextBasePtr : basePtr;
294-
setBasePtr(newLoad, advancingBasePtr);
264+
// 1. buffer op uses pointer after increment
265+
// 2. buffer op uses pointers from loop arguments,
266+
// incremented pointer is used on next iteration
267+
Value advancingBasePtr = isAddFirst(info) ? nextBasePtr : basePtr;
268+
newBufferOp.getPtrMutable().assign(advancingBasePtr);
295269
}
296270
return newForOp;
297271
}
298272

299-
template <typename OpType>
300-
static void collectLoads(SmallVector<LoadData> &loads, scf::ForOp forOp) {
301-
forOp.walk([&loads, forOp](OpType loadOp) {
302-
auto loadData = analyzeLoad(loadOp, forOp);
303-
if (loadData.has_value()) {
304-
loads.push_back(loadData.value());
273+
static SmallVector<BufferOpInfo> collectBufferOps(scf::ForOp forOp) {
274+
SmallVector<BufferOpInfo> list;
275+
forOp.walk([&list, forOp](amdttg::BufferOpAddressinInterface op) {
276+
auto info = analyzeBufferOp(op, forOp);
277+
if (info.has_value()) {
278+
list.push_back(info.value());
305279
}
306280
});
281+
return list;
307282
}
308283

309284
LogicalResult matchAndRewrite(scf::ForOp forOp,
310285
PatternRewriter &rewriter) const override {
311-
LDBG("Analyzing ForOp for for offset pointer optimization: " << forOp);
312-
// Gather buffer loads which could be optimized
313-
SmallVector<LoadData> loads;
314-
collectLoads<triton::amdgpu::BufferLoadOp>(loads, forOp);
315-
collectLoads<triton::amdgpu::BufferLoadToLocalOp>(loads, forOp);
286+
LDBG("Analyzing ForOp for offset pointer optimization: " << forOp);
287+
// Gather buffer buffer operations which could be optimized
288+
SmallVector<BufferOpInfo> infoList = collectBufferOps(forOp);
316289

317-
if (loads.empty())
318-
return rewriter.notifyMatchFailure(forOp, "no suitable buffer loads");
290+
if (infoList.empty())
291+
return rewriter.notifyMatchFailure(forOp,
292+
"no suitable buffer operations");
319293

320294
// Perform IR transformation
321-
createScalarIncrements(rewriter, loads);
322-
auto newForOp = cloneLoopWithBasePtrIncrements(rewriter, forOp, loads);
323-
rewriter.replaceAllUsesWith(forOp.getResults(),
324-
newForOp.getResults().drop_back(loads.size()));
295+
createScalarIncrements(rewriter, infoList);
296+
auto newForOp = cloneLoopWithBasePtrIncrements(rewriter, forOp, infoList);
297+
rewriter.replaceAllUsesWith(
298+
forOp.getResults(), newForOp.getResults().drop_back(infoList.size()));
325299
rewriter.eraseOp(forOp);
326300
return success();
327301
}

0 commit comments

Comments
 (0)