@@ -124,6 +124,26 @@ tt.func @test_canonicalize_convert_local_load(%arg0: !ttg.async.token) -> tensor
124124
125125// -----
126126
127+ #blocked = #ttg.blocked <{sizePerThread = [1 , 32 ], threadsPerWarp = [32 , 1 ], warpsPerCTA = [4 , 2 ], order = [0 , 1 ]}>
128+ #linear = #ttg.linear <{register = [[0 , 1 ], [0 , 2 ], [0 , 4 ], [0 , 8 ], [0 , 16 ]], lane = [[1 , 0 ], [2 , 0 ], [4 , 0 ], [8 , 0 ], [0 , 32 ]], warp = [[32 , 0 ], [64 , 0 ], [16 , 0 ]], block = []}>
129+ #tmem = #ttng.tensor_memory_encoding <blockM = 128 , blockN = 64 , unpacked = false >
130+ // CHECK-LABEL: test_canonicalize_convert_tmem_store
131+ module attributes {" ttg.num-ctas" = 1 : i32 , " ttg.num-warps" = 8 : i32 , ttg.target = " cuda:100" , " ttg.threads-per-warp" = 32 : i32 } {
132+ tt.func @test_canonicalize_convert_tmem_store (
133+ %arg0: tensor <128 x64 xbf16 , #linear >,
134+ %arg1: !ttg.memdesc <128 x64 xbf16 , #tmem , #ttng.tensor_memory , mutable >
135+ ) {
136+ %true = arith.constant true
137+ // CHECK-NOT: ttg.convert_layout
138+ %1 = ttg.convert_layout %arg0 : tensor <128 x64 xbf16 , #linear > -> tensor <128 x64 xbf16 , #blocked >
139+ // CHECK: ttng.tmem_store %{{.*}} : tensor<128x64xbf16, #linear> ->
140+ ttng.tmem_store %1 , %arg1 , %true : tensor <128 x64 xbf16 , #blocked > -> !ttg.memdesc <128 x64 xbf16 , #tmem , #ttng.tensor_memory , mutable >
141+ tt.return
142+ }
143+ }
144+
145+ // -----
146+
127147#blocked = #ttg.blocked <{sizePerThread = [1 , 4 ], threadsPerWarp = [8 , 4 ], warpsPerCTA = [1 , 1 ], order = [1 , 0 ], CTAsPerCGA = [1 , 1 ], CTASplitNum = [1 , 1 ], CTAOrder = [1 , 0 ]}>
128148#shared = #ttg.swizzled_shared <{vec = 1 , perPhase =2 , maxPhase =8 , order = [1 , 0 ], CTAsPerCGA = [1 , 1 ], CTASplitNum = [1 , 1 ], CTAOrder = [1 , 0 ]}>
129149#smem = #ttg.shared_memory
0 commit comments