Skip to content

Commit 3e858ca

Browse files
WoosukKwonamd-xiaoyu12
authored andcommitted
[Misc] Add max_seq_len to CommonAttentionMetadata (vllm-project#23216)
Signed-off-by: Woosuk Kwon <[email protected]> Signed-off-by: Xiao Yu <[email protected]>
1 parent 753abb5 commit 3e858ca

File tree

12 files changed

+22
-7
lines changed

12 files changed

+22
-7
lines changed

tests/v1/attention/utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ def create_common_attn_metadata(
5858
dtype=torch.int32,
5959
device=device)
6060
seq_lens_cpu = seq_lens.cpu()
61+
max_seq_len = int(seq_lens_cpu.max())
6162

6263
# Create computed tokens (context length for each sequence)
6364
context_lens = [
@@ -101,6 +102,7 @@ def create_common_attn_metadata(
101102
num_reqs=batch_spec.batch_size,
102103
num_actual_tokens=num_tokens,
103104
max_query_len=max_query_len,
105+
max_seq_len=max_seq_len,
104106
block_table_tensor=block_table_tensor,
105107
slot_mapping=slot_mapping,
106108
causal=True,

tests/v1/spec_decode/test_tree_attention.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ def forward_attention(
5050
dtype=torch.int32,
5151
)
5252
context_lens = seq_lens - query_lens
53+
max_seq_len = int(seq_lens.max())
5354
max_query_len = q_len
5455
num_actual_tokens = query_start_loc[-1]
5556

@@ -81,6 +82,7 @@ def forward_attention(
8182
num_reqs=batch_size,
8283
num_actual_tokens=num_actual_tokens,
8384
max_query_len=max_query_len,
85+
max_seq_len=max_seq_len,
8486
block_table_tensor=block_table,
8587
slot_mapping=slot_mapping,
8688
)

vllm/v1/attention/backends/flash_attn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,7 @@ def build(self,
233233
num_reqs = common_attn_metadata.num_reqs
234234
num_actual_tokens = common_attn_metadata.num_actual_tokens
235235
max_query_len = common_attn_metadata.max_query_len
236-
max_seq_len = int(common_attn_metadata.seq_lens_cpu.max())
236+
max_seq_len = common_attn_metadata.max_seq_len
237237
query_start_loc = common_attn_metadata.query_start_loc
238238
seq_lens = common_attn_metadata.seq_lens
239239
seq_lens_cpu = common_attn_metadata.seq_lens_cpu

vllm/v1/attention/backends/flashinfer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -463,7 +463,7 @@ def build(self,
463463

464464
page_size = self.page_size
465465
max_q_len = common_attn_metadata.max_query_len
466-
max_seq_len = common_attn_metadata.seq_lens_cpu.max().item()
466+
max_seq_len = common_attn_metadata.max_seq_len
467467
seq_lens = common_attn_metadata.seq_lens
468468
seq_lens_cpu = common_attn_metadata.seq_lens_cpu
469469
block_table_tensor = common_attn_metadata.block_table_tensor

vllm/v1/attention/backends/flex_attention.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -305,7 +305,7 @@ def build(self,
305305
num_actual_tokens = common_attn_metadata.num_actual_tokens
306306
max_query_len = common_attn_metadata.max_query_len
307307

308-
max_seq_len = int(common_attn_metadata.seq_lens_cpu.max())
308+
max_seq_len = common_attn_metadata.max_seq_len
309309
query_start_loc = common_attn_metadata.query_start_loc
310310
seq_lens = common_attn_metadata.seq_lens
311311
block_table_tensor = common_attn_metadata.block_table_tensor

vllm/v1/attention/backends/rocm_aiter_fa.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,7 @@ def build(self,
270270

271271
num_actual_tokens = common_attn_metadata.num_actual_tokens
272272
max_query_len = common_attn_metadata.max_query_len
273-
max_seq_len = int(common_attn_metadata.seq_lens_cpu.max())
273+
max_seq_len = common_attn_metadata.max_seq_len
274274
query_start_loc = common_attn_metadata.query_start_loc
275275
seq_lens = common_attn_metadata.seq_lens
276276
block_table_tensor = common_attn_metadata.block_table_tensor

vllm/v1/attention/backends/tree_attn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,7 @@ def build(
205205
q_start_loc = common_attn_metadata.query_start_loc
206206
max_query_len = common_attn_metadata.max_query_len
207207
kv_seqlens = common_attn_metadata.seq_lens
208-
max_seq_len = int(common_attn_metadata.seq_lens_cpu.max())
208+
max_seq_len = common_attn_metadata.max_seq_len
209209
block_table = common_attn_metadata.block_table_tensor
210210
slot_mapping = common_attn_metadata.slot_mapping
211211

vllm/v1/attention/backends/triton_attn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def build(self,
9090
num_actual_tokens = common_attn_metadata.num_actual_tokens
9191
max_query_len = common_attn_metadata.max_query_len
9292

93-
max_seq_len = int(common_attn_metadata.seq_lens_cpu.max())
93+
max_seq_len = common_attn_metadata.max_seq_len
9494
query_start_loc = common_attn_metadata.query_start_loc
9595
seq_lens = common_attn_metadata.seq_lens
9696
block_table_tensor = common_attn_metadata.block_table_tensor

vllm/v1/attention/backends/utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,8 @@ class CommonAttentionMetadata:
5858
"""Total number of tokens in batch"""
5959
max_query_len: int
6060
"""Longest query in batch"""
61+
max_seq_len: int
62+
"""Longest context length in batch"""
6163

6264
block_table_tensor: torch.Tensor
6365
slot_mapping: torch.Tensor
@@ -107,6 +109,7 @@ def _make_metadata_with_slice(
107109

108110
seq_lens = attn_metadata.seq_lens[request_slice]
109111
seq_lens_cpu = attn_metadata.seq_lens_cpu[request_slice]
112+
max_seq_len = int(seq_lens_cpu.max())
110113
num_computed_tokens_cpu = attn_metadata.num_computed_tokens_cpu[
111114
request_slice]
112115

@@ -128,6 +131,7 @@ def _make_metadata_with_slice(
128131
num_reqs=num_requests,
129132
num_actual_tokens=num_actual_tokens,
130133
max_query_len=max_query_len,
134+
max_seq_len=max_seq_len,
131135
block_table_tensor=block_table_tensor,
132136
slot_mapping=slot_mapping,
133137
)
@@ -520,6 +524,7 @@ def make_local_attention_virtual_batches(
520524

521525
query_start_loc_cpu = torch.from_numpy(cu_seqlens_q_local)
522526
seq_lens_cpu = torch.from_numpy(seqlens_k_local)
527+
max_seq_len = int(seq_lens_cpu.max())
523528

524529
return CommonAttentionMetadata(
525530
query_start_loc_cpu=query_start_loc_cpu,
@@ -531,6 +536,7 @@ def make_local_attention_virtual_batches(
531536
num_reqs=len(seq_lens_cpu),
532537
num_actual_tokens=common_attn_metadata.num_actual_tokens,
533538
max_query_len=seqlens_q_local.max(),
539+
max_seq_len=max_seq_len,
534540
block_table_tensor=block_table_local,
535541
slot_mapping=common_attn_metadata.slot_mapping,
536542
causal=True,

vllm/v1/attention/backends/xformers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,7 @@ def build(
231231
q_seqlens = torch.diff(q_start_loc)
232232
max_query_len = common_attn_metadata.max_query_len
233233
kv_seqlens = common_attn_metadata.seq_lens
234-
max_seq_len = int(common_attn_metadata.seq_lens_cpu.max())
234+
max_seq_len = common_attn_metadata.max_seq_len
235235
block_table = common_attn_metadata.block_table_tensor
236236
slot_mapping = common_attn_metadata.slot_mapping
237237

0 commit comments

Comments
 (0)