Skip to content

Commit aac7ae7

Browse files
mbrookhartzwu-2025
authored andcommitted
[TMEM] Remove Unneeded Stores (triton-lang#6892)
Noticed that OptimizeAccumulatorInit and HoistTMEMAlloc were both doing some init rewriting and alloc movement, but that HoistTMEMAlloc was initializing tmem values that OptimizeAccumulatorInit had invalidated via use of the useD flag. This PR adds a pattern and a test to remove those as part of HoiseTMEMAlloc.
1 parent 23b0072 commit aac7ae7

File tree

6 files changed

+127
-19
lines changed

6 files changed

+127
-19
lines changed

include/triton/Dialect/TritonGPU/Transforms/Utility.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,9 @@ SetVector<Value> getNestedOperands(Operation *op);
247247
// Erase the given loop carried values from the loop, where `loop` is replaced
248248
// with a new loop.
249249
void eraseLoopCarriedValues(scf::ForOp &loop, llvm::BitVector indices);
250+
251+
// Get a boolean if the Value is an arith::ConstantOp
252+
std::optional<bool> getBoolFromConstant(Value cst);
250253
} // namespace mlir
251254

252255
namespace mlir::triton {

lib/Dialect/TritonGPU/Transforms/HoistTMEMAlloc.cpp

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
#include "mlir/Dialect/Arith/IR/Arith.h"
2+
#include "mlir/Dialect/SCF/IR/SCF.h"
13
#include "mlir/IR/Dominance.h"
24
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
35
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
@@ -92,6 +94,53 @@ class RemoveUnusedTMEMLoad : public OpRewritePattern<TMEMTokenLoadOp> {
9294
}
9395
};
9496

97+
class RemoveUnusedTMEMStore : public OpRewritePattern<TMEMTokenStoreOp> {
98+
public:
99+
using OpRewritePattern::OpRewritePattern;
100+
101+
LogicalResult matchAndRewrite(TMEMTokenStoreOp store,
102+
PatternRewriter &rewriter) const override {
103+
auto pred = getBoolFromConstant(store.getPred());
104+
if (!pred || pred.value() == false)
105+
return failure(); // we've already processed this
106+
auto tok = store.getToken();
107+
if (!tok.hasOneUse())
108+
return failure();
109+
auto loop = dyn_cast<scf::ForOp>(*tok.getUsers().begin());
110+
if (!loop)
111+
return failure();
112+
auto loopTok = loop.getBody()->getArgument(
113+
tok.getUses().begin()->getOperandNumber() - 2);
114+
if (!loopTok.hasOneUse())
115+
return failure();
116+
auto mma =
117+
dyn_cast<nvidia_gpu::MMAv5OpInterface>(*loopTok.getUsers().begin());
118+
if (!mma)
119+
return failure();
120+
auto useD = dyn_cast<BlockArgument>(mma.useAccumulator());
121+
if (!useD)
122+
return failure();
123+
auto parent = useD.getParentBlock()->getParentOp();
124+
if (parent != loop)
125+
return failure();
126+
auto loopInit = loop.getInitArgs()[useD.getArgNumber() - 1];
127+
auto val = getBoolFromConstant(loopInit);
128+
if (!val)
129+
return failure();
130+
if (val.value() == true)
131+
return failure();
132+
auto loc = store.getLoc();
133+
rewriter.setInsertionPoint(store);
134+
Value diff = rewriter.create<arith::SubIOp>(loc, loop.getUpperBound(),
135+
loop.getLowerBound());
136+
Value zero = rewriter.create<arith::ConstantIntOp>(loc, 0, diff.getType());
137+
Value cond = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sle,
138+
diff, zero);
139+
store.getPredMutable().assign(cond);
140+
return success();
141+
}
142+
};
143+
95144
// Load-store forwarding pattern.
96145
class CombineTMEMLoadAndStore : public OpRewritePattern<TMEMTokenStoreOp> {
97146
public:
@@ -411,7 +460,8 @@ struct HoistTMEMAlloc
411460
mlir::RewritePatternSet patterns(&getContext());
412461
patterns.add<RotateTMEMStoreInLoop, RotateTMEMLoadInLoop,
413462
CombineTMEMLoadAndStore, CombineTMEMStoreAndSelect,
414-
SinkTMEMLoad, RemoveUnusedTMEMLoad>(&getContext());
463+
SinkTMEMLoad, RemoveUnusedTMEMLoad, RemoveUnusedTMEMStore>(
464+
&getContext());
415465
scf::ForOp::getCanonicalizationPatterns(patterns, &getContext());
416466
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
417467
llvm_unreachable("Failed to hoist tmem_store");

lib/Dialect/TritonGPU/Transforms/OptimizeAccumulatorInit.cpp

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -171,18 +171,6 @@ findZeroInitOp(Value accUse, scf::ForOp forOp, bool &loopArgIsZero) {
171171
return std::nullopt;
172172
}
173173

174-
std::optional<bool> getBoolFromConstant(Value cst) {
175-
auto constantOp = cst.getDefiningOp<arith::ConstantOp>();
176-
if (!constantOp) {
177-
return std::nullopt;
178-
}
179-
assert(constantOp.getValue());
180-
if (auto boolAttr = dyn_cast<BoolAttr>(constantOp.getValue())) {
181-
return boolAttr.getValue();
182-
}
183-
return std::nullopt;
184-
}
185-
186174
} // namespace
187175

188176
class OptimizeAccumulatorInitPass

lib/Dialect/TritonGPU/Transforms/Utility.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1400,6 +1400,18 @@ void eraseLoopCarriedValues(scf::ForOp &loop, llvm::BitVector indices) {
14001400
loop = newLoop;
14011401
}
14021402

1403+
std::optional<bool> getBoolFromConstant(Value cst) {
1404+
auto constantOp = cst.getDefiningOp<arith::ConstantOp>();
1405+
if (!constantOp) {
1406+
return std::nullopt;
1407+
}
1408+
assert(constantOp.getValue());
1409+
if (auto boolAttr = dyn_cast<BoolAttr>(constantOp.getValue())) {
1410+
return boolAttr.getValue();
1411+
}
1412+
return std::nullopt;
1413+
}
1414+
14031415
} // namespace mlir
14041416

14051417
namespace mlir::triton {

test/TritonGPU/hoist-tmem-alloc.mlir

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -307,3 +307,58 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ
307307
tt.return %res_f16 : tensor<128x128xf16, #blocked>
308308
}
309309
}
310+
311+
// -----
312+
#blocked = #ttg.blocked<{sizePerThread = [1, 128], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}>
313+
#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}>
314+
#blocked2 = #ttg.blocked<{sizePerThread = [1, 2, 64], threadsPerWarp = [32, 1, 1], warpsPerCTA = [4, 1, 1], order = [0, 2, 1]}>
315+
#blocked3 = #ttg.blocked<{sizePerThread = [1, 64, 2], threadsPerWarp = [32, 1, 1], warpsPerCTA = [4, 1, 1], order = [0, 1, 2]}>
316+
#blocked4 = #ttg.blocked<{sizePerThread = [1, 64], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}>
317+
#blocked5 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [2, 2], order = [1, 0]}>
318+
#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 32, transposed = false, elementBitWidth = 8}>
319+
#shared1 = #ttg.nvmma_shared<{swizzlingByteWidth = 64, transposed = false, elementBitWidth = 8}>
320+
#shared2 = #ttg.nvmma_shared<{swizzlingByteWidth = 32, transposed = true, elementBitWidth = 8}>
321+
#smem = #ttg.shared_memory
322+
#tmem = #ttng.tensor_memory_encoding<blockM = 128, blockN = 128, unpacked = true>
323+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-stages" = 4 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:100", "ttg.threads-per-warp" = 32 : i32, "ttg.warp-specialized" = true} {
324+
// CHECK-LABEL: @matmul_kernel_tma_persistent_nested
325+
tt.func public @matmul_kernel_tma_persistent_nested(%arg0: !tt.tensordesc<tensor<128x32xf8E4M3FN, #shared>>, %arg1: i32, %arg2: i32, %arg3: i64, %arg4: i64, %arg5: !tt.tensordesc<tensor<128x32xf8E4M3FN, #shared>>, %arg6: i32, %arg7: i32, %arg8: i64, %arg9: i64, %arg10: !tt.tensordesc<tensor<128x64xf8E4M3FN, #shared1>>, %arg11: i32, %arg12: i32, %arg13: i64, %arg14: i64, %arg15: i32 {tt.divisibility = 16 : i32}, %arg16: i32 {tt.divisibility = 16 : i32}, %arg17: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} {
326+
%false = arith.constant false
327+
%true = arith.constant true
328+
// CHECK: %[[ZERO:.*]] = arith.constant 0 : i32
329+
%c0_i32 = arith.constant 0 : i32
330+
%c1_i32 = arith.constant 1 : i32
331+
%c4_i32 = arith.constant 4 : i32
332+
%c32_i32 = arith.constant 32 : i32
333+
%cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked>
334+
scf.for %arg18 = %c0_i32 to %c4_i32 step %c1_i32 : i32 {
335+
// CHECK: %[[ACC:.*]], %[[TOK:.*]] = ttng.tmem_alloc
336+
// CHECK: %[[DIFF:.*]] = arith.subi %[[LIMIT:.*]], %[[START:.*]] : i32
337+
// CHECK: %[[COND:.*]] = arith.cmpi sle, %[[DIFF]], %[[ZERO]] : i32
338+
// CHECK-NEXT: %[[NTOK:.*]] = ttng.tmem_store %[[CST:.*]], %[[ACC]][%[[TOK]]], %[[COND]]
339+
// CHECK-NEXT: scf.for %[[ITER:.*]] = %[[START]] to %[[LIMIT]] step
340+
%20:3 = scf.for %arg19 = %arg11 to %arg12 step %c1_i32 iter_args(%arg20 = %cst, %arg21 = %c0_i32, %arg22 = %false) -> (tensor<128x128xf32, #blocked>, i32, i1) : i32 {
341+
%28 = tt.descriptor_load %arg0[%arg19, %arg21] : !tt.tensordesc<tensor<128x32xf8E4M3FN, #shared>> -> tensor<128x32xf8E4M3FN, #blocked1>
342+
%29 = ttg.local_alloc %28 : (tensor<128x32xf8E4M3FN, #blocked1>) -> !ttg.memdesc<128x32xf8E4M3FN, #shared, #smem>
343+
%30 = tt.descriptor_load %arg5[%arg19, %arg21] : !tt.tensordesc<tensor<128x32xf8E4M3FN, #shared>> -> tensor<128x32xf8E4M3FN, #blocked1>
344+
%31 = ttg.local_alloc %30 : (tensor<128x32xf8E4M3FN, #blocked1>) -> !ttg.memdesc<128x32xf8E4M3FN, #shared, #smem>
345+
%32 = ttg.memdesc_trans %31 {order = array<i32: 1, 0>} : !ttg.memdesc<128x32xf8E4M3FN, #shared, #smem> -> !ttg.memdesc<32x128xf8E4M3FN, #shared2, #smem>
346+
%acc, %acc_tok = ttng.tmem_alloc %arg20 : (tensor<128x128xf32, #blocked>) -> (!ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>, !ttg.async.token)
347+
%mma_tok = ttng.tc_gen5_mma %29, %32, %acc[%acc_tok], %arg22, %true : !ttg.memdesc<128x32xf8E4M3FN, #shared, #smem>, !ttg.memdesc<32x128xf8E4M3FN, #shared2, #smem>, !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>
348+
%34, %load_tok = ttng.tmem_load %acc[%mma_tok] : !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> -> tensor<128x128xf32, #blocked>
349+
%35 = arith.addi %arg21, %c32_i32 : i32
350+
scf.yield %34, %35, %true : tensor<128x128xf32, #blocked>, i32, i1
351+
}
352+
%21 = tt.reshape %20#0 : tensor<128x128xf32, #blocked> -> tensor<128x2x64xf32, #blocked2>
353+
%22 = tt.trans %21 {order = array<i32: 0, 2, 1>} : tensor<128x2x64xf32, #blocked2> -> tensor<128x64x2xf32, #blocked3>
354+
%outLHS, %outRHS = tt.split %22 : tensor<128x64x2xf32, #blocked3> -> tensor<128x64xf32, #blocked4>
355+
%23 = tt.fp_to_fp %outLHS, rounding = rtne : tensor<128x64xf32, #blocked4> -> tensor<128x64xf8E4M3FN, #blocked4>
356+
%24 = ttg.convert_layout %23 : tensor<128x64xf8E4M3FN, #blocked4> -> tensor<128x64xf8E4M3FN, #blocked5>
357+
tt.descriptor_store %arg10[%c0_i32, %c0_i32], %24 : !tt.tensordesc<tensor<128x64xf8E4M3FN, #shared1>>, tensor<128x64xf8E4M3FN, #blocked5>
358+
%25 = tt.fp_to_fp %outRHS, rounding = rtne : tensor<128x64xf32, #blocked4> -> tensor<128x64xf8E4M3FN, #blocked4>
359+
%26 = ttg.convert_layout %25 : tensor<128x64xf8E4M3FN, #blocked4> -> tensor<128x64xf8E4M3FN, #blocked5>
360+
tt.descriptor_store %arg10[%c0_i32, %c0_i32], %26 : !tt.tensordesc<tensor<128x64xf8E4M3FN, #shared1>>, tensor<128x64xf8E4M3FN, #blocked5>
361+
}
362+
tt.return
363+
}
364+
}

test/TritonGPU/loop-pipeline-blackwell.mlir

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ
110110
#smem = #ttg.shared_memory
111111
#tmem = #ttng.tensor_memory_encoding<blockM = 128, blockN = 128, unpacked = true>
112112
module attributes {"ttg.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32, ttg.target = "cuda:100", "ttg.threads-per-warp" = 32 : i32} {
113-
tt.func @matmul_loop_cast_load(%lb : index, %ub : index, %step : index,
113+
tt.func @matmul_loop_cast_load(%lb : i32, %ub : i32, %step : i32,
114114
%A : !tt.ptr<f8E4M3FN> {tt.divisibility = 16 : i32},
115115
%B : !tt.ptr<f8E4M3FN> {tt.divisibility = 16 : i32}) -> tensor<128x128xf32, #C> {
116116
// CHECK-LABEL: tt.func @matmul_loop_cast_load
@@ -137,7 +137,7 @@ module attributes {"ttg.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32, ttg.targ
137137
%a_off = arith.constant dense<4> : tensor<128x32xi32, #AL>
138138
%b_off = arith.constant dense<4> : tensor<32x128xi32, #BL>
139139

140-
%loop:3 = scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr<f8E4M3FN>, #AL>, tensor<32x128x!tt.ptr<f8E4M3FN>, #BL>, tensor<128x128xf32, #C>) {
140+
%loop:3 = scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr<f8E4M3FN>, #AL>, tensor<32x128x!tt.ptr<f8E4M3FN>, #BL>, tensor<128x128xf32, #C>) : i32 {
141141
%a___ = tt.load %a_ptr : tensor<128x32x!tt.ptr<f8E4M3FN>, #AL>
142142
%a__ = tt.fp_to_fp %a___ : tensor<128x32xf8E4M3FN, #AL> -> tensor<128x32xf16, #AL>
143143
%a_ = ttg.convert_layout %a__ : tensor<128x32xf16, #AL> -> tensor<128x32xf16, #A>
@@ -250,7 +250,7 @@ tt.func private @pipelined_gather(
250250
#smem = #ttg.shared_memory
251251

252252
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:100", "ttg.threads-per-warp" = 32 : i32} {
253-
tt.func public @block_scale_mxfp_matmul(%lb : index, %ub : index, %step : index, %arg0: !tt.ptr<f8E5M2> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f8E5M2> {tt.divisibility = 16 : i32}, %arg3: !tt.ptr<i8> {tt.divisibility = 16 : i32}, %arg4: !tt.ptr<i8> {tt.divisibility = 16 : i32}) -> tensor<128x128xf32, #blocked4> {
253+
tt.func public @block_scale_mxfp_matmul(%lb : i32, %ub : i32, %step : i32, %arg0: !tt.ptr<f8E5M2> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f8E5M2> {tt.divisibility = 16 : i32}, %arg3: !tt.ptr<i8> {tt.divisibility = 16 : i32}, %arg4: !tt.ptr<i8> {tt.divisibility = 16 : i32}) -> tensor<128x128xf32, #blocked4> {
254254
// CHECK: ttg.local_alloc : () -> !ttg.memdesc<2x128x256xf8E5M2
255255
// CHECK: ttg.local_alloc : () -> !ttg.memdesc<2x256x128xf8E5M2
256256
// Do not multibuffer the scale loads, as we cannot pipeline the mma due to tmem.cp not being used
@@ -288,7 +288,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ
288288
%arg3_init = tt.addptr %arg3_splat, %57 : tensor<1x2x32x4x4x!tt.ptr<i8>, #blocked2>, tensor<1x2x32x4x4xi32, #blocked2>
289289
%arg4_init = tt.addptr %arg4_splat, %57 : tensor<1x2x32x4x4x!tt.ptr<i8>, #blocked2>, tensor<1x2x32x4x4xi32, #blocked2>
290290

291-
%99:5 = scf.for %iv = %lb to %ub step %step iter_args(%arg15 = %cst_1, %arg16 = %arg0_init, %arg17 = %arg1_init, %arg18 = %arg3_init, %arg19 = %arg4_init) -> (tensor<128x128xf32, #blocked4>, tensor<128x256x!tt.ptr<f8E5M2>, #blocked>, tensor<256x128x!tt.ptr<f8E5M2>, #blocked1>, tensor<1x2x32x4x4x!tt.ptr<i8>, #blocked2>, tensor<1x2x32x4x4x!tt.ptr<i8>, #blocked2>) {
291+
%99:5 = scf.for %iv = %lb to %ub step %step iter_args(%arg15 = %cst_1, %arg16 = %arg0_init, %arg17 = %arg1_init, %arg18 = %arg3_init, %arg19 = %arg4_init) -> (tensor<128x128xf32, #blocked4>, tensor<128x256x!tt.ptr<f8E5M2>, #blocked>, tensor<256x128x!tt.ptr<f8E5M2>, #blocked1>, tensor<1x2x32x4x4x!tt.ptr<i8>, #blocked2>, tensor<1x2x32x4x4x!tt.ptr<i8>, #blocked2>) : i32 {
292292
%117 = tt.load %arg16 : tensor<128x256x!tt.ptr<f8E5M2>, #blocked>
293293
%118 = ttg.local_alloc %117 : (tensor<128x256xf8E5M2, #blocked>) -> !ttg.memdesc<128x256xf8E5M2, #shared, #ttg.shared_memory>
294294
%119 = tt.load %arg17 : tensor<256x128x!tt.ptr<f8E5M2>, #blocked1>
@@ -338,7 +338,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ
338338
#smem = #ttg.shared_memory
339339

340340
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:100", "ttg.threads-per-warp" = 32 : i32} {
341-
tt.func public @block_scale_mxfp_matmul_tmem_copy(%lb : index, %ub : index, %step : index, %arg0: !tt.ptr<f8E5M2> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f8E5M2> {tt.divisibility = 16 : i32}, %arg3: !tt.ptr<i8> {tt.divisibility = 16 : i32}, %arg4: !tt.ptr<i8> {tt.divisibility = 16 : i32}) -> tensor<128x128xf32, #blocked4> {
341+
tt.func public @block_scale_mxfp_matmul_tmem_copy(%lb : i32, %ub : i32, %step : i32, %arg0: !tt.ptr<f8E5M2> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f8E5M2> {tt.divisibility = 16 : i32}, %arg3: !tt.ptr<i8> {tt.divisibility = 16 : i32}, %arg4: !tt.ptr<i8> {tt.divisibility = 16 : i32}) -> tensor<128x128xf32, #blocked4> {
342342
// CHECK: ttg.local_alloc : () -> !ttg.memdesc<3x128x256xf8E5M2
343343
// CHECK: ttg.local_alloc : () -> !ttg.memdesc<3x256x128xf8E5M2
344344
// CHECK: ttg.local_alloc : () -> !ttg.memdesc<3x1x2x32x4x4xi8
@@ -375,7 +375,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ
375375
%arg3_init = tt.addptr %arg3_splat, %57 : tensor<1x2x32x4x4x!tt.ptr<i8>, #blocked2>, tensor<1x2x32x4x4xi32, #blocked2>
376376
%arg4_init = tt.addptr %arg4_splat, %57 : tensor<1x2x32x4x4x!tt.ptr<i8>, #blocked2>, tensor<1x2x32x4x4xi32, #blocked2>
377377

378-
%99:6 = scf.for %iv = %lb to %ub step %step iter_args(%arg15 = %cst_1, %arg16 = %arg0_init, %arg17 = %arg1_init, %arg18 = %arg3_init, %arg19 = %arg4_init, %init_flag=%false) -> (tensor<128x128xf32, #blocked4>, tensor<128x256x!tt.ptr<f8E5M2>, #blocked>, tensor<256x128x!tt.ptr<f8E5M2>, #blocked1>, tensor<1x2x32x4x4x!tt.ptr<i8>, #blocked2>, tensor<1x2x32x4x4x!tt.ptr<i8>, #blocked2>, i1) {
378+
%99:6 = scf.for %iv = %lb to %ub step %step iter_args(%arg15 = %cst_1, %arg16 = %arg0_init, %arg17 = %arg1_init, %arg18 = %arg3_init, %arg19 = %arg4_init, %init_flag=%false) -> (tensor<128x128xf32, #blocked4>, tensor<128x256x!tt.ptr<f8E5M2>, #blocked>, tensor<256x128x!tt.ptr<f8E5M2>, #blocked1>, tensor<1x2x32x4x4x!tt.ptr<i8>, #blocked2>, tensor<1x2x32x4x4x!tt.ptr<i8>, #blocked2>, i1) : i32 {
379379
%117 = tt.load %arg16 : tensor<128x256x!tt.ptr<f8E5M2>, #blocked>
380380
%118 = ttg.local_alloc %117 : (tensor<128x256xf8E5M2, #blocked>) -> !ttg.memdesc<128x256xf8E5M2, #shared, #ttg.shared_memory>
381381
%119 = tt.load %arg17 : tensor<256x128x!tt.ptr<f8E5M2>, #blocked1>

0 commit comments

Comments
 (0)