Skip to content

Commit 8fcf1d2

Browse files
committed
address review comments
1 parent 1cfaa6e commit 8fcf1d2

File tree

1 file changed

+55
-70
lines changed

1 file changed

+55
-70
lines changed

third_party/amd/lib/TritonAMDGPUTransforms/OptimizeBufferOpPtr.cpp

Lines changed: 55 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -101,35 +101,20 @@ struct AdvanceBasePointer : public OpRewritePattern<scf::ForOp> {
101101
return Value();
102102
}
103103

104-
// load is a target buffer load
105-
// offsetIncrement is a tensor added to offsets on each iteration
106-
// baseIncrement is a scalar which will be added to base pointer after
107-
// optimization offsetInitialized is a value of offset on first loop iteration
108-
// incrementOp is an operation that advances offset tensor
109-
struct LoadData {
110-
Operation *load;
104+
// Description of struct fields:
105+
// - op is a target BufferOp, for example BufferLoadOp or BufferLoadToLocalOp
106+
// - offsetIncrement is a tensor added to offsets on each iteration
107+
// - baseIncrement is a scalar which will be added to base pointer
108+
// - offsetInitializer is a value of offset on first loop iteration
109+
// - incrementOp is an operation that advances offset tensor
110+
struct BufferOpInfo {
111+
amdttg::BufferOpAddressinInterface op;
111112
Value offsetIncrement;
112113
Value baseIncrement;
113114
Value offsetInitializer;
114115
Operation *incrementOp;
115116
};
116117

117-
static Value getOffset(Operation *load) {
118-
if (auto specific = dyn_cast<amdttg::BufferLoadOp>(load))
119-
return specific.getOffsets();
120-
if (auto specific = dyn_cast<amdttg::BufferLoadToLocalOp>(load))
121-
return specific.getOffsets();
122-
assert(false && "unsupported operation type");
123-
}
124-
125-
static Value getBasePtr(Operation *load) {
126-
if (auto specific = dyn_cast<amdttg::BufferLoadOp>(load))
127-
return specific.getPtr();
128-
if (auto specific = dyn_cast<amdttg::BufferLoadToLocalOp>(load))
129-
return specific.getPtr();
130-
assert(false && "unsupported operation type");
131-
}
132-
133118
static void setOffset(Operation *load, Value newOffset) {
134119
assert((isa<amdttg::BufferLoadOp, amdttg::BufferLoadToLocalOp>(load)));
135120
const int offsetIdx = isa<amdttg::BufferLoadOp>(load) ? 1 : 2;
@@ -143,11 +128,11 @@ struct AdvanceBasePointer : public OpRewritePattern<scf::ForOp> {
143128
}
144129

145130
// Perform series of checks to decide if given operation could be optimized.
146-
// If optimization is possible, return filled LoadData
147-
static std::optional<LoadData> analyzeLoad(Operation *loadOp,
148-
scf::ForOp targetFor) {
149-
LDBG("Analyzing: " << *loadOp);
150-
Value maybeOffsetsBlockArg = getOffset(loadOp);
131+
// If optimization is possible, return filled BufferOpInfo
132+
static std::optional<BufferOpInfo>
133+
analyzeBufferOp(amdttg::BufferOpAddressinInterface op, scf::ForOp targetFor) {
134+
LDBG("Analyzing: " << *op);
135+
Value maybeOffsetsBlockArg = op.getOffsets();
151136
auto maybeOffsetDefOp = maybeOffsetsBlockArg.getDefiningOp();
152137
if (maybeOffsetDefOp && isa<arith::AddIOp>(maybeOffsetDefOp)) {
153138
for (auto &use : maybeOffsetDefOp->getUses()) {
@@ -174,7 +159,7 @@ struct AdvanceBasePointer : public OpRewritePattern<scf::ForOp> {
174159
LDBG("Rejected: expect load offset to be a target loop argument");
175160
return {};
176161
}
177-
auto basePtr = getBasePtr(loadOp);
162+
auto basePtr = op.getPtr();
178163
auto defOpBlock = basePtr.getParentBlock();
179164
if (!defOpBlock->getParentOp()->isProperAncestor(targetFor)) {
180165
LDBG("Rejected: expect load base Ptr to be invariant to the loop");
@@ -207,33 +192,33 @@ struct AdvanceBasePointer : public OpRewritePattern<scf::ForOp> {
207192
return {};
208193
}
209194
Value offsetInitializer = forOp.getInitArgs()[offsetOperandNo];
210-
LoadData data = {loadOp, advanceStep, Value(), offsetInitializer,
211-
incrementOp};
195+
BufferOpInfo data = {op, advanceStep, Value(), offsetInitializer,
196+
incrementOp};
212197
LDBG("Load is suitable for offset pointer optimization");
213198
return data;
214199
}
215200

216201
// Create scalar values which will increment load base ptr
217-
// Fills appropriate fields in given LoadData structures
202+
// Fills appropriate fields in given BufferOpInfo structures
218203
static void createScalarIncrements(PatternRewriter &rewriter,
219-
SmallVector<LoadData> &loads) {
220-
for (auto &loadData : loads) {
221-
auto scalarStep = scalarizeValue(rewriter, loadData.offsetIncrement);
222-
loadData.baseIncrement = scalarStep;
204+
SmallVector<BufferOpInfo> &loads) {
205+
for (auto &BufferOpInfo : loads) {
206+
auto scalarStep = scalarizeValue(rewriter, BufferOpInfo.offsetIncrement);
207+
BufferOpInfo.baseIncrement = scalarStep;
223208
}
224209
}
225210

226-
static bool isAddFirst(LoadData &ld) {
227-
return getOffset(ld.load).getDefiningOp() == ld.incrementOp;
211+
static bool isAddFirst(BufferOpInfo &info) {
212+
return info.op.getOffsets().getDefiningOp() == info.incrementOp;
228213
}
229214

230215
static scf::ForOp
231216
cloneLoopWithBasePtrIncrements(PatternRewriter &rewriter, scf::ForOp forOp,
232-
SmallVector<LoadData> &loads) {
217+
SmallVector<BufferOpInfo> &infoList) {
233218
// Create new loop with additional arguments
234219
llvm::SmallVector<Value> newLoopArgs(forOp.getInitArgs());
235-
for (auto loadData : loads) {
236-
newLoopArgs.push_back(getBasePtr(loadData.load));
220+
for (auto info : infoList) {
221+
newLoopArgs.push_back(info.op.getPtr());
237222
}
238223
rewriter.setInsertionPoint(forOp);
239224
auto newForOp = rewriter.create<scf::ForOp>(
@@ -251,19 +236,19 @@ struct AdvanceBasePointer : public OpRewritePattern<scf::ForOp> {
251236
rewriter.clone(op, mapping);
252237
}
253238
// Create base pointer increment operations
254-
auto basePtrs = newBlock->getArguments().take_back(loads.size());
239+
auto basePtrs = newBlock->getArguments().take_back(infoList.size());
255240
llvm::SmallVector<Value> nextIterBasePtrs;
256-
for (auto [loadData, basePtr] : llvm::zip(loads, basePtrs)) {
257-
if (isAddFirst(loadData)) {
241+
for (auto [info, basePtr] : llvm::zip(infoList, basePtrs)) {
242+
if (isAddFirst(info)) {
258243
rewriter.setInsertionPoint(newBlock, newBlock->begin());
259244
} else {
260245
rewriter.setInsertionPoint(newBlock, newBlock->end());
261246
}
262-
Value step = loadData.baseIncrement;
263-
if (mapping.contains(loadData.baseIncrement)) {
264-
step = mapping.lookup(loadData.baseIncrement);
247+
Value step = info.baseIncrement;
248+
if (mapping.contains(info.baseIncrement)) {
249+
step = mapping.lookup(info.baseIncrement);
265250
}
266-
auto loc = loadData.incrementOp->getLoc();
251+
auto loc = info.incrementOp->getLoc();
267252
auto ptrType = basePtr.getType();
268253
auto nextIterBasePtr =
269254
rewriter.create<triton::AddPtrOp>(loc, ptrType, basePtr, step);
@@ -282,46 +267,46 @@ struct AdvanceBasePointer : public OpRewritePattern<scf::ForOp> {
282267
newYieldOperands);
283268
// Replace dynamic load offsets with invariant value
284269
// Replace base ptr with incrementing value
285-
for (auto [loadData, basePtr, nextBasePtr] :
286-
llvm::zip(loads, basePtrs, nextIterBasePtrs)) {
287-
auto newLoad = mapping.lookup<Operation *>(loadData.load);
288-
setOffset(newLoad, loadData.offsetInitializer);
270+
for (auto [info, basePtr, nextBasePtr] :
271+
llvm::zip(infoList, basePtrs, nextIterBasePtrs)) {
272+
auto newLoad = mapping.lookup<Operation *>(info.op.getOperation());
273+
setOffset(newLoad, info.offsetInitializer);
289274
// two cases:
290275
// 1. first advance pointer, then load
291276
// 2. load uses pointers from loop arguments, advanced pointer used on
292277
// next iteration
293-
Value advancingBasePtr = isAddFirst(loadData) ? nextBasePtr : basePtr;
278+
Value advancingBasePtr = isAddFirst(info) ? nextBasePtr : basePtr;
294279
setBasePtr(newLoad, advancingBasePtr);
295280
}
296281
return newForOp;
297282
}
298283

299-
template <typename OpType>
300-
static void collectLoads(SmallVector<LoadData> &loads, scf::ForOp forOp) {
301-
forOp.walk([&loads, forOp](OpType loadOp) {
302-
auto loadData = analyzeLoad(loadOp, forOp);
303-
if (loadData.has_value()) {
304-
loads.push_back(loadData.value());
284+
static SmallVector<BufferOpInfo> collectBufferOps(scf::ForOp forOp) {
285+
SmallVector<BufferOpInfo> list;
286+
forOp.walk([&list, forOp](amdttg::BufferOpAddressinInterface op) {
287+
auto info = analyzeBufferOp(op, forOp);
288+
if (info.has_value()) {
289+
list.push_back(info.value());
305290
}
306291
});
292+
return list;
307293
}
308294

309295
LogicalResult matchAndRewrite(scf::ForOp forOp,
310296
PatternRewriter &rewriter) const override {
311-
LDBG("Analyzing ForOp for for offset pointer optimization: " << forOp);
312-
// Gather buffer loads which could be optimized
313-
SmallVector<LoadData> loads;
314-
collectLoads<triton::amdgpu::BufferLoadOp>(loads, forOp);
315-
collectLoads<triton::amdgpu::BufferLoadToLocalOp>(loads, forOp);
297+
LDBG("Analyzing ForOp for offset pointer optimization: " << forOp);
298+
// Gather buffer buffer operations which could be optimized
299+
SmallVector<BufferOpInfo> infoList = collectBufferOps(forOp);
316300

317-
if (loads.empty())
318-
return rewriter.notifyMatchFailure(forOp, "no suitable buffer loads");
301+
if (infoList.empty())
302+
return rewriter.notifyMatchFailure(forOp,
303+
"no suitable buffer operations");
319304

320305
// Perform IR transformation
321-
createScalarIncrements(rewriter, loads);
322-
auto newForOp = cloneLoopWithBasePtrIncrements(rewriter, forOp, loads);
323-
rewriter.replaceAllUsesWith(forOp.getResults(),
324-
newForOp.getResults().drop_back(loads.size()));
306+
createScalarIncrements(rewriter, infoList);
307+
auto newForOp = cloneLoopWithBasePtrIncrements(rewriter, forOp, infoList);
308+
rewriter.replaceAllUsesWith(
309+
forOp.getResults(), newForOp.getResults().drop_back(infoList.size()));
325310
rewriter.eraseOp(forOp);
326311
return success();
327312
}

0 commit comments

Comments
 (0)