Skip to content

Commit 2d42a38

Browse files
authored
feat: added single tile scheduler for attn kernel (#473)
1 parent e897b6c commit 2d42a38

File tree

5 files changed

+162
-74
lines changed

5 files changed

+162
-74
lines changed

src/kernels/attention/sm80_collective_epilogue.cuh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ struct Sm80CollectiveEpilogue {
8383
char* smem) {
8484
static constexpr int kBlockM = get<0>(TileShape{});
8585

86-
const auto [m_block_idx, batch_idx, kv_head_idx] = block_coord_mnk;
86+
const auto [batch_idx, m_block_idx, kv_head_idx] = block_coord_mnk;
8787
const auto [q_packed_len, kv_len, head_dim] = problem_shape_mnk;
8888

8989
// Smem

src/kernels/attention/sm80_collective_mha.cuh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ struct Sm80CollectiveMha {
169169
static constexpr int kBlockM = get<0>(TileShape{});
170170
static constexpr int kBlockN = get<1>(TileShape{});
171171

172-
const auto [m_block_idx, batch_idx, kv_head_idx] = block_coord_mnk;
172+
const auto [batch_idx, m_block_idx, kv_head_idx] = block_coord_mnk;
173173
const auto [q_packed_len, kv_len, head_dim] = problem_shape_mnk;
174174

175175
const int sliding_window = LOCAL ? params.sliding_window : kv_len;

src/kernels/attention/sm80_kernel_mha.cuh

Lines changed: 78 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,14 @@ namespace llm {
1313

1414
using namespace cute;
1515

16-
template <class CollectiveMainloop_, class CollectiveEpilogue_>
16+
template <class CollectiveMainloop_,
17+
class CollectiveEpilogue_,
18+
class TileScheduler_>
1719
class Sm80KernelMha {
1820
public:
1921
using CollectiveMainloop = CollectiveMainloop_;
2022
using CollectiveEpilogue = CollectiveEpilogue_;
23+
using TileScheduler = TileScheduler_;
2124

2225
using TiledMma = typename CollectiveMainloop::TiledMma;
2326

@@ -39,45 +42,22 @@ class Sm80KernelMha {
3942
// Kernel params
4043
using MainloopParams = typename CollectiveMainloop::Params;
4144
using EpilogueParams = typename CollectiveEpilogue::Params;
45+
using TileSchedulerParams = typename TileScheduler::Params;
46+
47+
// returns grid and block shape for kernel launch
48+
using TileSchedulerArgs = typename TileScheduler::Arguments;
49+
static dim3 get_grid_shape(TileSchedulerArgs const& args) {
50+
return TileScheduler::get_grid_shape(args);
51+
}
52+
static dim3 get_block_shape() { return kMmaThreads; }
4253

4354
template <class Params>
44-
CUTE_DEVICE void operator()(const Params& params, char* smem) {
55+
CUTE_DEVICE void operator()(const Params& params,
56+
const TileSchedulerParams& scheduler_params,
57+
char* smem) {
4558
CollectiveMainloop mha;
4659
CollectiveEpilogue epilogue;
47-
48-
const auto tidx = threadIdx.x;
49-
50-
// block coord
51-
const int m_block_idx = blockIdx.x;
52-
const int batch_idx = blockIdx.y;
53-
const int kv_head_idx = blockIdx.z;
54-
auto block_coord_mnk = make_coord(m_block_idx, batch_idx, kv_head_idx);
55-
56-
// (q_packed_len, HEAD_DIM)
57-
MHATile<Params> tile(params, batch_idx, kv_head_idx);
58-
auto [Q, O] = tile.template get_qo_tile<Element>();
59-
// (kv_len, HEAD_DIM)
60-
auto [K, V] = tile.template get_kv_tile<Element>();
61-
62-
// problem shape
63-
const int q_packed_len = size<0>(Q);
64-
const int kv_len = size<0>(K);
65-
const int head_dim = params.head_dim;
66-
auto problem_shape_mnk = make_shape(q_packed_len, kv_len, head_dim);
67-
68-
if (m_block_idx * kBlockM >= q_packed_len) {
69-
// m out of bound, return
70-
return;
71-
}
72-
73-
// (BLK_M, HEAD_DIM)
74-
Tensor gQ =
75-
local_tile(Q, Shape<BLK_M, HEAD_DIM>{}, make_coord(m_block_idx, _0{}));
76-
Tensor gO =
77-
local_tile(O, Shape<BLK_M, HEAD_DIM>{}, make_coord(m_block_idx, _0{}));
78-
// (BLK_N, HEAD_DIM, n)
79-
Tensor gK = local_tile(K, Shape<BLK_N, HEAD_DIM>{}, make_coord(_, _0{}));
80-
Tensor gV = local_tile(V, Shape<BLK_N, HEAD_DIM>{}, make_coord(_, _0{}));
60+
TileScheduler scheduler(scheduler_params);
8161

8262
// construct params
8363
MainloopParams mainloop_params{params.sliding_window,
@@ -88,35 +68,68 @@ class Sm80KernelMha {
8868
params.group_size};
8969
EpilogueParams epilogue_params;
9070

91-
TiledMma tiled_mma;
92-
// accumulator: MMA,MMA_M,MMA_K)
93-
auto tOrAccO = partition_fragment_C(tiled_mma, Shape<BLK_M, HEAD_DIM>{});
94-
clear(tOrAccO);
95-
96-
constexpr int kRowsPerThr = kRowsPerMMA * size<1>(tOrAccO);
97-
OnlineSoftmax<kRowsPerThr> softmax(params.sm_scale_log2);
98-
99-
// mainloop
100-
mha(mainloop_params,
101-
gQ,
102-
gK,
103-
gV,
104-
tOrAccO,
105-
softmax,
106-
tidx,
107-
block_coord_mnk,
108-
problem_shape_mnk,
109-
smem);
110-
111-
// epilogue
112-
epilogue(epilogue_params,
113-
tOrAccO,
114-
tiled_mma,
115-
gO,
116-
tidx,
117-
block_coord_mnk,
118-
problem_shape_mnk,
119-
smem);
71+
// process each block
72+
for (const auto block_coord : scheduler) {
73+
// block coord: (batch_idx, m_block_idx, kv_head_idx)
74+
const auto [batch_idx, m_block_idx, kv_head_idx] = block_coord;
75+
const auto tidx = threadIdx.x;
76+
77+
// (q_packed_len, HEAD_DIM)
78+
MHATile<Params> tile(params, batch_idx, kv_head_idx);
79+
auto [Q, O] = tile.template get_qo_tile<Element>();
80+
// (kv_len, HEAD_DIM)
81+
auto [K, V] = tile.template get_kv_tile<Element>();
82+
83+
// problem shape
84+
const int q_packed_len = size<0>(Q);
85+
const int kv_len = size<0>(K);
86+
const int head_dim = params.head_dim;
87+
auto problem_shape_mnk = make_shape(q_packed_len, kv_len, head_dim);
88+
89+
if (m_block_idx * kBlockM >= q_packed_len) {
90+
// m out of bound, skip this block
91+
continue;
92+
}
93+
94+
// (BLK_M, HEAD_DIM)
95+
Tensor gQ = local_tile(
96+
Q, Shape<BLK_M, HEAD_DIM>{}, make_coord(m_block_idx, _0{}));
97+
Tensor gO = local_tile(
98+
O, Shape<BLK_M, HEAD_DIM>{}, make_coord(m_block_idx, _0{}));
99+
// (BLK_N, HEAD_DIM, n)
100+
Tensor gK = local_tile(K, Shape<BLK_N, HEAD_DIM>{}, make_coord(_, _0{}));
101+
Tensor gV = local_tile(V, Shape<BLK_N, HEAD_DIM>{}, make_coord(_, _0{}));
102+
103+
TiledMma tiled_mma;
104+
// accumulator: MMA,MMA_M,MMA_K)
105+
auto tOrAccO = partition_fragment_C(tiled_mma, Shape<BLK_M, HEAD_DIM>{});
106+
clear(tOrAccO);
107+
108+
constexpr int kRowsPerThr = kRowsPerMMA * size<1>(tOrAccO);
109+
OnlineSoftmax<kRowsPerThr> softmax(params.sm_scale_log2);
110+
111+
// mainloop
112+
mha(mainloop_params,
113+
gQ,
114+
gK,
115+
gV,
116+
tOrAccO,
117+
softmax,
118+
tidx,
119+
block_coord,
120+
problem_shape_mnk,
121+
smem);
122+
123+
// epilogue
124+
epilogue(epilogue_params,
125+
tOrAccO,
126+
tiled_mma,
127+
gO,
128+
tidx,
129+
block_coord,
130+
problem_shape_mnk,
131+
smem);
132+
}
120133
}
121134
};
122135

src/kernels/attention/sm80_mha_launch.cuh

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,17 +9,20 @@
99
#include "sm80_collective_epilogue.cuh"
1010
#include "sm80_collective_mha.cuh"
1111
#include "sm80_kernel_mha.cuh"
12+
#include "tile_scheduler.cuh"
1213

1314
namespace llm {
1415

1516
namespace detail {
1617
/// Generic kernel template.
1718
template <typename Operator, typename Params>
1819
__global__ __launch_bounds__(Operator::kMmaThreads) void device_kernel(
19-
__grid_constant__ const Params params) {
20+
__grid_constant__ const Params params,
21+
__grid_constant__ const typename Operator::TileSchedulerParams
22+
scheduler_params) {
2023
extern __shared__ char smem[];
2124
Operator op;
22-
op(params, smem);
25+
op(params, scheduler_params, smem);
2326
}
2427
} // namespace detail
2528

@@ -61,7 +64,17 @@ void sm80_launch_mha_kernel(const Params& params, cudaStream_t stream) {
6164
using CollectiveEpilogue =
6265
Sm80CollectiveEpilogue<TileShape, Dtype, HEAD_DIM, EVEN_K>;
6366

64-
using AttnKernel = Sm80KernelMha<CollectiveMainloop, CollectiveEpilogue>;
67+
// TODO: support persistent kernels
68+
using TileScheduler = SingleTileScheduler;
69+
70+
const auto m_blocks = cute::ceil_div(max_q_packed_len, BLK_M);
71+
typename TileScheduler::Arguments scheduler_args{
72+
batch_size, m_blocks, n_kv_heads};
73+
auto scheduler_params =
74+
TileScheduler::to_underlying_arguments(scheduler_args);
75+
76+
using AttnKernel =
77+
Sm80KernelMha<CollectiveMainloop, CollectiveEpilogue, TileScheduler>;
6578

6679
auto mha_kernel = detail::device_kernel<AttnKernel, Params>;
6780

@@ -71,11 +84,10 @@ void sm80_launch_mha_kernel(const Params& params, cudaStream_t stream) {
7184
mha_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
7285
}
7386

74-
// TODO: support persistent kernels
75-
dim3 grid(cute::ceil_div(max_q_packed_len, BLK_M), batch_size, n_kv_heads);
76-
dim3 block = AttnKernel::kMmaThreads;
87+
const dim3 grid = AttnKernel::get_grid_shape(scheduler_args);
88+
const dim3 block = AttnKernel::get_block_shape();
7789

78-
mha_kernel<<<grid, block, smem_size, stream>>>(params);
90+
mha_kernel<<<grid, block, smem_size, stream>>>(params, scheduler_params);
7991
// TODO: check launch status
8092
}
8193

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
#pragma once
2+
3+
#include <cuda.h>
4+
#include <cuda_runtime.h>
5+
6+
#include <cute/layout.hpp>
7+
#include <cute/tensor.hpp>
8+
9+
namespace llm {
10+
11+
class SingleTileScheduler {
12+
public:
13+
// Host side kernel arguments
14+
struct Arguments {
15+
int batch_size = 0;
16+
int m_blocks = 0;
17+
int n_kv_heads = 0;
18+
};
19+
static dim3 get_grid_shape(Arguments const& args) {
20+
return {(uint32_t)args.batch_size,
21+
(uint32_t)args.m_blocks,
22+
(uint32_t)args.n_kv_heads};
23+
}
24+
25+
// Device side kernel params
26+
using Params = Arguments;
27+
static Params to_underlying_arguments(const Arguments& args) { return args; }
28+
29+
// End Iterator tag
30+
class EndIterator {};
31+
class Iterator {
32+
public:
33+
CUTE_DEVICE
34+
Iterator() = default;
35+
36+
CUTE_DEVICE
37+
dim3 operator*() const { return blockIdx; }
38+
39+
CUTE_DEVICE
40+
Iterator& operator++() {
41+
valid_ = false;
42+
return *this;
43+
}
44+
45+
// compare against end iterator
46+
CUTE_DEVICE
47+
bool operator!=(const EndIterator&) const { return valid_; }
48+
49+
private:
50+
bool valid_ = true;
51+
};
52+
53+
CUTE_DEVICE
54+
SingleTileScheduler(const Params& params) {}
55+
56+
CUTE_DEVICE
57+
Iterator begin() const { return {}; }
58+
59+
CUTE_DEVICE
60+
EndIterator end() const { return {}; }
61+
};
62+
63+
} // namespace llm

0 commit comments

Comments
 (0)