Skip to content

Commit f47029e

Browse files
[AMD] Fix BlockPingpong for non-MFMA dot (#9618)
This PR fixes crash of blocked pingpong optimization when applied to fma dot. --------- Co-authored-by: Alexander Efimov <efimov.alexander@gmail.com>
1 parent 0585186 commit f47029e

File tree

2 files changed

+79
-1
lines changed

2 files changed

+79
-1
lines changed

test/TritonGPU/amd/amd-block-pingpong.mlir

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2032,3 +2032,76 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ
20322032
tt.return %result#0 : tensor<128x128xf32, #mma>
20332033
}
20342034
}
2035+
2036+
// -----
2037+
2038+
// Test with FMA based dot, pingpong should skip optimization of such kernels.
2039+
// Based on pingpong_small test.
2040+
2041+
// CHECK-LABEL: fma_dot_neg
2042+
// CHECK-NOT: rocdl.sched.barrier
2043+
// CHECK-NOT: rocdl.s.setprio
2044+
// CHECK-NOT: async
2045+
2046+
#blocked = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 4], order = [0, 1]}>
2047+
#blocked1 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
2048+
#fake_mma = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [8, 8], warpsPerCTA = [2, 2], order = [1, 0]}>
2049+
#shared = #ttg.swizzled_shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0]}>
2050+
#shared1 = #ttg.swizzled_shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1]}>
2051+
#smem = #ttg.shared_memory
2052+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} {
2053+
tt.func public @fma_dot_neg(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg2: !tt.ptr<f16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}) {
2054+
%cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #fake_mma>
2055+
%c1_i32 = arith.constant 1 : i32
2056+
%cst_0 = arith.constant dense<64> : tensor<64x128xi32, #blocked>
2057+
%cst_1 = arith.constant dense<64> : tensor<128x64xi32, #blocked1>
2058+
%c0_i32 = arith.constant 0 : i32
2059+
%c64_i32 = arith.constant 64 : i32
2060+
%0 = tt.splat %arg0 : !tt.ptr<f16> -> tensor<128x1x!tt.ptr<f16>, #blocked1>
2061+
%1 = tt.get_program_id x : i32
2062+
%2 = tt.splat %1 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>
2063+
%3 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>
2064+
%4 = arith.addi %2, %3 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>
2065+
%5 = tt.expand_dims %4 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<128x1xi32, #blocked1>
2066+
%6 = tt.splat %arg3 : i32 -> tensor<128x1xi32, #blocked1>
2067+
%7 = arith.muli %5, %6 : tensor<128x1xi32, #blocked1>
2068+
%8 = tt.addptr %0, %7 : tensor<128x1x!tt.ptr<f16>, #blocked1>, tensor<128x1xi32, #blocked1>
2069+
%9 = tt.broadcast %8 : tensor<128x1x!tt.ptr<f16>, #blocked1> -> tensor<128x64x!tt.ptr<f16>, #blocked1>
2070+
%10 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>>
2071+
%11 = tt.expand_dims %10 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1>
2072+
%12 = tt.broadcast %11 : tensor<1x64xi32, #blocked1> -> tensor<128x64xi32, #blocked1>
2073+
%13 = tt.addptr %9, %12 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1>
2074+
%14 = tt.splat %arg1 : !tt.ptr<f16> -> tensor<64x1x!tt.ptr<f16>, #blocked>
2075+
%15 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>>
2076+
%16 = tt.expand_dims %15 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked>
2077+
%17 = tt.addptr %14, %16 : tensor<64x1x!tt.ptr<f16>, #blocked>, tensor<64x1xi32, #blocked>
2078+
%18 = tt.broadcast %17 : tensor<64x1x!tt.ptr<f16>, #blocked> -> tensor<64x128x!tt.ptr<f16>, #blocked>
2079+
%19 = tt.splat %arg4 : i32 -> tensor<64x128xi32, #blocked>
2080+
%20 = tt.addptr %18, %19 : tensor<64x128x!tt.ptr<f16>, #blocked>, tensor<64x128xi32, #blocked>
2081+
%21 = ttg.local_alloc : () -> !ttg.memdesc<1x128x64xf16, #shared, #smem, mutable>
2082+
%22 = ttg.local_alloc : () -> !ttg.memdesc<1x64x128xf16, #shared1, #smem, mutable>
2083+
%23 = ttg.memdesc_index %21[%c0_i32] : !ttg.memdesc<1x128x64xf16, #shared, #smem, mutable> -> !ttg.memdesc<128x64xf16, #shared, #smem, mutable>
2084+
%24 = ttg.memdesc_index %22[%c0_i32] : !ttg.memdesc<1x64x128xf16, #shared1, #smem, mutable> -> !ttg.memdesc<64x128xf16, #shared1, #smem, mutable>
2085+
%25:6 = scf.for %arg5 = %c0_i32 to %c64_i32 step %c1_i32 iter_args(%arg6 = %cst, %arg7 = %13, %arg8 = %20, %arg9 = %c0_i32, %arg10 = %23, %arg11 = %24) -> (tensor<128x128xf32, #fake_mma>, tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<64x128x!tt.ptr<f16>, #blocked>, i32, !ttg.memdesc<128x64xf16, #shared, #smem, mutable>, !ttg.memdesc<64x128xf16, #shared1, #smem, mutable>) : i32 {
2086+
%26 = tt.addptr %arg7, %cst_1 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1>
2087+
%27 = tt.load %26 : tensor<128x64x!tt.ptr<f16>, #blocked1>
2088+
%28 = tt.addptr %arg8, %cst_0 : tensor<64x128x!tt.ptr<f16>, #blocked>, tensor<64x128xi32, #blocked>
2089+
%29 = tt.load %28 : tensor<64x128x!tt.ptr<f16>, #blocked>
2090+
%30 = ttg.local_load %arg10 : !ttg.memdesc<128x64xf16, #shared, #smem, mutable> -> tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #fake_mma}>>
2091+
%31 = ttg.local_load %arg11 : !ttg.memdesc<64x128xf16, #shared1, #smem, mutable> -> tensor<64x128xf16, #ttg.dot_op<{opIdx = 1, parent = #fake_mma}>>
2092+
%32 = arith.negf %31 : tensor<64x128xf16, #ttg.dot_op<{opIdx = 1, parent = #fake_mma}>>
2093+
%33 = tt.dot %30, %32, %arg6 : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #fake_mma}>> * tensor<64x128xf16, #ttg.dot_op<{opIdx = 1, parent = #fake_mma}>> -> tensor<128x128xf32, #fake_mma>
2094+
%34 = arith.addi %arg9, %c1_i32 : i32
2095+
%35 = arith.cmpi slt, %34, %c1_i32 : i32
2096+
%36 = arith.select %35, %34, %c0_i32 : i32
2097+
%37 = ttg.memdesc_index %21[%36] : !ttg.memdesc<1x128x64xf16, #shared, #smem, mutable> -> !ttg.memdesc<128x64xf16, #shared, #smem, mutable>
2098+
ttg.local_store %27, %37 : tensor<128x64xf16, #blocked1> -> !ttg.memdesc<128x64xf16, #shared, #smem, mutable>
2099+
%38 = ttg.memdesc_index %22[%36] : !ttg.memdesc<1x64x128xf16, #shared1, #smem, mutable> -> !ttg.memdesc<64x128xf16, #shared1, #smem, mutable>
2100+
ttg.local_store %29, %38 : tensor<64x128xf16, #blocked> -> !ttg.memdesc<64x128xf16, #shared1, #smem, mutable>
2101+
scf.yield %33, %26, %28, %36, %37, %38 : tensor<128x128xf32, #fake_mma>, tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<64x128x!tt.ptr<f16>, #blocked>, i32, !ttg.memdesc<128x64xf16, #shared, #smem, mutable>, !ttg.memdesc<64x128xf16, #shared1, #smem, mutable>
2102+
}
2103+
ttg.local_dealloc %21 : !ttg.memdesc<1x128x64xf16, #shared, #smem, mutable>
2104+
ttg.local_dealloc %22 : !ttg.memdesc<1x64x128xf16, #shared1, #smem, mutable>
2105+
tt.return
2106+
}
2107+
}

third_party/amd/lib/TritonAMDGPUTransforms/BlockPingpong.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1070,7 +1070,12 @@ void Pingponger::getDotPingponged() {
10701070
auto encoding = cast<RankedTensorType>(aType).getEncoding();
10711071
auto srcEncoding = cast<ttg::DotOperandEncodingAttr>(encoding);
10721072
kWidth = srcEncoding.getKWidth();
1073-
auto mfmaEncoding = cast<ttg::AMDMfmaEncodingAttr>(srcEncoding.getParent());
1073+
auto mfmaEncoding =
1074+
dyn_cast<ttg::AMDMfmaEncodingAttr>(srcEncoding.getParent());
1075+
if (!mfmaEncoding) {
1076+
LDBG("Encountered non-MFMA layout");
1077+
return;
1078+
}
10741079
SmallVector<int64_t> intShape;
10751080
auto mnkDim = mfmaEncoding.getInstrShape();
10761081
intShape.push_back(mnkDim[0]);

0 commit comments

Comments
 (0)