Skip to content

Commit dacaf5a

Browse files
wbn03wgy0804zhaoyang-star
authored
Replace head_mapping params with num_kv_heads to attention kernel. (#1997)
Co-authored-by: wangguoya <[email protected]> Co-authored-by: Yang Zhao <[email protected]>
1 parent 24cde76 commit dacaf5a

File tree

5 files changed

+26
-37
lines changed

5 files changed

+26
-37
lines changed

benchmarks/kernels/benchmark_paged_attention.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,6 @@ def main(
3737
query.uniform_(-scale, scale)
3838

3939
assert num_query_heads % num_kv_heads == 0
40-
num_queries_per_kv = num_query_heads // num_kv_heads
41-
head_mapping = torch.repeat_interleave(
42-
torch.arange(num_kv_heads, dtype=torch.int32, device="cuda"),
43-
num_queries_per_kv)
4440
alibi_slopes = None
4541
if use_alibi:
4642
alibi_slopes = torch.randn(num_query_heads,
@@ -103,7 +99,7 @@ def run_benchmark(num_iters: int, profile: bool = False) -> float:
10399
query,
104100
key_cache,
105101
value_cache,
106-
head_mapping,
102+
num_kv_heads,
107103
scale,
108104
block_tables,
109105
context_lens,
@@ -120,7 +116,7 @@ def run_benchmark(num_iters: int, profile: bool = False) -> float:
120116
query,
121117
key_cache,
122118
value_cache,
123-
head_mapping,
119+
num_kv_heads,
124120
scale,
125121
block_tables,
126122
context_lens,

csrc/attention/attention_kernels.cu

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ __device__ void paged_attention_kernel(
8989
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
9090
const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x]
9191
const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size]
92-
const int* __restrict__ head_mapping, // [num_heads]
92+
const int num_kv_heads, // [num_heads]
9393
const float scale,
9494
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
9595
const int* __restrict__ context_lens, // [num_seqs]
@@ -132,7 +132,8 @@ __device__ void paged_attention_kernel(
132132

133133
const int head_idx = blockIdx.x;
134134
const int num_heads = gridDim.x;
135-
const int kv_head_idx = head_mapping[head_idx];
135+
const int num_queries_per_kv = num_heads / num_kv_heads;
136+
const int kv_head_idx = head_idx / num_queries_per_kv;
136137
const float alibi_slope = alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx];
137138

138139
// A vector type to store a part of a key or a query.
@@ -401,7 +402,7 @@ __global__ void paged_attention_v1_kernel(
401402
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
402403
const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x]
403404
const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size]
404-
const int* __restrict__ head_mapping, // [num_heads]
405+
const int num_kv_heads, // [num_heads]
405406
const float scale,
406407
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
407408
const int* __restrict__ context_lens, // [num_seqs]
@@ -412,7 +413,7 @@ __global__ void paged_attention_v1_kernel(
412413
const int kv_head_stride) {
413414
paged_attention_kernel<scalar_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS>(
414415
/* exp_sums */ nullptr, /* max_logits */ nullptr,
415-
out, q, k_cache, v_cache, head_mapping, scale, block_tables, context_lens,
416+
out, q, k_cache, v_cache, num_kv_heads, scale, block_tables, context_lens,
416417
max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, kv_head_stride);
417418
}
418419

@@ -430,7 +431,7 @@ __global__ void paged_attention_v2_kernel(
430431
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
431432
const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x]
432433
const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size]
433-
const int* __restrict__ head_mapping, // [num_heads]
434+
const int num_kv_heads, // [num_heads]
434435
const float scale,
435436
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
436437
const int* __restrict__ context_lens, // [num_seqs]
@@ -440,7 +441,7 @@ __global__ void paged_attention_v2_kernel(
440441
const int kv_block_stride,
441442
const int kv_head_stride) {
442443
paged_attention_kernel<scalar_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, PARTITION_SIZE>(
443-
exp_sums, max_logits, tmp_out, q, k_cache, v_cache, head_mapping, scale,
444+
exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale,
444445
block_tables, context_lens, max_num_blocks_per_seq, alibi_slopes,
445446
q_stride, kv_block_stride, kv_head_stride);
446447
}
@@ -556,7 +557,7 @@ __global__ void paged_attention_v2_reduce_kernel(
556557
query_ptr, \
557558
key_cache_ptr, \
558559
value_cache_ptr, \
559-
head_mapping_ptr, \
560+
num_kv_heads, \
560561
scale, \
561562
block_tables_ptr, \
562563
context_lens_ptr, \
@@ -576,7 +577,7 @@ void paged_attention_v1_launcher(
576577
torch::Tensor& query,
577578
torch::Tensor& key_cache,
578579
torch::Tensor& value_cache,
579-
torch::Tensor& head_mapping,
580+
int num_kv_heads,
580581
float scale,
581582
torch::Tensor& block_tables,
582583
torch::Tensor& context_lens,
@@ -602,7 +603,6 @@ void paged_attention_v1_launcher(
602603
T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
603604
T* key_cache_ptr = reinterpret_cast<T*>(key_cache.data_ptr());
604605
T* value_cache_ptr = reinterpret_cast<T*>(value_cache.data_ptr());
605-
int* head_mapping_ptr = reinterpret_cast<int*>(head_mapping.data_ptr());
606606
int* block_tables_ptr = block_tables.data_ptr<int>();
607607
int* context_lens_ptr = context_lens.data_ptr<int>();
608608

@@ -651,7 +651,7 @@ void paged_attention_v1_launcher(
651651
query, \
652652
key_cache, \
653653
value_cache, \
654-
head_mapping, \
654+
num_kv_heads, \
655655
scale, \
656656
block_tables, \
657657
context_lens, \
@@ -681,7 +681,7 @@ void paged_attention_v1(
681681
torch::Tensor& query, // [num_seqs, num_heads, head_size]
682682
torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
683683
torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size]
684-
torch::Tensor& head_mapping, // [num_heads]
684+
int num_kv_heads, // [num_heads]
685685
float scale,
686686
torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
687687
torch::Tensor& context_lens, // [num_seqs]
@@ -708,7 +708,7 @@ void paged_attention_v1(
708708
query_ptr, \
709709
key_cache_ptr, \
710710
value_cache_ptr, \
711-
head_mapping_ptr, \
711+
num_kv_heads, \
712712
scale, \
713713
block_tables_ptr, \
714714
context_lens_ptr, \
@@ -739,7 +739,7 @@ void paged_attention_v2_launcher(
739739
torch::Tensor& query,
740740
torch::Tensor& key_cache,
741741
torch::Tensor& value_cache,
742-
torch::Tensor& head_mapping,
742+
int num_kv_heads,
743743
float scale,
744744
torch::Tensor& block_tables,
745745
torch::Tensor& context_lens,
@@ -768,7 +768,6 @@ void paged_attention_v2_launcher(
768768
T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
769769
T* key_cache_ptr = reinterpret_cast<T*>(key_cache.data_ptr());
770770
T* value_cache_ptr = reinterpret_cast<T*>(value_cache.data_ptr());
771-
int* head_mapping_ptr = reinterpret_cast<int*>(head_mapping.data_ptr());
772771
int* block_tables_ptr = block_tables.data_ptr<int>();
773772
int* context_lens_ptr = context_lens.data_ptr<int>();
774773

@@ -823,7 +822,7 @@ void paged_attention_v2_launcher(
823822
query, \
824823
key_cache, \
825824
value_cache, \
826-
head_mapping, \
825+
num_kv_heads, \
827826
scale, \
828827
block_tables, \
829828
context_lens, \
@@ -856,7 +855,7 @@ void paged_attention_v2(
856855
torch::Tensor& query, // [num_seqs, num_heads, head_size]
857856
torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
858857
torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size]
859-
torch::Tensor& head_mapping, // [num_heads]
858+
int num_kv_heads, // [num_heads]
860859
float scale,
861860
torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
862861
torch::Tensor& context_lens, // [num_seqs]

csrc/ops.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ void paged_attention_v1(
55
torch::Tensor& query,
66
torch::Tensor& key_cache,
77
torch::Tensor& value_cache,
8-
torch::Tensor& head_mapping,
8+
int num_kv_heads,
99
float scale,
1010
torch::Tensor& block_tables,
1111
torch::Tensor& context_lens,
@@ -21,7 +21,7 @@ void paged_attention_v2(
2121
torch::Tensor& query,
2222
torch::Tensor& key_cache,
2323
torch::Tensor& value_cache,
24-
torch::Tensor& head_mapping,
24+
int num_kv_heads,
2525
float scale,
2626
torch::Tensor& block_tables,
2727
torch::Tensor& context_lens,

tests/kernels/test_attention.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -131,9 +131,6 @@ def test_paged_attention(
131131

132132
assert num_query_heads % num_kv_heads == 0
133133
num_queries_per_kv = num_query_heads // num_kv_heads
134-
head_mapping = torch.repeat_interleave(
135-
torch.arange(num_kv_heads, dtype=torch.int32, device="cuda"),
136-
num_queries_per_kv)
137134
alibi_slopes = None
138135
if use_alibi:
139136
alibi_slopes = torch.randn(num_query_heads,
@@ -170,7 +167,7 @@ def test_paged_attention(
170167
query,
171168
key_cache,
172169
value_cache,
173-
head_mapping,
170+
num_kv_heads,
174171
scale,
175172
block_tables,
176173
context_lens,
@@ -202,7 +199,7 @@ def test_paged_attention(
202199
query,
203200
key_cache,
204201
value_cache,
205-
head_mapping,
202+
num_kv_heads,
206203
scale,
207204
block_tables,
208205
context_lens,

vllm/model_executor/layers/attention.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -54,9 +54,6 @@ def __init__(
5454

5555
assert self.num_heads % self.num_kv_heads == 0
5656
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
57-
self.head_mapping = torch.repeat_interleave(
58-
torch.arange(self.num_kv_heads, dtype=torch.int32, device="cuda"),
59-
self.num_queries_per_kv)
6057

6158
if self.head_size not in _SUPPORTED_HEAD_SIZES:
6259
raise ValueError(f"head_size ({self.head_size}) is not supported. "
@@ -77,7 +74,7 @@ def forward(
7774
Args:
7875
query: shape = [batch_size, seq_len, num_heads * head_size]
7976
key: shape = [batch_size, seq_len, num_kv_heads * head_size]
80-
value: shape = [batch_size, num_kv_heads * head_size]
77+
value: shape = [batch_size, seq_len, num_kv_heads * head_size]
8178
key_cache: shape = [num_blocks, num_kv_heads, head_size/x,
8279
block_size, x]
8380
value_cache: shape = [num_blocks, num_kv_heads, head_size,
@@ -172,7 +169,7 @@ def forward(
172169
key_cache,
173170
value_cache,
174171
input_metadata,
175-
self.head_mapping,
172+
self.num_kv_heads,
176173
self.scale,
177174
self.alibi_slopes,
178175
)
@@ -217,7 +214,7 @@ def _paged_attention(
217214
key_cache: torch.Tensor,
218215
value_cache: torch.Tensor,
219216
input_metadata: InputMetadata,
220-
head_mapping: torch.Tensor,
217+
num_kv_heads: int,
221218
scale: float,
222219
alibi_slopes: Optional[torch.Tensor],
223220
) -> torch.Tensor:
@@ -244,7 +241,7 @@ def _paged_attention(
244241
query,
245242
key_cache,
246243
value_cache,
247-
head_mapping,
244+
num_kv_heads,
248245
scale,
249246
input_metadata.block_tables,
250247
input_metadata.context_lens,
@@ -274,7 +271,7 @@ def _paged_attention(
274271
query,
275272
key_cache,
276273
value_cache,
277-
head_mapping,
274+
num_kv_heads,
278275
scale,
279276
input_metadata.block_tables,
280277
input_metadata.context_lens,

0 commit comments

Comments
 (0)