Skip to content

Commit 47f914d

Browse files
authored
refactor: split mla kernels into collective_mla collective_epilogue (#475)
1 parent faa8a3f commit 47f914d

17 files changed

+1294
-1117
lines changed

src/kernels/attention/CMakeLists.txt

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,16 +11,13 @@ cc_library(
1111
layout_convertor.h
1212
fast_cast.cuh
1313
online_softmax.cuh
14+
safe_copy.h
1415
mask.h
1516
static_dispatch.h
1617
mha_params.h
1718
mha_tile.h
18-
mha_kernel_sm80.cuh
19-
sm80_mha_dispatch.cuh
2019
mla_params.h
2120
mla_tile.h
22-
mla_traits_sm80.h
23-
mla_kernel_sm80.cuh
2421
attn_combine_kernel.cuh
2522
DEPS
2623
cutlass
@@ -74,9 +71,8 @@ cc_test(
7471
NAME
7572
mla_kernel_test
7673
SRCS
77-
mla_traits_test.cpp
78-
mla_kernel_sm80_test.cu
79-
mla_kernel_sm80_pagedkv_test.cu
74+
sm80_mla_test.cu
75+
sm80_mla_pagedkv_test.cu
8076
DEPS
8177
:attention.template
8278
absl::random_random
@@ -117,7 +113,7 @@ nvbench_binary(
117113

118114
nvbench_binary(
119115
NAME
120-
mla_sm80_bench
116+
sm80_mla_bench
121117
SRCS
122118
mla_sm80_bench.cu
123119
DEPS

src/kernels/attention/attn_combine_kernel.cuh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66
#include <cute/layout.hpp>
77
#include <cute/tensor.hpp>
88

9-
#include "cute_extensions.cuh"
109
#include "fast_cast.cuh"
10+
#include "safe_copy.h"
1111

1212
namespace llm {
1313

@@ -239,4 +239,4 @@ void launch_attn_combine_kernel(const Params& params, cudaStream_t stream) {
239239
combine_kernel<<<grid, kThreads, 0, stream>>>(params);
240240
}
241241

242-
} // namespace llm
242+
} // namespace llm

src/kernels/attention/generate_instantiation_cu.py

Lines changed: 8 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -39,17 +39,18 @@
3939
"""
4040

4141
MLA_KERNEL_TEMPLATE = """
42-
#include "mla_kernel_sm80.cuh" // IWYU pragma: export
42+
#include "sm80_mla_launch.cuh" // IWYU pragma: export
4343
#include "mla_params.h" // IWYU pragma: export
44-
#include "mla_traits_sm80.h" // IWYU pragma: export
4544
4645
namespace llm {{
4746
48-
using Traits = MLATraitsSM80<{DTYPE}, {HEAD_DIM}, {ROPE_HEAD_DIM}, {BLK_M}, {BLK_N}, {BLK_K}, {STAGES}>;
4947
using Params = MLAPagedKVParams;
5048
51-
template void launch_mla_kernel_sm80<Traits, Params>(const Params& params,
52-
cudaStream_t stream);
49+
template void sm80_launch_mla_kernel</*DTYPE=*/{DTYPE},
50+
/*HEAD_DIM=*/{HEAD_DIM},
51+
/*ROPE_HEAD_DIM=*/{ROPE_HEAD_DIM},
52+
Params>(const Params& params,
53+
cudaStream_t stream);
5354
}} // namespace llm
5455
"""
5556

@@ -87,28 +88,18 @@ class MLAKernel:
8788
dtype: str
8889
head_dim: int
8990
rope_head_dim: int
90-
blk_m: int
91-
blk_n: int
92-
blk_k: int
93-
stages: int
9491

9592
@property
9693
def template(self) -> str:
97-
assert self.head_dim % self.blk_k == 0
98-
9994
return MLA_KERNEL_TEMPLATE.format(
10095
DTYPE=DTYPE_MAP[self.dtype],
10196
HEAD_DIM=self.head_dim,
10297
ROPE_HEAD_DIM=self.rope_head_dim,
103-
BLK_M=self.blk_m,
104-
BLK_N=self.blk_n,
105-
BLK_K=self.blk_k,
106-
STAGES=self.stages,
10798
)
10899

109100
@property
110101
def filename(self) -> str:
111-
return f"mla_{self.dtype}_hd{self.head_dim}_rhd{self.rope_head_dim}_m{self.blk_m}_n{self.blk_n}_k{self.blk_k}_s{self.stages}_sm80.cu"
102+
return f"sm80_mla_{self.dtype}_hd{self.head_dim}_rhd{self.rope_head_dim}.cu"
112103

113104

114105
def gen_mha_kernels() -> Iterator[MHAKernel]:
@@ -141,25 +132,15 @@ def gen_mha_kernels() -> Iterator[MHAKernel]:
141132
def gen_mla_kernels() -> Iterator[MLAKernel]:
142133
# TODO: choose BLK_M, BLK_N, BLK_K, STAGES based on compute capability
143134
# mla kernel instantiations
144-
for dtype, head_dim, rope_head_dim, (
145-
blk_m,
146-
blk_n,
147-
blk_k,
148-
stages,
149-
) in itertools.product(
135+
for dtype, head_dim, rope_head_dim in itertools.product(
150136
["fp16", "bf16"], # dtype
151137
[512], # head_dim
152138
[64], # rope_head_dim
153-
[(64, 16, 128, 1)], # blk_m, blk_n, blk_k, stages
154139
):
155140
yield MLAKernel(
156141
dtype=dtype,
157142
head_dim=head_dim,
158143
rope_head_dim=rope_head_dim,
159-
blk_m=blk_m,
160-
blk_n=blk_n,
161-
blk_k=blk_k,
162-
stages=stages,
163144
)
164145

165146

0 commit comments

Comments
 (0)