Skip to content

Commit a6af766

Browse files
authored
kernel: seperate oob iterations for better performance. (#384)
1 parent 973c9b5 commit a6af766

File tree

5 files changed

+176
-71
lines changed

5 files changed

+176
-71
lines changed

src/kernels/attention/attention_kernel_sm80.cuh

Lines changed: 104 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -57,9 +57,9 @@ __global__ void mha_kernel_sm80(__grid_constant__ const Params params) {
5757
using SmemTiledCopyO = typename Traits::SmemTiledCopyO;
5858

5959
const int m_block = blockIdx.x;
60-
const auto batch_idx = blockIdx.y;
61-
const auto head_idx = blockIdx.z;
62-
const auto tidx = threadIdx.x;
60+
const int batch_idx = blockIdx.y;
61+
const int head_idx = blockIdx.z;
62+
const int tidx = threadIdx.x;
6363

6464
AttentionTile<Params> tile(params);
6565

@@ -75,7 +75,7 @@ __global__ void mha_kernel_sm80(__grid_constant__ const Params params) {
7575
const int kv_len = size<0>(K);
7676

7777
if (m_block * kBlockM >= q_len) {
78-
// out of bound, return
78+
// m out of bound, return
7979
return;
8080
}
8181

@@ -134,46 +134,51 @@ __global__ void mha_kernel_sm80(__grid_constant__ const Params params) {
134134
// (BLK_M, HEAD_DIM) -> (blk_m, head_dim)
135135
Tensor cQ = make_identity_tensor(Shape<_BLK_M, _HEAD_DIM>{});
136136
Tensor tQcQ = gmem_thr_copy_Q.partition_S(cQ);
137-
// (BLK_N, HEAD_DIM) -> (blk_n, head_dim)
138-
Tensor cKV = make_identity_tensor(Shape<_BLK_N, _HEAD_DIM>{});
139-
Tensor tKcKV = gmem_thr_copy_KV.partition_S(cKV);
140137

141138
auto produce_q = [&]() {
142139
auto tQgQ = gmem_thr_copy_Q.partition_S(gQ);
143140
auto tQsQ = gmem_thr_copy_Q.partition_D(sQ);
141+
auto max_coord = make_coord(q_len - m_block * kBlockM, head_dim);
144142
safe_copy</*EVEN_MN=*/false, EVEN_K>(
145-
gmem_tiled_copy_Q,
146-
tQgQ,
147-
tQsQ,
148-
tQcQ,
149-
make_coord(q_len - m_block * kBlockM, head_dim));
143+
gmem_tiled_copy_Q, tQgQ, tQsQ, tQcQ, max_coord);
150144
};
151145

152-
// TODO: seperate mask iterations
146+
// (BLK_N, HEAD_DIM) -> (blk_n, head_dim)
147+
Tensor cKV = make_identity_tensor(Shape<_BLK_N, _HEAD_DIM>{});
148+
Tensor tKVcKV = gmem_thr_copy_KV.partition_S(cKV);
149+
153150
Tensor tKsK = gmem_thr_copy_KV.partition_D(sK);
154151
auto produce_k = [&](int ni) {
155152
auto tKgK = gmem_thr_copy_KV.partition_S(gK(_, _, ni));
153+
auto max_coord = make_coord(kv_len - ni * kBlockN, head_dim);
156154
// skip zfill_mn for k since mask will mask out oob with -inf
157155
safe_copy</*EVEN_MN=*/false,
158156
EVEN_K,
159-
/*ZERO_FILL_MN=*/false>(
160-
gmem_tiled_copy_KV,
161-
tKgK,
162-
tKsK,
163-
tKcKV,
164-
make_coord(kv_len - ni * kBlockN, head_dim));
157+
/*ZFILL_MN=*/false>(
158+
gmem_tiled_copy_KV, tKgK, tKsK, tKVcKV, max_coord);
159+
};
160+
161+
auto produce_k_no_oob = [&](int ni) {
162+
auto tKgK = gmem_thr_copy_KV.partition_S(gK(_, _, ni));
163+
auto max_coord = make_coord(kv_len - ni * kBlockN, head_dim);
164+
safe_copy</*EVEN_MN=*/true, EVEN_K>(
165+
gmem_tiled_copy_KV, tKgK, tKsK, tKVcKV, max_coord);
165166
};
166167

167168
Tensor tVsV = gmem_thr_copy_KV.partition_D(sV);
168169
auto produce_v = [&](int ni) {
169170
auto tVgV = gmem_thr_copy_KV.partition_S(gV(_, _, ni));
171+
auto max_coord = make_coord(kv_len - ni * kBlockN, head_dim);
170172
// skipping ZFILL_MN for v may cause nan issue
171173
safe_copy</*EVEN_MN=*/false, EVEN_K>(
172-
gmem_tiled_copy_KV,
173-
tVgV,
174-
tVsV,
175-
tKcKV,
176-
make_coord(kv_len - ni * kBlockN, head_dim));
174+
gmem_tiled_copy_KV, tVgV, tVsV, tKVcKV, max_coord);
175+
};
176+
177+
auto produce_v_no_oob = [&](int ni) {
178+
auto tVgV = gmem_thr_copy_KV.partition_S(gV(_, _, ni));
179+
auto max_coord = make_coord(kv_len - ni * kBlockN, head_dim);
180+
safe_copy</*EVEN_MN=*/true, EVEN_K>(
181+
gmem_tiled_copy_KV, tVgV, tVsV, tKVcKV, max_coord);
177182
};
178183

179184
TiledMma tiled_mma;
@@ -281,84 +286,131 @@ __global__ void mha_kernel_sm80(__grid_constant__ const Params params) {
281286

282287
// wait for smem copy done before gmem copy
283288
__syncthreads();
289+
290+
auto max_coord = make_coord(q_len - m_block * kBlockM, head_dim);
284291
safe_copy</*EVEN_MN=*/false,
285292
EVEN_K,
286-
/*ZERO_FILL_MN=*/false,
287-
/*ZERO_FILL_K=*/false>(
288-
gmem_tiled_copy_O,
289-
tOsO,
290-
tOgO,
291-
tOcO,
292-
make_coord(q_len - m_block * kBlockM, head_dim));
293+
/*ZFILL_MN=*/false,
294+
/*ZFILL_K=*/false>(
295+
gmem_tiled_copy_O, tOsO, tOgO, tOcO, max_coord);
293296
};
294297

298+
// output accumulator, (MMA,MMA_M,MMA_K)
299+
auto tOrAccO = partition_fragment_C(tiled_mma, Shape<_BLK_M, _HEAD_DIM>{});
300+
auto tOrAccO_rc_view =
301+
make_tensor(tOrAccO.data(), Layout::to_rowcol(tOrAccO.layout()));
302+
clear(tOrAccO);
303+
295304
const int diagonal = m_block * kBlockM + kv_len - q_len;
296305
// process kv in range: [kv_idx_min, kv_idx_max)
297306
const int kv_idx_min = std::max(0, diagonal - sliding_window);
298307
const int kv_idx_max = std::min(kv_len, diagonal + kBlockM);
299308
const int n_block_min = LOCAL ? kv_idx_min / kBlockN : 0;
300309
const int n_block_max = cute::ceil_div(kv_idx_max, kBlockN);
301-
// TODO: handle n_block_min >= n_block_max
302310

303-
// ############### Prologue ###############
311+
if (n_block_min >= n_block_max) {
312+
// write output to gmem
313+
epilogue(tOrAccO);
314+
return;
315+
}
304316

317+
// ############### Prologue ###############
318+
int n_block_idx = n_block_max - 1;
305319
// produce q: [] => [q]
306320
produce_q();
307321
cp_async_fence();
308322
// produce k: [q] => [q, k]
309-
produce_k(n_block_min);
323+
produce_k(n_block_idx);
310324
cp_async_fence();
311325

312326
// ############### Mainloop ###############
313327

314-
// output accumulator, (MMA,MMA_M,MMA_K)
315-
auto tOrAccO = partition_fragment_C(tiled_mma, Shape<_BLK_M, _HEAD_DIM>{});
316-
auto tOrAccO_rc_view =
317-
make_tensor(tOrAccO.data(), Layout::to_rowcol(tOrAccO.layout()));
328+
OnlineSoftmax<kRowsPerMMA * size<1>(tOrAccO)> softmax(sm_scale_log2);
329+
Mask<kBlockM, kBlockM, ALIBI, LOCAL> mask(
330+
q_len, kv_len, sliding_window, alibi_slope);
318331

319332
// attention score accumulator, (MMA,MMA_M,MMA_N)
320333
auto tSrAccS = partition_fragment_C(tiled_mma, Shape<_BLK_M, _BLK_N>{});
321334
auto tSrAccS_rc_view =
322335
make_tensor(tSrAccS.data(), Layout::to_rowcol(tSrAccS.layout()));
336+
// seperate oob mask iterations for better performance
337+
constexpr int n_oob_mask = cute::ceil_div(kBlockM, kBlockN) + 1;
323338

324-
OnlineSoftmax<kRowsPerMMA * size<1>(tOrAccO)> softmax(sm_scale_log2);
325-
Mask<kBlockM, kBlockM, ALIBI, LOCAL> mask(
326-
q_len, kv_len, sliding_window, alibi_slope);
327-
328-
clear(tOrAccO);
329-
CUTE_NO_UNROLL
330-
for (int ni = n_block_min; ni < n_block_max; ++ni) {
339+
// oob mask iterations
340+
CUTE_UNROLL
341+
for (int i = 0; i < n_oob_mask; ++i) {
331342
clear(tSrAccS);
332343

333344
// wait k, queue: [q, k] => []
334345
cp_async_wait<0>();
335346
__syncthreads();
336347

337348
// produce v, [] => [v]
338-
produce_v(ni);
349+
if (i == 0) {
350+
produce_v(n_block_idx);
351+
} else {
352+
produce_v_no_oob(n_block_idx);
353+
}
339354
cp_async_fence();
340355

341356
342357
compute_qk(tSrAccS);
343358

344-
// apply soft cap if needed
345359
if constexpr (SOFT_CAP) {
346360
apply_logits_soft_cap(tSrAccS);
347361
}
362+
mask.apply(tSrAccS_rc_view, m_block, n_block_idx, tidx);
363+
softmax.rescale(tSrAccS_rc_view, tOrAccO_rc_view);
348364

349-
// apply mask for block (m_block, ni)
350-
mask.apply(tSrAccS_rc_view, m_block, ni, tidx);
365+
// wait v, [v] => []
366+
cp_async_wait<0>();
367+
__syncthreads();
368+
369+
// produce next k: [] => [k]
370+
if (n_block_idx > n_block_min) {
371+
produce_k_no_oob(n_block_idx - 1);
372+
}
373+
cp_async_fence();
374+
375+
// 2> O = softmax(S)*V
376+
compute_sv(tSrAccS, tOrAccO);
377+
378+
--n_block_idx;
379+
if (n_block_idx < n_block_min) {
380+
// no more kv blocks to process
381+
break;
382+
}
383+
}
351384

352-
// apply softmax and rescale
385+
// non-oob mask iterations
386+
CUTE_NO_UNROLL
387+
for (; n_block_idx >= n_block_min; --n_block_idx) {
388+
clear(tSrAccS);
389+
390+
// wait k, queue: [q, k] => []
391+
cp_async_wait<0>();
392+
__syncthreads();
393+
394+
// produce v, [] => [v]
395+
produce_v_no_oob(n_block_idx);
396+
cp_async_fence();
397+
398+
399+
compute_qk(tSrAccS);
400+
401+
if constexpr (SOFT_CAP) {
402+
apply_logits_soft_cap(tSrAccS);
403+
}
404+
mask.apply</*OOB_MASK=*/false>(tSrAccS_rc_view, m_block, n_block_idx, tidx);
353405
softmax.rescale(tSrAccS_rc_view, tOrAccO_rc_view);
354406

355407
// wait v, [v] => []
356408
cp_async_wait<0>();
357409
__syncthreads();
358410

359411
// produce next k: [] => [k]
360-
if (ni != n_block_max - 1) {
361-
produce_k(ni + 1);
412+
if (n_block_idx > n_block_min) {
413+
produce_k_no_oob(n_block_idx - 1);
362414
}
363415
cp_async_fence();
364416

src/kernels/attention/attention_kernel_sm80_pagedkv_test.cu

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,21 @@
66
#include "attention_params.h"
77
#include "attention_ref.h"
88
#include "cute/layout.hpp"
9-
#include "static_dispatch.h"
109

1110
namespace llm {
11+
#define DISPATCH_HEAD_DIM_(HEAD_DIM_V, HEAD_DIM_NAME, ...) \
12+
[&] { \
13+
if (HEAD_DIM_V <= 64) { \
14+
constexpr static int HEAD_DIM_NAME = 64; \
15+
return __VA_ARGS__(); \
16+
} else if (HEAD_DIM_V <= 256) { \
17+
constexpr static int HEAD_DIM_NAME = 256; \
18+
return __VA_ARGS__(); \
19+
} else { \
20+
assert(false); \
21+
} \
22+
}()
23+
1224
namespace {
1325
torch::Tensor attention_pagedkv_sm80(
1426
torch::Tensor query, // [q_seq_len, n_heads, head_dim]
@@ -61,10 +73,8 @@ torch::Tensor attention_pagedkv_sm80(
6173
params.block_cu_lens = block_cu_lens.const_data_ptr<int32_t>();
6274
params.block_size = block_size;
6375

64-
DISPATCH_TORCH_DTYPE(query.dtype(), DTYPE, [&] {
65-
DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, [&] {
66-
run_attention_kernel_sm80<DTYPE, HEAD_DIM>(params);
67-
});
76+
DISPATCH_HEAD_DIM_(head_dim, HEAD_DIM, [&] {
77+
run_attention_kernel_sm80<cute::half_t, HEAD_DIM>(params);
6878
});
6979
return out;
7080
}

src/kernels/attention/attention_kernel_sm80_test.cu

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,32 @@
1010
#include "static_dispatch.h"
1111

1212
namespace llm {
13+
#define DISPATCH_HEAD_DIM_(HEAD_DIM_V, HEAD_DIM_NAME, ...) \
14+
[&] { \
15+
if (HEAD_DIM_V <= 64) { \
16+
constexpr static int HEAD_DIM_NAME = 64; \
17+
return __VA_ARGS__(); \
18+
} else if (HEAD_DIM_V <= 256) { \
19+
constexpr static int HEAD_DIM_NAME = 256; \
20+
return __VA_ARGS__(); \
21+
} else { \
22+
assert(false); \
23+
} \
24+
}()
25+
26+
#define DISPATCH_TORCH_DTYPE_(TORCH_DTYPE, TYPE_NAME, ...) \
27+
[&] { \
28+
if (TORCH_DTYPE == torch::kHalf) { \
29+
using TYPE_NAME = cute::half_t; \
30+
return __VA_ARGS__(); \
31+
} else if (TORCH_DTYPE == torch::kBFloat16) { \
32+
using TYPE_NAME = cute::bfloat16_t; \
33+
return __VA_ARGS__(); \
34+
} else { \
35+
assert(false); \
36+
} \
37+
}()
38+
1339
namespace {
1440
torch::Tensor attention_sm80(
1541
torch::Tensor query, // [batch_size, q_len, n_heads, head_dim]
@@ -57,8 +83,8 @@ torch::Tensor attention_sm80(
5783
params.logits_soft_cap = logits_soft_cap;
5884
params.sliding_window = sliding_window;
5985

60-
DISPATCH_TORCH_DTYPE(query.dtype(), DTYPE, [&] {
61-
DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, [&] {
86+
DISPATCH_TORCH_DTYPE_(query.dtype(), DTYPE, [&] {
87+
DISPATCH_HEAD_DIM_(head_dim, HEAD_DIM, [&] {
6288
run_attention_kernel_sm80<DTYPE, HEAD_DIM>(params);
6389
});
6490
});

src/kernels/attention/attention_kernel_sm80_varlen_test.cu

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,21 @@
66
#include "attention_params.h"
77
#include "attention_ref.h"
88
#include "cute/layout.hpp"
9-
#include "static_dispatch.h"
109

1110
namespace llm {
11+
#define DISPATCH_HEAD_DIM_(HEAD_DIM_V, HEAD_DIM_NAME, ...) \
12+
[&] { \
13+
if (HEAD_DIM_V <= 64) { \
14+
constexpr static int HEAD_DIM_NAME = 64; \
15+
return __VA_ARGS__(); \
16+
} else if (HEAD_DIM_V <= 256) { \
17+
constexpr static int HEAD_DIM_NAME = 256; \
18+
return __VA_ARGS__(); \
19+
} else { \
20+
assert(false); \
21+
} \
22+
}()
23+
1224
namespace {
1325
torch::Tensor attention_varlen_sm80(
1426
torch::Tensor query, // [q_len, n_heads, head_dim]
@@ -54,10 +66,8 @@ torch::Tensor attention_varlen_sm80(
5466
params.q_cu_lens = q_cu_lens.const_data_ptr<int32_t>();
5567
params.kv_cu_lens = kv_cu_lens.const_data_ptr<int32_t>();
5668

57-
DISPATCH_TORCH_DTYPE(query.dtype(), DTYPE, [&] {
58-
DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, [&] {
59-
run_attention_kernel_sm80<DTYPE, HEAD_DIM>(params);
60-
});
69+
DISPATCH_HEAD_DIM_(head_dim, HEAD_DIM, [&] {
70+
run_attention_kernel_sm80<cute::half_t, HEAD_DIM>(params);
6171
});
6272
return out;
6373
}

0 commit comments

Comments
 (0)