Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 20 additions & 22 deletions src/kernels/attention/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ include(cc_library)
include(cc_test)

cc_library(
NAME
NAME
attention.template
HDRS
fast_math.h
Expand All @@ -15,9 +15,8 @@ cc_library(
static_dispatch.h
mha_params.h
mha_tile.h
mha_traits_sm80.h
mha_kernel_sm80.cuh
mha_dispatch_sm80.cuh
sm80_mha_dispatch.cuh
mla_params.h
mla_tile.h
mla_traits_sm80.h
Expand All @@ -39,11 +38,11 @@ execute_process(
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/
COMMAND_ERROR_IS_FATAL ANY
)
# globbing all generated files in sub directory "generated"
file(GLOB GENERATED_SRC_FILES "${CMAKE_CURRENT_BINARY_DIR}/generated/*.cu")
# globbing all generated files in sub directory "gensrc"
file(GLOB GENERATED_SRC_FILES "${CMAKE_CURRENT_BINARY_DIR}/gensrc/*.cu")

cc_library(
NAME
NAME
attention.kernels
HDRS
attn_api.h
Expand All @@ -62,9 +61,8 @@ cc_test(
mha_kernel_test
SRCS
# mha_cpu_test.cpp
mha_traits_test.cpp
mha_kernel_sm80_test.cu
mha_kernel_sm80_pagedkv_test.cu
sm80_mha_test.cu
sm80_mha_pagedkv_test.cu
DEPS
:attention.template
absl::random_random
Expand Down Expand Up @@ -99,31 +97,31 @@ cc_test(
)

nvbench_binary(
NAME
mha_sm80_bench
SRCS
mha_sm80_bench.cu
NAME
sm80_mha_bench
SRCS
sm80_mha_bench.cu
DEPS
:attention.template
:attention.template
)

nvbench_binary(
NAME
mha_sm80_pagedkv_bench
SRCS
mha_sm80_pagedkv_bench.cu
NAME
sm80_mha_pagedkv_bench
SRCS
sm80_mha_pagedkv_bench.cu
DEPS
absl::random_random
:attention.template
)

nvbench_binary(
NAME
NAME
mla_sm80_bench
SRCS
SRCS
mla_sm80_bench.cu
DEPS
:attention.template
:attention.template
)

add_subdirectory(tools)
add_subdirectory(tools)
4 changes: 2 additions & 2 deletions src/kernels/attention/attn_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
#include <ATen/cuda/CUDAContext.h>

#include "cute/layout.hpp"
#include "mha_dispatch_sm80.cuh"
#include "mha_params.h"
#include "sm80_mha_dispatch.cuh"
#include "static_dispatch.h"

namespace llm {
Expand Down Expand Up @@ -66,7 +66,7 @@ void paged_kv_varlen_mha(
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, [&] {
DISPATCH_TORCH_DTYPE(query.scalar_type(), DTYPE, [&] {
run_mha_kernel_sm80<DTYPE, HEAD_DIM>(params, stream);
sm80_run_mha<DTYPE, HEAD_DIM>(params, stream);
});
});
}
Expand Down
10 changes: 5 additions & 5 deletions src/kernels/attention/generate_instantiation_cu.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,21 +20,21 @@


MHA_KERNEL_TEMPLATE = """
#include "mha_kernel_sm80.cuh" // IWYU pragma: export
#include "sm80_mha_launch.cuh" // IWYU pragma: export
#include "mha_params.h" // IWYU pragma: export

namespace llm {{

using Params = MHAPagedKVParams;

template void launch_mha_kernel_sm80</*DTYPE=*/{DTYPE},
template void sm80_launch_mha_kernel</*DTYPE=*/{DTYPE},
/*HEAD_DIM=*/{HEAD_DIM},
/*EVEN_K=*/{EVEN_K},
/*ALIBI=*/{ALIBI},
/*SOFT_CAP=*/{SOFT_CAP},
/*LOCAL=*/{LOCAL},
Params>(const Params& params,
cudaStream_t stream);
cudaStream_t stream);
}} // namespace llm
"""

Expand Down Expand Up @@ -79,7 +79,7 @@ def filename(self) -> str:
def to_str(val: bool) -> str:
return "1" if val else "0"

return f"mha_{self.dtype}_hd{self.head_dim}_ek{to_str(self.even_k)}_al{to_str(self.alibi)}_sc{to_str(self.soft_cap)}_lc{to_str(self.local)}_sm80.cu"
return f"sm80_mha_{self.dtype}_hd{self.head_dim}_ek{to_str(self.even_k)}_al{to_str(self.alibi)}_sc{to_str(self.soft_cap)}_lc{to_str(self.local)}.cu"


@dataclass
Expand Down Expand Up @@ -164,7 +164,7 @@ def gen_mla_kernels() -> Iterator[MLAKernel]:


if __name__ == "__main__":
output_dir = Path.cwd() / "generated"
output_dir = Path.cwd() / "gensrc"
shutil.rmtree(output_dir, ignore_errors=True)
output_dir.mkdir(parents=True, exist_ok=True)

Expand Down
Loading
Loading