Skip to content

Commit 829f32b

Browse files
authored
refactor: split attention kernel into collective mainloop, collective epilogue and kernel (#469)
1 parent 44c6d66 commit 829f32b

31 files changed

+824
-715
lines changed

src/kernels/attention/CMakeLists.txt

Lines changed: 20 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ include(cc_library)
44
include(cc_test)
55

66
cc_library(
7-
NAME
7+
NAME
88
attention.template
99
HDRS
1010
fast_math.h
@@ -15,9 +15,8 @@ cc_library(
1515
static_dispatch.h
1616
mha_params.h
1717
mha_tile.h
18-
mha_traits_sm80.h
1918
mha_kernel_sm80.cuh
20-
mha_dispatch_sm80.cuh
19+
sm80_mha_dispatch.cuh
2120
mla_params.h
2221
mla_tile.h
2322
mla_traits_sm80.h
@@ -39,11 +38,11 @@ execute_process(
3938
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/
4039
COMMAND_ERROR_IS_FATAL ANY
4140
)
42-
# globbing all generated files in sub directory "generated"
43-
file(GLOB GENERATED_SRC_FILES "${CMAKE_CURRENT_BINARY_DIR}/generated/*.cu")
41+
# globbing all generated files in sub directory "gensrc"
42+
file(GLOB GENERATED_SRC_FILES "${CMAKE_CURRENT_BINARY_DIR}/gensrc/*.cu")
4443

4544
cc_library(
46-
NAME
45+
NAME
4746
attention.kernels
4847
HDRS
4948
attn_api.h
@@ -62,9 +61,8 @@ cc_test(
6261
mha_kernel_test
6362
SRCS
6463
# mha_cpu_test.cpp
65-
mha_traits_test.cpp
66-
mha_kernel_sm80_test.cu
67-
mha_kernel_sm80_pagedkv_test.cu
64+
sm80_mha_test.cu
65+
sm80_mha_pagedkv_test.cu
6866
DEPS
6967
:attention.template
7068
absl::random_random
@@ -99,31 +97,31 @@ cc_test(
9997
)
10098

10199
nvbench_binary(
102-
NAME
103-
mha_sm80_bench
104-
SRCS
105-
mha_sm80_bench.cu
100+
NAME
101+
sm80_mha_bench
102+
SRCS
103+
sm80_mha_bench.cu
106104
DEPS
107-
:attention.template
105+
:attention.template
108106
)
109107

110108
nvbench_binary(
111-
NAME
112-
mha_sm80_pagedkv_bench
113-
SRCS
114-
mha_sm80_pagedkv_bench.cu
109+
NAME
110+
sm80_mha_pagedkv_bench
111+
SRCS
112+
sm80_mha_pagedkv_bench.cu
115113
DEPS
116114
absl::random_random
117115
:attention.template
118116
)
119117

120118
nvbench_binary(
121-
NAME
119+
NAME
122120
mla_sm80_bench
123-
SRCS
121+
SRCS
124122
mla_sm80_bench.cu
125123
DEPS
126-
:attention.template
124+
:attention.template
127125
)
128126

129-
add_subdirectory(tools)
127+
add_subdirectory(tools)

src/kernels/attention/attn_api.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33
#include <ATen/cuda/CUDAContext.h>
44

55
#include "cute/layout.hpp"
6-
#include "mha_dispatch_sm80.cuh"
76
#include "mha_params.h"
7+
#include "sm80_mha_dispatch.cuh"
88
#include "static_dispatch.h"
99

1010
namespace llm {
@@ -66,7 +66,7 @@ void paged_kv_varlen_mha(
6666
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
6767
DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, [&] {
6868
DISPATCH_TORCH_DTYPE(query.scalar_type(), DTYPE, [&] {
69-
run_mha_kernel_sm80<DTYPE, HEAD_DIM>(params, stream);
69+
sm80_run_mha<DTYPE, HEAD_DIM>(params, stream);
7070
});
7171
});
7272
}

src/kernels/attention/generate_instantiation_cu.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,21 +20,21 @@
2020

2121

2222
MHA_KERNEL_TEMPLATE = """
23-
#include "mha_kernel_sm80.cuh" // IWYU pragma: export
23+
#include "sm80_mha_launch.cuh" // IWYU pragma: export
2424
#include "mha_params.h" // IWYU pragma: export
2525
2626
namespace llm {{
2727
2828
using Params = MHAPagedKVParams;
2929
30-
template void launch_mha_kernel_sm80</*DTYPE=*/{DTYPE},
30+
template void sm80_launch_mha_kernel</*DTYPE=*/{DTYPE},
3131
/*HEAD_DIM=*/{HEAD_DIM},
3232
/*EVEN_K=*/{EVEN_K},
3333
/*ALIBI=*/{ALIBI},
3434
/*SOFT_CAP=*/{SOFT_CAP},
3535
/*LOCAL=*/{LOCAL},
3636
Params>(const Params& params,
37-
cudaStream_t stream);
37+
cudaStream_t stream);
3838
}} // namespace llm
3939
"""
4040

@@ -79,7 +79,7 @@ def filename(self) -> str:
7979
def to_str(val: bool) -> str:
8080
return "1" if val else "0"
8181

82-
return f"mha_{self.dtype}_hd{self.head_dim}_ek{to_str(self.even_k)}_al{to_str(self.alibi)}_sc{to_str(self.soft_cap)}_lc{to_str(self.local)}_sm80.cu"
82+
return f"sm80_mha_{self.dtype}_hd{self.head_dim}_ek{to_str(self.even_k)}_al{to_str(self.alibi)}_sc{to_str(self.soft_cap)}_lc{to_str(self.local)}.cu"
8383

8484

8585
@dataclass
@@ -164,7 +164,7 @@ def gen_mla_kernels() -> Iterator[MLAKernel]:
164164

165165

166166
if __name__ == "__main__":
167-
output_dir = Path.cwd() / "generated"
167+
output_dir = Path.cwd() / "gensrc"
168168
shutil.rmtree(output_dir, ignore_errors=True)
169169
output_dir.mkdir(parents=True, exist_ok=True)
170170

0 commit comments

Comments
 (0)