Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 4 additions & 8 deletions src/kernels/attention/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,13 @@ cc_library(
layout_convertor.h
fast_cast.cuh
online_softmax.cuh
safe_copy.h
mask.h
static_dispatch.h
mha_params.h
mha_tile.h
mha_kernel_sm80.cuh
sm80_mha_dispatch.cuh
mla_params.h
mla_tile.h
mla_traits_sm80.h
mla_kernel_sm80.cuh
attn_combine_kernel.cuh
DEPS
cutlass
Expand Down Expand Up @@ -74,9 +71,8 @@ cc_test(
NAME
mla_kernel_test
SRCS
mla_traits_test.cpp
mla_kernel_sm80_test.cu
mla_kernel_sm80_pagedkv_test.cu
sm80_mla_test.cu
sm80_mla_pagedkv_test.cu
DEPS
:attention.template
absl::random_random
Expand Down Expand Up @@ -117,7 +113,7 @@ nvbench_binary(

nvbench_binary(
NAME
mla_sm80_bench
sm80_mla_bench
SRCS
mla_sm80_bench.cu
DEPS
Expand Down
4 changes: 2 additions & 2 deletions src/kernels/attention/attn_combine_kernel.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
#include <cute/layout.hpp>
#include <cute/tensor.hpp>

#include "cute_extensions.cuh"
#include "fast_cast.cuh"
#include "safe_copy.h"

namespace llm {

Expand Down Expand Up @@ -239,4 +239,4 @@ void launch_attn_combine_kernel(const Params& params, cudaStream_t stream) {
combine_kernel<<<grid, kThreads, 0, stream>>>(params);
}

} // namespace llm
} // namespace llm
35 changes: 8 additions & 27 deletions src/kernels/attention/generate_instantiation_cu.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,17 +39,18 @@
"""

MLA_KERNEL_TEMPLATE = """
#include "mla_kernel_sm80.cuh" // IWYU pragma: export
#include "sm80_mla_launch.cuh" // IWYU pragma: export
#include "mla_params.h" // IWYU pragma: export
#include "mla_traits_sm80.h" // IWYU pragma: export

namespace llm {{

using Traits = MLATraitsSM80<{DTYPE}, {HEAD_DIM}, {ROPE_HEAD_DIM}, {BLK_M}, {BLK_N}, {BLK_K}, {STAGES}>;
using Params = MLAPagedKVParams;

template void launch_mla_kernel_sm80<Traits, Params>(const Params& params,
cudaStream_t stream);
template void sm80_launch_mla_kernel</*DTYPE=*/{DTYPE},
/*HEAD_DIM=*/{HEAD_DIM},
/*ROPE_HEAD_DIM=*/{ROPE_HEAD_DIM},
Params>(const Params& params,
cudaStream_t stream);
}} // namespace llm
"""

Expand Down Expand Up @@ -87,28 +88,18 @@ class MLAKernel:
dtype: str
head_dim: int
rope_head_dim: int
blk_m: int
blk_n: int
blk_k: int
stages: int

@property
def template(self) -> str:
assert self.head_dim % self.blk_k == 0

return MLA_KERNEL_TEMPLATE.format(
DTYPE=DTYPE_MAP[self.dtype],
HEAD_DIM=self.head_dim,
ROPE_HEAD_DIM=self.rope_head_dim,
BLK_M=self.blk_m,
BLK_N=self.blk_n,
BLK_K=self.blk_k,
STAGES=self.stages,
)

@property
def filename(self) -> str:
return f"mla_{self.dtype}_hd{self.head_dim}_rhd{self.rope_head_dim}_m{self.blk_m}_n{self.blk_n}_k{self.blk_k}_s{self.stages}_sm80.cu"
return f"sm80_mla_{self.dtype}_hd{self.head_dim}_rhd{self.rope_head_dim}.cu"


def gen_mha_kernels() -> Iterator[MHAKernel]:
Expand Down Expand Up @@ -141,25 +132,15 @@ def gen_mha_kernels() -> Iterator[MHAKernel]:
def gen_mla_kernels() -> Iterator[MLAKernel]:
# TODO: choose BLK_M, BLK_N, BLK_K, STAGES based on compute capability
# mla kernel instantiations
for dtype, head_dim, rope_head_dim, (
blk_m,
blk_n,
blk_k,
stages,
) in itertools.product(
for dtype, head_dim, rope_head_dim in itertools.product(
["fp16", "bf16"], # dtype
[512], # head_dim
[64], # rope_head_dim
[(64, 16, 128, 1)], # blk_m, blk_n, blk_k, stages
):
yield MLAKernel(
dtype=dtype,
head_dim=head_dim,
rope_head_dim=rope_head_dim,
blk_m=blk_m,
blk_n=blk_n,
blk_k=blk_k,
stages=stages,
)


Expand Down
Loading