@@ -276,22 +276,22 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ
276276// CHECK: tt.dot
277277// CHECK: scf.yield
278278
279- #blocked = #triton_gpu .blocked <{sizePerThread = [1 , 4 ], threadsPerWarp = [8 , 8 ], warpsPerCTA = [4 , 1 ], order = [1 , 0 ]}>
280- #mma = #triton_gpu .amd_mfma <{versionMajor = 3 , versionMinor = 0 , warpsPerCTA = [4 , 1 ], instrShape = [32 , 32 ], isTransposed = true }>
281- module attributes {" triton_gpu .num-ctas" = 1 : i32 , " triton_gpu .num-warps" = 4 : i32 , triton_gpu .target = " hip:gfx942" , " triton_gpu .threads-per-warp" = 64 : i32 } {
279+ #blocked = #ttg .blocked <{sizePerThread = [1 , 4 ], threadsPerWarp = [8 , 8 ], warpsPerCTA = [4 , 1 ], order = [1 , 0 ]}>
280+ #mma = #ttg .amd_mfma <{versionMajor = 3 , versionMinor = 0 , warpsPerCTA = [4 , 1 ], instrShape = [32 , 32 ], isTransposed = true }>
281+ module attributes {" ttg .num-ctas" = 1 : i32 , " ttg .num-warps" = 4 : i32 , ttg .target = " hip:gfx942" , " ttg .threads-per-warp" = 64 : i32 } {
282282 tt.func public @_triton_gemm_kernel_atomic_rmw (%arg0: !tt.ptr <f16 > {tt.divisibility = 16 : i32 , tt.pointer_range = 32 : i32 } loc (unknown ), %arg1: !tt.ptr <f16 > {tt.divisibility = 16 : i32 , tt.pointer_range = 32 : i32 } loc (unknown ), %arg2: !tt.ptr <f16 > {tt.divisibility = 16 : i32 , tt.pointer_range = 32 : i32 } loc (unknown ), %arg3: i32 {tt.divisibility = 16 : i32 } loc (unknown ), %arg4: i32 {tt.divisibility = 16 : i32 } loc (unknown )) attributes {noinline = false } {
283283 %cst = arith.constant dense <32 > : tensor <32 x32 xi32 , #blocked >
284284 %c0_i32 = arith.constant 0 : i32
285285 %c1_i32 = arith.constant 1 : i32
286286 %c31_i32 = arith.constant 31 : i32
287287 %c32_i32 = arith.constant 32 : i32
288288 %cst_0 = arith.constant dense <0.000000e+00 > : tensor <32 x32 xf32 , #mma >
289- %0 = tt.make_range {end = 32 : i32 , start = 0 : i32 } : tensor <32 xi32 , #triton_gpu .slice <{dim = 1 , parent = #blocked }>>
290- %1 = tt.expand_dims %0 {axis = 1 : i32 } : tensor <32 xi32 , #triton_gpu .slice <{dim = 1 , parent = #blocked }>> -> tensor <32 x1 xi32 , #blocked >
289+ %0 = tt.make_range {end = 32 : i32 , start = 0 : i32 } : tensor <32 xi32 , #ttg .slice <{dim = 1 , parent = #blocked }>>
290+ %1 = tt.expand_dims %0 {axis = 1 : i32 } : tensor <32 xi32 , #ttg .slice <{dim = 1 , parent = #blocked }>> -> tensor <32 x1 xi32 , #blocked >
291291 %2 = tt.splat %arg4 : i32 -> tensor <32 x1 xi32 , #blocked >
292292 %3 = arith.muli %1 , %2 : tensor <32 x1 xi32 , #blocked >
293- %4 = tt.make_range {end = 32 : i32 , start = 0 : i32 } : tensor <32 xi32 , #triton_gpu .slice <{dim = 0 , parent = #blocked }>>
294- %5 = tt.expand_dims %4 {axis = 0 : i32 } : tensor <32 xi32 , #triton_gpu .slice <{dim = 0 , parent = #blocked }>> -> tensor <1 x32 xi32 , #blocked >
293+ %4 = tt.make_range {end = 32 : i32 , start = 0 : i32 } : tensor <32 xi32 , #ttg .slice <{dim = 0 , parent = #blocked }>>
294+ %5 = tt.expand_dims %4 {axis = 0 : i32 } : tensor <32 xi32 , #ttg .slice <{dim = 0 , parent = #blocked }>> -> tensor <1 x32 xi32 , #blocked >
295295 %6 = tt.broadcast %3 : tensor <32 x1 xi32 , #blocked > -> tensor <32 x32 xi32 , #blocked >
296296 %7 = tt.broadcast %5 : tensor <1 x32 xi32 , #blocked > -> tensor <32 x32 xi32 , #blocked >
297297 %8 = arith.addi %6 , %7 : tensor <32 x32 xi32 , #blocked >
@@ -317,19 +317,19 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
317317 %28:3 = scf.for %arg5 = %c0_i32 to %25 step %c1_i32 iter_args (%arg6 = %cst_0 , %arg7 = %10 , %arg8 = %12 ) -> (tensor <32 x32 xf32 , #mma >, tensor <32 x32 x!tt.ptr <f16 >, #blocked >, tensor <32 x32 x!tt.ptr <f16 >, #blocked >) : i32 {
318318 %32 = tt.load %arg7 : tensor <32 x32 x!tt.ptr <f16 >, #blocked >
319319 %33 = tt.load %arg8 : tensor <32 x32 x!tt.ptr <f16 >, #blocked >
320- %34 = triton_gpu .convert_layout %32 : tensor <32 x32 xf16 , #blocked > -> tensor <32 x32 xf16 , #triton_gpu .dot_op <{opIdx = 0 , parent = #mma , kWidth = 4 }>>
321- %35 = triton_gpu .convert_layout %33 : tensor <32 x32 xf16 , #blocked > -> tensor <32 x32 xf16 , #triton_gpu .dot_op <{opIdx = 1 , parent = #mma , kWidth = 4 }>>
322- %36 = tt.dot %34 , %35 , %arg6 : tensor <32 x32 xf16 , #triton_gpu .dot_op <{opIdx = 0 , parent = #mma , kWidth = 4 }>> * tensor <32 x32 xf16 , #triton_gpu .dot_op <{opIdx = 1 , parent = #mma , kWidth = 4 }>> -> tensor <32 x32 xf32 , #mma >
320+ %34 = ttg .convert_layout %32 : tensor <32 x32 xf16 , #blocked > -> tensor <32 x32 xf16 , #ttg .dot_op <{opIdx = 0 , parent = #mma , kWidth = 4 }>>
321+ %35 = ttg .convert_layout %33 : tensor <32 x32 xf16 , #blocked > -> tensor <32 x32 xf16 , #ttg .dot_op <{opIdx = 1 , parent = #mma , kWidth = 4 }>>
322+ %36 = tt.dot %34 , %35 , %arg6 : tensor <32 x32 xf16 , #ttg .dot_op <{opIdx = 0 , parent = #mma , kWidth = 4 }>> * tensor <32 x32 xf16 , #ttg .dot_op <{opIdx = 1 , parent = #mma , kWidth = 4 }>> -> tensor <32 x32 xf32 , #mma >
323323 %37 = tt.addptr %arg7 , %cst : tensor <32 x32 x!tt.ptr <f16 >, #blocked >, tensor <32 x32 xi32 , #blocked >
324324 %38 = tt.addptr %arg8 , %27 : tensor <32 x32 x!tt.ptr <f16 >, #blocked >, tensor <32 x32 xi32 , #blocked >
325325 %39 = arith.truncf %36 : tensor <32 x32 xf32 , #mma > to tensor <32 x32 xf16 , #mma >
326- %40 = triton_gpu .convert_layout %39 : tensor <32 x32 xf16 , #mma > -> tensor <32 x32 xf16 , #blocked >
326+ %40 = ttg .convert_layout %39 : tensor <32 x32 xf16 , #mma > -> tensor <32 x32 xf16 , #blocked >
327327 %41 = tt.atomic_rmw fadd , acq_rel , gpu , %16 , %40 , %23 : (tensor <32 x32 x!tt.ptr <f16 >, #blocked >, tensor <32 x32 xf16 , #blocked >, tensor <32 x32 xi1 , #blocked >) -> tensor <32 x32 xf16 , #blocked >
328328 scf.yield %36 , %37 , %38 : tensor <32 x32 xf32 , #mma >, tensor <32 x32 x!tt.ptr <f16 >, #blocked >, tensor <32 x32 x!tt.ptr <f16 >, #blocked >
329329 }
330330 %29 = arith.truncf %28#0 : tensor <32 x32 xf32 , #mma > to tensor <32 x32 xf16 , #mma >
331- %30 = triton_gpu .convert_layout %16 : tensor <32 x32 x!tt.ptr <f16 >, #blocked > -> tensor <32 x32 x!tt.ptr <f16 >, #mma >
332- %31 = triton_gpu .convert_layout %23 : tensor <32 x32 xi1 , #blocked > -> tensor <32 x32 xi1 , #mma >
331+ %30 = ttg .convert_layout %16 : tensor <32 x32 x!tt.ptr <f16 >, #blocked > -> tensor <32 x32 x!tt.ptr <f16 >, #mma >
332+ %31 = ttg .convert_layout %23 : tensor <32 x32 xi1 , #blocked > -> tensor <32 x32 xi1 , #mma >
333333 tt.store %30 , %29 , %31 : tensor <32 x32 x!tt.ptr <f16 >, #mma >
334334 tt.return
335335 }
0 commit comments