Skip to content

Commit 82cef71

Browse files
committed
support buffer_load_to_local
1 parent 6205f38 commit 82cef71

File tree

2 files changed

+85
-20
lines changed

2 files changed

+85
-20
lines changed

test/TritonGPU/amd/amd-convert-buffer-ops-base-ptr-increment.mlir

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,39 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.thr
4646

4747
// -----
4848

49+
// COMMON-LABEL: buffer_load_to_local
50+
// COMMON-DAG: [[X_OFFSET_CST:%.*]] = arith.constant dense<123>
51+
// COMMON: scf.for {{.*}} iter_args({{.*}}, [[X_BASE:%.*]] = {{.*}}
52+
// COMMON: amdgpu.buffer_load_to_local [[X_BASE]]{{\[}}[[X_OFFSET_CST]]{{\]}}
53+
// COMMON: [[NEXT_X_BASE:%.*]] = tt.addptr [[X_BASE]], %c64_i32
54+
// COMMON: scf.yield {{.*}}, [[NEXT_X_BASE]]
55+
56+
#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 1], order = [1, 0]}>
57+
#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}>
58+
#smem = #ttg.shared_memory
59+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 64 : i32} {
60+
tt.func public @buffer_load_to_local(%X: !tt.ptr<f16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}) attributes {noinline = false} {
61+
%cst = arith.constant dense<64> : tensor<16x64xi32, #blocked>
62+
%c0 = arith.constant 0 : index
63+
%c128 = arith.constant 128 : index
64+
%c1 = arith.constant 1 : index
65+
66+
%Xoffset_init = arith.constant dense<123> : tensor<16x64xi32, #blocked>
67+
68+
%x_dummy_buffer = ttg.local_alloc : () -> !ttg.memdesc<16x64xf16, #shared, #smem, mutable, 16x64>
69+
70+
%for = scf.for %idx = %c0 to %c128 step %c1 iter_args(%Xoffset = %Xoffset_init) -> (tensor<16x64xi32, #blocked>) {
71+
%x = amdgpu.buffer_load_to_local %X[%Xoffset] into %x_dummy_buffer : <f16>[tensor<16x64xi32, #blocked>] -> !ttg.memdesc<16x64xf16, #shared, #smem, mutable, 16x64>
72+
73+
%Xoffset_next = arith.addi %Xoffset, %cst : tensor<16x64xi32, #blocked>
74+
scf.yield %Xoffset_next : tensor<16x64xi32, #blocked>
75+
}
76+
tt.return
77+
}
78+
}
79+
80+
// -----
81+
4982
// COMMON-LABEL: add_before_load
5083
// COMMON-DAG: [[X_OFFSET_CST:%.*]] = arith.constant dense<123>
5184
// COMMON: scf.for {{.*}} iter_args({{.*}}, [[X_BASE:%.*]] = {{.*}})

third_party/amd/lib/TritonAMDGPUTransforms/OptimizeBufferOpPtr.cpp

Lines changed: 52 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ using mlir::triton::AMD::ISAFamily;
2525

2626
namespace ttg = mlir::triton::gpu;
2727
namespace tt = mlir::triton;
28+
namespace amdttg = mlir::triton::amdgpu;
2829

2930
namespace mlir {
3031

@@ -106,19 +107,47 @@ struct AdvanceBasePointer : public OpRewritePattern<scf::ForOp> {
106107
// optimization offsetInitialized is a value of offset on first loop iteration
107108
// incrementOp is an operation that advances offset tensor
108109
struct LoadData {
109-
triton::amdgpu::BufferLoadOp load;
110+
Operation *load;
110111
Value offsetIncrement;
111112
Value baseIncrement;
112113
Value offsetInitializer;
113114
Operation *incrementOp;
114115
};
115116

117+
static Value getOffset(Operation *load) {
118+
if (auto specific = dyn_cast<amdttg::BufferLoadOp>(load))
119+
return specific.getOffsets();
120+
if (auto specific = dyn_cast<amdttg::BufferLoadToLocalOp>(load))
121+
return specific.getOffsets();
122+
assert(false && "unsupported operation type");
123+
}
124+
125+
static Value getBasePtr(Operation *load) {
126+
if (auto specific = dyn_cast<amdttg::BufferLoadOp>(load))
127+
return specific.getPtr();
128+
if (auto specific = dyn_cast<amdttg::BufferLoadToLocalOp>(load))
129+
return specific.getPtr();
130+
assert(false && "unsupported operation type");
131+
}
132+
133+
static void setOffset(Operation *load, Value newOffset) {
134+
assert((isa<amdttg::BufferLoadOp, amdttg::BufferLoadToLocalOp>(load)));
135+
const int offsetIdx = isa<amdttg::BufferLoadOp>(load) ? 1 : 2;
136+
load->setOperand(offsetIdx, newOffset);
137+
}
138+
139+
static void setBasePtr(Operation *load, Value newBasePtr) {
140+
assert((isa<amdttg::BufferLoadOp, amdttg::BufferLoadToLocalOp>(load)));
141+
const int ptrIdx = isa<amdttg::BufferLoadOp>(load) ? 0 : 1;
142+
load->setOperand(ptrIdx, newBasePtr);
143+
}
144+
116145
// Perform series of checks to decide if given operation could be optimized.
117146
// If optimization is possible, return filled LoadData
118-
static std::optional<LoadData>
119-
analyzeLoad(triton::amdgpu::BufferLoadOp loadOp, scf::ForOp targetFor) {
120-
LDBG("Analyzing: " << loadOp);
121-
Value maybeOffsetsBlockArg = loadOp.getOffsets();
147+
static std::optional<LoadData> analyzeLoad(Operation *loadOp,
148+
scf::ForOp targetFor) {
149+
LDBG("Analyzing: " << *loadOp);
150+
Value maybeOffsetsBlockArg = getOffset(loadOp);
122151
auto maybeOffsetDefOp = maybeOffsetsBlockArg.getDefiningOp();
123152
if (maybeOffsetDefOp && isa<arith::AddIOp>(maybeOffsetDefOp)) {
124153
for (auto &use : maybeOffsetDefOp->getUses()) {
@@ -145,7 +174,7 @@ struct AdvanceBasePointer : public OpRewritePattern<scf::ForOp> {
145174
LDBG("Rejected: expect load offset to be a target loop argument");
146175
return {};
147176
}
148-
auto basePtr = loadOp.getPtr();
177+
auto basePtr = getBasePtr(loadOp);
149178
auto defOpBlock = basePtr.getParentBlock();
150179
if (!defOpBlock->getParentOp()->isProperAncestor(targetFor)) {
151180
LDBG("Rejected: expect load base Ptr to be invariant to the loop");
@@ -195,7 +224,7 @@ struct AdvanceBasePointer : public OpRewritePattern<scf::ForOp> {
195224
}
196225

197226
static bool isAddFirst(LoadData &ld) {
198-
return ld.load.getOffsets().getDefiningOp() == ld.incrementOp;
227+
return getOffset(ld.load).getDefiningOp() == ld.incrementOp;
199228
}
200229

201230
static scf::ForOp
@@ -204,7 +233,7 @@ struct AdvanceBasePointer : public OpRewritePattern<scf::ForOp> {
204233
// Create new loop with additional arguments
205234
llvm::SmallVector<Value> newLoopArgs(forOp.getInitArgs());
206235
for (auto loadData : loads) {
207-
newLoopArgs.push_back(loadData.load.getPtr());
236+
newLoopArgs.push_back(getBasePtr(loadData.load));
208237
}
209238
rewriter.setInsertionPoint(forOp);
210239
auto newForOp = rewriter.create<scf::ForOp>(
@@ -255,32 +284,35 @@ struct AdvanceBasePointer : public OpRewritePattern<scf::ForOp> {
255284
// Replace base ptr with incrementing value
256285
for (auto [loadData, basePtr, nextBasePtr] :
257286
llvm::zip(loads, basePtrs, nextIterBasePtrs)) {
258-
auto newLoad = cast<triton::amdgpu::BufferLoadOp>(
259-
mapping.lookup<Operation *>(loadData.load));
260-
constexpr int ptrIdx = 0;
261-
constexpr int offsetIdx = 1;
262-
newLoad.setOperand(offsetIdx, loadData.offsetInitializer);
287+
auto newLoad = mapping.lookup<Operation *>(loadData.load);
288+
setOffset(newLoad, loadData.offsetInitializer);
263289
// two cases:
264290
// 1. first advance pointer, then load
265291
// 2. load uses pointers from loop arguments, advanced pointer used on
266292
// next iteration
267293
Value advancingBasePtr = isAddFirst(loadData) ? nextBasePtr : basePtr;
268-
newLoad.setOperand(ptrIdx, advancingBasePtr);
294+
setBasePtr(newLoad, advancingBasePtr);
269295
}
270296
return newForOp;
271297
}
272298

273-
LogicalResult matchAndRewrite(scf::ForOp forOp,
274-
PatternRewriter &rewriter) const override {
275-
LDBG("Analyzing ForOp for for offset pointer optimization: " << forOp);
276-
// Gather buffer loads which could be optimized
277-
SmallVector<LoadData> loads;
278-
forOp.walk([&loads, forOp](triton::amdgpu::BufferLoadOp loadOp) {
299+
template <typename OpType>
300+
static void collectLoads(SmallVector<LoadData> &loads, scf::ForOp forOp) {
301+
forOp.walk([&loads, forOp](OpType loadOp) {
279302
auto loadData = analyzeLoad(loadOp, forOp);
280303
if (loadData.has_value()) {
281304
loads.push_back(loadData.value());
282305
}
283306
});
307+
}
308+
309+
LogicalResult matchAndRewrite(scf::ForOp forOp,
310+
PatternRewriter &rewriter) const override {
311+
LDBG("Analyzing ForOp for for offset pointer optimization: " << forOp);
312+
// Gather buffer loads which could be optimized
313+
SmallVector<LoadData> loads;
314+
collectLoads<triton::amdgpu::BufferLoadOp>(loads, forOp);
315+
collectLoads<triton::amdgpu::BufferLoadToLocalOp>(loads, forOp);
284316

285317
if (loads.empty())
286318
return rewriter.notifyMatchFailure(forOp, "no suitable buffer loads");

0 commit comments

Comments
 (0)