@@ -285,62 +285,6 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ
285285
286286// -----
287287
288- #blocked = #ttg.blocked <{sizePerThread = [1 , 1 ], threadsPerWarp = [1 , 32 ], warpsPerCTA = [1 , 4 ], order = [1 , 0 ], CGALayout = [[1 , 0 ]]}>
289- #blocked1 = #ttg.blocked <{sizePerThread = [1 , 1 ], threadsPerWarp = [1 , 32 ], warpsPerCTA = [1 , 4 ], order = [1 , 0 ], CGALayout = [[0 , 0 ]]}>
290- module attributes {" ttg.num-ctas" = 2 : i32 , " ttg.num-warps" = 4 : i32 , ttg.target = " cuda:100" , " ttg.threads-per-warp" = 32 : i32 } {
291- // LAYOUT_16x256{LITERAL}: #ttg.linear<{register = [[0, 1], [8, 0], [0, 8], [0, 16], [0, 32], [0, 64], [0, 128]], lane = [[0, 2], [0, 4], [1, 0], [2, 0], [4, 0]], warp = [[16, 0], [32, 0]], block = [[64, 0]]}>
292- // CHECK-DAG: #[[$TMEM:.+]] = #ttng.tensor_memory_encoding<blockM = 64, blockN = 256, colStride = 1, CTASplitM = 2>
293- // CHECK-DAG: #[[$B:.+]] = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0], CGALayout = {{\[\[1, 0\]\]}}}>
294- // CHECK-DAG: #[[$L:.+]] = #ttg.linear<{register = {{\[\[0, 1\], \[0, 2\], \[0, 4\], \[0, 8\], \[0, 16\], \[0, 32\], \[0, 64\]\]}}, lane = {{\[\[1, 0\], \[2, 0\], \[4, 0\], \[8, 0\], \[0, 128\]\]}}, warp = {{\[\[16, 0\], \[32, 0\]\]}}, block = {{\[\[64, 0\]\]}}}>
295- // CHECK-LABEL: mmav5_multi_ctas
296- // CHECK-DAG: %[[TRUE:.+]] = arith.constant true
297- // CHECK-DAG: %[[A:.+]] = ttg.local_alloc %{{.*}} : (tensor<128x64xf16, #{{.*}}>) -> !ttg.memdesc<128x64xf16, #{{.*}}, #smem
298- // CHECK-DAG: %[[B:.+]] = ttg.local_alloc %{{.*}} : (tensor<64x256xf16, #{{.*}}>) -> !ttg.memdesc<64x256xf16, #{{.*}}, #smem
299- // CHECK-DAG: %[[ACC:.+]], %[[ACC_TOK:.+]] = ttng.tmem_alloc %{{.*}} : (tensor<128x256xf32, #{{.*}}>) -> (!ttg.memdesc<128x256xf32, #{{.*}}, #ttng.tensor_memory, mutable>, !ttg.async.token)
300- // CHECK: %[[MMA_TOK:.+]] = ttng.tc_gen5_mma %[[A]], %[[B]], %[[ACC]][%[[ACC_TOK]]], %[[TRUE]], %[[TRUE]] : !ttg.memdesc<128x64xf16, #shared1, #smem>, !ttg.memdesc<64x256xf16, #shared, #smem>, !ttg.memdesc<128x256xf32, #[[$TMEM]], #ttng.tensor_memory, mutable>
301- // CHECK: %[[R:.+]], %{{.*}} = ttng.tmem_load %[[ACC]][%[[MMA_TOK]]] : !ttg.memdesc<128x256xf32, #{{.*}}, #ttng.tensor_memory, mutable> -> tensor<128x256xf32
302- // CHECK: %[[CVT:.+]] = ttg.convert_layout %[[R]] : tensor<128x256xf32, #[[$L]]> -> tensor<128x256xf32, #[[$B]]>
303- // CHECK: tt.return %[[CVT]] : tensor<128x256xf32
304- tt.func public @mmav5_multi_ctas (%a: tensor <128 x64 xf16 , #blocked >, %b: tensor <64 x256 xf16 , #blocked1 >, %c: tensor <128 x256 xf32 , #blocked >) -> tensor <128 x256 xf32 , #blocked > {
305- %ad = ttg.convert_layout %a : tensor <128 x64 xf16 , #blocked > -> tensor <128 x64 xf16 , #ttg.dot_op <{opIdx = 0 , parent = #blocked }>>
306- %bd = ttg.convert_layout %b : tensor <64 x256 xf16 , #blocked1 > -> tensor <64 x256 xf16 , #ttg.dot_op <{opIdx = 1 , parent = #blocked }>>
307- %d = tt.dot %ad , %bd , %c , inputPrecision = tf32 : tensor <128 x64 xf16 , #ttg.dot_op <{opIdx = 0 , parent = #blocked }>> * tensor <64 x256 xf16 , #ttg.dot_op <{opIdx = 1 , parent = #blocked }>> -> tensor <128 x256 xf32 , #blocked >
308- tt.return %d : tensor <128 x256 xf32 , #blocked >
309- }
310- }
311-
312-
313- // -----
314-
315- #blocked = #ttg.blocked <{sizePerThread = [1 , 1 ], threadsPerWarp = [1 , 32 ], warpsPerCTA = [1 , 4 ], order = [1 , 0 ], CGALayout = [[1 , 0 ]]}>
316- #blocked1 = #ttg.blocked <{sizePerThread = [1 , 1 ], threadsPerWarp = [1 , 32 ], warpsPerCTA = [1 , 4 ], order = [1 , 0 ], CGALayout = [[1 , 0 ]]}>
317- #blocked2 = #ttg.blocked <{sizePerThread = [1 , 1 ], threadsPerWarp = [1 , 32 ], warpsPerCTA = [1 , 4 ], order = [1 , 0 ], CGALayout = [[1 , 0 ]]}>
318- module attributes {" ttg.num-ctas" = 2 : i32 , " ttg.num-warps" = 4 : i32 , ttg.target = " cuda:100" , " ttg.threads-per-warp" = 32 : i32 } {
319- // CHECK-DAG: #[[$TMEM:.+]] = #ttng.tensor_memory_encoding<blockM = 64, blockN = 256, colStride = 1, CTASplitM = 2, twoCTAs = true>
320- // CHECK-DAG: #[[$B:.+]] = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0], CGALayout = {{\[\[1, 0\]\]}}}>
321- // CHECK-DAG: #[[$L:.+]] = #ttg.linear<{register = {{\[\[0, 1\], \[0, 2\], \[0, 4\], \[0, 8\], \[0, 16\], \[0, 32\], \[0, 64\]\]}}, lane = {{\[\[1, 0\], \[2, 0\], \[4, 0\], \[8, 0\], \[16, 0\]\]}}, warp = {{\[\[32, 0\], \[0, 128\]\]}}, block = {{\[\[64, 0\]\]}}}>
322- // CHECK-DAG: #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16, CGALayout = {{\[\[1, 0\]\]}}}>
323- // CHECK-DAG: #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16, CGALayout = {{\[\[0, 1\]\]}}}>
324- // CHECK-LABEL: mmav5_2ctas
325- // CHECK-DAG: %[[TRUE:.+]] = arith.constant true
326- // CHECK-DAG: %[[A:.+]] = ttg.local_alloc %{{.*}} : (tensor<128x64xf16, #{{.*}}>) -> !ttg.memdesc<128x64xf16, #{{.*}}, #smem
327- // CHECK-DAG: %[[B:.+]] = ttg.local_alloc %{{.*}} : (tensor<64x256xf16, #{{.*}}>) -> !ttg.memdesc<64x256xf16, #{{.*}}, #smem
328- // CHECK-DAG: %[[ACC:.+]], %[[ACC_TOK:.+]] = ttng.tmem_alloc %{{.*}} : (tensor<128x256xf32, #{{.*}}>) -> (!ttg.memdesc<128x256xf32, #{{.*}}, #ttng.tensor_memory, mutable>, !ttg.async.token)
329- // CHECK: %[[MMA_TOK:.+]] = ttng.tc_gen5_mma %[[A]], %[[B]], %[[ACC]][%[[ACC_TOK]]], %[[TRUE]], %[[TRUE]] {two_ctas} : !ttg.memdesc<128x64xf16, #shared, #smem>, !ttg.memdesc<64x256xf16, #shared1, #smem>, !ttg.memdesc<128x256xf32, #[[$TMEM]], #ttng.tensor_memory, mutable>
330- // CHECK: %[[R:.+]], %{{.*}} = ttng.tmem_load %[[ACC]][%[[MMA_TOK]]] : !ttg.memdesc<128x256xf32, #{{.*}}, #ttng.tensor_memory, mutable> -> tensor<128x256xf32
331- // CHECK: %[[CVT:.+]] = ttg.convert_layout %[[R]] : tensor<128x256xf32, #[[$L]]> -> tensor<128x256xf32, #[[$B]]>
332- // CHECK: tt.return %[[CVT]] : tensor<128x256xf32
333- tt.func public @mmav5_2ctas (%a: tensor <128 x64 xf16 , #blocked2 >, %b_ptr: tensor <64 x256 x!tt.ptr <f16 >, #blocked1 >, %c: tensor <128 x256 xf32 , #blocked >) -> tensor <128 x256 xf32 , #blocked > {
334- %ad = ttg.convert_layout %a : tensor <128 x64 xf16 , #blocked2 > -> tensor <128 x64 xf16 , #ttg.dot_op <{opIdx = 0 , parent = #blocked }>>
335- %b = tt.load %b_ptr : tensor <64 x256 x!tt.ptr <f16 >, #blocked1 >
336- %bd = ttg.convert_layout %b : tensor <64 x256 xf16 , #blocked1 > -> tensor <64 x256 xf16 , #ttg.dot_op <{opIdx = 1 , parent = #blocked }>>
337- %d = tt.dot %ad , %bd , %c , inputPrecision = tf32 : tensor <128 x64 xf16 , #ttg.dot_op <{opIdx = 0 , parent = #blocked }>> * tensor <64 x256 xf16 , #ttg.dot_op <{opIdx = 1 , parent = #blocked }>> -> tensor <128 x256 xf32 , #blocked >
338- tt.return %d : tensor <128 x256 xf32 , #blocked >
339- }
340- }
341-
342- // -----
343-
344288#blocked = #ttg.blocked <{sizePerThread = [1 , 1 ], threadsPerWarp = [1 , 32 ], warpsPerCTA = [1 , 4 ], order = [1 , 0 ]}>
345289#blocked1 = #ttg.blocked <{sizePerThread = [1 , 1 ], threadsPerWarp = [16 , 2 ], warpsPerCTA = [4 , 1 ], order = [1 , 0 ]}>
346290#blocked2 = #ttg.blocked <{sizePerThread = [1 , 1 ], threadsPerWarp = [2 , 16 ], warpsPerCTA = [4 , 1 ], order = [1 , 0 ]}>
0 commit comments