Skip to content

Commit 7044f66

Browse files
authored
kernel: fix register spilling issue for attention head_dim=256 (#397)
1 parent 29a9b31 commit 7044f66

File tree

5 files changed

+72
-115
lines changed

5 files changed

+72
-115
lines changed

src/kernels/attention/mha_kernel_sm80.cuh

Lines changed: 25 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -299,21 +299,6 @@ __global__ void mha_kernel_sm80(__grid_constant__ const Params params) {
299299
return;
300300
}
301301

302-
// ############### Prologue ###############
303-
int n_block_idx = n_block_max - 1;
304-
// produce query: [] => [q]
305-
produce_query();
306-
cp_async_fence();
307-
// produce key: [q] => [q, k]
308-
produce_key(n_block_idx);
309-
cp_async_fence();
310-
311-
// ############### Mainloop ###############
312-
// attention score accumulator, (MMA,MMA_M,MMA_N)
313-
auto tSrAccS = partition_fragment_C(tiled_mma, Shape<_BLK_M, _BLK_N>{});
314-
auto tSrAccS_rc_view =
315-
make_tensor(tSrAccS.data(), Layout::to_rowcol(tSrAccS.layout()));
316-
317302
auto apply_logits_soft_cap = [&](auto& tSrAccS) {
318303
if constexpr (SOFT_CAP) {
319304
CUTE_UNROLL
@@ -323,7 +308,7 @@ __global__ void mha_kernel_sm80(__grid_constant__ const Params params) {
323308
}
324309
};
325310

326-
constexpr int kMMA_M = size<1>(tSrAccS);
311+
constexpr int kMMA_M = size<1>(tOrAccO);
327312
using Softmax = OnlineSoftmax<kRowsPerMMA * kMMA_M>;
328313
using Mask = Mask<kBlockM, kBlockM, kRowsPerMMA, kMMA_M, ALIBI, LOCAL>;
329314

@@ -338,12 +323,26 @@ __global__ void mha_kernel_sm80(__grid_constant__ const Params params) {
338323
sm_scale,
339324
params.alibi_slopes_ptr);
340325

341-
// seperate oob mask iterations for better performance
326+
// ############### Prologue ###############
327+
// produce query: [] => [q]
328+
produce_query();
329+
cp_async_fence();
330+
// produce key: [q] => [q, k]
331+
produce_key(n_block_max - 1);
332+
cp_async_fence();
333+
334+
// ############### Mainloop ###############
342335
constexpr int n_oob_mask = cute::ceil_div(kBlockM, kBlockN) + 1;
336+
const int n_blocks = n_block_max - n_block_min;
343337

344-
// oob mask iterations
345-
CUTE_UNROLL
346-
for (int i = 0; i < n_oob_mask; ++i) {
338+
CUTE_NO_UNROLL
339+
for (int i = 0; i < n_blocks; ++i) {
340+
const int n_block_idx = n_block_max - 1 - i;
341+
342+
// attention score accumulator, (MMA,MMA_M,MMA_N)
343+
auto tSrAccS = partition_fragment_C(tiled_mma, Shape<_BLK_M, _BLK_N>{});
344+
auto tSrAccS_rc_view =
345+
make_tensor(tSrAccS.data(), Layout::to_rowcol(tSrAccS.layout()));
347346
clear(tSrAccS);
348347

349348
// wait key, queue: [q, k] => []
@@ -361,57 +360,20 @@ __global__ void mha_kernel_sm80(__grid_constant__ const Params params) {
361360
362361
compute_qk(tSrAccS);
363362

364-
if constexpr (SOFT_CAP) {
365-
apply_logits_soft_cap(tSrAccS);
366-
}
367-
mask.apply(tSrAccS_rc_view, n_block_idx);
368-
softmax.rescale(tSrAccS_rc_view, tOrAccO_rc_view);
369-
370363
// wait value, [v] => []
371364
cp_async_wait<0>();
372365
__syncthreads();
373366

374-
// produce next key: [] => [k]
375-
if (n_block_idx > n_block_min) {
376-
produce_key_no_oob(n_block_idx - 1);
377-
}
378-
cp_async_fence();
379-
380-
// 2> O = softmax(S)*V
381-
compute_sv(tSrAccS, tOrAccO);
382-
383-
--n_block_idx;
384-
if (n_block_idx < n_block_min) {
385-
// no more kv blocks to process
386-
break;
387-
}
388-
}
389-
390-
// non-oob mask iterations
391-
CUTE_NO_UNROLL
392-
for (; n_block_idx >= n_block_min; --n_block_idx) {
393-
clear(tSrAccS);
394-
395-
// wait key, queue: [q, k] => []
396-
cp_async_wait<0>();
397-
__syncthreads();
398-
399-
// produce value, [] => [v]
400-
produce_value_no_oob(n_block_idx);
401-
cp_async_fence();
402-
403-
404-
compute_qk(tSrAccS);
405-
406367
if constexpr (SOFT_CAP) {
407368
apply_logits_soft_cap(tSrAccS);
408369
}
409-
mask.apply</*OOB_MASK=*/false>(tSrAccS_rc_view, n_block_idx);
410-
softmax.rescale(tSrAccS_rc_view, tOrAccO_rc_view);
411370

412-
// wait value, [v] => []
413-
cp_async_wait<0>();
414-
__syncthreads();
371+
if (i < n_oob_mask) {
372+
mask.apply(tSrAccS_rc_view, n_block_idx);
373+
} else {
374+
mask.apply</*OOB_MASK=*/false>(tSrAccS_rc_view, n_block_idx);
375+
}
376+
softmax.rescale(tSrAccS_rc_view, tOrAccO_rc_view);
415377

416378
// produce next key: [] => [k]
417379
if (n_block_idx > n_block_min) {

src/kernels/attention/mha_sm80_bench.cu

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7,22 +7,10 @@
77
#include "mha_dispatch_sm80.cuh"
88
#include "mha_kernel_sm80.cuh" // IWYU pragma: keep
99
#include "mha_params.h"
10+
#include "static_dispatch.h"
1011

1112
using namespace llm;
1213

13-
#define DISPATCH_HEAD_DIM_(HEAD_DIM_V, HEAD_DIM_NAME, ...) \
14-
[&] { \
15-
if (HEAD_DIM_V <= 64) { \
16-
constexpr static int HEAD_DIM_NAME = 64; \
17-
return __VA_ARGS__(); \
18-
} else if (HEAD_DIM_V <= 128) { \
19-
constexpr static int HEAD_DIM_NAME = 128; \
20-
return __VA_ARGS__(); \
21-
} else { \
22-
assert(false); \
23-
} \
24-
}()
25-
2614
void mha_bench_sm80(nvbench::state& state) {
2715
// Collect CUPTI metrics
2816
state.collect_cupti_metrics();
@@ -82,7 +70,7 @@ void mha_bench_sm80(nvbench::state& state) {
8270
params.sliding_window = sliding_window;
8371

8472
state.exec([&](nvbench::launch& launch) {
85-
DISPATCH_HEAD_DIM_(head_dim, HEAD_DIM, [&] {
73+
DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, [&] {
8674
run_mha_kernel_sm80<cute::half_t, HEAD_DIM>(params, launch.get_stream());
8775
});
8876
});

src/kernels/attention/mha_sm80_pagedkv_bench.cu

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -8,22 +8,10 @@
88
#include "mha_dispatch_sm80.cuh"
99
#include "mha_kernel_sm80.cuh" // IWYU pragma: keep
1010
#include "mha_params.h"
11+
#include "static_dispatch.h"
1112

1213
using namespace llm;
1314

14-
#define DISPATCH_HEAD_DIM_(HEAD_DIM_V, HEAD_DIM_NAME, ...) \
15-
[&] { \
16-
if (HEAD_DIM_V <= 64) { \
17-
constexpr static int HEAD_DIM_NAME = 64; \
18-
return __VA_ARGS__(); \
19-
} else if (HEAD_DIM_V <= 128) { \
20-
constexpr static int HEAD_DIM_NAME = 128; \
21-
return __VA_ARGS__(); \
22-
} else { \
23-
assert(false); \
24-
} \
25-
}()
26-
2715
void mha_bench_sm80(nvbench::state& state) {
2816
// Collect CUPTI metrics
2917
state.collect_cupti_metrics();
@@ -130,7 +118,7 @@ void mha_bench_sm80(nvbench::state& state) {
130118
params.block_cu_lens = block_cu_lens.const_data_ptr<int32_t>();
131119

132120
state.exec([&](nvbench::launch& launch) {
133-
DISPATCH_HEAD_DIM_(head_dim, HEAD_DIM, [&] {
121+
DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, [&] {
134122
run_mha_kernel_sm80<cute::half_t, HEAD_DIM>(params, launch.get_stream());
135123
});
136124
});

src/kernels/attention/mha_traits_sm80.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ struct MHATraitsSM80 {
9393
// Tiled copy for QKV
9494
// g2s tiled copy for q
9595
using GmemTiledCopyQ = decltype(make_tiled_copy(
96-
Copy_Atom<SM80_CP_ASYNC_CACHEGLOBAL_ZFILL<cute::uint128_t>, DType>{},
96+
Copy_Atom<SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>, DType>{},
9797
GmemCopyThrLayout{}, // Thr layout: (_16,_8)/(_32, _4)
9898
Layout<Shape<_1, _8>>{} // Val layout: 8 vals per read
9999
));

src/kernels/attention/online_softmax.cuh

Lines changed: 42 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -52,53 +52,72 @@ struct OnlineSoftmax {
5252

5353
// computes the softmax scores and rescales the output
5454
// - score = exp(score - row_max`)
55-
// - O = O * s_scale
55+
// - o = o * s_scale
5656
// - internal: row_sum = row_sum * s_scale + row_sum`
5757
template <typename FragmentS, typename FragmentO>
5858
CUTE_DEVICE void rescale(FragmentS& rAccS, FragmentO& rAccO) {
59+
// row_max = max(row_max, scores)
60+
FragmentT pre_row_max;
61+
cute::copy(row_max_, pre_row_max);
5962
CUTE_UNROLL
6063
for (int si = 0; si < size<0>(rAccS); ++si) {
61-
// rowmax across 4 threads
62-
float cur_rowmax = row_max_(si);
64+
float row_max = row_max_(si);
65+
// rowmax within a thread
6366
CUTE_UNROLL
6467
for (int sj = 0; sj < size<1>(rAccS); ++sj) {
65-
cur_rowmax = max(cur_rowmax, rAccS(si, sj));
68+
row_max = max(row_max, rAccS(si, sj));
6669
}
67-
cur_rowmax = detail::group_reduce_max<4>(cur_rowmax);
70+
// rowmax across 4 threads
71+
row_max_(si) = detail::group_reduce_max<4>(row_max);
72+
}
6873

69-
// scores = exp(scores - row_max)
70-
const float rowmax_scale = cur_rowmax * sm_scale_;
71-
float cur_rowsum = 0;
74+
// o = o * s_scale
75+
CUTE_UNROLL
76+
for (int si = 0; si < size<0>(rAccO); ++si) {
77+
const float s_scale =
78+
ptx::exp2((pre_row_max(si) - row_max_(si)) * sm_scale_);
79+
CUTE_UNROLL
80+
for (int sj = 0; sj < size<1>(rAccO); ++sj) {
81+
rAccO(si, sj) *= s_scale;
82+
}
83+
}
84+
85+
// scores = exp(scores - row_max)
86+
CUTE_UNROLL
87+
for (int si = 0; si < size<0>(rAccS); ++si) {
88+
const float rowmax_scale = row_max_(si) * sm_scale_;
7289
CUTE_UNROLL
7390
for (int sj = 0; sj < size<1>(rAccS); sj++) {
7491
rAccS(si, sj) = ptx::exp2(rAccS(si, sj) * sm_scale_ - rowmax_scale);
75-
cur_rowsum += rAccS(si, sj);
7692
}
93+
}
7794

78-
// scores_scale = exp(max - cur_rowmax)
79-
const float scores_scale =
80-
ptx::exp2(row_max_(si) * sm_scale_ - rowmax_scale);
81-
// o_2 = o_1 * s_scale
95+
// row_sum = row_sum * s_scale + row_sum`
96+
CUTE_UNROLL
97+
for (int si = 0; si < size<0>(rAccS); ++si) {
98+
const float s_scale =
99+
ptx::exp2((pre_row_max(si) - row_max_(si)) * sm_scale_);
100+
row_sum_(si) *= s_scale;
82101
CUTE_UNROLL
83-
for (int sj = 0; sj < size<1>(rAccO); ++sj) {
84-
rAccO(si, sj) *= scores_scale;
102+
for (int sj = 0; sj < size<1>(rAccS); sj++) {
103+
// rowsum within a thread
104+
row_sum_(si) += rAccS(si, sj);
85105
}
86-
87-
// update row_max and row_sum
88-
row_max_(si) = cur_rowmax;
89-
// s_2 = s_1 * s_scale + row_sum
90-
row_sum_(si) = row_sum_(si) * scores_scale + cur_rowsum;
91106
}
92107
}
93108

94-
// finalizes the softmax computation with O = O / row_sum
109+
// finalizes the softmax computation with o = o / row_sum
95110
template <typename FragmentO>
96111
CUTE_DEVICE void finalize(FragmentO& rAccO) {
97112
CUTE_UNROLL
98-
for (int oi = 0; oi < size<0>(rAccO); ++oi) {
113+
for (int i = 0; i < size(row_sum_); ++i) {
99114
// rowsum across 4 threads
100-
row_sum_(oi) = detail::group_reduce_sum<4>(row_sum_(oi));
115+
row_sum_(i) = detail::group_reduce_sum<4>(row_sum_(i));
116+
}
101117

118+
// o = o / row_sum
119+
CUTE_UNROLL
120+
for (int oi = 0; oi < size<0>(rAccO); ++oi) {
102121
CUTE_UNROLL
103122
for (int oj = 0; oj < size<1>(rAccO); ++oj) {
104123
rAccO(oi, oj) *= ptx::rcp(row_sum_(oi));

0 commit comments

Comments
 (0)