Skip to content

Commit d827851

Browse files
authored
[Blackwell] Hoist constant TMem allocation out of the loop (triton-lang#5857)
For cases where TMem is constant hoisting the allocation of the loop avoid having to store multiple times.
1 parent 0cb0140 commit d827851

File tree

3 files changed

+175
-104
lines changed

3 files changed

+175
-104
lines changed

lib/Dialect/TritonNvidiaGPU/Transforms/KeepAccInTMem.cpp

Lines changed: 134 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,138 @@ class TMEMToGlobal : public OpRewritePattern<triton::StoreOp> {
7979
}
8080
};
8181

82+
static void addTMEMLoad(IRRewriter &rewriter, ttng::TMEMAllocOp localAlloc,
83+
Operation *user, int argNo) {
84+
rewriter.setInsertionPoint(user);
85+
auto load = rewriter.create<ttng::TMEMLoadOp>(
86+
user->getLoc(), user->getOperand(argNo).getType(),
87+
localAlloc->getResult(0));
88+
user->setOperand(argNo, load);
89+
}
90+
91+
static bool canKeepAccInTmem(scf::ForOp forOp, Operation *mmaOp,
92+
ttng::TMEMAllocOp &localAlloc,
93+
ttng::TMEMLoadOp &localLoad,
94+
SmallVector<std::pair<Operation *, int>> &accUsers,
95+
unsigned &yieldArgNo) {
96+
// The expected sequence of instructions:
97+
// %acc_tm = ttg.local_alloc %acc
98+
// ttng.tc_gen5_mma %A_sh, %B_sh, %acc_tm
99+
// %acc_res = ttg.local_load %acc_tm
100+
localAlloc = mmaOp->getOperand(2).getDefiningOp<ttng::TMEMAllocOp>();
101+
if (!localAlloc) {
102+
return false;
103+
}
104+
for (auto user : localAlloc->getUsers()) {
105+
if (isa<ttng::TMEMLoadOp>(user)) {
106+
localLoad = cast<ttng::TMEMLoadOp>(user);
107+
} else if (user != mmaOp) {
108+
// The accumulator is used by another operation, not something we
109+
// expect.
110+
localLoad = nullptr;
111+
return false;
112+
}
113+
}
114+
115+
SmallVector<Value> queue;
116+
queue.push_back(localLoad->getResult(0));
117+
bool foundDotCycle = false;
118+
while (!queue.empty()) {
119+
Value value = queue.pop_back_val();
120+
for (auto &use : value.getUses()) {
121+
if (use.getOwner() == localAlloc) {
122+
foundDotCycle = true;
123+
continue;
124+
}
125+
if (auto yieldOp = dyn_cast<scf::YieldOp>(use.getOwner())) {
126+
if (yieldOp->getParentOp() == forOp) {
127+
yieldArgNo = use.getOperandNumber();
128+
queue.push_back(forOp.getRegionIterArg(yieldArgNo));
129+
continue;
130+
}
131+
if (auto ifOp = dyn_cast<scf::IfOp>(yieldOp->getParentOp())) {
132+
// TODO: Accumulator being used in the yield of ifOp means that
133+
// it is being modified in the other branch of the ifOp. This is not
134+
// something we can handle yet.
135+
return false;
136+
}
137+
// Not sure what are we doing here. Back out.
138+
return false;
139+
}
140+
accUsers.emplace_back(use.getOwner(), use.getOperandNumber());
141+
}
142+
}
143+
return foundDotCycle;
144+
}
145+
146+
static void hoistReadModifyWrite(Operation *mmaOp, scf::ForOp forOp) {
147+
// For the transformation to make sense, the accumulator must be
148+
// reused by the same MMA operation in subsequent iterations.
149+
SmallVector<std::pair<Operation *, int>> accUsers;
150+
ttng::TMEMAllocOp localAlloc = nullptr;
151+
ttng::TMEMLoadOp localLoad = nullptr;
152+
unsigned yieldArgNo;
153+
if (!canKeepAccInTmem(forOp, mmaOp, localAlloc, localLoad, accUsers,
154+
yieldArgNo)) {
155+
return;
156+
}
157+
158+
assert(localLoad != nullptr);
159+
assert(localAlloc != nullptr);
160+
Type loadType = localLoad->getResult(0).getType();
161+
IRRewriter rewriter(forOp);
162+
localAlloc->moveBefore(forOp);
163+
localAlloc->setOperand(0, forOp.getInitArgs()[yieldArgNo]);
164+
mmaOp->setOperand(2, localAlloc->getResult(0));
165+
// Unlink the local_load from the yield. Short circuit the unused yield
166+
// value with the corresponding iter arg.
167+
forOp.getBody()->getTerminator()->setOperand(
168+
yieldArgNo, forOp.getRegionIterArg(yieldArgNo));
169+
170+
// Add TMEM loads before all the uses
171+
// TODO: We could be more efficient here, reusing loads instead of
172+
// creating new ones for each use.
173+
for (auto [user, argNo] : accUsers) {
174+
addTMEMLoad(rewriter, localAlloc, user, argNo);
175+
}
176+
177+
rewriter.setInsertionPointAfter(forOp);
178+
auto afterLoopLoad = rewriter.create<ttng::TMEMLoadOp>(
179+
forOp.getLoc(), loadType, localAlloc->getResult(0));
180+
forOp->getResult(yieldArgNo).replaceAllUsesWith(afterLoopLoad->getResult(0));
181+
182+
localLoad->erase();
183+
}
184+
185+
// Hoist invariant tmem_alloc. This could technically be done as general LICM
186+
// but controlling tmem liveranga more precisley is likely to be important.
187+
static void hoistInvariantInputs(Operation *mmaOp, scf::ForOp forOp) {
188+
for (auto operand : mmaOp->getOperands()) {
189+
if (forOp.isDefinedOutsideOfLoop(operand))
190+
continue;
191+
auto tmemAllocOp = operand.getDefiningOp<ttng::TMEMAllocOp>();
192+
if (!tmemAllocOp || tmemAllocOp.getType().getMutableMemory())
193+
continue;
194+
assert(tmemAllocOp.getSrc());
195+
Value src = tmemAllocOp.getSrc();
196+
SmallVector<Operation *> opToHoist = {tmemAllocOp.getOperation()};
197+
// Also hoist simple unary elementwise that may have sinked into the loop.
198+
while (Operation *defOp = src.getDefiningOp()) {
199+
if (forOp.isDefinedOutsideOfLoop(src))
200+
break;
201+
if (!(isMemoryEffectFree(defOp) && isSpeculatable(defOp) &&
202+
defOp->getNumOperands() == 1))
203+
break;
204+
opToHoist.push_back(defOp);
205+
src = defOp->getOperand(0);
206+
}
207+
if (!forOp.isDefinedOutsideOfLoop(src))
208+
continue;
209+
for (auto op : llvm::reverse(opToHoist)) {
210+
forOp.moveOutOfLoop(op);
211+
}
212+
}
213+
}
82214
class TritonNvidiaGPUKeepAccInTMemPass
83215
: public TritonNvidiaGPUKeepAccInTMemPassBase<
84216
TritonNvidiaGPUKeepAccInTMemPass> {
@@ -99,70 +231,6 @@ class TritonNvidiaGPUKeepAccInTMemPass
99231
}
100232
}
101233

102-
bool canKeepAccInTmem(scf::ForOp forOp, Operation *mmaOp,
103-
ttng::TMEMAllocOp &localAlloc,
104-
ttng::TMEMLoadOp &localLoad,
105-
SmallVector<std::pair<Operation *, int>> &accUsers,
106-
unsigned &yieldArgNo) {
107-
// The expected sequence of instructions:
108-
// %acc_tm = ttg.local_alloc %acc
109-
// ttng.tc_gen5_mma %A_sh, %B_sh, %acc_tm
110-
// %acc_res = ttg.local_load %acc_tm
111-
localAlloc = mmaOp->getOperand(2).getDefiningOp<ttng::TMEMAllocOp>();
112-
if (!localAlloc) {
113-
return false;
114-
}
115-
for (auto user : localAlloc->getUsers()) {
116-
if (isa<ttng::TMEMLoadOp>(user)) {
117-
localLoad = cast<ttng::TMEMLoadOp>(user);
118-
} else if (user != mmaOp) {
119-
// The accumulator is used by another operation, not something we
120-
// expect.
121-
localLoad = nullptr;
122-
return false;
123-
}
124-
}
125-
126-
SmallVector<Value> queue;
127-
queue.push_back(localLoad->getResult(0));
128-
bool foundDotCycle = false;
129-
while (!queue.empty()) {
130-
Value value = queue.pop_back_val();
131-
for (auto &use : value.getUses()) {
132-
if (use.getOwner() == localAlloc) {
133-
foundDotCycle = true;
134-
continue;
135-
}
136-
if (auto yieldOp = dyn_cast<scf::YieldOp>(use.getOwner())) {
137-
if (yieldOp->getParentOp() == forOp) {
138-
yieldArgNo = use.getOperandNumber();
139-
queue.push_back(forOp.getRegionIterArg(yieldArgNo));
140-
continue;
141-
}
142-
if (auto ifOp = dyn_cast<scf::IfOp>(yieldOp->getParentOp())) {
143-
// TODO: Accumulator being used in the yield of ifOp means that
144-
// it is being modified in the other branch of the ifOp. This is not
145-
// something we can handle yet.
146-
return false;
147-
}
148-
// Not sure what are we doing here. Back out.
149-
return false;
150-
}
151-
accUsers.emplace_back(use.getOwner(), use.getOperandNumber());
152-
}
153-
}
154-
return foundDotCycle;
155-
}
156-
157-
void addTMEMLoad(IRRewriter &rewriter, ttng::TMEMAllocOp localAlloc,
158-
Operation *user, int argNo) {
159-
rewriter.setInsertionPoint(user);
160-
auto load = rewriter.create<ttng::TMEMLoadOp>(
161-
user->getLoc(), user->getOperand(argNo).getType(),
162-
localAlloc->getResult(0));
163-
user->setOperand(argNo, load);
164-
}
165-
166234
void runOnForOp(scf::ForOp forOp) {
167235
SmallVector<Operation *> mmaOps;
168236
forOp.walk([&](Operation *mmaOp) {
@@ -177,43 +245,8 @@ class TritonNvidiaGPUKeepAccInTMemPass
177245
}
178246

179247
for (auto mmaOp : mmaOps) {
180-
// For the transformation to make sense, the accumulator must be
181-
// reused by the same MMA operation in subsequent iterations.
182-
SmallVector<std::pair<Operation *, int>> accUsers;
183-
ttng::TMEMAllocOp localAlloc = nullptr;
184-
ttng::TMEMLoadOp localLoad = nullptr;
185-
unsigned yieldArgNo;
186-
if (!canKeepAccInTmem(forOp, mmaOp, localAlloc, localLoad, accUsers,
187-
yieldArgNo)) {
188-
continue;
189-
}
190-
191-
assert(localLoad != nullptr);
192-
assert(localAlloc != nullptr);
193-
Type loadType = localLoad->getResult(0).getType();
194-
IRRewriter rewriter(forOp);
195-
localAlloc->moveBefore(forOp);
196-
localAlloc->setOperand(0, forOp.getInitArgs()[yieldArgNo]);
197-
mmaOp->setOperand(2, localAlloc->getResult(0));
198-
// Unlink the local_load from the yield. Short circuit the unused yield
199-
// value with the corresponding iter arg.
200-
forOp.getBody()->getTerminator()->setOperand(
201-
yieldArgNo, forOp.getRegionIterArg(yieldArgNo));
202-
203-
// Add TMEM loads before all the uses
204-
// TODO: We could be more efficient here, reusing loads instead of
205-
// creating new ones for each use.
206-
for (auto [user, argNo] : accUsers) {
207-
addTMEMLoad(rewriter, localAlloc, user, argNo);
208-
}
209-
210-
rewriter.setInsertionPointAfter(forOp);
211-
auto afterLoopLoad = rewriter.create<ttng::TMEMLoadOp>(
212-
forOp.getLoc(), loadType, localAlloc->getResult(0));
213-
forOp->getResult(yieldArgNo)
214-
.replaceAllUsesWith(afterLoopLoad->getResult(0));
215-
216-
localLoad->erase();
248+
hoistReadModifyWrite(mmaOp, forOp);
249+
hoistInvariantInputs(mmaOp, forOp);
217250
}
218251
}
219252
};

python/test/unit/language/test_matmul.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -761,6 +761,7 @@ def mxfp8_mxfp4_matmul( #
761761
BLOCK_N: tl.constexpr, #
762762
BLOCK_K: tl.constexpr, #
763763
NUM_STAGES: tl.constexpr): #
764+
tensor_scale: tl.constexpr = isinstance(a_scale.dtype, tl.pointer_type)
764765
pid = tl.program_id(axis=0)
765766
num_pid_m = tl.cdiv(M, BLOCK_M)
766767
pid_m = pid % num_pid_m
@@ -781,7 +782,10 @@ def mxfp8_mxfp4_matmul( #
781782
for k in tl.range(0, tl.cdiv(K, BLOCK_K), num_stages=NUM_STAGES):
782783
a = tl.load(a_ptrs)
783784
b = tl.load(b_ptrs)
784-
scale_a = tl.load(a_scale_ptr)
785+
if tensor_scale:
786+
scale_a = tl.load(a_scale_ptr)
787+
else:
788+
scale_a = tl.full(a_scale_ptr.shape, a_scale.to(tl.int8), dtype=tl.int8)
785789
scale_b = tl.load(b_scale_ptr)
786790
accumulator = tl.dot_scaled(a, scale_a, "e5m2", b, scale_b, "e2m1", accumulator)
787791
a_ptrs += BLOCK_K * stride_ak
@@ -801,8 +805,9 @@ def mxfp8_mxfp4_matmul( #
801805
(128, 256, 256), (128, 128, 64), (128, 64, 128)])
802806
@pytest.mark.parametrize("NUM_STAGES", [1, 3])
803807
@pytest.mark.parametrize("B_TRANS", [True, False])
808+
@pytest.mark.parametrize("CONST_SCALE", [True, False])
804809
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 10, reason="Requires compute capability >= 10")
805-
def test_mxfp8_mxfp4_matmul(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, NUM_STAGES, B_TRANS, device):
810+
def test_mxfp8_mxfp4_matmul(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, NUM_STAGES, B_TRANS, CONST_SCALE, device):
806811
if BLOCK_N == 256 and BLOCK_K == 256:
807812
NUM_STAGES = 2
808813

@@ -826,12 +831,15 @@ def test_mxfp8_mxfp4_matmul(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, NUM_STAGES, B_TR
826831
b_scale = b_scale_mxfp4.data
827832

828833
a_scale_ref = a_scale_mxfp4.to(torch.float32).repeat_interleave(32, dim=1)[:M, :K]
834+
if CONST_SCALE:
835+
a_scale_ref = torch.full_like(a_scale_ref, 2.0)
836+
a_scale = 128 # 2.0 in e8m0
829837
b_scale_ref = b_scale_mxfp4.to(torch.float32).repeat_interleave(32, dim=1).T.contiguous()[:K, :N]
830838
ref_out = torch.matmul(a_ref * a_scale_ref, b_ref * b_scale_ref)
831839

832840
output = a.new_empty((M, N), dtype=torch.float32)
833841
grid = (triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N), 1)
834-
out = mxfp8_mxfp4_matmul[grid](a, b, output, a_scale, b_scale, M, N, K, a_scale.stride(0), a.stride(0), a.stride(1),
842+
out = mxfp8_mxfp4_matmul[grid](a, b, output, a_scale, b_scale, M, N, K, b_scale.stride(0), a.stride(0), a.stride(1),
835843
b.stride(0), b.stride(1), output.stride(0), output.stride(1), BLOCK_M, BLOCK_N,
836844
BLOCK_K, NUM_STAGES=NUM_STAGES)
837845
ttgir = out.asm["ttgir"]

test/TritonGPU/blackwell_acc_tmem.mlir

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,3 +111,33 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ
111111
tt.return %res_f16 : tensor<128x128xf16, #blocked>
112112
}
113113
}
114+
115+
// -----
116+
#blocked = #ttg.blocked<{sizePerThread = [1, 16], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
117+
#blocked1 = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [1, 0]}>
118+
#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 8}>
119+
#shared1 = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = true, elementBitWidth = 8, fp4Padded = true}>
120+
#smem = #ttg.shared_memory
121+
#tmem = #ttng.tensor_memory_encoding<blockM = 128, blockN = 128, unpacked = true>
122+
#tmem_scales = #ttng.tensor_memory_scales_encoding<>
123+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:100", "ttg.threads-per-warp" = 32 : i32} {
124+
// CHECK-LABEL: @hoist_constant_inputs
125+
tt.func public @hoist_constant_inputs(%arg0: !ttg.memdesc<128x128xf8E5M2, #shared, #smem>, %arg1: !ttg.memdesc<64x128xi8, #shared1, #smem>, %arg2: !ttg.memdesc<128x4xi8, #tmem_scales, #ttng.tensor_memory>, %arg3: i32, %arg4: !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>) {
126+
%true = arith.constant true
127+
%cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked>
128+
%c0_i32 = arith.constant 0 : i32
129+
%c1_i32 = arith.constant 1 : i32
130+
// CHECK: arith.trunci
131+
// CHECK: tt.splat
132+
// CHECK: ttng.tmem_alloc
133+
// CHECK: scf.for
134+
// CHECK: ttng.tc_gen5_mma_scaled
135+
scf.for %arg5 = %c0_i32 to %arg3 step %c1_i32 : i32 {
136+
%0 = arith.trunci %arg3 : i32 to i8
137+
%1 = tt.splat %0 : i8 -> tensor<128x4xi8, #blocked1>
138+
%2 = ttng.tmem_alloc %1 : (tensor<128x4xi8, #blocked1>) -> !ttg.memdesc<128x4xi8, #tmem_scales, #ttng.tensor_memory>
139+
ttng.tc_gen5_mma_scaled %arg0, %arg1, %arg4, %arg2, %2, %true, %true lhs = e5m2 rhs = e2m1 : (!ttg.memdesc<128x128xf8E5M2, #shared, #smem>, !ttg.memdesc<64x128xi8, #shared1, #smem>, !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>, !ttg.memdesc<128x4xi8, #tmem_scales, #ttng.tensor_memory>, !ttg.memdesc<128x4xi8, #tmem_scales, #ttng.tensor_memory>, i1, i1) -> ()
140+
}
141+
tt.return
142+
}
143+
}

0 commit comments

Comments
 (0)