22
33#include " mma.cuh"
44#include " common.cuh"
5+ #include " convert.cuh"
56
67using namespace ggml_cuda_mma ;
78
@@ -27,20 +28,35 @@ static __global__ void mul_mat_f(
2728 const int stride_col_id, const int stride_row_id,
2829 const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
2930 const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) {
30- #if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
31+ // TODO: handle this in a consistent and simpler way after AMD MFMA support has been added
32+ #if (!defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)) || defined(AMD_WMMA_AVAILABLE)
33+ #if defined(AMD_WMMA_AVAILABLE)
34+ // Special case for tf32, just dummy mma layout as wmma doesn't support it.
35+ constexpr int tile_B_I = std::is_same_v<T, float > ? 8 : 16 ;
36+ constexpr int tile_C_J = std::is_same_v<T, float > ? 8 : 16 ;
37+ typedef tile<16 , 8 , T> tile_A;
38+ typedef tile<tile_B_I, 8 , T> tile_B;
39+ typedef tile<16 , tile_C_J, float > tile_C;
40+
41+ constexpr bool a_supported = tile_A::supported ();
42+ constexpr bool b_supported = tile_B::supported ();
43+ constexpr bool c_supported = tile_C::supported ();
44+ constexpr bool supported = a_supported && b_supported && c_supported;
45+ #else
3146 constexpr bool I_16_supported = tile<16 , 8 , T>::supported () && tile<16 , 8 , float >::supported ();
3247 constexpr bool I_32_supported = tile<32 , 8 , T>::supported () && tile<32 , 8 , float >::supported ();
33-
34- if (!I_16_supported && !I_32_supported) {
35- NO_DEVICE_CODE;
36- return ;
37- }
48+ constexpr bool supported = I_16_supported || I_32_supported;
3849
3950 constexpr int I_preferred = I_16_supported ? 16 : 32 ; // For Turing MMA both work but 16 is ~1% faster.
4051
4152 typedef tile<I_preferred, 8 , T> tile_A;
4253 typedef tile<8 , 8 , T> tile_B;
4354 typedef tile<I_preferred, 8 , float > tile_C;
55+ #endif // defined(AMD_WMMA_AVAILABLE)
56+ if constexpr (!supported) {
57+ NO_DEVICE_CODE;
58+ return ;
59+ }
4460
4561 constexpr int warp_size = ggml_cuda_get_physical_warp_size ();
4662 constexpr int tile_k_padded = warp_size + 4 ;
@@ -161,11 +177,11 @@ static __global__ void mul_mat_f(
161177
162178 if constexpr (!has_ids) {
163179 const float2 tmp = j < cols_per_block ? y2[j*stride_col_y + col] : make_float2 (0 .0f , 0 .0f );
164- tile_xy[j0*tile_k_padded + threadIdx .x ] = { tmp. x , tmp. y } ;
180+ tile_xy[j0*tile_k_padded + threadIdx .x ] = ggml_cuda_cast<T>( tmp) ;
165181 } else {
166182 const bool valid = j < cols_per_block && (col_base + j) < ncols_dst_total && slot_map[j] >= 0 ;
167183 float2 tmp = valid ? *(const float2 *) &y[slot_map[j]*stride_channel_y + 2 *(j*stride_col_y + col)] : make_float2 (0 .0f , 0 .0f );
168- tile_xy[j0*tile_k_padded + threadIdx .x ] = { tmp. x , tmp. y } ;
184+ tile_xy[j0*tile_k_padded + threadIdx .x ] = ggml_cuda_cast<T>( tmp) ;
169185 }
170186 }
171187 } else {
@@ -239,7 +255,7 @@ static __global__ void mul_mat_f(
239255 channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
240256 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
241257 NO_DEVICE_CODE;
242- #endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
258+ #endif // ( !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)) || defined(AMD_WMMA_AVAILABLE )
243259}
244260
245261// This kernel is for larger batch sizes of mul_mat_id
@@ -253,20 +269,35 @@ static __global__ void mul_mat_f_ids(
253269 const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
254270 const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst,
255271 const uint3 sis1_fd, const uint3 nch_fd) {
256- #if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
272+ // TODO: handle this in a consistent and simpler way after AMD MFMA support has been added
273+ #if (!defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)) || defined(AMD_WMMA_AVAILABLE)
274+ #if defined(AMD_WMMA_AVAILABLE)
275+ // Special case for tf32, just dummy mma layout as wmma doesn't support it.
276+ constexpr int tile_B_I = std::is_same_v<T, float > ? 8 : 16 ;
277+ constexpr int tile_C_J = std::is_same_v<T, float > ? 8 : 16 ;
278+ typedef tile<16 , 8 , T> tile_A;
279+ typedef tile<tile_B_I, 8 , T> tile_B;
280+ typedef tile<16 , tile_C_J, float > tile_C;
281+
282+ constexpr bool a_supported = tile_A::supported ();
283+ constexpr bool b_supported = tile_B::supported ();
284+ constexpr bool c_supported = tile_C::supported ();
285+ constexpr bool supported = a_supported && b_supported && c_supported;
286+ #else
257287 constexpr bool I_16_supported = tile<16 , 8 , T>::supported () && tile<16 , 8 , float >::supported ();
258288 constexpr bool I_32_supported = tile<32 , 8 , T>::supported () && tile<32 , 8 , float >::supported ();
289+ constexpr bool supported = I_16_supported || I_32_supported;
259290
260- if (!I_16_supported && !I_32_supported) {
261- NO_DEVICE_CODE;
262- return ;
263- }
264-
265- constexpr int I_preferred = I_16_supported ? 16 : 32 ; // For Turing MMA both work butr 16 is ~1% faster.
291+ constexpr int I_preferred = I_16_supported ? 16 : 32 ; // For Turing MMA both work but 16 is ~1% faster.
266292
267293 typedef tile<I_preferred, 8 , T> tile_A;
268294 typedef tile<8 , 8 , T> tile_B;
269295 typedef tile<I_preferred, 8 , float > tile_C;
296+ #endif // defined(AMD_WMMA_AVAILABLE)
297+ if constexpr (!supported) {
298+ NO_DEVICE_CODE;
299+ return ;
300+ }
270301
271302 constexpr int warp_size = ggml_cuda_get_physical_warp_size ();
272303 constexpr int tile_k_padded = warp_size + 4 ;
@@ -408,7 +439,7 @@ static __global__ void mul_mat_f_ids(
408439#pragma unroll
409440 for (int j0 = 0 ; j0 < tile_B::I; ++j0) {
410441 const float2 tmp = vals_buf[curr_buf][j0];
411- tile_xy[j0*tile_k_padded + threadIdx .x ] = { tmp. x , tmp. y } ;
442+ tile_xy[j0*tile_k_padded + threadIdx .x ] = ggml_cuda_cast<T>( tmp) ;
412443 }
413444
414445 if (itB + 1 < ntB) {
@@ -492,7 +523,7 @@ static __global__ void mul_mat_f_ids(
492523 channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
493524 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, sis1_fd, nch_fd);
494525 NO_DEVICE_CODE;
495- #endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
526+ #endif // ( !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)) || defined(AMD_WMMA_AVAILABLE )
496527}
497528
498529template <typename T, int cols_per_block, int nwarps>
@@ -554,7 +585,8 @@ void mul_mat_f_cuda(
554585 cudaStream_t stream, const mmf_ids_data * ids_data) {
555586 typedef tile<16 , 8 , T> tile_A_16;
556587 typedef tile<32 , 8 , T> tile_A_32;
557- typedef tile< 8 , 8 , T> tile_B;
588+ typedef tile<16 , 8 , T> tile_B_16;
589+ typedef tile< 8 , 8 , T> tile_B_8;
558590
559591 GGML_ASSERT (ncols_x % 2 == 0 );
560592 GGML_ASSERT (stride_row % 2 == 0 );
@@ -581,7 +613,8 @@ void mul_mat_f_cuda(
581613
582614 constexpr int rows_per_block = MMF_ROWS_PER_BLOCK;
583615 const int nbytes_shared_iter = nwarps_best * (volta_mma_available (cc) ? tile_A_32::I : tile_A_16::I) * (warp_size + 4 ) * 4 ;
584- const int nbytes_shared_combine = GGML_PAD (cols_per_block, tile_B::I) * (nwarps_best*rows_per_block + 4 ) * 4 ;
616+ const int nbytes_cols_per_block_pad = amd_wmma_available (cc) ? tile_B_16::I : tile_B_8::I;
617+ const int nbytes_shared_combine = GGML_PAD (cols_per_block, nbytes_cols_per_block_pad) * (nwarps_best*rows_per_block + 4 ) * 4 ;
585618 const int nbytes_shared = std::max (nbytes_shared_iter, nbytes_shared_combine);
586619 const int nbytes_slotmap = ids ? GGML_PAD (cols_per_block, 16 ) * sizeof (int ) : 0 ;
587620 const int nbytes_shared_total = nbytes_shared + nbytes_slotmap;
0 commit comments