@@ -89,7 +89,7 @@ __device__ void paged_attention_kernel(
89
89
const scalar_t * __restrict__ q, // [num_seqs, num_heads, head_size]
90
90
const scalar_t * __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x]
91
91
const scalar_t * __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size]
92
- const int * __restrict__ head_mapping, // [num_heads]
92
+ const int num_kv_heads, // [num_heads]
93
93
const float scale,
94
94
const int * __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
95
95
const int * __restrict__ context_lens, // [num_seqs]
@@ -132,7 +132,8 @@ __device__ void paged_attention_kernel(
132
132
133
133
const int head_idx = blockIdx .x ;
134
134
const int num_heads = gridDim .x ;
135
- const int kv_head_idx = head_mapping[head_idx];
135
+ const int num_queries_per_kv = num_heads / num_kv_heads;
136
+ const int kv_head_idx = head_idx / num_queries_per_kv;
136
137
const float alibi_slope = alibi_slopes == nullptr ? 0 .f : alibi_slopes[head_idx];
137
138
138
139
// A vector type to store a part of a key or a query.
@@ -401,7 +402,7 @@ __global__ void paged_attention_v1_kernel(
401
402
const scalar_t * __restrict__ q, // [num_seqs, num_heads, head_size]
402
403
const scalar_t * __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x]
403
404
const scalar_t * __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size]
404
- const int * __restrict__ head_mapping, // [num_heads]
405
+ const int num_kv_heads, // [num_heads]
405
406
const float scale,
406
407
const int * __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
407
408
const int * __restrict__ context_lens, // [num_seqs]
@@ -412,7 +413,7 @@ __global__ void paged_attention_v1_kernel(
412
413
const int kv_head_stride) {
413
414
paged_attention_kernel<scalar_t , HEAD_SIZE, BLOCK_SIZE, NUM_THREADS>(
414
415
/* exp_sums */ nullptr , /* max_logits */ nullptr ,
415
- out, q, k_cache, v_cache, head_mapping , scale, block_tables, context_lens,
416
+ out, q, k_cache, v_cache, num_kv_heads , scale, block_tables, context_lens,
416
417
max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, kv_head_stride);
417
418
}
418
419
@@ -430,7 +431,7 @@ __global__ void paged_attention_v2_kernel(
430
431
const scalar_t * __restrict__ q, // [num_seqs, num_heads, head_size]
431
432
const scalar_t * __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x]
432
433
const scalar_t * __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size]
433
- const int * __restrict__ head_mapping, // [num_heads]
434
+ const int num_kv_heads, // [num_heads]
434
435
const float scale,
435
436
const int * __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
436
437
const int * __restrict__ context_lens, // [num_seqs]
@@ -440,7 +441,7 @@ __global__ void paged_attention_v2_kernel(
440
441
const int kv_block_stride,
441
442
const int kv_head_stride) {
442
443
paged_attention_kernel<scalar_t , HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, PARTITION_SIZE>(
443
- exp_sums, max_logits, tmp_out, q, k_cache, v_cache, head_mapping , scale,
444
+ exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads , scale,
444
445
block_tables, context_lens, max_num_blocks_per_seq, alibi_slopes,
445
446
q_stride, kv_block_stride, kv_head_stride);
446
447
}
@@ -556,7 +557,7 @@ __global__ void paged_attention_v2_reduce_kernel(
556
557
query_ptr, \
557
558
key_cache_ptr, \
558
559
value_cache_ptr, \
559
- head_mapping_ptr, \
560
+ num_kv_heads, \
560
561
scale, \
561
562
block_tables_ptr, \
562
563
context_lens_ptr, \
@@ -576,7 +577,7 @@ void paged_attention_v1_launcher(
576
577
torch::Tensor& query,
577
578
torch::Tensor& key_cache,
578
579
torch::Tensor& value_cache,
579
- torch::Tensor& head_mapping ,
580
+ int num_kv_heads ,
580
581
float scale,
581
582
torch::Tensor& block_tables,
582
583
torch::Tensor& context_lens,
@@ -602,7 +603,6 @@ void paged_attention_v1_launcher(
602
603
T* query_ptr = reinterpret_cast <T*>(query.data_ptr ());
603
604
T* key_cache_ptr = reinterpret_cast <T*>(key_cache.data_ptr ());
604
605
T* value_cache_ptr = reinterpret_cast <T*>(value_cache.data_ptr ());
605
- int * head_mapping_ptr = reinterpret_cast <int *>(head_mapping.data_ptr ());
606
606
int * block_tables_ptr = block_tables.data_ptr <int >();
607
607
int * context_lens_ptr = context_lens.data_ptr <int >();
608
608
@@ -651,7 +651,7 @@ void paged_attention_v1_launcher(
651
651
query, \
652
652
key_cache, \
653
653
value_cache, \
654
- head_mapping , \
654
+ num_kv_heads , \
655
655
scale, \
656
656
block_tables, \
657
657
context_lens, \
@@ -681,7 +681,7 @@ void paged_attention_v1(
681
681
torch::Tensor& query, // [num_seqs, num_heads, head_size]
682
682
torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
683
683
torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size]
684
- torch::Tensor& head_mapping, // [num_heads]
684
+ int num_kv_heads, // [num_heads]
685
685
float scale,
686
686
torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
687
687
torch::Tensor& context_lens, // [num_seqs]
@@ -708,7 +708,7 @@ void paged_attention_v1(
708
708
query_ptr, \
709
709
key_cache_ptr, \
710
710
value_cache_ptr, \
711
- head_mapping_ptr, \
711
+ num_kv_heads, \
712
712
scale, \
713
713
block_tables_ptr, \
714
714
context_lens_ptr, \
@@ -739,7 +739,7 @@ void paged_attention_v2_launcher(
739
739
torch::Tensor& query,
740
740
torch::Tensor& key_cache,
741
741
torch::Tensor& value_cache,
742
- torch::Tensor& head_mapping ,
742
+ int num_kv_heads ,
743
743
float scale,
744
744
torch::Tensor& block_tables,
745
745
torch::Tensor& context_lens,
@@ -768,7 +768,6 @@ void paged_attention_v2_launcher(
768
768
T* query_ptr = reinterpret_cast <T*>(query.data_ptr ());
769
769
T* key_cache_ptr = reinterpret_cast <T*>(key_cache.data_ptr ());
770
770
T* value_cache_ptr = reinterpret_cast <T*>(value_cache.data_ptr ());
771
- int * head_mapping_ptr = reinterpret_cast <int *>(head_mapping.data_ptr ());
772
771
int * block_tables_ptr = block_tables.data_ptr <int >();
773
772
int * context_lens_ptr = context_lens.data_ptr <int >();
774
773
@@ -823,7 +822,7 @@ void paged_attention_v2_launcher(
823
822
query, \
824
823
key_cache, \
825
824
value_cache, \
826
- head_mapping , \
825
+ num_kv_heads , \
827
826
scale, \
828
827
block_tables, \
829
828
context_lens, \
@@ -856,7 +855,7 @@ void paged_attention_v2(
856
855
torch::Tensor& query, // [num_seqs, num_heads, head_size]
857
856
torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
858
857
torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size]
859
- torch::Tensor& head_mapping, // [num_heads]
858
+ int num_kv_heads, // [num_heads]
860
859
float scale,
861
860
torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
862
861
torch::Tensor& context_lens, // [num_seqs]
0 commit comments