|
39 | 39 | """ |
40 | 40 |
|
41 | 41 | MLA_KERNEL_TEMPLATE = """ |
42 | | -#include "mla_kernel_sm80.cuh" // IWYU pragma: export |
| 42 | +#include "sm80_mla_launch.cuh" // IWYU pragma: export |
43 | 43 | #include "mla_params.h" // IWYU pragma: export |
44 | | -#include "mla_traits_sm80.h" // IWYU pragma: export |
45 | 44 |
|
46 | 45 | namespace llm {{ |
47 | 46 |
|
48 | | -using Traits = MLATraitsSM80<{DTYPE}, {HEAD_DIM}, {ROPE_HEAD_DIM}, {BLK_M}, {BLK_N}, {BLK_K}, {STAGES}>; |
49 | 47 | using Params = MLAPagedKVParams; |
50 | 48 |
|
51 | | -template void launch_mla_kernel_sm80<Traits, Params>(const Params& params, |
52 | | - cudaStream_t stream); |
| 49 | +template void sm80_launch_mla_kernel</*DTYPE=*/{DTYPE}, |
| 50 | + /*HEAD_DIM=*/{HEAD_DIM}, |
| 51 | + /*ROPE_HEAD_DIM=*/{ROPE_HEAD_DIM}, |
| 52 | + Params>(const Params& params, |
| 53 | + cudaStream_t stream); |
53 | 54 | }} // namespace llm |
54 | 55 | """ |
55 | 56 |
|
@@ -87,28 +88,18 @@ class MLAKernel: |
87 | 88 | dtype: str |
88 | 89 | head_dim: int |
89 | 90 | rope_head_dim: int |
90 | | - blk_m: int |
91 | | - blk_n: int |
92 | | - blk_k: int |
93 | | - stages: int |
94 | 91 |
|
95 | 92 | @property |
96 | 93 | def template(self) -> str: |
97 | | - assert self.head_dim % self.blk_k == 0 |
98 | | - |
99 | 94 | return MLA_KERNEL_TEMPLATE.format( |
100 | 95 | DTYPE=DTYPE_MAP[self.dtype], |
101 | 96 | HEAD_DIM=self.head_dim, |
102 | 97 | ROPE_HEAD_DIM=self.rope_head_dim, |
103 | | - BLK_M=self.blk_m, |
104 | | - BLK_N=self.blk_n, |
105 | | - BLK_K=self.blk_k, |
106 | | - STAGES=self.stages, |
107 | 98 | ) |
108 | 99 |
|
109 | 100 | @property |
110 | 101 | def filename(self) -> str: |
111 | | - return f"mla_{self.dtype}_hd{self.head_dim}_rhd{self.rope_head_dim}_m{self.blk_m}_n{self.blk_n}_k{self.blk_k}_s{self.stages}_sm80.cu" |
| 102 | + return f"sm80_mla_{self.dtype}_hd{self.head_dim}_rhd{self.rope_head_dim}.cu" |
112 | 103 |
|
113 | 104 |
|
114 | 105 | def gen_mha_kernels() -> Iterator[MHAKernel]: |
@@ -141,25 +132,15 @@ def gen_mha_kernels() -> Iterator[MHAKernel]: |
141 | 132 | def gen_mla_kernels() -> Iterator[MLAKernel]: |
142 | 133 | # TODO: choose BLK_M, BLK_N, BLK_K, STAGES based on compute capability |
143 | 134 | # mla kernel instantiations |
144 | | - for dtype, head_dim, rope_head_dim, ( |
145 | | - blk_m, |
146 | | - blk_n, |
147 | | - blk_k, |
148 | | - stages, |
149 | | - ) in itertools.product( |
| 135 | + for dtype, head_dim, rope_head_dim in itertools.product( |
150 | 136 | ["fp16", "bf16"], # dtype |
151 | 137 | [512], # head_dim |
152 | 138 | [64], # rope_head_dim |
153 | | - [(64, 16, 128, 1)], # blk_m, blk_n, blk_k, stages |
154 | 139 | ): |
155 | 140 | yield MLAKernel( |
156 | 141 | dtype=dtype, |
157 | 142 | head_dim=head_dim, |
158 | 143 | rope_head_dim=rope_head_dim, |
159 | | - blk_m=blk_m, |
160 | | - blk_n=blk_n, |
161 | | - blk_k=blk_k, |
162 | | - stages=stages, |
163 | 144 | ) |
164 | 145 |
|
165 | 146 |
|
|
0 commit comments