From ea2f469c5c1dcbe55ce409a036433acc076f9307 Mon Sep 17 00:00:00 2001 From: Michael Mi Date: Tue, 17 Jun 2025 22:17:52 -0700 Subject: [PATCH 1/2] refactor: move TileShape into launch_mha_kernel_sm80 --- .../attention/generate_instantiation_cu.py | 31 ++----- src/kernels/attention/mha_dispatch_sm80.cuh | 85 +++++-------------- src/kernels/attention/mha_kernel_sm80.cuh | 15 +++- 3 files changed, 37 insertions(+), 94 deletions(-) diff --git a/src/kernels/attention/generate_instantiation_cu.py b/src/kernels/attention/generate_instantiation_cu.py index bfd8fe4d..9a22918e 100755 --- a/src/kernels/attention/generate_instantiation_cu.py +++ b/src/kernels/attention/generate_instantiation_cu.py @@ -22,19 +22,18 @@ MHA_KERNEL_TEMPLATE = """ #include "mha_kernel_sm80.cuh" // IWYU pragma: export #include "mha_params.h" // IWYU pragma: export -#include "mha_traits_sm80.h" // IWYU pragma: export namespace llm {{ -using Traits = MHATraitsSM80<{DTYPE}, {HEAD_DIM}, {BLK_M}, {BLK_N}, {BLK_K}>; using Params = MHAPagedKVParams; -template void launch_mha_kernel_sm80(const Params& params, + /*LOCAL=*/{LOCAL}, + Params>(const Params& params, cudaStream_t stream); }} // namespace llm """ @@ -59,9 +58,6 @@ class MHAKernel: dtype: str head_dim: int - blk_m: int - blk_n: int - blk_k: int even_k: bool alibi: bool soft_cap: bool @@ -69,14 +65,9 @@ class MHAKernel: @property def template(self) -> str: - assert self.head_dim % self.blk_k == 0 - return MHA_KERNEL_TEMPLATE.format( DTYPE=DTYPE_MAP[self.dtype], HEAD_DIM=self.head_dim, - BLK_M=self.blk_m, - BLK_N=self.blk_n, - BLK_K=self.blk_k, EVEN_K=BOOL_MAP[self.even_k], ALIBI=BOOL_MAP[self.alibi], SOFT_CAP=BOOL_MAP[self.soft_cap], @@ -88,7 +79,7 @@ def filename(self) -> str: def to_str(val: bool) -> str: return "1" if val else "0" - return f"mha_{self.dtype}_hd{self.head_dim}_m{self.blk_m}_n{self.blk_n}_k{self.blk_k}_ek{to_str(self.even_k)}_al{to_str(self.alibi)}_sc{to_str(self.soft_cap)}_lc{to_str(self.local)}_sm80.cu" + return f"mha_{self.dtype}_hd{self.head_dim}_ek{to_str(self.even_k)}_al{to_str(self.alibi)}_sc{to_str(self.soft_cap)}_lc{to_str(self.local)}_sm80.cu" @dataclass @@ -125,9 +116,6 @@ def gen_mha_kernels() -> Iterator[MHAKernel]: for ( dtype, head_dim, - blk_m, - blk_n, - blk_k, even_k, alibi, soft_cap, @@ -135,23 +123,14 @@ def gen_mha_kernels() -> Iterator[MHAKernel]: ) in itertools.product( ["fp16", "bf16"], # dtype [64, 96, 128, 256], # head_dim - [64], # blk_m - [64], # blk_n - [32, 64], # blk_k [False, True], # even_k [False, True], # alibi [False, True], # soft_cap [False, True], # local ): - # skip invalid configurations - if head_dim % blk_k != 0: - continue yield MHAKernel( dtype=dtype, head_dim=head_dim, - blk_m=blk_m, - blk_n=blk_n, - blk_k=blk_k, even_k=even_k, alibi=alibi, soft_cap=soft_cap, diff --git a/src/kernels/attention/mha_dispatch_sm80.cuh b/src/kernels/attention/mha_dispatch_sm80.cuh index 314cbb6b..6546e3fa 100644 --- a/src/kernels/attention/mha_dispatch_sm80.cuh +++ b/src/kernels/attention/mha_dispatch_sm80.cuh @@ -3,42 +3,19 @@ #include #include -#include "mha_traits_sm80.h" #include "static_dispatch.h" namespace llm { // forward declaration -template + bool LOCAL, + typename Params> void launch_mha_kernel_sm80(const Params& params, cudaStream_t stream); -namespace detail { - -template -void dispatch_mha_kernel_sm80(const Params& params, cudaStream_t stream) { - // dispatch to proper kernel instantiation based on params - DISPATCH_BOOL(params.head_dim == Traits::kHeadDim, EVEN_K, [&] { - DISPATCH_BOOL(params.alibi_slopes_ptr != nullptr, ALIBI, [&] { - DISPATCH_BOOL(params.logits_soft_cap > 0, SOFT_CAP, [&] { - DISPATCH_BOOL(params.sliding_window >= 0, LOCAL, [&] { - launch_mha_kernel_sm80(params, stream); - }); - }); - }); - }); -} - -} // namespace detail - // user-facing function to run the attention kernel template void run_mha_kernel_sm80(Params& params, cudaStream_t stream = nullptr) { @@ -56,43 +33,23 @@ void run_mha_kernel_sm80(Params& params, cudaStream_t stream = nullptr) { // * 8.6 | 8.9 : 0, 8, 16, 32, 64, 100 // * 9.0 | 10.x: 0, 8, 16, 32, 64, 100, 132, 164, 196, 228 // * 12.0 : 0, 8, 16, 32, 64, 100 - if constexpr (HEAD_DIM == 64) { - using Traits = MHATraitsSM80; - detail::dispatch_mha_kernel_sm80(params, stream); - } else if constexpr (HEAD_DIM == 96) { - using Traits = MHATraitsSM80; - detail::dispatch_mha_kernel_sm80(params, stream); - } else if constexpr (HEAD_DIM == 128) { - using Traits = MHATraitsSM80; - detail::dispatch_mha_kernel_sm80(params, stream); - } else if constexpr (HEAD_DIM == 256) { - using Traits = MHATraitsSM80; - detail::dispatch_mha_kernel_sm80(params, stream); - } else { - // use the default block size - using Traits = MHATraitsSM80 0, SOFT_CAP, [&] { + DISPATCH_BOOL(params.sliding_window >= 0, LOCAL, [&] { + launch_mha_kernel_sm80; - detail::dispatch_mha_kernel_sm80(params, stream); - } + EVEN_K, + ALIBI, + SOFT_CAP, + LOCAL, + Params>(params, stream); + }); + }); + }); + }); } -} // namespace llm \ No newline at end of file +} // namespace llm diff --git a/src/kernels/attention/mha_kernel_sm80.cuh b/src/kernels/attention/mha_kernel_sm80.cuh index 3f87a6c6..ae06a309 100644 --- a/src/kernels/attention/mha_kernel_sm80.cuh +++ b/src/kernels/attention/mha_kernel_sm80.cuh @@ -13,6 +13,7 @@ #include "layout_convertor.h" #include "mask.h" #include "mha_tile.h" +#include "mha_traits_sm80.h" #include "online_softmax.cuh" namespace llm { @@ -436,17 +437,23 @@ __global__ __launch_bounds__(Traits::kThreadNum) void mha_kernel_sm80( epilogue(tOrO); } -template + bool LOCAL, + typename Params> void launch_mha_kernel_sm80(const Params& params, cudaStream_t stream) { const auto batch_size = params.batch_size; const auto n_kv_heads = params.n_kv_heads; const auto max_q_packed_len = params.max_q_len * params.group_size; + constexpr int BLK_M = 64; + constexpr int BLK_N = 64; + constexpr int BLK_K = HEAD_DIM % 64 == 0 ? 64 : 32; + using Traits = MHATraitsSM80; + const auto smem_size = sizeof(MHASharedStorage); auto mha_kernel = mha_kernel_sm80; @@ -460,4 +467,4 @@ void launch_mha_kernel_sm80(const Params& params, cudaStream_t stream) { mha_kernel<<>>(params); } -} // namespace llm \ No newline at end of file +} // namespace llm From 6b263153058f45cadebd2c35071787a12d21e7ce Mon Sep 17 00:00:00 2001 From: Michael Mi Date: Tue, 17 Jun 2025 22:19:07 -0700 Subject: [PATCH 2/2] update --- src/kernels/attention/mha_dispatch_sm80.cuh | 12 ------------ src/kernels/attention/mha_kernel_sm80.cuh | 12 ++++++++++++ 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/src/kernels/attention/mha_dispatch_sm80.cuh b/src/kernels/attention/mha_dispatch_sm80.cuh index 6546e3fa..6c990329 100644 --- a/src/kernels/attention/mha_dispatch_sm80.cuh +++ b/src/kernels/attention/mha_dispatch_sm80.cuh @@ -22,18 +22,6 @@ void run_mha_kernel_sm80(Params& params, cudaStream_t stream = nullptr) { // normalize params that for performance optimization params.normalize(); - // TODO: tune block shape MNK based on the head dim and smem size - // https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#features-and-technical-specifications-technical-specifications-per-compute-capability - // SM | 7.0 | 7.2 | 7.5 | 8.0 | 8.6 | 8.7 | 8.9 | 9.0 | 10.x | 12.0| - // Max SMEM (KB)| 96 | 64 | 164 | 100 | 164 | 100 | 228 | 100 | - // valid dynamic shared memory sizes for different compute capabilities: - // * 7.0 | 7.2 : 0, 8, 16, 32, 64, 96 - // * 7.5 : 0, 32, 64 - // * 8.0 | 8.7 : 0, 8, 16, 32, 64, 100, 132, 164 - // * 8.6 | 8.9 : 0, 8, 16, 32, 64, 100 - // * 9.0 | 10.x: 0, 8, 16, 32, 64, 100, 132, 164, 196, 228 - // * 12.0 : 0, 8, 16, 32, 64, 100 - // dispatch to proper kernel instantiation based on params DISPATCH_BOOL(params.head_dim == HEAD_DIM, EVEN_K, [&] { DISPATCH_BOOL(params.alibi_slopes_ptr != nullptr, ALIBI, [&] { diff --git a/src/kernels/attention/mha_kernel_sm80.cuh b/src/kernels/attention/mha_kernel_sm80.cuh index ae06a309..5b57a281 100644 --- a/src/kernels/attention/mha_kernel_sm80.cuh +++ b/src/kernels/attention/mha_kernel_sm80.cuh @@ -449,6 +449,18 @@ void launch_mha_kernel_sm80(const Params& params, cudaStream_t stream) { const auto n_kv_heads = params.n_kv_heads; const auto max_q_packed_len = params.max_q_len * params.group_size; + // TODO: tune block shape MNK based on the head dim and smem size + // https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#features-and-technical-specifications-technical-specifications-per-compute-capability + // SM | 7.0 | 7.2 | 7.5 | 8.0 | 8.6 | 8.7 | 8.9 | 9.0 | 10.x | 12.0| + // Max SMEM (KB)| 96 | 64 | 164 | 100 | 164 | 100 | 228 | 100 | + // valid dynamic shared memory sizes for different compute capabilities: + // * 7.0 | 7.2 : 0, 8, 16, 32, 64, 96 + // * 7.5 : 0, 32, 64 + // * 8.0 | 8.7 : 0, 8, 16, 32, 64, 100, 132, 164 + // * 8.6 | 8.9 : 0, 8, 16, 32, 64, 100 + // * 9.0 | 10.x: 0, 8, 16, 32, 64, 100, 132, 164, 196, 228 + // * 12.0 : 0, 8, 16, 32, 64, 100 + constexpr int BLK_M = 64; constexpr int BLK_N = 64; constexpr int BLK_K = HEAD_DIM % 64 == 0 ? 64 : 32;