33#include < cute/int_tuple.hpp>
44#include < cute/layout.hpp>
55
6- #include " mha_kernel_sm80.cuh"
76#include " mha_traits_sm80.h"
87#include " static_dispatch.h"
98
109namespace llm {
11- namespace detail {
10+ // forward declaration
1211template <typename Traits,
1312 typename Params,
1413 bool EVEN_K,
1514 bool ALIBI,
1615 bool SOFT_CAP,
1716 bool LOCAL>
18- void launch_mha_kernel (const Params& params, cudaStream_t stream) {
19- const auto batch_size = params.batch_size ;
20- const auto n_kv_heads = params.n_kv_heads ;
21- const auto max_q_packed_len = params.max_q_len * params.group_size ;
17+ void launch_mha_kernel_sm80 (const Params& params, cudaStream_t stream);
2218
23- const auto smem_size = Traits::kSmemSize ;
24- auto mha_kernel =
25- mha_kernel_sm80<Traits, Params, EVEN_K, ALIBI, SOFT_CAP, LOCAL>;
26- cudaFuncSetAttribute (
27- mha_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
28- // TODO: support persistent kernels
29- dim3 grid (cute::ceil_div (max_q_packed_len, Traits::kBlockM ),
30- batch_size,
31- n_kv_heads);
32- dim3 block = Traits::kThreadNum ;
33- mha_kernel<<<grid, block, smem_size, stream>>> (params);
34- }
19+ namespace detail {
3520
3621template <typename Traits, typename Params>
37- void run_mha_kernel (const Params& params, cudaStream_t stream) {
22+ void dispatch_mha_kernel_sm80 (const Params& params, cudaStream_t stream) {
3823 // dispatch to proper kernel instantiation based on params
3924 DISPATCH_BOOL (params.head_dim == Traits::kHeadDim , EVEN_K, [&] {
4025 DISPATCH_BOOL (params.alibi_slopes_ptr != nullptr , ALIBI, [&] {
4126 DISPATCH_BOOL (params.logits_soft_cap > 0 , SOFT_CAP, [&] {
4227 DISPATCH_BOOL (params.sliding_window >= 0 , LOCAL, [&] {
43- launch_mha_kernel<Traits, Params, EVEN_K, ALIBI, SOFT_CAP, LOCAL>(
44- params, stream);
28+ launch_mha_kernel_sm80<Traits,
29+ Params,
30+ EVEN_K,
31+ ALIBI,
32+ SOFT_CAP,
33+ LOCAL>(params, stream);
4534 });
4635 });
4736 });
@@ -63,36 +52,36 @@ void run_mha_kernel_sm80(Params& params, cudaStream_t stream = nullptr) {
6352 /* BLK_M=*/ 64 ,
6453 /* BLK_N=*/ 64 ,
6554 /* BLK_K=*/ 64 >;
66- detail::run_mha_kernel <Traits>(params, stream);
55+ detail::dispatch_mha_kernel_sm80 <Traits>(params, stream);
6756 } else if constexpr (HEAD_DIM == 96 ) {
6857 using Traits = MHATraitsSM80<Dtype,
6958 HEAD_DIM,
7059 /* BLK_M=*/ 64 ,
7160 /* BLK_N=*/ 64 ,
7261 /* BLK_K=*/ 32 >;
73- detail::run_mha_kernel <Traits>(params, stream);
62+ detail::dispatch_mha_kernel_sm80 <Traits>(params, stream);
7463 } else if constexpr (HEAD_DIM == 128 ) {
7564 using Traits = MHATraitsSM80<Dtype,
7665 HEAD_DIM,
7766 /* BLK_M=*/ 64 ,
7867 /* BLK_N=*/ 64 ,
7968 /* BLK_K=*/ 64 >;
80- detail::run_mha_kernel <Traits>(params, stream);
69+ detail::dispatch_mha_kernel_sm80 <Traits>(params, stream);
8170 } else if constexpr (HEAD_DIM == 256 ) {
8271 using Traits = MHATraitsSM80<Dtype,
8372 HEAD_DIM,
8473 /* BLK_M=*/ 64 ,
8574 /* BLK_N=*/ 64 ,
8675 /* BLK_K=*/ 64 >;
87- detail::run_mha_kernel <Traits>(params, stream);
76+ detail::dispatch_mha_kernel_sm80 <Traits>(params, stream);
8877 } else {
8978 // use the default block size
9079 using Traits = MHATraitsSM80<Dtype,
9180 HEAD_DIM,
9281 /* BLK_M=*/ 64 ,
9382 /* BLK_N=*/ 64 ,
9483 /* BLK_K=*/ 64 >;
95- detail::run_mha_kernel <Traits>(params, stream);
84+ detail::dispatch_mha_kernel_sm80 <Traits>(params, stream);
9685 }
9786}
9887
0 commit comments