@@ -276,3 +276,22 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 :
276276 tt.return %r : tensor <128 x64 xf32 , #mma >
277277 }
278278}
279+
280+ // -----
281+
282+ #blocked = #triton_gpu.blocked <{sizePerThread = [16 , 1 ], threadsPerWarp = [32 , 1 ], warpsPerCTA = [4 , 1 ], order = [1 , 0 ]}>
283+ #blocked1 = #triton_gpu.blocked <{sizePerThread = [1 , 16 ], threadsPerWarp = [1 , 32 ], warpsPerCTA = [1 , 4 ], order = [0 , 1 ]}>
284+ #mma = #triton_gpu.nvidia_mma <{versionMajor = 2 , versionMinor = 0 , warpsPerCTA = [4 , 1 ], instrShape = [16 , 8 ]}>
285+ module attributes {" triton_gpu.target" = " cuda:90" , " triton_gpu.num-ctas" = 1 : i32 , " triton_gpu.num-warps" = 4 : i32 , " triton_gpu.threads-per-warp" = 32 : i32 } {
286+ // CHECK-LABEL: mmav2_reorder_transpose
287+ // CHECK: triton_gpu.local_alloc
288+ // CHECK: triton_gpu.memdesc_trans
289+ // CHECK: triton_gpu.local_load
290+ // CHECK: tt.dot
291+ tt.func @mmav2_reorder_transpose (%t: tensor <32 x128 xf16 , #blocked1 >, %dotb: tensor <32 x64 xf16 , #triton_gpu.dot_op <{opIdx = 1 , parent = #mma , kWidth = 2 }>>, %dotc: tensor <128 x64 xf32 , #mma >) -> tensor <128 x64 xf32 , #mma >{
292+ %a = tt.trans %t {order = array<i32 : 1 , 0 >} : tensor <32 x128 xf16 , #blocked1 > -> tensor <128 x32 xf16 , #blocked >
293+ %cv = triton_gpu.convert_layout %a : tensor <128 x32 xf16 , #blocked > -> tensor <128 x32 xf16 , #triton_gpu.dot_op <{opIdx = 0 , parent = #mma , kWidth = 2 }>>
294+ %r = tt.dot %cv , %dotb , %dotc , inputPrecision = tf32 : tensor <128 x32 xf16 , #triton_gpu.dot_op <{opIdx = 0 , parent = #mma , kWidth = 2 }>> * tensor <32 x64 xf16 , #triton_gpu.dot_op <{opIdx = 1 , parent = #mma , kWidth = 2 }>> -> tensor <128 x64 xf32 , #mma >
295+ tt.return %r : tensor <128 x64 xf32 , #mma >
296+ }
297+ }
0 commit comments