Skip to content

Commit 6b95400

Browse files
authored
refactor: move kernel code into different folders (#487)
1 parent 25e591b commit 6b95400

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

62 files changed

+177
-155
lines changed

src/kernels/attention/CMakeLists.txt

Lines changed: 57 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ include(cc_test)
66
cc_library(
77
NAME
88
attention.template
9+
INCLUDES
10+
${CMAKE_CURRENT_SOURCE_DIR}
911
HDRS
1012
fast_math.h
1113
layout_convertor.h
@@ -41,13 +43,13 @@ file(GLOB GENERATED_SRC_FILES "${CMAKE_CURRENT_BINARY_DIR}/gensrc/*.cu")
4143
cc_library(
4244
NAME
4345
attention.kernels
46+
INCLUDES
47+
${CMAKE_CURRENT_SOURCE_DIR}
4448
HDRS
4549
attn_api.h
4650
SRCS
4751
attn_api.cpp
4852
${GENERATED_SRC_FILES}
49-
INCLUDES
50-
${CMAKE_CURRENT_SOURCE_DIR}
5153
DEPS
5254
:attention.template
5355
glog::glog
@@ -56,9 +58,11 @@ cc_library(
5658
cc_test(
5759
NAME
5860
mha_kernel_test
61+
INCLUDES
62+
${CMAKE_CURRENT_SOURCE_DIR}
5963
SRCS
6064
# sm80_mha_test.cu
61-
sm80_mha_pagedkv_test.cu
65+
tests/sm80_mha_pagedkv_test.cu
6266
DEPS
6367
:attention.kernels
6468
absl::random_random
@@ -69,77 +73,91 @@ cc_test(
6973
cc_test(
7074
NAME
7175
sm120_fmha_kernel_test
76+
INCLUDES
77+
${CMAKE_CURRENT_SOURCE_DIR}
7278
SRCS
73-
sm120_fmha_test.cu
79+
tests/sm120_fmha_test.cu
7480
DEPS
7581
:attention.template
76-
absl::random_random
7782
:gtest_main
83+
absl::random_random
7884
torch
7985
)
8086

8187
cc_test(
8288
NAME
8389
mla_kernel_test
90+
INCLUDES
91+
${CMAKE_CURRENT_SOURCE_DIR}
8492
SRCS
85-
sm80_mla_test.cu
86-
sm80_mla_pagedkv_test.cu
93+
tests/sm80_mla_test.cu
94+
tests/sm80_mla_pagedkv_test.cu
8795
DEPS
8896
:attention.kernels
89-
absl::random_random
9097
:gtest_main
98+
absl::random_random
9199
torch
92100
)
93101

94102
cc_test(
95103
NAME
96104
attn_combine_kernel_test
105+
INCLUDES
106+
${CMAKE_CURRENT_SOURCE_DIR}
97107
SRCS
98-
attn_combine_kernel_test.cu
108+
tests/attn_combine_kernel_test.cu
99109
DEPS
100110
:attention.template
101111
absl::random_random
102112
:gtest_main
103113
torch
104114
)
105115

106-
nvbench_binary(
107-
NAME
108-
sm80_mha_bench
109-
SRCS
110-
sm80_mha_bench.cu
111-
DEPS
112-
:attention.template
113-
)
114-
115-
nvbench_binary(
116-
NAME
117-
sm80_mha_pagedkv_bench
118-
SRCS
119-
sm80_mha_pagedkv_bench.cu
120-
DEPS
121-
absl::random_random
122-
:attention.template
123-
)
124-
125-
nvbench_binary(
126-
NAME
127-
sm80_mla_bench
128-
SRCS
129-
mla_sm80_bench.cu
130-
DEPS
131-
:attention.template
132-
)
133-
134116
cc_test(
135117
NAME
136118
tma_test
119+
INCLUDES
120+
${CMAKE_CURRENT_SOURCE_DIR}
137121
SRCS
138-
sm120_tma_block_load_test.cu
122+
tests/sm120_tma_block_load_test.cu
139123
DEPS
140124
:gtest_main
141125
:cutlass
142126
torch
143127
)
144128

145-
add_subdirectory(tools)
129+
# nvbench_binary(
130+
# NAME
131+
# sm80_mha_bench
132+
# INCLUDES
133+
# ${CMAKE_CURRENT_SOURCE_DIR}
134+
# SRCS
135+
# sm80_mha_bench.cu
136+
# DEPS
137+
# :attention.template
138+
# )
139+
140+
# nvbench_binary(
141+
# NAME
142+
# sm80_mha_pagedkv_bench
143+
# INCLUDES
144+
# ${CMAKE_CURRENT_SOURCE_DIR}
145+
# SRCS
146+
# sm80_mha_pagedkv_bench.cu
147+
# DEPS
148+
# absl::random_random
149+
# :attention.template
150+
# )
151+
152+
# nvbench_binary(
153+
# NAME
154+
# sm80_mla_bench
155+
# INCLUDES
156+
# ${CMAKE_CURRENT_SOURCE_DIR}
157+
# SRCS
158+
# mla_sm80_bench.cu
159+
# DEPS
160+
# :attention.template
161+
# )
162+
163+
# add_subdirectory(tools)

src/kernels/attention/attn_api.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,11 @@
22

33
#include <ATen/cuda/CUDAContext.h>
44

5-
#include "cute/layout.hpp"
5+
#include <cute/layout.hpp>
6+
7+
#include "common/static_dispatch.h"
8+
#include "device/sm80_mha_dispatch.cuh"
69
#include "mha_params.h"
7-
#include "sm80_mha_dispatch.cuh"
8-
#include "static_dispatch.h"
910

1011
namespace llm {
1112
using namespace cute;
File renamed without changes.
File renamed without changes.
File renamed without changes.

src/kernels/attention/sm120_collective_epilogue.cuh renamed to src/kernels/attention/collective/sm120_collective_epilogue.cuh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77
#include <cute/layout.hpp>
88
#include <cute/tensor.hpp>
99

10-
#include "fast_cast.cuh"
11-
#include "safe_copy.h"
10+
#include "common/fast_cast.cuh"
11+
#include "common/safe_copy.h"
1212

1313
namespace llm {
1414
using namespace cute;

src/kernels/attention/sm120_collective_fmha_mainloop_ws.cuh renamed to src/kernels/attention/collective/sm120_collective_fmha_mainloop_ws.cuh

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,9 @@
99
#include <cute/tensor.hpp>
1010
#include <cutlass/pipeline/pipeline.hpp>
1111

12-
#include "fast_cast.cuh"
13-
#include "layout_convertor.h"
14-
#include "safe_copy.h"
12+
#include "common/fast_cast.cuh"
13+
#include "common/layout_convertor.h"
14+
#include "common/safe_copy.h"
1515
#include "sm120_collective_load_cpasync_ws.cuh"
1616

1717
namespace llm {

src/kernels/attention/sm120_collective_load_cpasync_ws.cuh renamed to src/kernels/attention/collective/sm120_collective_load_cpasync_ws.cuh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
#include <cute/layout.hpp>
1010
#include <cute/tensor.hpp>
1111

12-
#include "safe_copy.h"
12+
#include "common/safe_copy.h"
1313

1414
namespace llm {
1515

src/kernels/attention/sm80_collective_epilogue.cuh renamed to src/kernels/attention/collective/sm80_collective_epilogue.cuh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77
#include <cute/layout.hpp>
88
#include <cute/tensor.hpp>
99

10-
#include "fast_cast.cuh"
11-
#include "safe_copy.h"
10+
#include "common/fast_cast.cuh"
11+
#include "common/safe_copy.h"
1212

1313
namespace llm {
1414
using namespace cute;

src/kernels/attention/sm80_collective_mha.cuh renamed to src/kernels/attention/collective/sm80_collective_mha.cuh

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,9 @@
88
#include <cute/layout.hpp>
99
#include <cute/tensor.hpp>
1010

11-
#include "fast_cast.cuh"
12-
#include "layout_convertor.h"
13-
#include "safe_copy.h"
11+
#include "common/fast_cast.cuh"
12+
#include "common/layout_convertor.h"
13+
#include "common/safe_copy.h"
1414

1515
namespace llm {
1616

0 commit comments

Comments
 (0)