Skip to content

Commit 3521ba4

Browse files
authored
[Core][Model runner refactoring 1/N] Refactor attn metadata term (#4518)
1 parent 2d7bce9 commit 3521ba4

27 files changed

+554
-525
lines changed

benchmarks/kernels/benchmark_paged_attention.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
def main(
1717
version: str,
1818
num_seqs: int,
19-
context_len: int,
19+
seq_len: int,
2020
num_query_heads: int,
2121
num_kv_heads: int,
2222
head_size: int,
@@ -48,12 +48,12 @@ def main(
4848
dtype=torch.float,
4949
device=device)
5050

51-
context_lens = [context_len for _ in range(num_seqs)]
52-
max_context_len = max(context_lens)
53-
context_lens = torch.tensor(context_lens, dtype=torch.int, device=device)
51+
seq_lens = [seq_len for _ in range(num_seqs)]
52+
max_seq_len = max(seq_lens)
53+
seq_lens = torch.tensor(seq_lens, dtype=torch.int, device=device)
5454

5555
# Create the block tables.
56-
max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size
56+
max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
5757
block_tables = []
5858
for _ in range(num_seqs):
5959
block_table = [
@@ -77,8 +77,7 @@ def main(
7777
# Prepare for the paged attention kernel.
7878
output = torch.empty_like(query)
7979
if version == "v2":
80-
num_partitions = ((max_context_len + PARTITION_SIZE - 1) //
81-
PARTITION_SIZE)
80+
num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE)
8281
tmp_output = torch.empty(
8382
size=(num_seqs, num_query_heads, num_partitions, head_size),
8483
dtype=output.dtype,
@@ -110,9 +109,9 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float:
110109
num_kv_heads,
111110
scale,
112111
block_tables,
113-
context_lens,
112+
seq_lens,
114113
block_size,
115-
max_context_len,
114+
max_seq_len,
116115
alibi_slopes,
117116
kv_cache_dtype,
118117
kv_scale,
@@ -129,9 +128,9 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float:
129128
num_kv_heads,
130129
scale,
131130
block_tables,
132-
context_lens,
131+
seq_lens,
133132
block_size,
134-
max_context_len,
133+
max_seq_len,
135134
alibi_slopes,
136135
kv_cache_dtype,
137136
kv_scale,
@@ -166,7 +165,7 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float:
166165
choices=["v1", "v2"],
167166
default="v2")
168167
parser.add_argument("--batch-size", type=int, default=8)
169-
parser.add_argument("--context-len", type=int, default=4096)
168+
parser.add_argument("--seq_len", type=int, default=4096)
170169
parser.add_argument("--num-query-heads", type=int, default=64)
171170
parser.add_argument("--num-kv-heads", type=int, default=8)
172171
parser.add_argument("--head-size",
@@ -199,7 +198,7 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float:
199198
main(
200199
version=args.version,
201200
num_seqs=args.batch_size,
202-
context_len=args.context_len,
201+
seq_len=args.seq_len,
203202
num_query_heads=args.num_query_heads,
204203
num_kv_heads=args.num_kv_heads,
205204
head_size=args.head_size,

csrc/attention/attention_kernels.cu

Lines changed: 38 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ __device__ void paged_attention_kernel(
104104
const int num_kv_heads, // [num_heads]
105105
const float scale,
106106
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
107-
const int* __restrict__ context_lens, // [num_seqs]
107+
const int* __restrict__ seq_lens, // [num_seqs]
108108
const int max_num_blocks_per_seq,
109109
const float* __restrict__ alibi_slopes, // [num_heads]
110110
const int q_stride,
@@ -115,23 +115,23 @@ __device__ void paged_attention_kernel(
115115
const int partition_idx = blockIdx.z;
116116
const int max_num_partitions = gridDim.z;
117117
constexpr bool USE_PARTITIONING = PARTITION_SIZE > 0;
118-
const int context_len = context_lens[seq_idx];
119-
if (USE_PARTITIONING && partition_idx * PARTITION_SIZE >= context_len) {
118+
const int seq_len = seq_lens[seq_idx];
119+
if (USE_PARTITIONING && partition_idx * PARTITION_SIZE >= seq_len) {
120120
// No work to do. Terminate the thread block.
121121
return;
122122
}
123123

124-
const int num_context_blocks = DIVIDE_ROUND_UP(context_len, BLOCK_SIZE);
125-
const int num_blocks_per_partition = USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_context_blocks;
124+
const int num_seq_blocks = DIVIDE_ROUND_UP(seq_len, BLOCK_SIZE);
125+
const int num_blocks_per_partition = USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_seq_blocks;
126126

127127
// [start_block_idx, end_block_idx) is the range of blocks to process.
128128
const int start_block_idx = USE_PARTITIONING ? partition_idx * num_blocks_per_partition : 0;
129-
const int end_block_idx = MIN(start_block_idx + num_blocks_per_partition, num_context_blocks);
129+
const int end_block_idx = MIN(start_block_idx + num_blocks_per_partition, num_seq_blocks);
130130
const int num_blocks = end_block_idx - start_block_idx;
131131

132132
// [start_token_idx, end_token_idx) is the range of tokens to process.
133133
const int start_token_idx = start_block_idx * BLOCK_SIZE;
134-
const int end_token_idx = MIN(start_token_idx + num_blocks * BLOCK_SIZE, context_len);
134+
const int end_token_idx = MIN(start_token_idx + num_blocks * BLOCK_SIZE, seq_len);
135135
const int num_tokens = end_token_idx - start_token_idx;
136136

137137
constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1);
@@ -245,12 +245,12 @@ __device__ void paged_attention_kernel(
245245
// This includes a reduction across the threads in the same thread group.
246246
float qk = scale * Qk_dot<scalar_t, THREAD_GROUP_SIZE>::dot(q_vecs[thread_group_offset], k_vecs);
247247
// Add the ALiBi bias if slopes are given.
248-
qk += (alibi_slope != 0) ? alibi_slope * (token_idx - context_len + 1) : 0;
248+
qk += (alibi_slope != 0) ? alibi_slope * (token_idx - seq_len + 1) : 0;
249249

250250
if (thread_group_offset == 0) {
251251
// Store the partial reductions to shared memory.
252252
// NOTE(woosuk): It is required to zero out the masked logits.
253-
const bool mask = token_idx >= context_len;
253+
const bool mask = token_idx >= seq_len;
254254
logits[token_idx - start_token_idx] = mask ? 0.f : qk;
255255
// Update the max value.
256256
qk_max = mask ? qk_max : fmaxf(qk_max, qk);
@@ -364,14 +364,14 @@ __device__ void paged_attention_kernel(
364364
} else {
365365
v_vec = *reinterpret_cast<const V_vec*>(v_ptr + offset);
366366
}
367-
if (block_idx == num_context_blocks - 1) {
367+
if (block_idx == num_seq_blocks - 1) {
368368
// NOTE(woosuk): When v_vec contains the tokens that are out of the context,
369369
// we should explicitly zero out the values since they may contain NaNs.
370370
// See https://github.com/vllm-project/vllm/issues/641#issuecomment-1682544472
371371
scalar_t* v_vec_ptr = reinterpret_cast<scalar_t*>(&v_vec);
372372
#pragma unroll
373373
for (int j = 0; j < V_VEC_SIZE; j++) {
374-
v_vec_ptr[j] = token_idx + j < context_len ? v_vec_ptr[j] : zero_value;
374+
v_vec_ptr[j] = token_idx + j < seq_len ? v_vec_ptr[j] : zero_value;
375375
}
376376
}
377377
accs[i] += dot(logits_vec, v_vec);
@@ -457,7 +457,7 @@ __global__ void paged_attention_v1_kernel(
457457
const int num_kv_heads, // [num_heads]
458458
const float scale,
459459
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
460-
const int* __restrict__ context_lens, // [num_seqs]
460+
const int* __restrict__ seq_lens, // [num_seqs]
461461
const int max_num_blocks_per_seq,
462462
const float* __restrict__ alibi_slopes, // [num_heads]
463463
const int q_stride,
@@ -466,7 +466,7 @@ __global__ void paged_attention_v1_kernel(
466466
const float kv_scale) {
467467
paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, IS_FP8_KV_CACHE>(
468468
/* exp_sums */ nullptr, /* max_logits */ nullptr,
469-
out, q, k_cache, v_cache, num_kv_heads, scale, block_tables, context_lens,
469+
out, q, k_cache, v_cache, num_kv_heads, scale, block_tables, seq_lens,
470470
max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, kv_head_stride, kv_scale);
471471
}
472472

@@ -489,7 +489,7 @@ __global__ void paged_attention_v2_kernel(
489489
const int num_kv_heads, // [num_heads]
490490
const float scale,
491491
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
492-
const int* __restrict__ context_lens, // [num_seqs]
492+
const int* __restrict__ seq_lens, // [num_seqs]
493493
const int max_num_blocks_per_seq,
494494
const float* __restrict__ alibi_slopes, // [num_heads]
495495
const int q_stride,
@@ -498,7 +498,7 @@ __global__ void paged_attention_v2_kernel(
498498
const float kv_scale) {
499499
paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, IS_FP8_KV_CACHE, PARTITION_SIZE>(
500500
exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale,
501-
block_tables, context_lens, max_num_blocks_per_seq, alibi_slopes,
501+
block_tables, seq_lens, max_num_blocks_per_seq, alibi_slopes,
502502
q_stride, kv_block_stride, kv_head_stride, kv_scale);
503503
}
504504

@@ -513,13 +513,13 @@ __global__ void paged_attention_v2_reduce_kernel(
513513
const float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
514514
const float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions]
515515
const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size]
516-
const int* __restrict__ context_lens, // [num_seqs]
516+
const int* __restrict__ seq_lens, // [num_seqs]
517517
const int max_num_partitions) {
518518
const int num_heads = gridDim.x;
519519
const int head_idx = blockIdx.x;
520520
const int seq_idx = blockIdx.y;
521-
const int context_len = context_lens[seq_idx];
522-
const int num_partitions = DIVIDE_ROUND_UP(context_len, PARTITION_SIZE);
521+
const int seq_len = seq_lens[seq_idx];
522+
const int num_partitions = DIVIDE_ROUND_UP(seq_len, PARTITION_SIZE);
523523
if (num_partitions == 1) {
524524
// No need to reduce. Only copy tmp_out to out.
525525
scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
@@ -616,7 +616,7 @@ __global__ void paged_attention_v2_reduce_kernel(
616616
num_kv_heads, \
617617
scale, \
618618
block_tables_ptr, \
619-
context_lens_ptr, \
619+
seq_lens_ptr, \
620620
max_num_blocks_per_seq, \
621621
alibi_slopes_ptr, \
622622
q_stride, \
@@ -639,8 +639,8 @@ void paged_attention_v1_launcher(
639639
int num_kv_heads,
640640
float scale,
641641
torch::Tensor& block_tables,
642-
torch::Tensor& context_lens,
643-
int max_context_len,
642+
torch::Tensor& seq_lens,
643+
int max_seq_len,
644644
const c10::optional<torch::Tensor>& alibi_slopes,
645645
float kv_scale) {
646646
int num_seqs = query.size(0);
@@ -664,11 +664,11 @@ void paged_attention_v1_launcher(
664664
CACHE_T* key_cache_ptr = reinterpret_cast<CACHE_T*>(key_cache.data_ptr());
665665
CACHE_T* value_cache_ptr = reinterpret_cast<CACHE_T*>(value_cache.data_ptr());
666666
int* block_tables_ptr = block_tables.data_ptr<int>();
667-
int* context_lens_ptr = context_lens.data_ptr<int>();
667+
int* seq_lens_ptr = seq_lens.data_ptr<int>();
668668

669669
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
670-
int padded_max_context_len = DIVIDE_ROUND_UP(max_context_len, BLOCK_SIZE) * BLOCK_SIZE;
671-
int logits_size = padded_max_context_len * sizeof(float);
670+
int padded_max_seq_len = DIVIDE_ROUND_UP(max_seq_len, BLOCK_SIZE) * BLOCK_SIZE;
671+
int logits_size = padded_max_seq_len * sizeof(float);
672672
int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);
673673
// Python-side check in vllm.worker.worker._check_if_can_support_max_seq_len
674674
// Keep that in sync with the logic here!
@@ -715,8 +715,8 @@ void paged_attention_v1_launcher(
715715
num_kv_heads, \
716716
scale, \
717717
block_tables, \
718-
context_lens, \
719-
max_context_len, \
718+
seq_lens, \
719+
max_seq_len, \
720720
alibi_slopes, \
721721
kv_scale);
722722

@@ -746,9 +746,9 @@ void paged_attention_v1(
746746
int num_kv_heads, // [num_heads]
747747
float scale,
748748
torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
749-
torch::Tensor& context_lens, // [num_seqs]
749+
torch::Tensor& seq_lens, // [num_seqs]
750750
int block_size,
751-
int max_context_len,
751+
int max_seq_len,
752752
const c10::optional<torch::Tensor>& alibi_slopes,
753753
const std::string& kv_cache_dtype,
754754
float kv_scale) {
@@ -790,7 +790,7 @@ void paged_attention_v1(
790790
num_kv_heads, \
791791
scale, \
792792
block_tables_ptr, \
793-
context_lens_ptr, \
793+
seq_lens_ptr, \
794794
max_num_blocks_per_seq, \
795795
alibi_slopes_ptr, \
796796
q_stride, \
@@ -803,7 +803,7 @@ void paged_attention_v1(
803803
exp_sums_ptr, \
804804
max_logits_ptr, \
805805
tmp_out_ptr, \
806-
context_lens_ptr, \
806+
seq_lens_ptr, \
807807
max_num_partitions);
808808

809809
template<
@@ -824,8 +824,8 @@ void paged_attention_v2_launcher(
824824
int num_kv_heads,
825825
float scale,
826826
torch::Tensor& block_tables,
827-
torch::Tensor& context_lens,
828-
int max_context_len,
827+
torch::Tensor& seq_lens,
828+
int max_seq_len,
829829
const c10::optional<torch::Tensor>& alibi_slopes,
830830
float kv_scale) {
831831
int num_seqs = query.size(0);
@@ -852,10 +852,10 @@ void paged_attention_v2_launcher(
852852
CACHE_T* key_cache_ptr = reinterpret_cast<CACHE_T*>(key_cache.data_ptr());
853853
CACHE_T* value_cache_ptr = reinterpret_cast<CACHE_T*>(value_cache.data_ptr());
854854
int* block_tables_ptr = block_tables.data_ptr<int>();
855-
int* context_lens_ptr = context_lens.data_ptr<int>();
855+
int* seq_lens_ptr = seq_lens.data_ptr<int>();
856856

857857
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
858-
int max_num_partitions = DIVIDE_ROUND_UP(max_context_len, PARTITION_SIZE);
858+
int max_num_partitions = DIVIDE_ROUND_UP(max_seq_len, PARTITION_SIZE);
859859
int logits_size = PARTITION_SIZE * sizeof(float);
860860
int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);
861861

@@ -909,8 +909,8 @@ void paged_attention_v2_launcher(
909909
num_kv_heads, \
910910
scale, \
911911
block_tables, \
912-
context_lens, \
913-
max_context_len, \
912+
seq_lens, \
913+
max_seq_len, \
914914
alibi_slopes, \
915915
kv_scale);
916916

@@ -943,9 +943,9 @@ void paged_attention_v2(
943943
int num_kv_heads, // [num_heads]
944944
float scale,
945945
torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
946-
torch::Tensor& context_lens, // [num_seqs]
946+
torch::Tensor& seq_lens, // [num_seqs]
947947
int block_size,
948-
int max_context_len,
948+
int max_seq_len,
949949
const c10::optional<torch::Tensor>& alibi_slopes,
950950
const std::string& kv_cache_dtype,
951951
float kv_scale) {

0 commit comments

Comments
 (0)