Skip to content

Commit 29a9b31

Browse files
authored
kernel: generate smaller kernel instantiations (#395)
1 parent e8bb746 commit 29a9b31

11 files changed

+136
-51
lines changed

src/kernels/attention/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ cc_library(
1616
mha_tile.h
1717
mha_traits_sm80.h
1818
mha_kernel_sm80.cuh
19-
mha_launch_sm80.cuh
19+
mha_dispatch_sm80.cuh
2020
DEPS
2121
cutlass
2222
)

src/kernels/attention/attn_api.cpp

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,13 @@
33
#include <ATen/cuda/CUDAContext.h>
44

55
#include "cute/layout.hpp"
6+
#include "mha_dispatch_sm80.cuh"
67
#include "mha_params.h"
78
#include "static_dispatch.h"
89

910
namespace llm {
1011
using namespace cute;
1112

12-
// forward declaration
13-
template <typename Dtype, int HEAD_DIM, typename Params>
14-
void run_mha_kernel_sm80(Params& params, cudaStream_t stream);
15-
1613
void paged_kv_varlen_mha(
1714
torch::Tensor& out, // [n_tokens, n_heads, head_dim]
1815
const torch::Tensor& query, // [n_tokens, n_heads, head_dim]

src/kernels/attention/generate_instantiation_cu.py

Lines changed: 83 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -7,51 +7,119 @@
77
from pathlib import Path
88
from typing import Iterator
99

10+
# map from python to c++ types
1011
DTYPE_MAP = {
1112
"fp16": "cute::half_t",
1213
"bf16": "cute::bfloat16_t",
1314
}
1415

15-
HEAD_DIMENSIONS = [64, 96, 128, 256]
16+
BOOL_MAP = {
17+
False: "false",
18+
True: "true",
19+
}
20+
1621

17-
PAGEDKV_KERNEL_IMPL_TEMPLATE = """
18-
#include "mha_launch_sm80.cuh" // IWYU pragma: keep
22+
MHA_KERNEL_TEMPLATE = """
23+
#include "mha_kernel_sm80.cuh" // IWYU pragma: export
24+
#include "mha_params.h" // IWYU pragma: export
25+
#include "mha_traits_sm80.h" // IWYU pragma: export
1926
2027
namespace llm {{
2128
29+
using Traits = MHATraitsSM80<{DTYPE}, {HEAD_DIM}, {BLK_M}, {BLK_N}, {BLK_K}>;
2230
using Params = MHAPagedKVParams;
23-
template void run_mha_kernel_sm80<{DTYPE}, {HEAD_DIM}, Params>(
24-
Params& params, cudaStream_t stream);
2531
32+
template void launch_mha_kernel_sm80<Traits,
33+
Params,
34+
/*EVEN_K=*/{EVEN_K},
35+
/*ALIBI=*/{ALIBI},
36+
/*SOFT_CAP=*/{SOFT_CAP},
37+
/*LOCAL=*/{LOCAL}>(const Params& params,
38+
cudaStream_t stream);
2639
}} // namespace llm
2740
"""
2841

42+
2943
@dataclass
30-
class Kernel:
44+
class MHAKernel:
3145
dtype: str
3246
head_dim: int
47+
blk_m: int
48+
blk_n: int
49+
blk_k: int
50+
even_k: bool
51+
alibi: bool
52+
soft_cap: bool
53+
local: bool
3354

3455
@property
3556
def template(self) -> str:
36-
return PAGEDKV_KERNEL_IMPL_TEMPLATE.format(
37-
DTYPE=DTYPE_MAP[self.dtype], HEAD_DIM=self.head_dim
57+
assert self.head_dim % self.blk_k == 0
58+
59+
return MHA_KERNEL_TEMPLATE.format(
60+
DTYPE=DTYPE_MAP[self.dtype],
61+
HEAD_DIM=self.head_dim,
62+
BLK_M=self.blk_m,
63+
BLK_N=self.blk_n,
64+
BLK_K=self.blk_k,
65+
EVEN_K=BOOL_MAP[self.even_k],
66+
ALIBI=BOOL_MAP[self.alibi],
67+
SOFT_CAP=BOOL_MAP[self.soft_cap],
68+
LOCAL=BOOL_MAP[self.local],
3869
)
3970

4071
@property
4172
def filename(self) -> str:
42-
return f"mha_{self.dtype}_hd{self.head_dim}_sm80.cu"
73+
def to_str(val: bool) -> str:
74+
return "1" if val else "0"
4375

76+
return f"mha_{self.dtype}_hd{self.head_dim}_m{self.blk_m}_n{self.blk_n}_k{self.blk_k}_ek{to_str(self.even_k)}_al{to_str(self.alibi)}_sc{to_str(self.soft_cap)}_lc{to_str(self.local)}.cu"
4477

45-
def get_all_kernels() -> Iterator[Kernel]:
46-
for dtype, head_dim in itertools.product(DTYPE_MAP.keys(), HEAD_DIMENSIONS):
47-
yield Kernel(dtype=dtype, head_dim=head_dim)
78+
79+
def gen_all_kernels() -> Iterator[MHAKernel]:
80+
# mha kernel instantiations
81+
for (
82+
dtype,
83+
head_dim,
84+
blk_m,
85+
blk_n,
86+
blk_k,
87+
even_k,
88+
alibi,
89+
soft_cap,
90+
local,
91+
) in itertools.product(
92+
["fp16", "bf16"], # dtype
93+
[64, 96, 128, 256], # head_dim
94+
[64], # blk_m
95+
[64], # blk_n
96+
[32, 64], # blk_k
97+
[False, True], # even_k
98+
[False, True], # alibi
99+
[False, True], # soft_cap
100+
[False, True], # local
101+
):
102+
# skip invalid configurations
103+
if head_dim % blk_k != 0:
104+
continue
105+
yield MHAKernel(
106+
dtype=dtype,
107+
head_dim=head_dim,
108+
blk_m=blk_m,
109+
blk_n=blk_n,
110+
blk_k=blk_k,
111+
even_k=even_k,
112+
alibi=alibi,
113+
soft_cap=soft_cap,
114+
local=local,
115+
)
48116

49117

50118
if __name__ == "__main__":
51119
output_dir = Path.cwd() / "generated"
52120
shutil.rmtree(output_dir, ignore_errors=True)
53121
output_dir.mkdir(parents=True, exist_ok=True)
54-
122+
55123
# written to several files to speed up compilation
56-
for kernel in get_all_kernels():
57-
(output_dir / kernel.filename).write_text(kernel.template)
124+
for kernel in gen_all_kernels():
125+
(output_dir / kernel.filename).write_text(kernel.template)

src/kernels/attention/mha_launch_sm80.cuh renamed to src/kernels/attention/mha_dispatch_sm80.cuh

Lines changed: 15 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -3,45 +3,34 @@
33
#include <cute/int_tuple.hpp>
44
#include <cute/layout.hpp>
55

6-
#include "mha_kernel_sm80.cuh"
76
#include "mha_traits_sm80.h"
87
#include "static_dispatch.h"
98

109
namespace llm {
11-
namespace detail {
10+
// forward declaration
1211
template <typename Traits,
1312
typename Params,
1413
bool EVEN_K,
1514
bool ALIBI,
1615
bool SOFT_CAP,
1716
bool LOCAL>
18-
void launch_mha_kernel(const Params& params, cudaStream_t stream) {
19-
const auto batch_size = params.batch_size;
20-
const auto n_kv_heads = params.n_kv_heads;
21-
const auto max_q_packed_len = params.max_q_len * params.group_size;
17+
void launch_mha_kernel_sm80(const Params& params, cudaStream_t stream);
2218

23-
const auto smem_size = Traits::kSmemSize;
24-
auto mha_kernel =
25-
mha_kernel_sm80<Traits, Params, EVEN_K, ALIBI, SOFT_CAP, LOCAL>;
26-
cudaFuncSetAttribute(
27-
mha_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
28-
// TODO: support persistent kernels
29-
dim3 grid(cute::ceil_div(max_q_packed_len, Traits::kBlockM),
30-
batch_size,
31-
n_kv_heads);
32-
dim3 block = Traits::kThreadNum;
33-
mha_kernel<<<grid, block, smem_size, stream>>>(params);
34-
}
19+
namespace detail {
3520

3621
template <typename Traits, typename Params>
37-
void run_mha_kernel(const Params& params, cudaStream_t stream) {
22+
void dispatch_mha_kernel_sm80(const Params& params, cudaStream_t stream) {
3823
// dispatch to proper kernel instantiation based on params
3924
DISPATCH_BOOL(params.head_dim == Traits::kHeadDim, EVEN_K, [&] {
4025
DISPATCH_BOOL(params.alibi_slopes_ptr != nullptr, ALIBI, [&] {
4126
DISPATCH_BOOL(params.logits_soft_cap > 0, SOFT_CAP, [&] {
4227
DISPATCH_BOOL(params.sliding_window >= 0, LOCAL, [&] {
43-
launch_mha_kernel<Traits, Params, EVEN_K, ALIBI, SOFT_CAP, LOCAL>(
44-
params, stream);
28+
launch_mha_kernel_sm80<Traits,
29+
Params,
30+
EVEN_K,
31+
ALIBI,
32+
SOFT_CAP,
33+
LOCAL>(params, stream);
4534
});
4635
});
4736
});
@@ -63,36 +52,36 @@ void run_mha_kernel_sm80(Params& params, cudaStream_t stream = nullptr) {
6352
/*BLK_M=*/64,
6453
/*BLK_N=*/64,
6554
/*BLK_K=*/64>;
66-
detail::run_mha_kernel<Traits>(params, stream);
55+
detail::dispatch_mha_kernel_sm80<Traits>(params, stream);
6756
} else if constexpr (HEAD_DIM == 96) {
6857
using Traits = MHATraitsSM80<Dtype,
6958
HEAD_DIM,
7059
/*BLK_M=*/64,
7160
/*BLK_N=*/64,
7261
/*BLK_K=*/32>;
73-
detail::run_mha_kernel<Traits>(params, stream);
62+
detail::dispatch_mha_kernel_sm80<Traits>(params, stream);
7463
} else if constexpr (HEAD_DIM == 128) {
7564
using Traits = MHATraitsSM80<Dtype,
7665
HEAD_DIM,
7766
/*BLK_M=*/64,
7867
/*BLK_N=*/64,
7968
/*BLK_K=*/64>;
80-
detail::run_mha_kernel<Traits>(params, stream);
69+
detail::dispatch_mha_kernel_sm80<Traits>(params, stream);
8170
} else if constexpr (HEAD_DIM == 256) {
8271
using Traits = MHATraitsSM80<Dtype,
8372
HEAD_DIM,
8473
/*BLK_M=*/64,
8574
/*BLK_N=*/64,
8675
/*BLK_K=*/64>;
87-
detail::run_mha_kernel<Traits>(params, stream);
76+
detail::dispatch_mha_kernel_sm80<Traits>(params, stream);
8877
} else {
8978
// use the default block size
9079
using Traits = MHATraitsSM80<Dtype,
9180
HEAD_DIM,
9281
/*BLK_M=*/64,
9382
/*BLK_N=*/64,
9483
/*BLK_K=*/64>;
95-
detail::run_mha_kernel<Traits>(params, stream);
84+
detail::dispatch_mha_kernel_sm80<Traits>(params, stream);
9685
}
9786
}
9887

src/kernels/attention/mha_kernel_sm80.cuh

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -432,4 +432,28 @@ __global__ void mha_kernel_sm80(__grid_constant__ const Params params) {
432432
epilogue(tOrAccO);
433433
}
434434

435+
template <typename Traits,
436+
typename Params,
437+
bool EVEN_K,
438+
bool ALIBI,
439+
bool SOFT_CAP,
440+
bool LOCAL>
441+
void launch_mha_kernel_sm80(const Params& params, cudaStream_t stream) {
442+
const auto batch_size = params.batch_size;
443+
const auto n_kv_heads = params.n_kv_heads;
444+
const auto max_q_packed_len = params.max_q_len * params.group_size;
445+
446+
const auto smem_size = Traits::kSmemSize;
447+
auto mha_kernel =
448+
mha_kernel_sm80<Traits, Params, EVEN_K, ALIBI, SOFT_CAP, LOCAL>;
449+
cudaFuncSetAttribute(
450+
mha_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
451+
// TODO: support persistent kernels
452+
dim3 grid(cute::ceil_div(max_q_packed_len, Traits::kBlockM),
453+
batch_size,
454+
n_kv_heads);
455+
dim3 block = Traits::kThreadNum;
456+
mha_kernel<<<grid, block, smem_size, stream>>>(params);
457+
}
458+
435459
} // namespace llm

src/kernels/attention/mha_kernel_sm80_pagedkv_test.cu

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33
#include <torch/torch.h>
44

55
#include "cute/layout.hpp"
6-
#include "mha_launch_sm80.cuh"
6+
#include "mha_dispatch_sm80.cuh"
7+
#include "mha_kernel_sm80.cuh" // IWYU pragma: keep
78
#include "mha_params.h"
89
#include "mha_ref.h"
910

src/kernels/attention/mha_kernel_sm80_test.cu

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44
#include <cstdint>
55

66
#include "cute/layout.hpp"
7-
#include "mha_launch_sm80.cuh"
7+
#include "mha_dispatch_sm80.cuh"
8+
#include "mha_kernel_sm80.cuh" // IWYU pragma: keep
89
#include "mha_params.h"
910
#include "mha_ref.h"
1011

src/kernels/attention/mha_kernel_sm80_varlen_test.cu

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33
#include <torch/torch.h>
44

55
#include "cute/layout.hpp"
6-
#include "mha_launch_sm80.cuh"
6+
#include "mha_dispatch_sm80.cuh"
7+
#include "mha_kernel_sm80.cuh" // IWYU pragma: keep
78
#include "mha_params.h"
89
#include "mha_ref.h"
910

src/kernels/attention/mha_sm80_bench.cu

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44
#include <cuda/std/chrono>
55
#include <nvbench/nvbench.cuh>
66

7-
#include "mha_launch_sm80.cuh"
7+
#include "mha_dispatch_sm80.cuh"
8+
#include "mha_kernel_sm80.cuh" // IWYU pragma: keep
89
#include "mha_params.h"
910

1011
using namespace llm;

src/kernels/attention/mha_sm80_pagedkv_bench.cu

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55
#include <cuda/std/chrono>
66
#include <nvbench/nvbench.cuh>
77

8-
#include "mha_launch_sm80.cuh"
8+
#include "mha_dispatch_sm80.cuh"
9+
#include "mha_kernel_sm80.cuh" // IWYU pragma: keep
910
#include "mha_params.h"
1011

1112
using namespace llm;

0 commit comments

Comments
 (0)