Skip to content

Commit 8d41e66

Browse files
authored
feat: added kernel builder for attn (#493)
1 parent 83a4415 commit 8d41e66

File tree

6 files changed

+126
-39
lines changed

6 files changed

+126
-39
lines changed

src/kernels/attention/common/fmha_block.h

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,17 +14,13 @@ using namespace cute;
1414
// AttentionTile specialization for AttentionParams
1515
template <typename TileShape, // (BLK_M, BLK_N, BLK_K)
1616
typename Element, // Element type
17+
typename StrideQ, // (B, Q, H, D)
18+
typename StrideK, // (B, Q, H, D)
19+
typename StrideV, // (B, Q, KH, D)
20+
typename StrideO, // (B, Q, KH, D)
1721
bool kLocal>
1822
struct FmhaBlock {
19-
// (B, Q, H, D)
20-
using StrideQ = Stride<int64_t, int64_t, int64_t, _1>;
21-
using StrideO = StrideQ;
22-
// (B, K, KH, D)
23-
using StrideK = Stride<int64_t, int64_t, int64_t, _1>;
24-
using StrideV = StrideK;
25-
2623
// Host side parameters
27-
2824
struct Arguments {
2925
const void* __restrict__ q_ptr;
3026
const void* __restrict__ k_ptr;

src/kernels/attention/fmha_runner.h

Lines changed: 22 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -5,20 +5,14 @@
55
#include <cute/layout.hpp>
66
#include <cute/tensor.hpp>
77

8-
#include "collective/sm120_collective_epilogue.cuh"
9-
#include "collective/sm120_collective_fmha_mainloop_ws.cuh"
10-
#include "common/fmha_block.h"
11-
#include "common/tile_scheduler.cuh"
128
#include "device/fmha.cuh"
139
#include "fmha_params.h"
14-
#include "kernel/sm120_kernel_fmha_ws.cuh"
10+
#include "kernel/kernel_builder.h" // IWYU pragma: keep
1511

1612
namespace llm {
17-
// ? Should include ArchTag?
18-
// * select right kernel based on ArchTag?
19-
// ? how to support fast compliling?
13+
// TODO: support fast compliling
2014
// * only compile the kernel for the target compute capability
21-
template <typename Element, int kHeadDim>
15+
template <class ArchTag, typename Element, int kHeadDim>
2216
class FmhaRunner {
2317
public:
2418
static bool run(const FmhaParams& params, cudaStream_t stream = nullptr) {
@@ -64,26 +58,25 @@ class FmhaRunner {
6458

6559
using TileShape = Shape<Int<BLK_M>, Int<BLK_N>, Int<kHeadDim>>;
6660

67-
using Block = FmhaBlock<TileShape, Element, LOCAL>;
68-
69-
using CollectiveMainloop = Sm120CollectiveFMhaWs<TileShape,
70-
Element,
71-
EVEN_K,
72-
ALIBI,
73-
SOFT_CAP,
74-
LOCAL,
75-
KV_USE_TMA>;
76-
using CollectiveEpilogue =
77-
Sm120CollectiveEpilogue<TileShape, Element, EVEN_K>;
78-
79-
// TODO: support persistent kernels
80-
using TileScheduler = SingleTileScheduler;
81-
82-
using AttnKernel = Sm120KernelFmhaWs<ProblemShape,
83-
Block,
84-
CollectiveMainloop,
85-
CollectiveEpilogue,
86-
TileScheduler>;
61+
// (B, Q, H, D)
62+
using StrideQ = Stride<int64_t, int64_t, int64_t, _1>;
63+
using StrideK = Stride<int64_t, int64_t, int64_t, _1>;
64+
using StrideV = StrideK;
65+
using StrideO = StrideQ;
66+
67+
using AttnKernel = typename KernelBuilder<ArchTag,
68+
ProblemShape,
69+
TileShape,
70+
Element,
71+
StrideQ,
72+
StrideK,
73+
StrideV,
74+
StrideO,
75+
EVEN_K,
76+
ALIBI,
77+
SOFT_CAP,
78+
LOCAL,
79+
KV_USE_TMA>::Kernel;
8780

8881
assert(params.n_heads % params.n_kv_heads == 0 &&
8982
"n_heads must be divisible by n_kv_heads");
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
#pragma once
2+
3+
#include <cute/util/type_traits.hpp>
4+
5+
namespace llm {
6+
7+
template <class ArchTag,
8+
class ProblemShape,
9+
class TileShape,
10+
class Element,
11+
class StrideQ,
12+
class StrideK,
13+
class StrideV,
14+
class StrideO,
15+
bool EVEN_K,
16+
bool ALIBI,
17+
bool SOFT_CAP,
18+
bool LOCAL,
19+
bool KV_USE_TMA,
20+
class Enable = void>
21+
struct KernelBuilder {
22+
static_assert(cute::dependent_false<Element>,
23+
"Could not build a kernel for given parameters.");
24+
};
25+
26+
} // namespace llm
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
#pragma once
2+
3+
#include <cutlass/arch/arch.h>
4+
5+
#include <cute/tensor.hpp>
6+
7+
#include "collective/sm120_collective_epilogue.cuh"
8+
#include "collective/sm120_collective_fmha_mainloop_ws.cuh"
9+
#include "common/fmha_block.h"
10+
#include "common/tile_scheduler.cuh"
11+
#include "kernel/sm120_kernel_fmha_ws.cuh"
12+
#include "kernel_builder_decl.h"
13+
14+
namespace llm {
15+
16+
template <class ProblemShape,
17+
class TileShape,
18+
class Element,
19+
class StrideQ,
20+
class StrideK,
21+
class StrideV,
22+
class StrideO,
23+
bool EVEN_K,
24+
bool ALIBI,
25+
bool SOFT_CAP,
26+
bool LOCAL,
27+
bool KV_USE_TMA>
28+
struct KernelBuilder<cutlass::arch::Sm120,
29+
ProblemShape,
30+
TileShape,
31+
Element,
32+
StrideQ,
33+
StrideK,
34+
StrideV,
35+
StrideO,
36+
EVEN_K,
37+
ALIBI,
38+
SOFT_CAP,
39+
LOCAL,
40+
KV_USE_TMA,
41+
cute::enable_if_t<not cute::is_tuple_v<Element>>> {
42+
using Block =
43+
FmhaBlock<TileShape, Element, StrideQ, StrideK, StrideV, StrideO, LOCAL>;
44+
45+
using CollectiveMainloop = Sm120CollectiveFMhaWs<TileShape,
46+
Element,
47+
EVEN_K,
48+
ALIBI,
49+
SOFT_CAP,
50+
LOCAL,
51+
KV_USE_TMA>;
52+
using CollectiveEpilogue =
53+
Sm120CollectiveEpilogue<TileShape, Element, EVEN_K>;
54+
55+
// TODO: support persistent kernels
56+
using TileScheduler = SingleTileScheduler;
57+
58+
using Kernel = Sm120KernelFmhaWs<ProblemShape,
59+
Block,
60+
CollectiveMainloop,
61+
CollectiveEpilogue,
62+
TileScheduler>;
63+
};
64+
65+
} // namespace llm
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
#pragma once
2+
3+
// kernel builder declarations
4+
#include "builders/kernel_builder_decl.h" // IWYU pragma: keep
5+
6+
// kernel build implementations
7+
#include "builders/sm120_kernel_builder.inl" // IWYU pragma: keep

src/kernels/attention/tests/sm120_fmha_test.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,10 +89,10 @@ torch::Tensor sm120_fmha(
8989
: nullptr;
9090

9191
// params.max_q_len = max_q_len;
92-
92+
using ArchTag = cutlass::arch::Sm120;
9393
DISPATCH_TORCH_DTYPE_(query.dtype(), Dtype, [&] {
9494
DISPATCH_HEAD_DIM_(head_dim, HEAD_DIM, [&] {
95-
FmhaRunner<Dtype, HEAD_DIM>::run(params, /*stream=*/nullptr);
95+
FmhaRunner<ArchTag, Dtype, HEAD_DIM>::run(params, /*stream=*/nullptr);
9696
});
9797
});
9898
return out;

0 commit comments

Comments
 (0)