@@ -2032,3 +2032,76 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ
20322032 tt.return %result#0 : tensor <128 x128 xf32 , #mma >
20332033 }
20342034}
2035+
2036+ // -----
2037+
2038+ // Test with FMA based dot, pingpong should skip optimization of such kernels.
2039+ // Based on pingpong_small test.
2040+
2041+ // CHECK-LABEL: fma_dot_neg
2042+ // CHECK-NOT: rocdl.sched.barrier
2043+ // CHECK-NOT: rocdl.s.setprio
2044+ // CHECK-NOT: async
2045+
2046+ #blocked = #ttg.blocked <{sizePerThread = [8 , 1 ], threadsPerWarp = [8 , 8 ], warpsPerCTA = [1 , 4 ], order = [0 , 1 ]}>
2047+ #blocked1 = #ttg.blocked <{sizePerThread = [1 , 8 ], threadsPerWarp = [8 , 8 ], warpsPerCTA = [4 , 1 ], order = [1 , 0 ]}>
2048+ #fake_mma = #ttg.blocked <{sizePerThread = [4 , 4 ], threadsPerWarp = [8 , 8 ], warpsPerCTA = [2 , 2 ], order = [1 , 0 ]}>
2049+ #shared = #ttg.swizzled_shared <{vec = 8 , perPhase = 1 , maxPhase = 8 , order = [1 , 0 ]}>
2050+ #shared1 = #ttg.swizzled_shared <{vec = 8 , perPhase = 1 , maxPhase = 8 , order = [0 , 1 ]}>
2051+ #smem = #ttg.shared_memory
2052+ module attributes {" ttg.num-ctas" = 1 : i32 , " ttg.num-warps" = 4 : i32 , ttg.target = " hip:gfx942" , " ttg.threads-per-warp" = 64 : i32 } {
2053+ tt.func public @fma_dot_neg (%arg0: !tt.ptr <f16 > {tt.divisibility = 16 : i32 , tt.pointer_range = 32 : i32 }, %arg1: !tt.ptr <f16 > {tt.divisibility = 16 : i32 , tt.pointer_range = 32 : i32 }, %arg2: !tt.ptr <f16 > {tt.divisibility = 16 : i32 , tt.pointer_range = 32 : i32 }, %arg3: i32 {tt.divisibility = 16 : i32 }, %arg4: i32 {tt.divisibility = 16 : i32 }) {
2054+ %cst = arith.constant dense <0.000000e+00 > : tensor <128 x128 xf32 , #fake_mma >
2055+ %c1_i32 = arith.constant 1 : i32
2056+ %cst_0 = arith.constant dense <64 > : tensor <64 x128 xi32 , #blocked >
2057+ %cst_1 = arith.constant dense <64 > : tensor <128 x64 xi32 , #blocked1 >
2058+ %c0_i32 = arith.constant 0 : i32
2059+ %c64_i32 = arith.constant 64 : i32
2060+ %0 = tt.splat %arg0 : !tt.ptr <f16 > -> tensor <128 x1 x!tt.ptr <f16 >, #blocked1 >
2061+ %1 = tt.get_program_id x : i32
2062+ %2 = tt.splat %1 : i32 -> tensor <128 xi32 , #ttg.slice <{dim = 1 , parent = #blocked1 }>>
2063+ %3 = tt.make_range {end = 128 : i32 , start = 0 : i32 } : tensor <128 xi32 , #ttg.slice <{dim = 1 , parent = #blocked1 }>>
2064+ %4 = arith.addi %2 , %3 : tensor <128 xi32 , #ttg.slice <{dim = 1 , parent = #blocked1 }>>
2065+ %5 = tt.expand_dims %4 {axis = 1 : i32 } : tensor <128 xi32 , #ttg.slice <{dim = 1 , parent = #blocked1 }>> -> tensor <128 x1 xi32 , #blocked1 >
2066+ %6 = tt.splat %arg3 : i32 -> tensor <128 x1 xi32 , #blocked1 >
2067+ %7 = arith.muli %5 , %6 : tensor <128 x1 xi32 , #blocked1 >
2068+ %8 = tt.addptr %0 , %7 : tensor <128 x1 x!tt.ptr <f16 >, #blocked1 >, tensor <128 x1 xi32 , #blocked1 >
2069+ %9 = tt.broadcast %8 : tensor <128 x1 x!tt.ptr <f16 >, #blocked1 > -> tensor <128 x64 x!tt.ptr <f16 >, #blocked1 >
2070+ %10 = tt.make_range {end = 64 : i32 , start = 0 : i32 } : tensor <64 xi32 , #ttg.slice <{dim = 0 , parent = #blocked1 }>>
2071+ %11 = tt.expand_dims %10 {axis = 0 : i32 } : tensor <64 xi32 , #ttg.slice <{dim = 0 , parent = #blocked1 }>> -> tensor <1 x64 xi32 , #blocked1 >
2072+ %12 = tt.broadcast %11 : tensor <1 x64 xi32 , #blocked1 > -> tensor <128 x64 xi32 , #blocked1 >
2073+ %13 = tt.addptr %9 , %12 : tensor <128 x64 x!tt.ptr <f16 >, #blocked1 >, tensor <128 x64 xi32 , #blocked1 >
2074+ %14 = tt.splat %arg1 : !tt.ptr <f16 > -> tensor <64 x1 x!tt.ptr <f16 >, #blocked >
2075+ %15 = tt.make_range {end = 64 : i32 , start = 0 : i32 } : tensor <64 xi32 , #ttg.slice <{dim = 1 , parent = #blocked }>>
2076+ %16 = tt.expand_dims %15 {axis = 1 : i32 } : tensor <64 xi32 , #ttg.slice <{dim = 1 , parent = #blocked }>> -> tensor <64 x1 xi32 , #blocked >
2077+ %17 = tt.addptr %14 , %16 : tensor <64 x1 x!tt.ptr <f16 >, #blocked >, tensor <64 x1 xi32 , #blocked >
2078+ %18 = tt.broadcast %17 : tensor <64 x1 x!tt.ptr <f16 >, #blocked > -> tensor <64 x128 x!tt.ptr <f16 >, #blocked >
2079+ %19 = tt.splat %arg4 : i32 -> tensor <64 x128 xi32 , #blocked >
2080+ %20 = tt.addptr %18 , %19 : tensor <64 x128 x!tt.ptr <f16 >, #blocked >, tensor <64 x128 xi32 , #blocked >
2081+ %21 = ttg.local_alloc : () -> !ttg.memdesc <1 x128 x64 xf16 , #shared , #smem , mutable >
2082+ %22 = ttg.local_alloc : () -> !ttg.memdesc <1 x64 x128 xf16 , #shared1 , #smem , mutable >
2083+ %23 = ttg.memdesc_index %21 [%c0_i32 ] : !ttg.memdesc <1 x128 x64 xf16 , #shared , #smem , mutable > -> !ttg.memdesc <128 x64 xf16 , #shared , #smem , mutable >
2084+ %24 = ttg.memdesc_index %22 [%c0_i32 ] : !ttg.memdesc <1 x64 x128 xf16 , #shared1 , #smem , mutable > -> !ttg.memdesc <64 x128 xf16 , #shared1 , #smem , mutable >
2085+ %25:6 = scf.for %arg5 = %c0_i32 to %c64_i32 step %c1_i32 iter_args (%arg6 = %cst , %arg7 = %13 , %arg8 = %20 , %arg9 = %c0_i32 , %arg10 = %23 , %arg11 = %24 ) -> (tensor <128 x128 xf32 , #fake_mma >, tensor <128 x64 x!tt.ptr <f16 >, #blocked1 >, tensor <64 x128 x!tt.ptr <f16 >, #blocked >, i32 , !ttg.memdesc <128 x64 xf16 , #shared , #smem , mutable >, !ttg.memdesc <64 x128 xf16 , #shared1 , #smem , mutable >) : i32 {
2086+ %26 = tt.addptr %arg7 , %cst_1 : tensor <128 x64 x!tt.ptr <f16 >, #blocked1 >, tensor <128 x64 xi32 , #blocked1 >
2087+ %27 = tt.load %26 : tensor <128 x64 x!tt.ptr <f16 >, #blocked1 >
2088+ %28 = tt.addptr %arg8 , %cst_0 : tensor <64 x128 x!tt.ptr <f16 >, #blocked >, tensor <64 x128 xi32 , #blocked >
2089+ %29 = tt.load %28 : tensor <64 x128 x!tt.ptr <f16 >, #blocked >
2090+ %30 = ttg.local_load %arg10 : !ttg.memdesc <128 x64 xf16 , #shared , #smem , mutable > -> tensor <128 x64 xf16 , #ttg.dot_op <{opIdx = 0 , parent = #fake_mma }>>
2091+ %31 = ttg.local_load %arg11 : !ttg.memdesc <64 x128 xf16 , #shared1 , #smem , mutable > -> tensor <64 x128 xf16 , #ttg.dot_op <{opIdx = 1 , parent = #fake_mma }>>
2092+ %32 = arith.negf %31 : tensor <64 x128 xf16 , #ttg.dot_op <{opIdx = 1 , parent = #fake_mma }>>
2093+ %33 = tt.dot %30 , %32 , %arg6 : tensor <128 x64 xf16 , #ttg.dot_op <{opIdx = 0 , parent = #fake_mma }>> * tensor <64 x128 xf16 , #ttg.dot_op <{opIdx = 1 , parent = #fake_mma }>> -> tensor <128 x128 xf32 , #fake_mma >
2094+ %34 = arith.addi %arg9 , %c1_i32 : i32
2095+ %35 = arith.cmpi slt , %34 , %c1_i32 : i32
2096+ %36 = arith.select %35 , %34 , %c0_i32 : i32
2097+ %37 = ttg.memdesc_index %21 [%36 ] : !ttg.memdesc <1 x128 x64 xf16 , #shared , #smem , mutable > -> !ttg.memdesc <128 x64 xf16 , #shared , #smem , mutable >
2098+ ttg.local_store %27 , %37 : tensor <128 x64 xf16 , #blocked1 > -> !ttg.memdesc <128 x64 xf16 , #shared , #smem , mutable >
2099+ %38 = ttg.memdesc_index %22 [%36 ] : !ttg.memdesc <1 x64 x128 xf16 , #shared1 , #smem , mutable > -> !ttg.memdesc <64 x128 xf16 , #shared1 , #smem , mutable >
2100+ ttg.local_store %29 , %38 : tensor <64 x128 xf16 , #blocked > -> !ttg.memdesc <64 x128 xf16 , #shared1 , #smem , mutable >
2101+ scf.yield %33 , %26 , %28 , %36 , %37 , %38 : tensor <128 x128 xf32 , #fake_mma >, tensor <128 x64 x!tt.ptr <f16 >, #blocked1 >, tensor <64 x128 x!tt.ptr <f16 >, #blocked >, i32 , !ttg.memdesc <128 x64 xf16 , #shared , #smem , mutable >, !ttg.memdesc <64 x128 xf16 , #shared1 , #smem , mutable >
2102+ }
2103+ ttg.local_dealloc %21 : !ttg.memdesc <1 x128 x64 xf16 , #shared , #smem , mutable >
2104+ ttg.local_dealloc %22 : !ttg.memdesc <1 x64 x128 xf16 , #shared1 , #smem , mutable >
2105+ tt.return
2106+ }
2107+ }
0 commit comments