Skip to content

Commit 6c593d7

Browse files
authored
kernel: added query packing support for attention (#392)
1 parent 8b1d6cc commit 6c593d7

File tree

10 files changed

+327
-182
lines changed

10 files changed

+327
-182
lines changed

src/kernels/attention/attention_kernel_sm80.cuh

Lines changed: 44 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -56,44 +56,34 @@ __global__ void mha_kernel_sm80(__grid_constant__ const Params params) {
5656

5757
const int m_block = blockIdx.x;
5858
const int batch_idx = blockIdx.y;
59-
const int head_idx = blockIdx.z;
59+
const int kv_head_idx = blockIdx.z;
6060
const int tidx = threadIdx.x;
6161

6262
AttentionTile<Params> tile(params);
6363

64-
const int group_size = params.n_heads / params.n_kv_heads;
64+
// preprocess input parameters
65+
const int head_dim = params.head_dim;
66+
const int group_size = params.group_size;
67+
const float logits_soft_cap = params.logits_soft_cap;
68+
const float sm_scale = params.sm_scale;
69+
const float sm_scale_log2 = params.sm_scale_log2;
70+
6571
// ProblemShape
66-
// (q_len, HEAD_DIM)
67-
auto [Q, O] = tile.template get_qo_tile<DType>(batch_idx, head_idx);
72+
// (q_packed_len, HEAD_DIM)
73+
auto [Q, O] = tile.template get_qo_tile<DType>(batch_idx, kv_head_idx);
6874
// (kv_len, HEAD_DIM)
69-
auto [K, V] =
70-
tile.template get_kv_tile<DType>(batch_idx, head_idx / group_size);
75+
auto [K, V] = tile.template get_kv_tile<DType>(batch_idx, kv_head_idx);
7176

72-
const int q_len = size<0>(Q);
77+
const int q_packed_len = size<0>(Q);
78+
const int q_len = q_packed_len / group_size;
7379
const int kv_len = size<0>(K);
7480

75-
if (m_block * kBlockM >= q_len) {
81+
if (m_block * kBlockM >= q_packed_len) {
7682
// m out of bound, return
7783
return;
7884
}
7985

80-
const int head_dim = params.head_dim;
8186
const int sliding_window = LOCAL ? params.sliding_window : kv_len;
82-
const float logits_soft_cap = params.logits_soft_cap;
83-
const float sm_scale = params.sm_scale;
84-
const float sm_scale_log2 = params.sm_scale_log2;
85-
const float alibi_slope =
86-
ALIBI ? (params.alibi_slopes_ptr[head_idx] / sm_scale) : 0.0f;
87-
88-
// preprocess input parameters
89-
auto apply_logits_soft_cap = [&](auto& tSrAccS) {
90-
if constexpr (SOFT_CAP) {
91-
CUTE_UNROLL
92-
for (int i = 0; i < size(tSrAccS); ++i) {
93-
tSrAccS(i) = ptx::tanh(tSrAccS(i) * logits_soft_cap);
94-
}
95-
}
96-
};
9787

9888
// Gmem
9989
// (BLK_M, HEAD_DIM)
@@ -136,7 +126,7 @@ __global__ void mha_kernel_sm80(__grid_constant__ const Params params) {
136126
auto produce_query = [&]() {
137127
auto tQgQ = gmem_thr_copy_Q.partition_S(gQ);
138128
auto tQsQ = gmem_thr_copy_Q.partition_D(sQ);
139-
auto max_coord = make_coord(q_len - m_block * kBlockM, head_dim);
129+
auto max_coord = make_coord(q_packed_len - m_block * kBlockM, head_dim);
140130
safe_copy</*EVEN_MN=*/false, EVEN_K, /*ZFILL_MN=*/true, /*ZFILL_K=*/true>(
141131
gmem_tiled_copy_Q, tQgQ, tQsQ, tQcQ, max_coord);
142132
};
@@ -285,7 +275,7 @@ __global__ void mha_kernel_sm80(__grid_constant__ const Params params) {
285275
// wait for smem copy done before gmem copy
286276
__syncthreads();
287277

288-
auto max_coord = make_coord(q_len - m_block * kBlockM, head_dim);
278+
auto max_coord = make_coord(q_packed_len - m_block * kBlockM, head_dim);
289279
safe_copy</*EVEN_MN=*/false, EVEN_K, /*ZFILL_MN=*/false, /*ZFILL_K=*/false>(
290280
gmem_tiled_copy_O, tOsO, tOgO, tOcO, max_coord);
291281
};
@@ -296,7 +286,7 @@ __global__ void mha_kernel_sm80(__grid_constant__ const Params params) {
296286
make_tensor(tOrAccO.data(), Layout::to_rowcol(tOrAccO.layout()));
297287
clear(tOrAccO);
298288

299-
const int diagonal = m_block * kBlockM + kv_len - q_len;
289+
const int diagonal = (m_block * kBlockM) / group_size + kv_len - q_len;
300290
// process kv in range: [kv_idx_min, kv_idx_max)
301291
const int kv_idx_min = std::max(0, diagonal - sliding_window);
302292
const int kv_idx_max = std::min(kv_len, diagonal + kBlockM);
@@ -319,15 +309,35 @@ __global__ void mha_kernel_sm80(__grid_constant__ const Params params) {
319309
cp_async_fence();
320310

321311
// ############### Mainloop ###############
322-
323-
OnlineSoftmax<kRowsPerMMA * size<1>(tOrAccO)> softmax(sm_scale_log2);
324-
Mask<kBlockM, kBlockM, ALIBI, LOCAL> mask(
325-
q_len, kv_len, sliding_window, alibi_slope);
326-
327312
// attention score accumulator, (MMA,MMA_M,MMA_N)
328313
auto tSrAccS = partition_fragment_C(tiled_mma, Shape<_BLK_M, _BLK_N>{});
329314
auto tSrAccS_rc_view =
330315
make_tensor(tSrAccS.data(), Layout::to_rowcol(tSrAccS.layout()));
316+
317+
auto apply_logits_soft_cap = [&](auto& tSrAccS) {
318+
if constexpr (SOFT_CAP) {
319+
CUTE_UNROLL
320+
for (int i = 0; i < size(tSrAccS); ++i) {
321+
tSrAccS(i) = ptx::tanh(tSrAccS(i) * logits_soft_cap);
322+
}
323+
}
324+
};
325+
326+
constexpr int kMMA_M = size<1>(tSrAccS);
327+
using Softmax = OnlineSoftmax<kRowsPerMMA * kMMA_M>;
328+
using Mask = Mask<kBlockM, kBlockM, kRowsPerMMA, kMMA_M, ALIBI, LOCAL>;
329+
330+
Softmax softmax(sm_scale_log2);
331+
Mask mask(tidx,
332+
m_block,
333+
q_len,
334+
kv_len,
335+
kv_head_idx,
336+
group_size,
337+
sliding_window,
338+
sm_scale,
339+
params.alibi_slopes_ptr);
340+
331341
// seperate oob mask iterations for better performance
332342
constexpr int n_oob_mask = cute::ceil_div(kBlockM, kBlockN) + 1;
333343

@@ -354,7 +364,7 @@ __global__ void mha_kernel_sm80(__grid_constant__ const Params params) {
354364
if constexpr (SOFT_CAP) {
355365
apply_logits_soft_cap(tSrAccS);
356366
}
357-
mask.apply(tSrAccS_rc_view, m_block, n_block_idx, tidx);
367+
mask.apply(tSrAccS_rc_view, n_block_idx);
358368
softmax.rescale(tSrAccS_rc_view, tOrAccO_rc_view);
359369

360370
// wait value, [v] => []
@@ -396,7 +406,7 @@ __global__ void mha_kernel_sm80(__grid_constant__ const Params params) {
396406
if constexpr (SOFT_CAP) {
397407
apply_logits_soft_cap(tSrAccS);
398408
}
399-
mask.apply</*OOB_MASK=*/false>(tSrAccS_rc_view, m_block, n_block_idx, tidx);
409+
mask.apply</*OOB_MASK=*/false>(tSrAccS_rc_view, n_block_idx);
400410
softmax.rescale(tSrAccS_rc_view, tOrAccO_rc_view);
401411

402412
// wait value, [v] => []

src/kernels/attention/attention_kernel_sm80_test.cu

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
#include "attention_params.h"
88
#include "attention_ref.h"
99
#include "cute/layout.hpp"
10-
#include "static_dispatch.h"
1110

1211
namespace llm {
1312
#define DISPATCH_HEAD_DIM_(HEAD_DIM_V, HEAD_DIM_NAME, ...) \

src/kernels/attention/attention_launch_sm80.cuh

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
#pragma once
22

3+
#include <cute/int_tuple.hpp>
4+
#include <cute/layout.hpp>
5+
36
#include "attention_kernel_sm80.cuh"
47
#include "attention_traits_sm80.h"
58
#include "static_dispatch.h"
@@ -14,17 +17,18 @@ template <typename Traits,
1417
bool LOCAL>
1518
void launch_attention_kernel(const Params& params, cudaStream_t stream) {
1619
const auto batch_size = params.batch_size;
17-
const auto n_heads = params.n_heads;
18-
const auto max_q_len = params.max_q_len;
20+
const auto n_kv_heads = params.n_kv_heads;
21+
const auto max_q_packed_len = params.max_q_len * params.group_size;
1922

2023
const auto smem_size = Traits::kSmemSize;
2124
auto attention_kernel =
2225
mha_kernel_sm80<Traits, Params, EVEN_K, ALIBI, SOFT_CAP, LOCAL>;
2326
cudaFuncSetAttribute(
2427
attention_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
2528
// TODO: support persistent kernels
26-
dim3 grid(
27-
(max_q_len + Traits::kBlockM - 1) / Traits::kBlockM, batch_size, n_heads);
29+
dim3 grid(cute::ceil_div(max_q_packed_len, Traits::kBlockM),
30+
batch_size,
31+
n_kv_heads);
2832
dim3 block = Traits::kThreadNum;
2933
attention_kernel<<<grid, block, smem_size, stream>>>(params);
3034
}

src/kernels/attention/attention_params.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ struct AttentionParamsCommon {
4444
float sm_scale_log2 = 0.0;
4545
int32_t block_shift_right = 0;
4646
int32_t block_mask = 0;
47+
int group_size = 0;
4748

4849
// used to initialize the params that used for performance optimization
4950
void normalize() {
@@ -66,6 +67,8 @@ struct AttentionParamsCommon {
6667
}
6768
sm_scale_log2 = static_cast<float>(sm_scale * M_LOG2E);
6869

70+
// block size must be power of 2
71+
assert(block_size > 0 && (block_size & (block_size - 1)) == 0);
6972
auto int_log2 = [](int x) {
7073
int n = 0;
7174
while (x >>= 1) {
@@ -76,6 +79,9 @@ struct AttentionParamsCommon {
7679
block_shift_right = int_log2(block_size);
7780
block_mask = block_size - 1;
7881

82+
assert(n_heads % n_kv_heads == 0);
83+
group_size = n_heads / n_kv_heads;
84+
7985
normalized = true;
8086
}
8187
};
@@ -113,7 +119,9 @@ struct VarLenAttentionParams : public AttentionParamsCommon {
113119
// paged KV cache
114120
struct PagedKVAttentionParams : public VarLenAttentionParams {
115121
// Paged KV cache
122+
// the first slot id of each block
116123
const int* __restrict__ block_table = nullptr;
124+
// array of length batch_size + 1 holding starting offset of each sequence.
117125
const int* __restrict__ block_cu_lens = nullptr;
118126
};
119127

src/kernels/attention/attention_tile.h

Lines changed: 93 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -24,21 +24,43 @@ struct AttentionTile<AttentionParams> {
2424

2525
// return the query/output tile: (q_len, head_dim)
2626
template <typename Element>
27-
CUTE_HOST_DEVICE auto get_qo_tile(int batch_idx, int head_idx) const {
27+
CUTE_HOST_DEVICE auto get_qo_tile(int batch_idx, int kv_head_idx) const {
2828
// (batch, seq, head, dim)
29-
const auto q_offset = batch_idx * get<0>(params_.q_stride) +
30-
head_idx * get<2>(params_.q_stride);
31-
const auto o_offset = batch_idx * get<0>(params_.o_stride) +
32-
head_idx * get<2>(params_.o_stride);
33-
34-
// q[batch_idx, :, head_idx, :]
35-
auto q =
36-
make_tensor(make_gmem_ptr((const Element*)params_.q_ptr + q_offset),
37-
make_shape(params_.q_len, params_.head_dim),
38-
make_stride(get<1>(params_.q_stride), _1{}));
39-
auto o = make_tensor(make_gmem_ptr((Element*)params_.o_ptr + o_offset),
40-
make_shape(params_.q_len, params_.head_dim),
41-
make_stride(get<1>(params_.o_stride), _1{}));
29+
30+
// packed all q/o in the same kv head group together
31+
// q/o [batch, n_tokens, n_heads, dim]
32+
// => q/o [*batch_idx, n_tokens, n_heads, dim]
33+
// => q/o [n_tokens, group_size, n_kv_heads, dim]
34+
// => q/o [n_tokens, group_size, *kv_head_idx, dim]
35+
// => q/o [(group_size, n_tokens), dim]
36+
// => q/o [packed_len, dim]
37+
const auto group_size = params_.group_size;
38+
const auto head_base = kv_head_idx * group_size;
39+
auto packed_idx_to_coord = [group_size, head_base](int packed_idx) {
40+
const int idx = packed_idx / group_size;
41+
const int offset = packed_idx % group_size;
42+
// (group_size, n_tokens)
43+
return make_coord(head_base + offset, idx);
44+
};
45+
46+
const auto packed_len = params_.q_len * group_size;
47+
const auto q_offset = batch_idx * get<0>(params_.q_stride);
48+
auto q = make_gather_tensor(
49+
make_gmem_ptr((const Element*)params_.q_ptr + q_offset),
50+
make_shape(packed_len, params_.head_dim),
51+
make_stride(
52+
make_stride(get<2>(params_.q_stride), get<1>(params_.q_stride)),
53+
_1{}),
54+
packed_idx_to_coord);
55+
56+
const auto o_offset = batch_idx * get<0>(params_.o_stride);
57+
auto o = make_gather_tensor(
58+
make_gmem_ptr((Element*)params_.o_ptr + o_offset),
59+
make_shape(packed_len, params_.head_dim),
60+
make_stride(
61+
make_stride(get<2>(params_.o_stride), get<1>(params_.o_stride)),
62+
_1{}),
63+
packed_idx_to_coord);
4264
return make_tuple(q, o);
4365
}
4466

@@ -75,24 +97,37 @@ struct AttentionTile<VarLenAttentionParams> {
7597

7698
// return the query tile: (q_len, head_dim)
7799
template <typename Element>
78-
CUTE_HOST_DEVICE auto get_qo_tile(int batch_idx, int head_idx) const {
100+
CUTE_HOST_DEVICE auto get_qo_tile(int batch_idx, int kv_head_idx) const {
79101
const auto begin = params_.q_cu_lens[batch_idx];
80102
const auto qo_len = params_.q_cu_lens[batch_idx + 1] - begin;
81-
// (seq, head, dim)
82-
const auto q_offset =
83-
begin * get<0>(params_.q_stride) + head_idx * get<1>(params_.q_stride);
84-
const auto o_offset =
85-
begin * get<0>(params_.o_stride) + head_idx * get<1>(params_.o_stride);
86-
87-
// q[begin:begin + q_len, head_idx, :]
88-
auto q =
89-
make_tensor(make_gmem_ptr((const Element*)params_.q_ptr + q_offset),
90-
make_shape(qo_len, params_.head_dim),
91-
make_stride(get<0>(params_.q_stride), _1{}));
92-
// o[begin:begin + o_len, head_idx, :]
93-
auto o = make_tensor(make_gmem_ptr((Element*)params_.o_ptr + o_offset),
94-
make_shape(qo_len, params_.head_dim),
95-
make_stride(get<0>(params_.o_stride), _1{}));
103+
104+
const auto group_size = params_.group_size;
105+
const auto head_base = kv_head_idx * group_size;
106+
auto packed_idx_to_coord = [group_size, head_base](int packed_idx) {
107+
const int idx = packed_idx / group_size;
108+
const int offset = packed_idx % group_size;
109+
// (group_size, n_tokens)
110+
return make_coord(head_base + offset, idx);
111+
};
112+
113+
const auto packed_len = qo_len * group_size;
114+
const auto q_offset = begin * get<0>(params_.q_stride);
115+
auto q = make_gather_tensor(
116+
make_gmem_ptr((const Element*)params_.q_ptr + q_offset),
117+
make_shape(packed_len, params_.head_dim),
118+
make_stride(
119+
make_stride(get<1>(params_.q_stride), get<0>(params_.q_stride)),
120+
_1{}),
121+
packed_idx_to_coord);
122+
123+
const auto o_offset = begin * get<0>(params_.o_stride);
124+
auto o = make_gather_tensor(
125+
make_gmem_ptr((Element*)params_.o_ptr + o_offset),
126+
make_shape(packed_len, params_.head_dim),
127+
make_stride(
128+
make_stride(get<1>(params_.o_stride), get<0>(params_.o_stride)),
129+
_1{}),
130+
packed_idx_to_coord);
96131
return make_tuple(q, o);
97132
}
98133

@@ -132,24 +167,36 @@ struct AttentionTile<PagedKVAttentionParams> {
132167

133168
// return the query/output tile: (q_len, head_dim)
134169
template <typename Element>
135-
CUTE_HOST_DEVICE auto get_qo_tile(int batch_idx, int head_idx) const {
170+
CUTE_HOST_DEVICE auto get_qo_tile(int batch_idx, int kv_head_idx) const {
136171
const auto begin = params_.q_cu_lens[batch_idx];
137172
const auto qo_len = params_.q_cu_lens[batch_idx + 1] - begin;
138-
// (seq, head, dim)
139-
const auto q_offset =
140-
begin * get<0>(params_.q_stride) + head_idx * get<1>(params_.q_stride);
141-
const auto o_offset =
142-
begin * get<0>(params_.o_stride) + head_idx * get<1>(params_.o_stride);
143-
144-
// q[begin:begin + q_len, head_idx, :]
145-
auto q =
146-
make_tensor(make_gmem_ptr((const Element*)params_.q_ptr + q_offset),
147-
make_shape(qo_len, params_.head_dim),
148-
make_stride(get<0>(params_.q_stride), _1{}));
149-
// o[begin:begin + o_len, head_idx, :]
150-
auto o = make_tensor(make_gmem_ptr((Element*)params_.o_ptr + o_offset),
151-
make_shape(qo_len, params_.head_dim),
152-
make_stride(get<0>(params_.o_stride), _1{}));
173+
const auto group_size = params_.group_size;
174+
const auto head_base = kv_head_idx * group_size;
175+
auto packed_idx_to_coord = [group_size, head_base](int packed_idx) {
176+
const int idx = packed_idx / group_size;
177+
const int offset = packed_idx % group_size;
178+
// (group_size, n_tokens)
179+
return make_coord(head_base + offset, idx);
180+
};
181+
182+
const auto packed_len = qo_len * group_size;
183+
const auto q_offset = begin * get<0>(params_.q_stride);
184+
auto q = make_gather_tensor(
185+
make_gmem_ptr((const Element*)params_.q_ptr + q_offset),
186+
make_shape(packed_len, params_.head_dim),
187+
make_stride(
188+
make_stride(get<1>(params_.q_stride), get<0>(params_.q_stride)),
189+
_1{}),
190+
packed_idx_to_coord);
191+
192+
const auto o_offset = begin * get<0>(params_.o_stride);
193+
auto o = make_gather_tensor(
194+
make_gmem_ptr((Element*)params_.o_ptr + o_offset),
195+
make_shape(packed_len, params_.head_dim),
196+
make_stride(
197+
make_stride(get<1>(params_.o_stride), get<0>(params_.o_stride)),
198+
_1{}),
199+
packed_idx_to_coord);
153200
return make_tuple(q, o);
154201
}
155202

src/kernels/attention/attention_traits_test.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
#include <cute/tensor.hpp>
44

55
#include "attention_traits_sm80.h"
6+
#include "cute/layout_composed.hpp"
7+
#include "gather_tensor.hpp"
68

79
namespace llm {
810

0 commit comments

Comments
 (0)