@@ -1577,3 +1577,40 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ
15771577 tt.return
15781578 }
15791579}
1580+
1581+ // -----
1582+
1583+ #blocked = #ttg.blocked <{sizePerThread = [1 , 8 ], threadsPerWarp = [4 , 8 ], warpsPerCTA = [4 , 1 ], order = [1 , 0 ]}>
1584+ #blocked1 = #ttg.blocked <{sizePerThread = [1 , 64 ], threadsPerWarp = [32 , 1 ], warpsPerCTA = [4 , 1 ], order = [0 , 1 ]}>
1585+ #blocked2 = #ttg.blocked <{sizePerThread = [1 , 128 ], threadsPerWarp = [32 , 1 ], warpsPerCTA = [4 , 1 ], order = [0 , 1 ]}>
1586+ #shared = #ttg.nvmma_shared <{swizzlingByteWidth = 128 , transposed = false , elementBitWidth = 16 }>
1587+ #smem = #ttg.shared_memory
1588+ #tmem = #ttng.tensor_memory_encoding <blockM = 128 , blockN = 128 , unpacked = true >
1589+ module attributes {" ttg.num-ctas" = 1 : i32 , " ttg.num-warps" = 4 : i32 , ttg.target = " cuda:100" , " ttg.threads-per-warp" = 32 : i32 } {
1590+ // Check that wait is pushed to the next stage, right before the tmem_alloc, and after the prologue.
1591+ // CHECK-LABEL: @wait_before_tmem_alloc
1592+ // CHECK: scf.for
1593+ // CHECK: "prologue"() {loop.cluster = 0 : i32, loop.stage = 2 : i32}
1594+ // CHECK: ttng.tmem_alloc {{.*}} {loop.cluster = 2 : i32, loop.stage = 2 : i32}
1595+ // CHECK: ttg.async_copy_global_to_local {{.*}} {loop.cluster = 2 : i32, loop.stage = 2 : i32}
1596+ // CHECK: ttng.tc_gen5_mma {{.*}} {loop.cluster = 2 : i32, loop.stage = 2 : i32}
1597+ // CHECK: ttng.wait_barrier {{.*}} {loop.cluster = 1 : i32, loop.stage = 3 : i32}
1598+ tt.func public @wait_before_tmem_alloc (%A: tensor <128 x128 x!tt.ptr <f16 >, #blocked > {tt.contiguity = 16 : i32 , tt.divisibility = 16 : i32 },
1599+ %B: tensor <128 x128 xf16 , #blocked > {tt.contiguity = 16 : i32 , tt.divisibility = 16 : i32 },
1600+ %arg1: i32 , %arg2: i32 , %arg3: i32 ) attributes {noinline = false } {
1601+ %cst = arith.constant dense <0.000000e+00 > : tensor <128 x128 xf32 , #blocked1 >
1602+ %true = arith.constant true
1603+ %0 = ttng.tmem_alloc : () -> !ttg.memdesc <128 x128 xf32 , #tmem , #ttng.tensor_memory , mutable >
1604+ ttng.tmem_store %cst , %0 , %true : tensor <128 x128 xf32 , #blocked1 > -> !ttg.memdesc <128 x128 xf32 , #tmem , #ttng.tensor_memory , mutable >
1605+ scf.for %arg4 = %arg1 to %arg2 step %arg3 : i32 {
1606+ %2 = " prologue" () {loop.cluster = 0 : i32 , loop.stage = 2 : i32 } : () -> tensor <128 x128 xf16 , #blocked2 >
1607+ %8 = ttng.tmem_alloc %B {loop.cluster = 0 : i32 , loop.stage = 2 : i32 } : (tensor <128 x128 xf16 , #blocked >) -> !ttg.memdesc <128 x128 xf16 , #tmem , #ttng.tensor_memory >
1608+ %5 = tt.load %A {loop.cluster = 0 : i32 , loop.stage = 2 : i32 } : tensor <128 x128 x!tt.ptr <f16 >, #blocked >
1609+ %6 = ttg.local_alloc %5 {loop.cluster = 0 : i32 , loop.stage = 2 : i32 } : (tensor <128 x128 xf16 , #blocked >) -> !ttg.memdesc <128 x128 xf16 , #shared , #smem >
1610+ ttng.tc_gen5_mma %6 , %8 , %0 , %true , %true {loop.cluster = 0 : i32 , loop.stage = 2 : i32 , tt.self_latency = 1 : i32 } : !ttg.memdesc <128 x128 xf16 , #shared , #smem >, !ttg.memdesc <128 x128 xf16 , #tmem , #ttng.tensor_memory >, !ttg.memdesc <128 x128 xf32 , #tmem , #ttng.tensor_memory , mutable >
1611+ } {tt.scheduled_max_stage = 2 : i32 }
1612+ %1 = ttng.tmem_load %0 : !ttg.memdesc <128 x128 xf32 , #tmem , #ttng.tensor_memory , mutable > -> tensor <128 x128 xf32 , #blocked1 >
1613+ " use" (%1 ) : (tensor <128 x128 xf32 , #blocked1 >) -> ()
1614+ tt.return
1615+ }
1616+ }
0 commit comments