Skip to content

Commit aa1a077

Browse files
authored
kernel: handle kv block range for attention kernel (#382)
1 parent aa54c85 commit aa1a077

File tree

5 files changed

+35
-123
lines changed

5 files changed

+35
-123
lines changed

src/kernels/attention/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ cc_test(
6060
attention_kernel_sm80_varlen_test.cu
6161
attention_kernel_sm80_pagedkv_test.cu
6262
DEPS
63-
:attention.kernel
63+
:attention.template
6464
absl::random_random
6565
GTest::gtest_main
6666
torch

src/kernels/attention/attention_kernel_sm80.cuh

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -80,10 +80,10 @@ __global__ void mha_kernel_sm80(__grid_constant__ const Params params) {
8080
}
8181

8282
const int head_dim = params.head_dim;
83+
const int sliding_window = LOCAL ? params.sliding_window : kv_len;
8384
const float logits_soft_cap = params.logits_soft_cap;
8485
const float sm_scale = params.sm_scale;
8586
const float sm_scale_log2 = params.sm_scale_log2;
86-
const float sliding_window = LOCAL ? params.sliding_window : kv_len;
8787
const float alibi_slope =
8888
ALIBI ? (params.alibi_slopes_ptr[head_idx] / sm_scale) : 0.0f;
8989

@@ -156,12 +156,12 @@ __global__ void mha_kernel_sm80(__grid_constant__ const Params params) {
156156
Tensor tKsK = gmem_thr_copy_KV.partition_D(sK);
157157
auto produce_k = [&](int ni) {
158158
auto tKgK = gmem_thr_copy_KV.partition_S(gK(_, _, ni));
159-
// skip zero fill oob for k since mask will mask out oob with -inf
159+
// skip zfill_mn for k since mask will mask out oob with -inf
160160
safe_copy<EVEN_K,
161161
/*EVEN_MN=*/false,
162162
/*ZERO_FILL_MN=*/false,
163163
/*ZERO_FILL_K=*/true>(
164-
gmem_tiled_copy_Q,
164+
gmem_tiled_copy_KV,
165165
tKgK,
166166
tKsK,
167167
tKcKV,
@@ -171,12 +171,12 @@ __global__ void mha_kernel_sm80(__grid_constant__ const Params params) {
171171
Tensor tVsV = gmem_thr_copy_KV.partition_D(sV);
172172
auto produce_v = [&](int ni) {
173173
auto tVgV = gmem_thr_copy_KV.partition_S(gV(_, _, ni));
174-
// TODO: skip zero fill oob for v, may have nan issue
174+
// skipping ZFILL_MN for v may cause nan issue
175175
safe_copy<EVEN_K,
176176
/*EVEN_MN=*/false,
177177
/*ZERO_FILL_MN=*/true,
178178
/*ZERO_FILL_K=*/true>(
179-
gmem_tiled_copy_Q,
179+
gmem_tiled_copy_KV,
180180
tVgV,
181181
tVsV,
182182
tKcKV,
@@ -299,13 +299,21 @@ __global__ void mha_kernel_sm80(__grid_constant__ const Params params) {
299299
make_coord(q_len - m_block * kBlockM, head_dim));
300300
};
301301

302+
const int diagonal = m_block * kBlockM + kv_len - q_len;
303+
// process kv in range: [kv_idx_min, kv_idx_max)
304+
const int kv_idx_min = std::max(0, diagonal - sliding_window);
305+
const int kv_idx_max = std::min(kv_len, diagonal + kBlockM);
306+
const int n_block_min = LOCAL ? kv_idx_min / kBlockN : 0;
307+
const int n_block_max = cute::ceil_div(kv_idx_max, kBlockN);
308+
// TODO: handle n_block_min >= n_block_max
309+
302310
// ############### Prologue ###############
303311

304312
// produce q: [] => [q]
305313
produce_q();
306314
cp_async_fence();
307315
// produce k: [q] => [q, k]
308-
produce_k(0);
316+
produce_k(n_block_min);
309317
cp_async_fence();
310318

311319
// ############### Mainloop ###############
@@ -324,10 +332,6 @@ __global__ void mha_kernel_sm80(__grid_constant__ const Params params) {
324332
Mask<kBlockM, kBlockM, ALIBI, LOCAL> mask(
325333
q_len, kv_len, sliding_window, alibi_slope);
326334

327-
// TODO: control block min/max precisely
328-
const int n_block_min = 0;
329-
const int n_block_max = cute::ceil_div(kv_len, kBlockN);
330-
331335
clear(tOrAccO);
332336
CUTE_NO_UNROLL
333337
for (int ni = n_block_min; ni < n_block_max; ++ni) {

src/kernels/attention/attention_kernel_sm80_varlen_test.cu

Lines changed: 1 addition & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -4,112 +4,12 @@
44

55
#include "attention_launch_sm80.cuh"
66
#include "attention_params.h"
7+
#include "attention_ref.h"
78
#include "cute/layout.hpp"
89
#include "static_dispatch.h"
910

1011
namespace llm {
1112
namespace {
12-
// Multi-head attention implementation using pytorch
13-
torch::Tensor attention_ref(
14-
torch::Tensor query, // [q_len, n_heads, head_dim]
15-
torch::Tensor key, // [kv_len, n_kv_heads, head_dim]
16-
torch::Tensor value, // [kv_len, n_kv_heads, head_dim]
17-
torch::optional<torch::Tensor> alibi_slopes, //[n_heads]
18-
float logits_soft_cap,
19-
int32_t sliding_window) {
20-
const auto q_len = query.size(-3);
21-
const auto kv_len = key.size(-3);
22-
const auto n_heads = query.size(-2);
23-
const auto n_kv_heads = key.size(-2);
24-
const auto head_dim = query.size(-1);
25-
assert(kv_len >= q_len);
26-
27-
if (n_heads != n_kv_heads) {
28-
assert(n_heads % n_kv_heads == 0);
29-
const auto group_size = n_heads / n_kv_heads;
30-
key = key.repeat_interleave(/*repeats=*/group_size, /*dim=*/-2);
31-
value = value.repeat_interleave(/*repeats=*/group_size, /*dim=*/-2);
32-
}
33-
34-
const float sm_scale = 1.0 / sqrt(head_dim);
35-
// query * key => [n_heads, q_len, kv_len]
36-
auto scores = torch::einsum("qhd,khd->hqk",
37-
{query.to(torch::kFloat), key.to(torch::kFloat)});
38-
// apply scale
39-
scores *= sm_scale;
40-
41-
// apply softcap if needed
42-
if (logits_soft_cap != 0.0) {
43-
scores = torch::tanh(scores / logits_soft_cap) * logits_soft_cap;
44-
}
45-
46-
// apply alibi bias
47-
if (alibi_slopes) {
48-
const auto& slopes = alibi_slopes.value();
49-
// calculate alibi attention bias
50-
// since it's causal mask, we can just use [0, 1, ...,, kv_len)
51-
auto distance = torch::arange(0, kv_len, query.options());
52-
// [n_heads, 1, kv_len]
53-
auto bias = distance.view({1, 1, kv_len}) * slopes.view({n_heads, 1, 1});
54-
scores += bias;
55-
}
56-
57-
auto mask = torch::ones({q_len, kv_len}, torch::kBool);
58-
if (sliding_window >= 0) {
59-
// sliding window mask
60-
// returns the upper triangular part of a matrix
61-
mask = torch::triu(mask, /*diagonal=*/kv_len - q_len - sliding_window);
62-
}
63-
64-
// apply causal mask
65-
// causal mask: returns the lower triangular part of a matrix
66-
mask = torch::tril(mask, /*diagonal=*/kv_len - q_len).to(query);
67-
scores = scores.masked_fill(mask == 0, -INFINITY);
68-
69-
// safe softmax
70-
scores = torch::softmax(scores, /*dim=*/-1);
71-
72-
// score * value => [q_len, n_heads, head_dim]
73-
return torch::einsum("hqk,khd->qhd", {scores, value.to(torch::kFloat)})
74-
.type_as(query);
75-
}
76-
77-
torch::Tensor attention_varlen_ref(
78-
torch::Tensor query, // [q_len, n_heads, head_dim]
79-
torch::Tensor key, // [kv_len, n_kv_heads, head_dim]
80-
torch::Tensor value, // [kv_len, n_kv_heads, head_dim]
81-
torch::Tensor q_cu_lens, // [batch_size + 1]
82-
torch::Tensor kv_cu_lens, // [batch_size + 1]
83-
torch::optional<torch::Tensor> alibi_slopes, //[n_heads]
84-
float logits_soft_cap,
85-
int32_t sliding_window) {
86-
torch::Tensor q_cu_lens_cpu = q_cu_lens.cpu();
87-
torch::Tensor kv_cu_seq_lens_cpu = kv_cu_lens.cpu();
88-
const size_t n_seqs = q_cu_lens_cpu.numel() - 1;
89-
const int32_t* q_cu_lens_ptr = q_cu_lens_cpu.data_ptr<int32_t>();
90-
const int32_t* kv_cu_lens_ptr = kv_cu_seq_lens_cpu.data_ptr<int32_t>();
91-
92-
std::vector<torch::Tensor> out_list;
93-
// process sequence one by one
94-
for (int64_t i = 0; i < n_seqs; ++i) {
95-
// calaculate attention for each sequence
96-
const int32_t q_start = q_cu_lens_ptr[i];
97-
const int32_t q_end = q_cu_lens_ptr[i + 1];
98-
const int32_t kv_start = kv_cu_lens_ptr[i];
99-
const int32_t kv_end = kv_cu_lens_ptr[i + 1];
100-
101-
torch::Tensor q = query.slice(/*dim=*/0, /*start=*/q_start, /*end=*/q_end);
102-
torch::Tensor k = key.slice(/*dim=*/0, /*start=*/kv_start, /*end=*/kv_end);
103-
torch::Tensor v =
104-
value.slice(/*dim=*/0, /*start=*/kv_start, /*end=*/kv_end);
105-
106-
auto output =
107-
attention_ref(q, k, v, alibi_slopes, logits_soft_cap, sliding_window);
108-
out_list.push_back(output);
109-
}
110-
return torch::cat(out_list, /*dim=*/0);
111-
}
112-
11313
torch::Tensor attention_varlen_sm80(
11414
torch::Tensor query, // [q_len, n_heads, head_dim]
11515
torch::Tensor key, // [kv_len, n_kv_heads, head_dim]

src/kernels/attention/attention_traits_sm80.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -117,16 +117,20 @@ struct AttentionTraitsSM80 {
117117
// O smem: (BLK_M, K):(K, 1), k-major, same as Q
118118
using SmemLayoutO = SmemLayoutQ;
119119

120+
// use 128-bit vectorizing copy
121+
using VectorizingCopy = AutoVectorizingCopyWithAssumedAlignment<128>;
122+
120123
// s2g tiled copy for O
121124
using GmemTiledCopyO = decltype(make_tiled_copy(
122-
Copy_Atom<DefaultCopy, DType>{},
125+
Copy_Atom<VectorizingCopy, DType>{},
123126
GmemCopyThrLayout{}, // Thr layout: (_16,_8)/(_32, _4)
124127
Layout<Shape<_1, _8>>{} // Val layout: 8 vals per read
125128
));
126129

127130
// r2s tiled copy for O
128131
using SmemTiledCopyO =
129-
decltype(make_tiled_copy_C(Copy_Atom<DefaultCopy, DType>{}, TiledMma{}));
132+
decltype(make_tiled_copy_C(Copy_Atom<VectorizingCopy, DType>{},
133+
TiledMma{}));
130134

131135
// constexpr values for kernel launch
132136
static constexpr size_t kSmemSize =

src/kernels/attention/flash_attn/src/kernel_traits.h

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -32,12 +32,15 @@ struct Flash_kernel_traits {
3232
using MMA_Atom_Arch = MMA_Atom<SM75_16x8x8_F32F16F16F32_TN>;
3333
#endif
3434

35+
// use 128-bit vectorizing copy
36+
using VectorizingCopy = AutoVectorizingCopyWithAssumedAlignment<128>;
37+
3538
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750
3639
using SmemCopyAtom = Copy_Atom<SM75_U32x4_LDSM_N, elem_type>;
3740
using SmemCopyAtomTransposed = Copy_Atom<SM75_U16x8_LDSM_T, elem_type>;
3841
#else
39-
using SmemCopyAtom = Copy_Atom<DefaultCopy, elem_type>;
40-
using SmemCopyAtomTransposed = Copy_Atom<DefaultCopy, elem_type>;
42+
using SmemCopyAtom = Copy_Atom<VectorizingCopy, elem_type>;
43+
using SmemCopyAtomTransposed = Copy_Atom<VectorizingCopy, elem_type>;
4144
#endif
4245
};
4346

@@ -49,6 +52,7 @@ struct Flash_fwd_kernel_traits : public Base {
4952
using ElementAccum = typename Base::ElementAccum;
5053
using index_t = typename Base::index_t;
5154
static constexpr bool Has_cp_async = Base::Has_cp_async;
55+
using VectorizingCopy = typename Base::VectorizingCopy;
5256
using SmemCopyAtom = typename Base::SmemCopyAtom;
5357
using SmemCopyAtomTransposed = typename Base::SmemCopyAtomTransposed;
5458

@@ -97,8 +101,8 @@ struct Flash_fwd_kernel_traits : public Base {
97101
using SmemLayoutO = decltype(tile_to_shape(
98102
SmemLayoutAtomO{},
99103
Shape<Int<kBlockM>, Int<kHeadDim>>{}));
100-
using SmemCopyAtomO = Copy_Atom<DefaultCopy, Element>;
101-
using SmemCopyAtomOaccum = Copy_Atom<DefaultCopy, ElementAccum>;
104+
using SmemCopyAtomO = Copy_Atom<VectorizingCopy, Element>;
105+
using SmemCopyAtomOaccum = Copy_Atom<VectorizingCopy, ElementAccum>;
102106

103107
static constexpr int kSmemQSize = size(SmemLayoutQ{}) * sizeof(Element);
104108
static constexpr int kSmemKVSize = size(SmemLayoutKV{}) * 2 * sizeof(Element);
@@ -121,7 +125,7 @@ struct Flash_fwd_kernel_traits : public Base {
121125
using Gmem_copy_struct = std::conditional_t<
122126
Has_cp_async,
123127
SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>,
124-
DefaultCopy
128+
VectorizingCopy
125129
>;
126130
using GmemTiledCopyQKV = decltype(
127131
make_tiled_copy(Copy_Atom<Gmem_copy_struct, Element>{},
@@ -140,7 +144,7 @@ struct Flash_fwd_kernel_traits : public Base {
140144
Layout<Shape<Int<kGmemRowsPerThread>, _8>, Stride<_8, _1>>{}));
141145

142146
using GmemTiledCopyO = decltype(
143-
make_tiled_copy(Copy_Atom<DefaultCopy, Element>{},
147+
make_tiled_copy(Copy_Atom<VectorizingCopy, Element>{},
144148
GmemLayoutAtom{},
145149
Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per store
146150

@@ -152,7 +156,7 @@ struct Flash_fwd_kernel_traits : public Base {
152156
Stride< _16, _1>>
153157
>;
154158
using GmemTiledCopyOaccum = decltype(
155-
make_tiled_copy(Copy_Atom<DefaultCopy, ElementAccum>{},
159+
make_tiled_copy(Copy_Atom<VectorizingCopy, ElementAccum>{},
156160
GmemLayoutAtomOaccum{},
157161
Layout<Shape < _1, _4>>{})); // Val layout, 4 vals per store
158162
using GmemLayoutAtomRotcossin = GmemLayoutAtom;
@@ -161,15 +165,15 @@ struct Flash_fwd_kernel_traits : public Base {
161165
GmemLayoutAtomRotcossin{},
162166
Layout<Shape < _1, _4>>{})); // Val layout, 4 vals per load
163167
using GmemTiledCopyRotcossinCont = decltype(
164-
make_tiled_copy(Copy_Atom<DefaultCopy, Element>{},
168+
make_tiled_copy(Copy_Atom<VectorizingCopy, Element>{},
165169
GmemLayoutAtomRotcossin{},
166170
Layout<Shape < _1, _8>>{})); // Val layout, 8 vals per load
167171
using GmemTiledCopyRotcossinPaged = decltype(
168172
make_tiled_copy(Copy_Atom<UniversalCopy<uint64_t>, Element>{},
169173
GmemLayoutAtomRotcossin{},
170174
Layout<Shape<Int<kGmemRowsPerThread>, _4>, Stride<_4, _1>>{})); // Val layout, 4 vals per load
171175
using GmemTiledCopyRotcossinContPaged = decltype(
172-
make_tiled_copy(Copy_Atom<DefaultCopy, Element>{},
176+
make_tiled_copy(Copy_Atom<VectorizingCopy, Element>{},
173177
GmemLayoutAtomRotcossin{},
174178
Layout<Shape<Int<kGmemRowsPerThread>, _8>, Stride<_8, _1>>{})); // Val layout, 8 vals per load
175179
};

0 commit comments

Comments
 (0)