@@ -104,7 +104,7 @@ __device__ void paged_attention_kernel(
104
104
const int num_kv_heads, // [num_heads]
105
105
const float scale,
106
106
const int * __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
107
- const int * __restrict__ context_lens , // [num_seqs]
107
+ const int * __restrict__ seq_lens , // [num_seqs]
108
108
const int max_num_blocks_per_seq,
109
109
const float * __restrict__ alibi_slopes, // [num_heads]
110
110
const int q_stride,
@@ -115,23 +115,23 @@ __device__ void paged_attention_kernel(
115
115
const int partition_idx = blockIdx .z ;
116
116
const int max_num_partitions = gridDim .z ;
117
117
constexpr bool USE_PARTITIONING = PARTITION_SIZE > 0 ;
118
- const int context_len = context_lens [seq_idx];
119
- if (USE_PARTITIONING && partition_idx * PARTITION_SIZE >= context_len ) {
118
+ const int seq_len = seq_lens [seq_idx];
119
+ if (USE_PARTITIONING && partition_idx * PARTITION_SIZE >= seq_len ) {
120
120
// No work to do. Terminate the thread block.
121
121
return ;
122
122
}
123
123
124
- const int num_context_blocks = DIVIDE_ROUND_UP (context_len , BLOCK_SIZE);
125
- const int num_blocks_per_partition = USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_context_blocks ;
124
+ const int num_seq_blocks = DIVIDE_ROUND_UP (seq_len , BLOCK_SIZE);
125
+ const int num_blocks_per_partition = USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_seq_blocks ;
126
126
127
127
// [start_block_idx, end_block_idx) is the range of blocks to process.
128
128
const int start_block_idx = USE_PARTITIONING ? partition_idx * num_blocks_per_partition : 0 ;
129
- const int end_block_idx = MIN (start_block_idx + num_blocks_per_partition, num_context_blocks );
129
+ const int end_block_idx = MIN (start_block_idx + num_blocks_per_partition, num_seq_blocks );
130
130
const int num_blocks = end_block_idx - start_block_idx;
131
131
132
132
// [start_token_idx, end_token_idx) is the range of tokens to process.
133
133
const int start_token_idx = start_block_idx * BLOCK_SIZE;
134
- const int end_token_idx = MIN (start_token_idx + num_blocks * BLOCK_SIZE, context_len );
134
+ const int end_token_idx = MIN (start_token_idx + num_blocks * BLOCK_SIZE, seq_len );
135
135
const int num_tokens = end_token_idx - start_token_idx;
136
136
137
137
constexpr int THREAD_GROUP_SIZE = MAX (WARP_SIZE / BLOCK_SIZE, 1 );
@@ -245,12 +245,12 @@ __device__ void paged_attention_kernel(
245
245
// This includes a reduction across the threads in the same thread group.
246
246
float qk = scale * Qk_dot<scalar_t , THREAD_GROUP_SIZE>::dot (q_vecs[thread_group_offset], k_vecs);
247
247
// Add the ALiBi bias if slopes are given.
248
- qk += (alibi_slope != 0 ) ? alibi_slope * (token_idx - context_len + 1 ) : 0 ;
248
+ qk += (alibi_slope != 0 ) ? alibi_slope * (token_idx - seq_len + 1 ) : 0 ;
249
249
250
250
if (thread_group_offset == 0 ) {
251
251
// Store the partial reductions to shared memory.
252
252
// NOTE(woosuk): It is required to zero out the masked logits.
253
- const bool mask = token_idx >= context_len ;
253
+ const bool mask = token_idx >= seq_len ;
254
254
logits[token_idx - start_token_idx] = mask ? 0 .f : qk;
255
255
// Update the max value.
256
256
qk_max = mask ? qk_max : fmaxf (qk_max, qk);
@@ -364,14 +364,14 @@ __device__ void paged_attention_kernel(
364
364
} else {
365
365
v_vec = *reinterpret_cast <const V_vec*>(v_ptr + offset);
366
366
}
367
- if (block_idx == num_context_blocks - 1 ) {
367
+ if (block_idx == num_seq_blocks - 1 ) {
368
368
// NOTE(woosuk): When v_vec contains the tokens that are out of the context,
369
369
// we should explicitly zero out the values since they may contain NaNs.
370
370
// See https://github.com/vllm-project/vllm/issues/641#issuecomment-1682544472
371
371
scalar_t * v_vec_ptr = reinterpret_cast <scalar_t *>(&v_vec);
372
372
#pragma unroll
373
373
for (int j = 0 ; j < V_VEC_SIZE; j++) {
374
- v_vec_ptr[j] = token_idx + j < context_len ? v_vec_ptr[j] : zero_value;
374
+ v_vec_ptr[j] = token_idx + j < seq_len ? v_vec_ptr[j] : zero_value;
375
375
}
376
376
}
377
377
accs[i] += dot (logits_vec, v_vec);
@@ -457,7 +457,7 @@ __global__ void paged_attention_v1_kernel(
457
457
const int num_kv_heads, // [num_heads]
458
458
const float scale,
459
459
const int * __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
460
- const int * __restrict__ context_lens , // [num_seqs]
460
+ const int * __restrict__ seq_lens , // [num_seqs]
461
461
const int max_num_blocks_per_seq,
462
462
const float * __restrict__ alibi_slopes, // [num_heads]
463
463
const int q_stride,
@@ -466,7 +466,7 @@ __global__ void paged_attention_v1_kernel(
466
466
const float kv_scale) {
467
467
paged_attention_kernel<scalar_t , cache_t , HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, IS_FP8_KV_CACHE>(
468
468
/* exp_sums */ nullptr , /* max_logits */ nullptr ,
469
- out, q, k_cache, v_cache, num_kv_heads, scale, block_tables, context_lens ,
469
+ out, q, k_cache, v_cache, num_kv_heads, scale, block_tables, seq_lens ,
470
470
max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, kv_head_stride, kv_scale);
471
471
}
472
472
@@ -489,7 +489,7 @@ __global__ void paged_attention_v2_kernel(
489
489
const int num_kv_heads, // [num_heads]
490
490
const float scale,
491
491
const int * __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
492
- const int * __restrict__ context_lens , // [num_seqs]
492
+ const int * __restrict__ seq_lens , // [num_seqs]
493
493
const int max_num_blocks_per_seq,
494
494
const float * __restrict__ alibi_slopes, // [num_heads]
495
495
const int q_stride,
@@ -498,7 +498,7 @@ __global__ void paged_attention_v2_kernel(
498
498
const float kv_scale) {
499
499
paged_attention_kernel<scalar_t , cache_t , HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, IS_FP8_KV_CACHE, PARTITION_SIZE>(
500
500
exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale,
501
- block_tables, context_lens , max_num_blocks_per_seq, alibi_slopes,
501
+ block_tables, seq_lens , max_num_blocks_per_seq, alibi_slopes,
502
502
q_stride, kv_block_stride, kv_head_stride, kv_scale);
503
503
}
504
504
@@ -513,13 +513,13 @@ __global__ void paged_attention_v2_reduce_kernel(
513
513
const float * __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
514
514
const float * __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions]
515
515
const scalar_t * __restrict__ tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size]
516
- const int * __restrict__ context_lens , // [num_seqs]
516
+ const int * __restrict__ seq_lens , // [num_seqs]
517
517
const int max_num_partitions) {
518
518
const int num_heads = gridDim .x ;
519
519
const int head_idx = blockIdx .x ;
520
520
const int seq_idx = blockIdx .y ;
521
- const int context_len = context_lens [seq_idx];
522
- const int num_partitions = DIVIDE_ROUND_UP (context_len , PARTITION_SIZE);
521
+ const int seq_len = seq_lens [seq_idx];
522
+ const int num_partitions = DIVIDE_ROUND_UP (seq_len , PARTITION_SIZE);
523
523
if (num_partitions == 1 ) {
524
524
// No need to reduce. Only copy tmp_out to out.
525
525
scalar_t * out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
@@ -616,7 +616,7 @@ __global__ void paged_attention_v2_reduce_kernel(
616
616
num_kv_heads, \
617
617
scale, \
618
618
block_tables_ptr, \
619
- context_lens_ptr, \
619
+ seq_lens_ptr, \
620
620
max_num_blocks_per_seq, \
621
621
alibi_slopes_ptr, \
622
622
q_stride, \
@@ -639,8 +639,8 @@ void paged_attention_v1_launcher(
639
639
int num_kv_heads,
640
640
float scale,
641
641
torch::Tensor& block_tables,
642
- torch::Tensor& context_lens ,
643
- int max_context_len ,
642
+ torch::Tensor& seq_lens ,
643
+ int max_seq_len ,
644
644
const c10::optional<torch::Tensor>& alibi_slopes,
645
645
float kv_scale) {
646
646
int num_seqs = query.size (0 );
@@ -664,11 +664,11 @@ void paged_attention_v1_launcher(
664
664
CACHE_T* key_cache_ptr = reinterpret_cast <CACHE_T*>(key_cache.data_ptr ());
665
665
CACHE_T* value_cache_ptr = reinterpret_cast <CACHE_T*>(value_cache.data_ptr ());
666
666
int * block_tables_ptr = block_tables.data_ptr <int >();
667
- int * context_lens_ptr = context_lens .data_ptr <int >();
667
+ int * seq_lens_ptr = seq_lens .data_ptr <int >();
668
668
669
669
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
670
- int padded_max_context_len = DIVIDE_ROUND_UP (max_context_len , BLOCK_SIZE) * BLOCK_SIZE;
671
- int logits_size = padded_max_context_len * sizeof (float );
670
+ int padded_max_seq_len = DIVIDE_ROUND_UP (max_seq_len , BLOCK_SIZE) * BLOCK_SIZE;
671
+ int logits_size = padded_max_seq_len * sizeof (float );
672
672
int outputs_size = (NUM_WARPS / 2 ) * head_size * sizeof (float );
673
673
// Python-side check in vllm.worker.worker._check_if_can_support_max_seq_len
674
674
// Keep that in sync with the logic here!
@@ -715,8 +715,8 @@ void paged_attention_v1_launcher(
715
715
num_kv_heads, \
716
716
scale, \
717
717
block_tables, \
718
- context_lens , \
719
- max_context_len , \
718
+ seq_lens , \
719
+ max_seq_len , \
720
720
alibi_slopes, \
721
721
kv_scale);
722
722
@@ -746,9 +746,9 @@ void paged_attention_v1(
746
746
int num_kv_heads, // [num_heads]
747
747
float scale,
748
748
torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
749
- torch::Tensor& context_lens , // [num_seqs]
749
+ torch::Tensor& seq_lens , // [num_seqs]
750
750
int block_size,
751
- int max_context_len ,
751
+ int max_seq_len ,
752
752
const c10::optional<torch::Tensor>& alibi_slopes,
753
753
const std::string& kv_cache_dtype,
754
754
float kv_scale) {
@@ -790,7 +790,7 @@ void paged_attention_v1(
790
790
num_kv_heads, \
791
791
scale, \
792
792
block_tables_ptr, \
793
- context_lens_ptr , \
793
+ seq_lens_ptr , \
794
794
max_num_blocks_per_seq, \
795
795
alibi_slopes_ptr, \
796
796
q_stride, \
@@ -803,7 +803,7 @@ void paged_attention_v1(
803
803
exp_sums_ptr, \
804
804
max_logits_ptr, \
805
805
tmp_out_ptr, \
806
- context_lens_ptr , \
806
+ seq_lens_ptr , \
807
807
max_num_partitions);
808
808
809
809
template <
@@ -824,8 +824,8 @@ void paged_attention_v2_launcher(
824
824
int num_kv_heads,
825
825
float scale,
826
826
torch::Tensor& block_tables,
827
- torch::Tensor& context_lens ,
828
- int max_context_len ,
827
+ torch::Tensor& seq_lens ,
828
+ int max_seq_len ,
829
829
const c10::optional<torch::Tensor>& alibi_slopes,
830
830
float kv_scale) {
831
831
int num_seqs = query.size (0 );
@@ -852,10 +852,10 @@ void paged_attention_v2_launcher(
852
852
CACHE_T* key_cache_ptr = reinterpret_cast <CACHE_T*>(key_cache.data_ptr ());
853
853
CACHE_T* value_cache_ptr = reinterpret_cast <CACHE_T*>(value_cache.data_ptr ());
854
854
int * block_tables_ptr = block_tables.data_ptr <int >();
855
- int * context_lens_ptr = context_lens .data_ptr <int >();
855
+ int * seq_lens_ptr = seq_lens .data_ptr <int >();
856
856
857
857
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
858
- int max_num_partitions = DIVIDE_ROUND_UP (max_context_len , PARTITION_SIZE);
858
+ int max_num_partitions = DIVIDE_ROUND_UP (max_seq_len , PARTITION_SIZE);
859
859
int logits_size = PARTITION_SIZE * sizeof (float );
860
860
int outputs_size = (NUM_WARPS / 2 ) * head_size * sizeof (float );
861
861
@@ -909,8 +909,8 @@ void paged_attention_v2_launcher(
909
909
num_kv_heads, \
910
910
scale, \
911
911
block_tables, \
912
- context_lens , \
913
- max_context_len , \
912
+ seq_lens , \
913
+ max_seq_len , \
914
914
alibi_slopes, \
915
915
kv_scale);
916
916
@@ -943,9 +943,9 @@ void paged_attention_v2(
943
943
int num_kv_heads, // [num_heads]
944
944
float scale,
945
945
torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
946
- torch::Tensor& context_lens , // [num_seqs]
946
+ torch::Tensor& seq_lens , // [num_seqs]
947
947
int block_size,
948
- int max_context_len ,
948
+ int max_seq_len ,
949
949
const c10::optional<torch::Tensor>& alibi_slopes,
950
950
const std::string& kv_cache_dtype,
951
951
float kv_scale) {
0 commit comments