Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 5 additions & 26 deletions src/kernels/attention/generate_instantiation_cu.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,19 +22,18 @@
MHA_KERNEL_TEMPLATE = """
#include "mha_kernel_sm80.cuh" // IWYU pragma: export
#include "mha_params.h" // IWYU pragma: export
#include "mha_traits_sm80.h" // IWYU pragma: export

namespace llm {{

using Traits = MHATraitsSM80<{DTYPE}, {HEAD_DIM}, {BLK_M}, {BLK_N}, {BLK_K}>;
using Params = MHAPagedKVParams;

template void launch_mha_kernel_sm80<Traits,
Params,
template void launch_mha_kernel_sm80</*DTYPE=*/{DTYPE},
/*HEAD_DIM=*/{HEAD_DIM},
/*EVEN_K=*/{EVEN_K},
/*ALIBI=*/{ALIBI},
/*SOFT_CAP=*/{SOFT_CAP},
/*LOCAL=*/{LOCAL}>(const Params& params,
/*LOCAL=*/{LOCAL},
Params>(const Params& params,
cudaStream_t stream);
}} // namespace llm
"""
Expand All @@ -59,24 +58,16 @@
class MHAKernel:
dtype: str
head_dim: int
blk_m: int
blk_n: int
blk_k: int
even_k: bool
alibi: bool
soft_cap: bool
local: bool

@property
def template(self) -> str:
assert self.head_dim % self.blk_k == 0

return MHA_KERNEL_TEMPLATE.format(
DTYPE=DTYPE_MAP[self.dtype],
HEAD_DIM=self.head_dim,
BLK_M=self.blk_m,
BLK_N=self.blk_n,
BLK_K=self.blk_k,
EVEN_K=BOOL_MAP[self.even_k],
ALIBI=BOOL_MAP[self.alibi],
SOFT_CAP=BOOL_MAP[self.soft_cap],
Expand All @@ -88,7 +79,7 @@ def filename(self) -> str:
def to_str(val: bool) -> str:
return "1" if val else "0"

return f"mha_{self.dtype}_hd{self.head_dim}_m{self.blk_m}_n{self.blk_n}_k{self.blk_k}_ek{to_str(self.even_k)}_al{to_str(self.alibi)}_sc{to_str(self.soft_cap)}_lc{to_str(self.local)}_sm80.cu"
return f"mha_{self.dtype}_hd{self.head_dim}_ek{to_str(self.even_k)}_al{to_str(self.alibi)}_sc{to_str(self.soft_cap)}_lc{to_str(self.local)}_sm80.cu"


@dataclass
Expand Down Expand Up @@ -125,33 +116,21 @@ def gen_mha_kernels() -> Iterator[MHAKernel]:
for (
dtype,
head_dim,
blk_m,
blk_n,
blk_k,
even_k,
alibi,
soft_cap,
local,
) in itertools.product(
["fp16", "bf16"], # dtype
[64, 96, 128, 256], # head_dim
[64], # blk_m
[64], # blk_n
[32, 64], # blk_k
[False, True], # even_k
[False, True], # alibi
[False, True], # soft_cap
[False, True], # local
):
# skip invalid configurations
if head_dim % blk_k != 0:
continue
yield MHAKernel(
dtype=dtype,
head_dim=head_dim,
blk_m=blk_m,
blk_n=blk_n,
blk_k=blk_k,
even_k=even_k,
alibi=alibi,
soft_cap=soft_cap,
Expand Down
85 changes: 15 additions & 70 deletions src/kernels/attention/mha_dispatch_sm80.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -3,96 +3,41 @@
#include <cute/int_tuple.hpp>
#include <cute/layout.hpp>

#include "mha_traits_sm80.h"
#include "static_dispatch.h"

namespace llm {
// forward declaration
template <typename Traits,
typename Params,
template <typename Dtype,
int HEAD_DIM,
bool EVEN_K,
bool ALIBI,
bool SOFT_CAP,
bool LOCAL>
bool LOCAL,
typename Params>
void launch_mha_kernel_sm80(const Params& params, cudaStream_t stream);

namespace detail {
// user-facing function to run the attention kernel
template <typename Dtype, int HEAD_DIM, typename Params>
void run_mha_kernel_sm80(Params& params, cudaStream_t stream = nullptr) {
// normalize params that for performance optimization
params.normalize();

template <typename Traits, typename Params>
void dispatch_mha_kernel_sm80(const Params& params, cudaStream_t stream) {
// dispatch to proper kernel instantiation based on params
DISPATCH_BOOL(params.head_dim == Traits::kHeadDim, EVEN_K, [&] {
DISPATCH_BOOL(params.head_dim == HEAD_DIM, EVEN_K, [&] {
DISPATCH_BOOL(params.alibi_slopes_ptr != nullptr, ALIBI, [&] {
DISPATCH_BOOL(params.logits_soft_cap > 0, SOFT_CAP, [&] {
DISPATCH_BOOL(params.sliding_window >= 0, LOCAL, [&] {
launch_mha_kernel_sm80<Traits,
Params,
launch_mha_kernel_sm80<Dtype,
HEAD_DIM,
EVEN_K,
ALIBI,
SOFT_CAP,
LOCAL>(params, stream);
LOCAL,
Params>(params, stream);
});
});
});
});
}

} // namespace detail

// user-facing function to run the attention kernel
template <typename Dtype, int HEAD_DIM, typename Params>
void run_mha_kernel_sm80(Params& params, cudaStream_t stream = nullptr) {
// normalize params that for performance optimization
params.normalize();

// TODO: tune block shape MNK based on the head dim and smem size
// https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#features-and-technical-specifications-technical-specifications-per-compute-capability
// SM | 7.0 | 7.2 | 7.5 | 8.0 | 8.6 | 8.7 | 8.9 | 9.0 | 10.x | 12.0|
// Max SMEM (KB)| 96 | 64 | 164 | 100 | 164 | 100 | 228 | 100 |
// valid dynamic shared memory sizes for different compute capabilities:
// * 7.0 | 7.2 : 0, 8, 16, 32, 64, 96
// * 7.5 : 0, 32, 64
// * 8.0 | 8.7 : 0, 8, 16, 32, 64, 100, 132, 164
// * 8.6 | 8.9 : 0, 8, 16, 32, 64, 100
// * 9.0 | 10.x: 0, 8, 16, 32, 64, 100, 132, 164, 196, 228
// * 12.0 : 0, 8, 16, 32, 64, 100
if constexpr (HEAD_DIM == 64) {
using Traits = MHATraitsSM80<Dtype,
HEAD_DIM,
/*BLK_M=*/64,
/*BLK_N=*/64,
/*BLK_K=*/64>;
detail::dispatch_mha_kernel_sm80<Traits>(params, stream);
} else if constexpr (HEAD_DIM == 96) {
using Traits = MHATraitsSM80<Dtype,
HEAD_DIM,
/*BLK_M=*/64,
/*BLK_N=*/64,
/*BLK_K=*/32>;
detail::dispatch_mha_kernel_sm80<Traits>(params, stream);
} else if constexpr (HEAD_DIM == 128) {
using Traits = MHATraitsSM80<Dtype,
HEAD_DIM,
/*BLK_M=*/64,
/*BLK_N=*/64,
/*BLK_K=*/64>;
detail::dispatch_mha_kernel_sm80<Traits>(params, stream);
} else if constexpr (HEAD_DIM == 256) {
using Traits = MHATraitsSM80<Dtype,
HEAD_DIM,
/*BLK_M=*/64,
/*BLK_N=*/64,
/*BLK_K=*/64>;
detail::dispatch_mha_kernel_sm80<Traits>(params, stream);
} else {
// use the default block size
using Traits = MHATraitsSM80<Dtype,
HEAD_DIM,
/*BLK_M=*/64,
/*BLK_N=*/64,
/*BLK_K=*/64>;
detail::dispatch_mha_kernel_sm80<Traits>(params, stream);
}
}

} // namespace llm
} // namespace llm
27 changes: 23 additions & 4 deletions src/kernels/attention/mha_kernel_sm80.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "layout_convertor.h"
#include "mask.h"
#include "mha_tile.h"
#include "mha_traits_sm80.h"
#include "online_softmax.cuh"

namespace llm {
Expand Down Expand Up @@ -436,17 +437,35 @@ __global__ __launch_bounds__(Traits::kThreadNum) void mha_kernel_sm80(
epilogue(tOrO);
}

template <typename Traits,
typename Params,
template <typename Dtype,
int HEAD_DIM,
bool EVEN_K,
bool ALIBI,
bool SOFT_CAP,
bool LOCAL>
bool LOCAL,
typename Params>
void launch_mha_kernel_sm80(const Params& params, cudaStream_t stream) {
const auto batch_size = params.batch_size;
const auto n_kv_heads = params.n_kv_heads;
const auto max_q_packed_len = params.max_q_len * params.group_size;

// TODO: tune block shape MNK based on the head dim and smem size
// https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#features-and-technical-specifications-technical-specifications-per-compute-capability
// SM | 7.0 | 7.2 | 7.5 | 8.0 | 8.6 | 8.7 | 8.9 | 9.0 | 10.x | 12.0|
// Max SMEM (KB)| 96 | 64 | 164 | 100 | 164 | 100 | 228 | 100 |
// valid dynamic shared memory sizes for different compute capabilities:
// * 7.0 | 7.2 : 0, 8, 16, 32, 64, 96
// * 7.5 : 0, 32, 64
// * 8.0 | 8.7 : 0, 8, 16, 32, 64, 100, 132, 164
// * 8.6 | 8.9 : 0, 8, 16, 32, 64, 100
// * 9.0 | 10.x: 0, 8, 16, 32, 64, 100, 132, 164, 196, 228
// * 12.0 : 0, 8, 16, 32, 64, 100

constexpr int BLK_M = 64;
constexpr int BLK_N = 64;
constexpr int BLK_K = HEAD_DIM % 64 == 0 ? 64 : 32;
using Traits = MHATraitsSM80<Dtype, HEAD_DIM, BLK_M, BLK_N, BLK_K>;

const auto smem_size = sizeof(MHASharedStorage<Traits>);
auto mha_kernel =
mha_kernel_sm80<Traits, Params, EVEN_K, ALIBI, SOFT_CAP, LOCAL>;
Expand All @@ -460,4 +479,4 @@ void launch_mha_kernel_sm80(const Params& params, cudaStream_t stream) {
mha_kernel<<<grid, block, smem_size, stream>>>(params);
}

} // namespace llm
} // namespace llm