Skip to content

Commit 5c057e0

Browse files
authored
[CPU] Refine batch reorder of CPU attention backend (#26096)
Signed-off-by: jiang1.li <[email protected]>
1 parent ed3aeb2 commit 5c057e0

File tree

2 files changed

+44
-128
lines changed

2 files changed

+44
-128
lines changed

vllm/v1/attention/backends/cpu_attn.py

Lines changed: 42 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,9 @@
1414
from vllm.config import VllmConfig
1515
from vllm.logger import init_logger
1616
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
17-
CommonAttentionMetadata)
18-
from vllm.v1.core.sched.output import SchedulerOutput
17+
CommonAttentionMetadata,
18+
split_decodes_and_prefills)
1919
from vllm.v1.kv_cache_interface import AttentionSpec
20-
from vllm.v1.worker.gpu_input_batch import InputBatch
2120

2221
try:
2322
import intel_extension_for_pytorch.llm.modules as ipex_modules
@@ -102,16 +101,16 @@ class TorchSDPAMetadata(AttentionMetadata):
102101
"""Metadata for PagedAttention."""
103102
# (batch_size,). The length of sequences (entire tokens seen so far) per
104103
# sequence.
105-
seq_lens_tensor: Optional[torch.Tensor]
104+
decode_seq_lens_tensor: Optional[torch.Tensor]
106105
# Maximum sequence length in the batch. 0 if it is prefill-only batch.
107-
max_decode_seq_len: int
106+
decode_max_seq_len: int
108107
# (batch_size, max_blocks_per_seq).
109108
# Block addresses per sequence. (Seq id -> list of physical block)
110109
# E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks
111110
# in the kv cache. Each block can contain up to block_size tokens.
112111
# 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph
113112
# captured.
114-
block_tables: Optional[torch.Tensor]
113+
decode_block_tables: Optional[torch.Tensor]
115114
"""Metadata for TorchSDPABackend.
116115
"""
117116
# Currently, input sequences can only contain all prompts
@@ -121,9 +120,9 @@ class TorchSDPAMetadata(AttentionMetadata):
121120

122121
# For chunked prefill only
123122
max_query_len: Optional[int] = None
124-
max_kv_len: Optional[int] = None
123+
prefill_max_seq_len: Optional[int] = None
125124
prefill_query_start_loc: Optional[torch.Tensor] = None
126-
kv_start_loc: Optional[torch.Tensor] = None
125+
prefill_seq_start_loc: Optional[torch.Tensor] = None
127126
prefill_block_tables: Optional[torch.Tensor] = None
128127

129128
# For V1 logits index only
@@ -307,8 +306,8 @@ def get_seq_len_block_table_args(
307306
or attn_type == AttentionType.ENCODER_ONLY):
308307
# Decoder self-attention
309308
# Choose max_seq_len based on whether we are in prompt_run
310-
return (self.seq_lens_tensor, self.max_decode_seq_len,
311-
self.block_tables)
309+
return (self.decode_seq_lens_tensor, self.decode_max_seq_len,
310+
self.decode_block_tables)
312311
elif attn_type == AttentionType.ENCODER_DECODER:
313312
# Enc/dec cross-attention KVs match encoder sequence length;
314313
# cross-attention utilizes special "cross" block tables
@@ -323,19 +322,14 @@ def get_seq_len_block_table_args(
323322

324323

325324
class TorchSDPAMetadataBuilderV1(AttentionMetadataBuilder[TorchSDPAMetadata]):
325+
reorder_batch_threshold: int = 1
326326

327327
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
328328
vllm_config: VllmConfig, device: torch.device) -> None:
329329
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
330330

331331
self.scheduler_config = vllm_config.scheduler_config
332-
333-
# For reorder
334-
self.reorder_prompt_req_index_list = np.empty(
335-
vllm_config.scheduler_config.max_num_seqs, dtype=np.int64)
336-
self.reorder_decode_req_index_list = np.empty(
337-
vllm_config.scheduler_config.max_num_seqs, dtype=np.int64)
338-
self.num_prompt_req: int = 0
332+
self._init_reorder_batch_threshold(1, False)
339333

340334
self.seq_start_loc_cpu = torch.zeros(
341335
vllm_config.scheduler_config.max_num_seqs + 1,
@@ -344,50 +338,6 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
344338
)
345339
self.seq_start_loc_np = self.seq_start_loc_cpu.numpy()
346340

347-
def reorder_batch(self, input_batch: InputBatch,
348-
scheduler_output: SchedulerOutput) -> bool:
349-
prompt_list_idx = 0
350-
decode_list_idx = 0
351-
for req_index in range(input_batch.num_reqs):
352-
if input_batch.num_computed_tokens_cpu[
353-
req_index] < input_batch.num_prompt_tokens[req_index]:
354-
# prompt stage
355-
self.reorder_prompt_req_index_list[prompt_list_idx] = req_index
356-
prompt_list_idx += 1
357-
else:
358-
# decode stage
359-
self.reorder_decode_req_index_list[decode_list_idx] = req_index
360-
decode_list_idx += 1
361-
assert decode_list_idx + prompt_list_idx == input_batch.num_reqs
362-
363-
# Update prompt requests number
364-
self.num_prompt_req = prompt_list_idx
365-
366-
reorder_req_num = 0
367-
for req_index in range(decode_list_idx):
368-
if self.reorder_decode_req_index_list[req_index] < prompt_list_idx:
369-
reorder_req_num += 1
370-
else:
371-
break
372-
373-
if reorder_req_num == 0:
374-
return False
375-
376-
reorder_prompt_list = (
377-
self.reorder_prompt_req_index_list[:prompt_list_idx]
378-
[-reorder_req_num:])
379-
reorder_decode_list = (
380-
self.reorder_decode_req_index_list[:decode_list_idx]
381-
[:reorder_req_num])
382-
assert reorder_decode_list.size == reorder_prompt_list.size
383-
384-
for idx in range(reorder_req_num):
385-
prompt_req_index = reorder_prompt_list[idx].item()
386-
decode_req_index = reorder_decode_list[idx].item()
387-
input_batch.swap_states(prompt_req_index, decode_req_index)
388-
389-
return True
390-
391341
def build(self,
392342
common_prefix_len: int,
393343
common_attn_metadata: CommonAttentionMetadata,
@@ -397,41 +347,46 @@ def build(self,
397347

398348
seq_lens_cpu = common_attn_metadata.seq_lens_cpu
399349
seq_lens_np = seq_lens_cpu.numpy()
400-
num_prompt_req = self.num_prompt_req
401-
max_prefill_seq_len = seq_lens_np[:num_prompt_req].max().item(
402-
) if num_prompt_req > 0 else 0
403-
max_decode_seq_len = seq_lens_np[num_prompt_req:num_reqs].max().item(
404-
) if num_prompt_req < num_reqs else 0
405-
self.seq_start_loc_np[0] = 0
406-
np.cumsum(seq_lens_np, out=self.seq_start_loc_np[1:num_reqs + 1])
407350

408351
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
409-
num_prefill_tokens = int(query_start_loc_cpu[num_prompt_req].item())
410-
num_decode_tokens = int(query_start_loc_cpu[num_reqs].item() -
411-
num_prefill_tokens)
352+
query_start_loc_np = query_start_loc_cpu.numpy()
353+
354+
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens =\
355+
split_decodes_and_prefills(common_attn_metadata,
356+
decode_threshold=self.reorder_batch_threshold,
357+
require_uniform=True)
358+
359+
max_prefill_seq_len = seq_lens_np[num_decodes:num_reqs].max().item(
360+
) if num_prefills > 0 else 0
361+
max_decode_seq_len = seq_lens_np[:num_decodes].max().item(
362+
) if num_prefills < num_reqs else 0
363+
self.seq_start_loc_np[0] = 0
364+
np.cumsum(seq_lens_np, out=self.seq_start_loc_np[1:num_reqs + 1])
412365

413366
slot_mapping = common_attn_metadata.slot_mapping.long()
414367
block_table_tensor = common_attn_metadata.block_table_tensor
368+
query_start_loc_np = query_start_loc_cpu.numpy()
369+
query_start_loc_np[num_decodes:num_reqs + 1] -= num_decode_tokens
415370

416371
attn_metadata = TorchSDPAMetadata(
417-
num_prefills=num_prompt_req,
372+
num_prefills=num_prefills,
418373
num_prefill_tokens=num_prefill_tokens,
419374
num_decode_tokens=num_decode_tokens,
420375
slot_mapping=slot_mapping,
421376
# to ensure inference when chunked_prefill is disabled
422377
seq_lens=seq_lens_cpu.tolist(),
423-
seq_lens_tensor=seq_lens_cpu[num_prompt_req:num_reqs], # decode
424-
max_decode_seq_len=max_decode_seq_len, # decode
425-
block_tables=block_table_tensor[num_prompt_req:num_reqs], # decode
378+
decode_seq_lens_tensor=seq_lens_cpu[:num_decodes], # decode
379+
decode_max_seq_len=max_decode_seq_len, # decode
380+
decode_block_tables=block_table_tensor[:num_decodes], # decode
426381
chunked_prefill=self.scheduler_config.chunked_prefill_enabled,
427382
max_query_len=max_query_len,
428-
max_kv_len=max_prefill_seq_len,
429-
prefill_query_start_loc=query_start_loc_cpu[:num_prompt_req +
383+
prefill_max_seq_len=max_prefill_seq_len,
384+
prefill_query_start_loc=query_start_loc_cpu[num_decodes:num_reqs +
430385
1], # prefill
431-
kv_start_loc=self.seq_start_loc_cpu[:num_prompt_req +
432-
1], # prefill
433-
prefill_block_tables=block_table_tensor[:
434-
num_prompt_req], # prefill
386+
prefill_seq_start_loc=self.seq_start_loc_cpu[num_decodes:num_reqs +
387+
1], # prefill
388+
prefill_block_tables=block_table_tensor[
389+
num_decodes:num_reqs], # prefill
435390
query_start_loc=query_start_loc_cpu[:num_reqs +
436391
1], # for logits index
437392
)
@@ -596,14 +551,14 @@ def forward(
596551
import intel_extension_for_pytorch.llm.modules as ipex_modules
597552
output = torch.empty_like(query)
598553
ipex_modules.PagedAttention.flash_attn_varlen_func(
599-
output[:prefill_meta.num_prefill_tokens, :, :],
600-
query[:prefill_meta.num_prefill_tokens, :, :],
554+
output[prefill_meta.num_decode_tokens:, :, :],
555+
query[prefill_meta.num_decode_tokens:, :, :],
601556
key_cache,
602557
value_cache,
603558
prefill_meta.prefill_query_start_loc,
604-
prefill_meta.kv_start_loc,
559+
prefill_meta.prefill_seq_start_loc,
605560
prefill_meta.max_query_len,
606-
prefill_meta.max_kv_len,
561+
prefill_meta.prefill_max_seq_len,
607562
self.scale,
608563
True,
609564
prefill_meta.prefill_block_tables,
@@ -621,8 +576,8 @@ def forward(
621576
) = decode_meta.get_seq_len_block_table_args(attn_type)
622577

623578
self.paged_attn_impl.forward_decode(
624-
output[attn_metadata.num_prefill_tokens:, :, :],
625-
query[attn_metadata.num_prefill_tokens:, :, :],
579+
output[:attn_metadata.num_decode_tokens, :, :],
580+
query[:attn_metadata.num_decode_tokens, :, :],
626581
key_cache,
627582
value_cache,
628583
block_tables_arg,

vllm/v1/worker/cpu_model_runner.py

Lines changed: 2 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
from vllm.config import VllmConfig
1010
from vllm.logger import init_logger
1111
from vllm.model_executor.model_loader import get_model
12-
from vllm.v1.attention.backends.cpu_attn import TorchSDPAMetadataBuilderV1
1312
from vllm.v1.utils import CpuGpuBuffer
1413
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
1514

@@ -33,50 +32,12 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
3332

3433
self._postprocess_tensors()
3534

35+
# Note: Remove the override after new attention backend finished
3636
def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None:
37-
"""
38-
Update the order of requests in the batch based on the attention
39-
backend's needs. For example, some attention backends (namely MLA) may
40-
want to separate requests based on if the attention computation will be
41-
compute-bound or memory-bound.
42-
43-
Args:
44-
scheduler_output: The scheduler output.
45-
"""
46-
# Attention free models have zero kv_cache_groups, however models
47-
# like Mamba are also attention free but use the kv_cache for
48-
# keeping its internal state. This is why we check the number
49-
# of kv_cache groups instead of solely checking
50-
# for self.model_config.is_attention_free.
51-
if len(self.kv_cache_config.kv_cache_groups) == 0:
52-
return
53-
5437
if len(self.kv_cache_config.kv_cache_groups) > 1:
5538
raise ValueError("Multiple KVCacheGroups is not"
5639
"currently supported with CPU model runner.")
57-
58-
# Guard against encoder-only / pooling models where `attn_groups`
59-
# may be empty or lack the expected metadata_builder.
60-
# Without this check, accessing `attn_groups[0][0]` would trigger
61-
# an AssertionError on CPU backend.
62-
if not hasattr(self, "attn_groups") or not self.attn_groups:
63-
return
64-
if not self.attn_groups[0]:
65-
return
66-
67-
mb = getattr(self.attn_groups[0][0], "metadata_builders", None)
68-
if isinstance(mb, list):
69-
if not isinstance(mb[0], TorchSDPAMetadataBuilderV1):
70-
return
71-
mb[0].reorder_batch(self.input_batch, scheduler_output)
72-
return
73-
elif not isinstance(mb, TorchSDPAMetadataBuilderV1):
74-
# Encoder-only / rerank models do not benefit from reordering,
75-
# so we safely skip here.
76-
return
77-
78-
# Safe path for decoder/attention-heavy models
79-
mb.reorder_batch(self.input_batch, scheduler_output)
40+
super()._may_reorder_batch(scheduler_output)
8041

8142
def _postprocess_tensors(self) -> None:
8243
# Note: replace device tensors with cpu tensors

0 commit comments

Comments
 (0)