@@ -25,6 +25,7 @@ using mlir::triton::AMD::ISAFamily;
2525
2626namespace ttg = mlir::triton::gpu;
2727namespace tt = mlir::triton;
28+ namespace amdttg = mlir::triton::amdgpu;
2829
2930namespace mlir {
3031
@@ -106,19 +107,47 @@ struct AdvanceBasePointer : public OpRewritePattern<scf::ForOp> {
106107 // optimization offsetInitialized is a value of offset on first loop iteration
107108 // incrementOp is an operation that advances offset tensor
108109 struct LoadData {
109- triton::amdgpu::BufferLoadOp load;
110+ Operation * load;
110111 Value offsetIncrement;
111112 Value baseIncrement;
112113 Value offsetInitializer;
113114 Operation *incrementOp;
114115 };
115116
117+ static Value getOffset (Operation *load) {
118+ if (auto specific = dyn_cast<amdttg::BufferLoadOp>(load))
119+ return specific.getOffsets ();
120+ if (auto specific = dyn_cast<amdttg::BufferLoadToLocalOp>(load))
121+ return specific.getOffsets ();
122+ assert (false && " unsupported operation type" );
123+ }
124+
125+ static Value getBasePtr (Operation *load) {
126+ if (auto specific = dyn_cast<amdttg::BufferLoadOp>(load))
127+ return specific.getPtr ();
128+ if (auto specific = dyn_cast<amdttg::BufferLoadToLocalOp>(load))
129+ return specific.getPtr ();
130+ assert (false && " unsupported operation type" );
131+ }
132+
133+ static void setOffset (Operation *load, Value newOffset) {
134+ assert ((isa<amdttg::BufferLoadOp, amdttg::BufferLoadToLocalOp>(load)));
135+ const int offsetIdx = isa<amdttg::BufferLoadOp>(load) ? 1 : 2 ;
136+ load->setOperand (offsetIdx, newOffset);
137+ }
138+
139+ static void setBasePtr (Operation *load, Value newBasePtr) {
140+ assert ((isa<amdttg::BufferLoadOp, amdttg::BufferLoadToLocalOp>(load)));
141+ const int ptrIdx = isa<amdttg::BufferLoadOp>(load) ? 0 : 1 ;
142+ load->setOperand (ptrIdx, newBasePtr);
143+ }
144+
116145 // Perform series of checks to decide if given operation could be optimized.
117146 // If optimization is possible, return filled LoadData
118- static std::optional<LoadData>
119- analyzeLoad (triton::amdgpu::BufferLoadOp loadOp, scf::ForOp targetFor) {
120- LDBG (" Analyzing: " << loadOp);
121- Value maybeOffsetsBlockArg = loadOp. getOffsets ( );
147+ static std::optional<LoadData> analyzeLoad (Operation *loadOp,
148+ scf::ForOp targetFor) {
149+ LDBG (" Analyzing: " << * loadOp);
150+ Value maybeOffsetsBlockArg = getOffset (loadOp );
122151 auto maybeOffsetDefOp = maybeOffsetsBlockArg.getDefiningOp ();
123152 if (maybeOffsetDefOp && isa<arith::AddIOp>(maybeOffsetDefOp)) {
124153 for (auto &use : maybeOffsetDefOp->getUses ()) {
@@ -145,7 +174,7 @@ struct AdvanceBasePointer : public OpRewritePattern<scf::ForOp> {
145174 LDBG (" Rejected: expect load offset to be a target loop argument" );
146175 return {};
147176 }
148- auto basePtr = loadOp. getPtr ( );
177+ auto basePtr = getBasePtr (loadOp );
149178 auto defOpBlock = basePtr.getParentBlock ();
150179 if (!defOpBlock->getParentOp ()->isProperAncestor (targetFor)) {
151180 LDBG (" Rejected: expect load base Ptr to be invariant to the loop" );
@@ -195,7 +224,7 @@ struct AdvanceBasePointer : public OpRewritePattern<scf::ForOp> {
195224 }
196225
197226 static bool isAddFirst (LoadData &ld) {
198- return ld.load . getOffsets ( ).getDefiningOp () == ld.incrementOp ;
227+ return getOffset ( ld.load ).getDefiningOp () == ld.incrementOp ;
199228 }
200229
201230 static scf::ForOp
@@ -204,7 +233,7 @@ struct AdvanceBasePointer : public OpRewritePattern<scf::ForOp> {
204233 // Create new loop with additional arguments
205234 llvm::SmallVector<Value> newLoopArgs (forOp.getInitArgs ());
206235 for (auto loadData : loads) {
207- newLoopArgs.push_back (loadData.load . getPtr ( ));
236+ newLoopArgs.push_back (getBasePtr ( loadData.load ));
208237 }
209238 rewriter.setInsertionPoint (forOp);
210239 auto newForOp = rewriter.create <scf::ForOp>(
@@ -255,32 +284,35 @@ struct AdvanceBasePointer : public OpRewritePattern<scf::ForOp> {
255284 // Replace base ptr with incrementing value
256285 for (auto [loadData, basePtr, nextBasePtr] :
257286 llvm::zip (loads, basePtrs, nextIterBasePtrs)) {
258- auto newLoad = cast<triton::amdgpu::BufferLoadOp>(
259- mapping.lookup <Operation *>(loadData.load ));
260- constexpr int ptrIdx = 0 ;
261- constexpr int offsetIdx = 1 ;
262- newLoad.setOperand (offsetIdx, loadData.offsetInitializer );
287+ auto newLoad = mapping.lookup <Operation *>(loadData.load );
288+ setOffset (newLoad, loadData.offsetInitializer );
263289 // two cases:
264290 // 1. first advance pointer, then load
265291 // 2. load uses pointers from loop arguments, advanced pointer used on
266292 // next iteration
267293 Value advancingBasePtr = isAddFirst (loadData) ? nextBasePtr : basePtr;
268- newLoad. setOperand (ptrIdx , advancingBasePtr);
294+ setBasePtr (newLoad , advancingBasePtr);
269295 }
270296 return newForOp;
271297 }
272298
273- LogicalResult matchAndRewrite (scf::ForOp forOp,
274- PatternRewriter &rewriter) const override {
275- LDBG (" Analyzing ForOp for for offset pointer optimization: " << forOp);
276- // Gather buffer loads which could be optimized
277- SmallVector<LoadData> loads;
278- forOp.walk ([&loads, forOp](triton::amdgpu::BufferLoadOp loadOp) {
299+ template <typename OpType>
300+ static void collectLoads (SmallVector<LoadData> &loads, scf::ForOp forOp) {
301+ forOp.walk ([&loads, forOp](OpType loadOp) {
279302 auto loadData = analyzeLoad (loadOp, forOp);
280303 if (loadData.has_value ()) {
281304 loads.push_back (loadData.value ());
282305 }
283306 });
307+ }
308+
309+ LogicalResult matchAndRewrite (scf::ForOp forOp,
310+ PatternRewriter &rewriter) const override {
311+ LDBG (" Analyzing ForOp for for offset pointer optimization: " << forOp);
312+ // Gather buffer loads which could be optimized
313+ SmallVector<LoadData> loads;
314+ collectLoads<triton::amdgpu::BufferLoadOp>(loads, forOp);
315+ collectLoads<triton::amdgpu::BufferLoadToLocalOp>(loads, forOp);
284316
285317 if (loads.empty ())
286318 return rewriter.notifyMatchFailure (forOp, " no suitable buffer loads" );
0 commit comments