Skip to content

Commit 028f93e

Browse files
zhang-hui-yulozhang hui
andauthored
HIP: RDNA4 tensor core support for MMF (ggml-org#17077)
* mmf for rdna4 * align the padding for rdna4 * forbit mul_mat_f for rdna4 * fix as comment * remove device kernels * add constexpr for early return * update based on review comment * change based on the review comment * pass compile error * keep code consistency --------- Co-authored-by: zhang hui <[email protected]>
1 parent 8e9ddba commit 028f93e

File tree

5 files changed

+180
-23
lines changed

5 files changed

+180
-23
lines changed

ggml/src/ggml-cuda/common.cuh

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,10 @@ static const char * cu_get_error_str(CUresult err) {
224224
#define AMD_MFMA_AVAILABLE
225225
#endif // defined(GGML_USE_HIP) && defined(CDNA) && !defined(GGML_HIP_NO_MMQ_MFMA)
226226

227+
#if defined(GGML_USE_HIP) && defined(RDNA4)
228+
#define AMD_WMMA_AVAILABLE
229+
#endif // defined(GGML_USE_HIP) && defined(RDNA4)
230+
227231
// The Volta instructions are in principle available on Turing or newer but they are effectively unusable:
228232
#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
229233
#define VOLTA_MMA_AVAILABLE
@@ -283,6 +287,10 @@ static bool amd_mfma_available(const int cc) {
283287
#endif //!defined(GGML_HIP_NO_MMQ_MFMA)
284288
}
285289

290+
static bool amd_wmma_available(const int cc) {
291+
return GGML_CUDA_CC_IS_RDNA4(cc);
292+
}
293+
286294
static bool volta_mma_available(const int cc) {
287295
return GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_VOLTA;
288296
}

ggml/src/ggml-cuda/convert.cuh

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,15 @@ template<typename dst_t, typename src_t>
3939
return __float2bfloat16(float(x));
4040
} else if constexpr(std::is_same_v<src_t, nv_bfloat16>) {
4141
return __bfloat162float(x);
42+
} else if constexpr(std::is_same_v<src_t, float2> && std::is_same_v<dst_t, half2>) {
43+
return __float22half2_rn(x);
44+
} else if constexpr(std::is_same_v<src_t, float2> && std::is_same_v<dst_t, nv_bfloat162>) {
45+
// bypass compile error on cuda 12.0.1
46+
#ifdef GGML_USE_HIP
47+
return __float22bfloat162_rn(x);
48+
#else
49+
return {x.x, x.y};
50+
#endif // GGML_USE_HIP
4251
} else if constexpr(std::is_same_v<dst_t, int32_t>) {
4352
return int32_t(x);
4453
} else {

ggml/src/ggml-cuda/mma.cuh

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,33 @@ namespace ggml_cuda_mma {
7474
static constexpr int J = J_;
7575

7676
#if defined(GGML_USE_HIP)
77+
#if defined(RDNA4)
78+
static constexpr int ne = I * J / 32;
79+
T x[ne] = {0};
80+
81+
static constexpr __device__ bool supported() {
82+
if (I == 16 && J == 16) return true;
83+
return false;
84+
}
85+
86+
static __device__ __forceinline__ int get_i(const int l) {
87+
if constexpr (I == 16 && J == 16) {
88+
return 8 * (threadIdx.x / 16) + l;
89+
} else {
90+
NO_DEVICE_CODE;
91+
return -1;
92+
}
93+
}
94+
95+
static __device__ __forceinline__ int get_j(const int l) {
96+
if constexpr (I == 16 && J == 16) {
97+
return threadIdx.x % 16;
98+
} else {
99+
NO_DEVICE_CODE;
100+
return -1;
101+
}
102+
}
103+
#else
77104
static constexpr int ne = I * J / 64;
78105
T x[ne] = {0};
79106

@@ -119,6 +146,7 @@ namespace ggml_cuda_mma {
119146
return -1;
120147
}
121148
}
149+
#endif // defined(RDNA4)
122150
#elif __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
123151
static constexpr int ne = I * J / 32;
124152
T x[ne] = {0};
@@ -236,6 +264,32 @@ namespace ggml_cuda_mma {
236264
return -1;
237265
}
238266
}
267+
#elif defined(AMD_WMMA_AVAILABLE)
268+
static constexpr int ne = I * J / 32;
269+
half2 x[ne] = {{0.0f, 0.0f}};
270+
271+
static constexpr __device__ bool supported() {
272+
if (I == 16 && J == 8) return true;
273+
return false;
274+
}
275+
276+
static __device__ __forceinline__ int get_i(const int l) {
277+
if constexpr (I == 16 && J == 8) {
278+
return threadIdx.x % 16;
279+
} else {
280+
NO_DEVICE_CODE;
281+
return -1;
282+
}
283+
}
284+
285+
static __device__ __forceinline__ int get_j(const int l) {
286+
if constexpr (I == 16 && J == 8) {
287+
return 4 * (threadIdx.x / 16) + l;
288+
} else {
289+
NO_DEVICE_CODE;
290+
return -1;
291+
}
292+
}
239293
#else
240294
static constexpr int ne = I * J / WARP_SIZE;
241295
half2 x[ne] = {{0.0f, 0.0f}};
@@ -285,6 +339,34 @@ namespace ggml_cuda_mma {
285339
struct tile<I_, J_, nv_bfloat162> {
286340
static constexpr int I = I_;
287341
static constexpr int J = J_;
342+
343+
#if defined(AMD_WMMA_AVAILABLE)
344+
static constexpr int ne = I * J / 32;
345+
nv_bfloat162 x[ne] = {{0.0f, 0.0f}};
346+
347+
static constexpr __device__ bool supported() {
348+
if (I == 16 && J == 8) return true;
349+
return false;
350+
}
351+
352+
static __device__ __forceinline__ int get_i(const int l) {
353+
if constexpr (I == 16 && J == 8) {
354+
return threadIdx.x % 16;
355+
} else {
356+
NO_DEVICE_CODE;
357+
return -1;
358+
}
359+
}
360+
361+
static __device__ __forceinline__ int get_j(const int l) {
362+
if constexpr (I == 16 && J == 8) {
363+
return 4 * (threadIdx.x / 16) + l;
364+
} else {
365+
NO_DEVICE_CODE;
366+
return -1;
367+
}
368+
}
369+
#else
288370
static constexpr int ne = I * J / WARP_SIZE;
289371
nv_bfloat162 x[ne] = {{0.0f, 0.0f}};
290372

@@ -320,6 +402,7 @@ namespace ggml_cuda_mma {
320402
return -1;
321403
}
322404
}
405+
#endif // defined(AMD_WMMA_AVAILABLE)
323406
};
324407

325408
template <int I, int J>
@@ -353,6 +436,8 @@ namespace ggml_cuda_mma {
353436
const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 2 * (threadIdx.x / t.I));
354437
xi[0] = xs[0];
355438
}
439+
#elif defined(AMD_WMMA_AVAILABLE)
440+
ggml_cuda_memcpy_1<sizeof(t.x)>(t.x, xs0 + t.get_i(0) * stride + t.get_j(0));
356441
#else
357442
#pragma unroll
358443
for (int l = 0; l < t.ne; ++l) {
@@ -639,12 +724,34 @@ namespace ggml_cuda_mma {
639724
: "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7])
640725
: "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[3]));
641726
#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
727+
#elif defined(AMD_WMMA_AVAILABLE)
728+
using halfx8_t = __attribute__((ext_vector_type(8))) _Float16;
729+
using floatx8_t = __attribute__((ext_vector_type(8))) float;
730+
floatx8_t& acc_frag = reinterpret_cast<floatx8_t&>(D.x[0]);
731+
const halfx8_t& a_frag = reinterpret_cast<const halfx8_t&>(A.x[0]);
732+
const halfx8_t& b_frag = reinterpret_cast<const halfx8_t&>(B.x[0]);
733+
acc_frag = __builtin_amdgcn_wmma_f32_16x16x16_f16_w32_gfx12(a_frag, b_frag, acc_frag);
642734
#else
643735
GGML_UNUSED_VARS(D, A, B);
644736
NO_DEVICE_CODE;
645737
#endif // TURING_MMA_AVAILABLE
646738
}
647739

740+
static __device__ __forceinline__ void mma(
741+
tile<16, 16, float> & D, const tile<16, 8, nv_bfloat162> & A, const tile<16, 8, nv_bfloat162> & B) {
742+
#if defined(AMD_WMMA_AVAILABLE)
743+
using bf16x8_t = __attribute__((ext_vector_type(8))) __bf16;
744+
using floatx8_t = __attribute__((ext_vector_type(8))) float;
745+
floatx8_t& acc_frag = reinterpret_cast<floatx8_t&>(D.x[0]);
746+
const bf16x8_t& a_frag = reinterpret_cast<const bf16x8_t&>(A.x[0]);
747+
const bf16x8_t& b_frag = reinterpret_cast<const bf16x8_t&>(B.x[0]);
748+
acc_frag = __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32_gfx12(a_frag, b_frag, acc_frag);
749+
#else
750+
GGML_UNUSED_VARS(D, A, B);
751+
NO_DEVICE_CODE;
752+
#endif // AMPERE_MMA_AVAILABLE
753+
}
754+
648755
static __device__ __forceinline__ void mma(
649756
tile<16, 16, int> & D, const tile<16, 8, int> & A, const tile<16, 8, int> & B) {
650757
#if defined(AMD_MFMA_AVAILABLE)

ggml/src/ggml-cuda/mmf.cu

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const
151151
return false;
152152
}
153153
} else {
154-
if (src1_ncols > 16) {
154+
if (src1_ncols > 16 || GGML_CUDA_CC_IS_RDNA4(cc)) {
155155
return false;
156156
}
157157
}
@@ -160,9 +160,9 @@ bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const
160160
case GGML_TYPE_F32:
161161
return ampere_mma_available(cc);
162162
case GGML_TYPE_F16:
163-
return volta_mma_available(cc) || turing_mma_available(cc);
163+
return volta_mma_available(cc) || turing_mma_available(cc) || amd_wmma_available(cc);
164164
case GGML_TYPE_BF16:
165-
return ampere_mma_available(cc);
165+
return ampere_mma_available(cc) || amd_wmma_available(cc);
166166
default:
167167
return false;
168168
}

ggml/src/ggml-cuda/mmf.cuh

Lines changed: 53 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
#include "mma.cuh"
44
#include "common.cuh"
5+
#include "convert.cuh"
56

67
using namespace ggml_cuda_mma;
78

@@ -27,20 +28,35 @@ static __global__ void mul_mat_f(
2728
const int stride_col_id, const int stride_row_id,
2829
const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
2930
const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) {
30-
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
31+
// TODO: handle this in a consistent and simpler way after AMD MFMA support has been added
32+
#if (!defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)) || defined(AMD_WMMA_AVAILABLE)
33+
#if defined(AMD_WMMA_AVAILABLE)
34+
// Special case for tf32, just dummy mma layout as wmma doesn't support it.
35+
constexpr int tile_B_I = std::is_same_v<T, float> ? 8 : 16;
36+
constexpr int tile_C_J = std::is_same_v<T, float> ? 8 : 16;
37+
typedef tile<16, 8, T> tile_A;
38+
typedef tile<tile_B_I, 8, T> tile_B;
39+
typedef tile<16, tile_C_J, float> tile_C;
40+
41+
constexpr bool a_supported = tile_A::supported();
42+
constexpr bool b_supported = tile_B::supported();
43+
constexpr bool c_supported = tile_C::supported();
44+
constexpr bool supported = a_supported && b_supported && c_supported;
45+
#else
3146
constexpr bool I_16_supported = tile<16, 8, T>::supported() && tile<16, 8, float>::supported();
3247
constexpr bool I_32_supported = tile<32, 8, T>::supported() && tile<32, 8, float>::supported();
33-
34-
if (!I_16_supported && !I_32_supported) {
35-
NO_DEVICE_CODE;
36-
return;
37-
}
48+
constexpr bool supported = I_16_supported || I_32_supported;
3849

3950
constexpr int I_preferred = I_16_supported ? 16 : 32; // For Turing MMA both work but 16 is ~1% faster.
4051

4152
typedef tile<I_preferred, 8, T> tile_A;
4253
typedef tile<8, 8, T> tile_B;
4354
typedef tile<I_preferred, 8, float> tile_C;
55+
#endif // defined(AMD_WMMA_AVAILABLE)
56+
if constexpr (!supported) {
57+
NO_DEVICE_CODE;
58+
return;
59+
}
4460

4561
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
4662
constexpr int tile_k_padded = warp_size + 4;
@@ -161,11 +177,11 @@ static __global__ void mul_mat_f(
161177

162178
if constexpr (!has_ids) {
163179
const float2 tmp = j < cols_per_block ? y2[j*stride_col_y + col] : make_float2(0.0f, 0.0f);
164-
tile_xy[j0*tile_k_padded + threadIdx.x] = {tmp.x, tmp.y};
180+
tile_xy[j0*tile_k_padded + threadIdx.x] = ggml_cuda_cast<T>(tmp);
165181
} else {
166182
const bool valid = j < cols_per_block && (col_base + j) < ncols_dst_total && slot_map[j] >= 0;
167183
float2 tmp = valid ? *(const float2*) &y[slot_map[j]*stride_channel_y + 2*(j*stride_col_y + col)] : make_float2(0.0f, 0.0f);
168-
tile_xy[j0*tile_k_padded + threadIdx.x] = {tmp.x, tmp.y};
184+
tile_xy[j0*tile_k_padded + threadIdx.x] = ggml_cuda_cast<T>(tmp);
169185
}
170186
}
171187
} else {
@@ -239,7 +255,7 @@ static __global__ void mul_mat_f(
239255
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
240256
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
241257
NO_DEVICE_CODE;
242-
#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
258+
#endif // (!defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)) || defined(AMD_WMMA_AVAILABLE)
243259
}
244260

245261
//This kernel is for larger batch sizes of mul_mat_id
@@ -253,20 +269,35 @@ static __global__ void mul_mat_f_ids(
253269
const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
254270
const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst,
255271
const uint3 sis1_fd, const uint3 nch_fd) {
256-
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
272+
// TODO: handle this in a consistent and simpler way after AMD MFMA support has been added
273+
#if (!defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)) || defined(AMD_WMMA_AVAILABLE)
274+
#if defined(AMD_WMMA_AVAILABLE)
275+
// Special case for tf32, just dummy mma layout as wmma doesn't support it.
276+
constexpr int tile_B_I = std::is_same_v<T, float> ? 8 : 16;
277+
constexpr int tile_C_J = std::is_same_v<T, float> ? 8 : 16;
278+
typedef tile<16, 8, T> tile_A;
279+
typedef tile<tile_B_I, 8, T> tile_B;
280+
typedef tile<16, tile_C_J, float> tile_C;
281+
282+
constexpr bool a_supported = tile_A::supported();
283+
constexpr bool b_supported = tile_B::supported();
284+
constexpr bool c_supported = tile_C::supported();
285+
constexpr bool supported = a_supported && b_supported && c_supported;
286+
#else
257287
constexpr bool I_16_supported = tile<16, 8, T>::supported() && tile<16, 8, float>::supported();
258288
constexpr bool I_32_supported = tile<32, 8, T>::supported() && tile<32, 8, float>::supported();
289+
constexpr bool supported = I_16_supported || I_32_supported;
259290

260-
if (!I_16_supported && !I_32_supported) {
261-
NO_DEVICE_CODE;
262-
return;
263-
}
264-
265-
constexpr int I_preferred = I_16_supported ? 16 : 32; // For Turing MMA both work butr 16 is ~1% faster.
291+
constexpr int I_preferred = I_16_supported ? 16 : 32; // For Turing MMA both work but 16 is ~1% faster.
266292

267293
typedef tile<I_preferred, 8, T> tile_A;
268294
typedef tile<8, 8, T> tile_B;
269295
typedef tile<I_preferred, 8, float> tile_C;
296+
#endif // defined(AMD_WMMA_AVAILABLE)
297+
if constexpr (!supported) {
298+
NO_DEVICE_CODE;
299+
return;
300+
}
270301

271302
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
272303
constexpr int tile_k_padded = warp_size + 4;
@@ -408,7 +439,7 @@ static __global__ void mul_mat_f_ids(
408439
#pragma unroll
409440
for (int j0 = 0; j0 < tile_B::I; ++j0) {
410441
const float2 tmp = vals_buf[curr_buf][j0];
411-
tile_xy[j0*tile_k_padded + threadIdx.x] = {tmp.x, tmp.y};
442+
tile_xy[j0*tile_k_padded + threadIdx.x] = ggml_cuda_cast<T>(tmp);
412443
}
413444

414445
if (itB + 1 < ntB) {
@@ -492,7 +523,7 @@ static __global__ void mul_mat_f_ids(
492523
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
493524
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, sis1_fd, nch_fd);
494525
NO_DEVICE_CODE;
495-
#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
526+
#endif // (!defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)) || defined(AMD_WMMA_AVAILABLE)
496527
}
497528

498529
template<typename T, int cols_per_block, int nwarps>
@@ -554,7 +585,8 @@ void mul_mat_f_cuda(
554585
cudaStream_t stream, const mmf_ids_data * ids_data) {
555586
typedef tile<16, 8, T> tile_A_16;
556587
typedef tile<32, 8, T> tile_A_32;
557-
typedef tile< 8, 8, T> tile_B;
588+
typedef tile<16, 8, T> tile_B_16;
589+
typedef tile< 8, 8, T> tile_B_8;
558590

559591
GGML_ASSERT(ncols_x % 2 == 0);
560592
GGML_ASSERT(stride_row % 2 == 0);
@@ -581,7 +613,8 @@ void mul_mat_f_cuda(
581613

582614
constexpr int rows_per_block = MMF_ROWS_PER_BLOCK;
583615
const int nbytes_shared_iter = nwarps_best * (volta_mma_available(cc) ? tile_A_32::I : tile_A_16::I) * (warp_size + 4) * 4;
584-
const int nbytes_shared_combine = GGML_PAD(cols_per_block, tile_B::I) * (nwarps_best*rows_per_block + 4) * 4;
616+
const int nbytes_cols_per_block_pad = amd_wmma_available(cc) ? tile_B_16::I : tile_B_8::I;
617+
const int nbytes_shared_combine = GGML_PAD(cols_per_block, nbytes_cols_per_block_pad) * (nwarps_best*rows_per_block + 4) * 4;
585618
const int nbytes_shared = std::max(nbytes_shared_iter, nbytes_shared_combine);
586619
const int nbytes_slotmap = ids ? GGML_PAD(cols_per_block, 16) * sizeof(int) : 0;
587620
const int nbytes_shared_total = nbytes_shared + nbytes_slotmap;

0 commit comments

Comments
 (0)