@@ -79,6 +79,138 @@ class TMEMToGlobal : public OpRewritePattern<triton::StoreOp> {
7979 }
8080};
8181
82+ static void addTMEMLoad (IRRewriter &rewriter, ttng::TMEMAllocOp localAlloc,
83+ Operation *user, int argNo) {
84+ rewriter.setInsertionPoint (user);
85+ auto load = rewriter.create <ttng::TMEMLoadOp>(
86+ user->getLoc (), user->getOperand (argNo).getType (),
87+ localAlloc->getResult (0 ));
88+ user->setOperand (argNo, load);
89+ }
90+
91+ static bool canKeepAccInTmem (scf::ForOp forOp, Operation *mmaOp,
92+ ttng::TMEMAllocOp &localAlloc,
93+ ttng::TMEMLoadOp &localLoad,
94+ SmallVector<std::pair<Operation *, int >> &accUsers,
95+ unsigned &yieldArgNo) {
96+ // The expected sequence of instructions:
97+ // %acc_tm = ttg.local_alloc %acc
98+ // ttng.tc_gen5_mma %A_sh, %B_sh, %acc_tm
99+ // %acc_res = ttg.local_load %acc_tm
100+ localAlloc = mmaOp->getOperand (2 ).getDefiningOp <ttng::TMEMAllocOp>();
101+ if (!localAlloc) {
102+ return false ;
103+ }
104+ for (auto user : localAlloc->getUsers ()) {
105+ if (isa<ttng::TMEMLoadOp>(user)) {
106+ localLoad = cast<ttng::TMEMLoadOp>(user);
107+ } else if (user != mmaOp) {
108+ // The accumulator is used by another operation, not something we
109+ // expect.
110+ localLoad = nullptr ;
111+ return false ;
112+ }
113+ }
114+
115+ SmallVector<Value> queue;
116+ queue.push_back (localLoad->getResult (0 ));
117+ bool foundDotCycle = false ;
118+ while (!queue.empty ()) {
119+ Value value = queue.pop_back_val ();
120+ for (auto &use : value.getUses ()) {
121+ if (use.getOwner () == localAlloc) {
122+ foundDotCycle = true ;
123+ continue ;
124+ }
125+ if (auto yieldOp = dyn_cast<scf::YieldOp>(use.getOwner ())) {
126+ if (yieldOp->getParentOp () == forOp) {
127+ yieldArgNo = use.getOperandNumber ();
128+ queue.push_back (forOp.getRegionIterArg (yieldArgNo));
129+ continue ;
130+ }
131+ if (auto ifOp = dyn_cast<scf::IfOp>(yieldOp->getParentOp ())) {
132+ // TODO: Accumulator being used in the yield of ifOp means that
133+ // it is being modified in the other branch of the ifOp. This is not
134+ // something we can handle yet.
135+ return false ;
136+ }
137+ // Not sure what are we doing here. Back out.
138+ return false ;
139+ }
140+ accUsers.emplace_back (use.getOwner (), use.getOperandNumber ());
141+ }
142+ }
143+ return foundDotCycle;
144+ }
145+
146+ static void hoistReadModifyWrite (Operation *mmaOp, scf::ForOp forOp) {
147+ // For the transformation to make sense, the accumulator must be
148+ // reused by the same MMA operation in subsequent iterations.
149+ SmallVector<std::pair<Operation *, int >> accUsers;
150+ ttng::TMEMAllocOp localAlloc = nullptr ;
151+ ttng::TMEMLoadOp localLoad = nullptr ;
152+ unsigned yieldArgNo;
153+ if (!canKeepAccInTmem (forOp, mmaOp, localAlloc, localLoad, accUsers,
154+ yieldArgNo)) {
155+ return ;
156+ }
157+
158+ assert (localLoad != nullptr );
159+ assert (localAlloc != nullptr );
160+ Type loadType = localLoad->getResult (0 ).getType ();
161+ IRRewriter rewriter (forOp);
162+ localAlloc->moveBefore (forOp);
163+ localAlloc->setOperand (0 , forOp.getInitArgs ()[yieldArgNo]);
164+ mmaOp->setOperand (2 , localAlloc->getResult (0 ));
165+ // Unlink the local_load from the yield. Short circuit the unused yield
166+ // value with the corresponding iter arg.
167+ forOp.getBody ()->getTerminator ()->setOperand (
168+ yieldArgNo, forOp.getRegionIterArg (yieldArgNo));
169+
170+ // Add TMEM loads before all the uses
171+ // TODO: We could be more efficient here, reusing loads instead of
172+ // creating new ones for each use.
173+ for (auto [user, argNo] : accUsers) {
174+ addTMEMLoad (rewriter, localAlloc, user, argNo);
175+ }
176+
177+ rewriter.setInsertionPointAfter (forOp);
178+ auto afterLoopLoad = rewriter.create <ttng::TMEMLoadOp>(
179+ forOp.getLoc (), loadType, localAlloc->getResult (0 ));
180+ forOp->getResult (yieldArgNo).replaceAllUsesWith (afterLoopLoad->getResult (0 ));
181+
182+ localLoad->erase ();
183+ }
184+
185+ // Hoist invariant tmem_alloc. This could technically be done as general LICM
186+ // but controlling tmem liveranga more precisley is likely to be important.
187+ static void hoistInvariantInputs (Operation *mmaOp, scf::ForOp forOp) {
188+ for (auto operand : mmaOp->getOperands ()) {
189+ if (forOp.isDefinedOutsideOfLoop (operand))
190+ continue ;
191+ auto tmemAllocOp = operand.getDefiningOp <ttng::TMEMAllocOp>();
192+ if (!tmemAllocOp || tmemAllocOp.getType ().getMutableMemory ())
193+ continue ;
194+ assert (tmemAllocOp.getSrc ());
195+ Value src = tmemAllocOp.getSrc ();
196+ SmallVector<Operation *> opToHoist = {tmemAllocOp.getOperation ()};
197+ // Also hoist simple unary elementwise that may have sinked into the loop.
198+ while (Operation *defOp = src.getDefiningOp ()) {
199+ if (forOp.isDefinedOutsideOfLoop (src))
200+ break ;
201+ if (!(isMemoryEffectFree (defOp) && isSpeculatable (defOp) &&
202+ defOp->getNumOperands () == 1 ))
203+ break ;
204+ opToHoist.push_back (defOp);
205+ src = defOp->getOperand (0 );
206+ }
207+ if (!forOp.isDefinedOutsideOfLoop (src))
208+ continue ;
209+ for (auto op : llvm::reverse (opToHoist)) {
210+ forOp.moveOutOfLoop (op);
211+ }
212+ }
213+ }
82214class TritonNvidiaGPUKeepAccInTMemPass
83215 : public TritonNvidiaGPUKeepAccInTMemPassBase<
84216 TritonNvidiaGPUKeepAccInTMemPass> {
@@ -99,70 +231,6 @@ class TritonNvidiaGPUKeepAccInTMemPass
99231 }
100232 }
101233
102- bool canKeepAccInTmem (scf::ForOp forOp, Operation *mmaOp,
103- ttng::TMEMAllocOp &localAlloc,
104- ttng::TMEMLoadOp &localLoad,
105- SmallVector<std::pair<Operation *, int >> &accUsers,
106- unsigned &yieldArgNo) {
107- // The expected sequence of instructions:
108- // %acc_tm = ttg.local_alloc %acc
109- // ttng.tc_gen5_mma %A_sh, %B_sh, %acc_tm
110- // %acc_res = ttg.local_load %acc_tm
111- localAlloc = mmaOp->getOperand (2 ).getDefiningOp <ttng::TMEMAllocOp>();
112- if (!localAlloc) {
113- return false ;
114- }
115- for (auto user : localAlloc->getUsers ()) {
116- if (isa<ttng::TMEMLoadOp>(user)) {
117- localLoad = cast<ttng::TMEMLoadOp>(user);
118- } else if (user != mmaOp) {
119- // The accumulator is used by another operation, not something we
120- // expect.
121- localLoad = nullptr ;
122- return false ;
123- }
124- }
125-
126- SmallVector<Value> queue;
127- queue.push_back (localLoad->getResult (0 ));
128- bool foundDotCycle = false ;
129- while (!queue.empty ()) {
130- Value value = queue.pop_back_val ();
131- for (auto &use : value.getUses ()) {
132- if (use.getOwner () == localAlloc) {
133- foundDotCycle = true ;
134- continue ;
135- }
136- if (auto yieldOp = dyn_cast<scf::YieldOp>(use.getOwner ())) {
137- if (yieldOp->getParentOp () == forOp) {
138- yieldArgNo = use.getOperandNumber ();
139- queue.push_back (forOp.getRegionIterArg (yieldArgNo));
140- continue ;
141- }
142- if (auto ifOp = dyn_cast<scf::IfOp>(yieldOp->getParentOp ())) {
143- // TODO: Accumulator being used in the yield of ifOp means that
144- // it is being modified in the other branch of the ifOp. This is not
145- // something we can handle yet.
146- return false ;
147- }
148- // Not sure what are we doing here. Back out.
149- return false ;
150- }
151- accUsers.emplace_back (use.getOwner (), use.getOperandNumber ());
152- }
153- }
154- return foundDotCycle;
155- }
156-
157- void addTMEMLoad (IRRewriter &rewriter, ttng::TMEMAllocOp localAlloc,
158- Operation *user, int argNo) {
159- rewriter.setInsertionPoint (user);
160- auto load = rewriter.create <ttng::TMEMLoadOp>(
161- user->getLoc (), user->getOperand (argNo).getType (),
162- localAlloc->getResult (0 ));
163- user->setOperand (argNo, load);
164- }
165-
166234 void runOnForOp (scf::ForOp forOp) {
167235 SmallVector<Operation *> mmaOps;
168236 forOp.walk ([&](Operation *mmaOp) {
@@ -177,43 +245,8 @@ class TritonNvidiaGPUKeepAccInTMemPass
177245 }
178246
179247 for (auto mmaOp : mmaOps) {
180- // For the transformation to make sense, the accumulator must be
181- // reused by the same MMA operation in subsequent iterations.
182- SmallVector<std::pair<Operation *, int >> accUsers;
183- ttng::TMEMAllocOp localAlloc = nullptr ;
184- ttng::TMEMLoadOp localLoad = nullptr ;
185- unsigned yieldArgNo;
186- if (!canKeepAccInTmem (forOp, mmaOp, localAlloc, localLoad, accUsers,
187- yieldArgNo)) {
188- continue ;
189- }
190-
191- assert (localLoad != nullptr );
192- assert (localAlloc != nullptr );
193- Type loadType = localLoad->getResult (0 ).getType ();
194- IRRewriter rewriter (forOp);
195- localAlloc->moveBefore (forOp);
196- localAlloc->setOperand (0 , forOp.getInitArgs ()[yieldArgNo]);
197- mmaOp->setOperand (2 , localAlloc->getResult (0 ));
198- // Unlink the local_load from the yield. Short circuit the unused yield
199- // value with the corresponding iter arg.
200- forOp.getBody ()->getTerminator ()->setOperand (
201- yieldArgNo, forOp.getRegionIterArg (yieldArgNo));
202-
203- // Add TMEM loads before all the uses
204- // TODO: We could be more efficient here, reusing loads instead of
205- // creating new ones for each use.
206- for (auto [user, argNo] : accUsers) {
207- addTMEMLoad (rewriter, localAlloc, user, argNo);
208- }
209-
210- rewriter.setInsertionPointAfter (forOp);
211- auto afterLoopLoad = rewriter.create <ttng::TMEMLoadOp>(
212- forOp.getLoc (), loadType, localAlloc->getResult (0 ));
213- forOp->getResult (yieldArgNo)
214- .replaceAllUsesWith (afterLoopLoad->getResult (0 ));
215-
216- localLoad->erase ();
248+ hoistReadModifyWrite (mmaOp, forOp);
249+ hoistInvariantInputs (mmaOp, forOp);
217250 }
218251 }
219252};
0 commit comments