Skip to content

Commit 5a14866

Browse files
pawelszczerbukzwu-2025
authored andcommitted
[PIPELINE] Improve pipelining for mmav5 with tmem operands (triton-lang#6908)
For MMAv5 with operands coming from `tmem_alloc` we can potentially push the `wait_barrier` to the next iteration, up to the point the tmem buffer is allocated.
1 parent 618277b commit 5a14866

File tree

2 files changed

+41
-0
lines changed

2 files changed

+41
-0
lines changed

lib/Dialect/TritonGPU/Transforms/Pipeliner/MMAv5PipelineUtility.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,10 @@ bool ttng::MMAv5PipelineableOperandsHelper::comesFromLoadOrOutsideLoop(
2525
while (isa<ttg::MemDescTransOp, ttg::MemDescReshapeOp>(v.getDefiningOp())) {
2626
v = v.getDefiningOp()->getOperand(0);
2727
}
28+
if (auto tmemAlloc = dyn_cast<ttng::TMEMAllocOp>(v.getDefiningOp())) {
29+
foundLoad = tmemAlloc;
30+
return false;
31+
}
2832
auto localAlloc = dyn_cast<ttg::LocalAllocOp>(v.getDefiningOp());
2933
if (!localAlloc) {
3034
return false;

test/TritonGPU/pipeline-lower-loop.mlir

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1577,3 +1577,40 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ
15771577
tt.return
15781578
}
15791579
}
1580+
1581+
// -----
1582+
1583+
#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
1584+
#blocked1 = #ttg.blocked<{sizePerThread = [1, 64], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}>
1585+
#blocked2 = #ttg.blocked<{sizePerThread = [1, 128], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}>
1586+
#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}>
1587+
#smem = #ttg.shared_memory
1588+
#tmem = #ttng.tensor_memory_encoding<blockM = 128, blockN = 128, unpacked = true>
1589+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:100", "ttg.threads-per-warp" = 32 : i32} {
1590+
// Check that wait is pushed to the next stage, right before the tmem_alloc, and after the prologue.
1591+
// CHECK-LABEL: @wait_before_tmem_alloc
1592+
// CHECK: scf.for
1593+
// CHECK: "prologue"() {loop.cluster = 0 : i32, loop.stage = 2 : i32}
1594+
// CHECK: ttng.tmem_alloc {{.*}} {loop.cluster = 2 : i32, loop.stage = 2 : i32}
1595+
// CHECK: ttg.async_copy_global_to_local {{.*}} {loop.cluster = 2 : i32, loop.stage = 2 : i32}
1596+
// CHECK: ttng.tc_gen5_mma {{.*}} {loop.cluster = 2 : i32, loop.stage = 2 : i32}
1597+
// CHECK: ttng.wait_barrier {{.*}} {loop.cluster = 1 : i32, loop.stage = 3 : i32}
1598+
tt.func public @wait_before_tmem_alloc(%A: tensor<128x128x!tt.ptr<f16>, #blocked> {tt.contiguity = 16 : i32, tt.divisibility = 16 : i32},
1599+
%B: tensor<128x128xf16, #blocked> {tt.contiguity = 16 : i32, tt.divisibility = 16 : i32},
1600+
%arg1: i32, %arg2: i32, %arg3: i32) attributes {noinline = false} {
1601+
%cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked1>
1602+
%true = arith.constant true
1603+
%0 = ttng.tmem_alloc : () -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>
1604+
ttng.tmem_store %cst, %0, %true : tensor<128x128xf32, #blocked1> -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>
1605+
scf.for %arg4 = %arg1 to %arg2 step %arg3 : i32 {
1606+
%2 = "prologue"() {loop.cluster = 0 : i32, loop.stage = 2 : i32} : () -> tensor<128x128xf16, #blocked2>
1607+
%8 = ttng.tmem_alloc %B {loop.cluster = 0 : i32, loop.stage = 2 : i32} : (tensor<128x128xf16, #blocked>) -> !ttg.memdesc<128x128xf16, #tmem, #ttng.tensor_memory>
1608+
%5 = tt.load %A {loop.cluster = 0 : i32, loop.stage = 2 : i32} : tensor<128x128x!tt.ptr<f16>, #blocked>
1609+
%6 = ttg.local_alloc %5 {loop.cluster = 0 : i32, loop.stage = 2 : i32} : (tensor<128x128xf16, #blocked>) -> !ttg.memdesc<128x128xf16, #shared, #smem>
1610+
ttng.tc_gen5_mma %6, %8, %0, %true, %true {loop.cluster = 0 : i32, loop.stage = 2 : i32, tt.self_latency = 1 : i32} : !ttg.memdesc<128x128xf16, #shared, #smem>, !ttg.memdesc<128x128xf16, #tmem, #ttng.tensor_memory>, !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>
1611+
} {tt.scheduled_max_stage = 2 : i32}
1612+
%1 = ttng.tmem_load %0 : !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> -> tensor<128x128xf32, #blocked1>
1613+
"use"(%1) : (tensor<128x128xf32, #blocked1>) -> ()
1614+
tt.return
1615+
}
1616+
}

0 commit comments

Comments
 (0)