Skip to content

Commit 4abfd87

Browse files
authored
[V1] [Hybrid] Validate compatibility of attention backend batch reordering at init time (#21557)
Signed-off-by: Thomas Parnell <[email protected]>
1 parent f5d0f47 commit 4abfd87

File tree

7 files changed

+96
-72
lines changed

7 files changed

+96
-72
lines changed

vllm/v1/attention/backends/flashinfer.py

Lines changed: 12 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from __future__ import annotations
55

66
from dataclasses import dataclass
7-
from typing import TYPE_CHECKING, ClassVar, Optional, Union
7+
from typing import ClassVar, Optional, Union
88

99
import torch
1010
from flashinfer import (BatchDecodeWithPagedKVCacheWrapper,
@@ -21,17 +21,17 @@
2121
from vllm.utils import cdiv, is_pin_memory_available
2222
from vllm.utils.flashinfer import use_trtllm_decode_attention
2323
from vllm.v1.attention.backends.flash_attn import use_cascade_attention
24-
from vllm.v1.attention.backends.utils import (
25-
AttentionCGSupport, AttentionMetadataBuilder, CommonAttentionMetadata,
26-
get_kv_cache_layout, get_per_layer_parameters,
27-
infer_global_hyperparameters, reorder_batch_to_split_decodes_and_prefills,
28-
split_decodes_and_prefills)
24+
# yapf conflicts with isort for this block
25+
# yapf: disable
26+
from vllm.v1.attention.backends.utils import (AttentionCGSupport,
27+
AttentionMetadataBuilder,
28+
CommonAttentionMetadata,
29+
get_kv_cache_layout,
30+
get_per_layer_parameters,
31+
infer_global_hyperparameters,
32+
split_decodes_and_prefills)
2933
from vllm.v1.kv_cache_interface import AttentionSpec
3034

31-
if TYPE_CHECKING:
32-
from vllm.v1.core.sched.output import SchedulerOutput
33-
from vllm.v1.worker.gpu_input_batch import InputBatch
34-
3535
FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024
3636

3737
logger = init_logger(__name__)
@@ -179,6 +179,8 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
179179
attn_cudagraph_support: ClassVar[AttentionCGSupport] = \
180180
AttentionCGSupport.PURE_DECODE_ONLY
181181

182+
reorder_batch_threshold: ClassVar[int] = 1
183+
182184
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
183185
vllm_config: VllmConfig, device: torch.device):
184186
self.device = device
@@ -239,12 +241,6 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
239241
dtype=torch.int32,
240242
device=self.device)
241243

242-
def reorder_batch(self, input_batch: InputBatch,
243-
scheduler_output: SchedulerOutput) -> bool:
244-
return reorder_batch_to_split_decodes_and_prefills(input_batch,
245-
scheduler_output,
246-
decode_threshold=1)
247-
248244
def _get_workspace_buffer(self):
249245
if self._workspace_buffer is None:
250246
self._workspace_buffer = torch.empty(

vllm/v1/attention/backends/mamba_attn.py

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,21 +2,17 @@
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33
import math
44
from dataclasses import dataclass
5-
from typing import TYPE_CHECKING, Optional
5+
from typing import ClassVar, Optional
66

77
import torch
88

99
from vllm.attention.backends.abstract import AttentionBackend
1010
from vllm.config import VllmConfig
11-
from vllm.v1.attention.backends.utils import (
12-
AttentionMetadataBuilder, CommonAttentionMetadata,
13-
reorder_batch_to_split_decodes_and_prefills, split_decodes_and_prefills)
11+
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
12+
CommonAttentionMetadata,
13+
split_decodes_and_prefills)
1414
from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec
1515

16-
if TYPE_CHECKING:
17-
from vllm.v1.core.sched.output import SchedulerOutput
18-
from vllm.v1.worker.gpu_input_batch import InputBatch
19-
2016

2117
def _query_start_loc_to_chunk_indices_offsets(query_start_loc: torch.Tensor,
2218
chunk_size: int,
@@ -87,6 +83,8 @@ class Mamba2AttentionMetadata:
8783
class Mamba2AttentionMetadataBuilder(
8884
AttentionMetadataBuilder[Mamba2AttentionMetadata]):
8985

86+
reorder_batch_threshold: ClassVar[int] = 1
87+
9088
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
9189
vllm_config: VllmConfig, device: torch.device):
9290
assert isinstance(kv_cache_spec, MambaSpec)
@@ -95,12 +93,6 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
9593
assert self.chunk_size is not None, (
9694
"chunk_size needs to be set in the model config for Mamba2 models")
9795

98-
def reorder_batch(self, input_batch: "InputBatch",
99-
scheduler_output: "SchedulerOutput") -> bool:
100-
return reorder_batch_to_split_decodes_and_prefills(input_batch,
101-
scheduler_output,
102-
decode_threshold=1)
103-
10496
def build(self,
10597
common_prefix_len: int,
10698
common_attn_metadata: CommonAttentionMetadata,

vllm/v1/attention/backends/mla/common.py

Lines changed: 7 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@
190190
import functools
191191
from abc import abstractmethod
192192
from dataclasses import dataclass, field
193-
from typing import TYPE_CHECKING, Generic, Optional, TypeVar, Union
193+
from typing import ClassVar, Generic, Optional, TypeVar, Union
194194

195195
import torch
196196

@@ -210,10 +210,11 @@
210210
from vllm.platforms import current_platform
211211
from vllm.utils import cdiv, round_down
212212
from vllm.utils.flashinfer import has_nvidia_artifactory
213-
from vllm.v1.attention.backends.utils import (
214-
AttentionMetadataBuilder, CommonAttentionMetadata,
215-
get_per_layer_parameters, infer_global_hyperparameters,
216-
reorder_batch_to_split_decodes_and_prefills, split_decodes_and_prefills)
213+
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
214+
CommonAttentionMetadata,
215+
get_per_layer_parameters,
216+
infer_global_hyperparameters,
217+
split_decodes_and_prefills)
217218
from vllm.v1.kv_cache_interface import AttentionSpec
218219

219220
try:
@@ -233,10 +234,6 @@
233234
except ImportError:
234235
flashinfer_available = False
235236

236-
if TYPE_CHECKING:
237-
from vllm.v1.core.sched.output import SchedulerOutput
238-
from vllm.v1.worker.gpu_input_batch import InputBatch
239-
240237
logger = init_logger(__name__)
241238

242239
CUDNN_WORKSPACE_SIZE = 12800
@@ -403,6 +400,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
403400
NOTE: Please read the comment at the top of the file before trying to
404401
understand this class
405402
"""
403+
reorder_batch_threshold: ClassVar[int] = 1
406404

407405
def __init__(self,
408406
kv_cache_spec: AttentionSpec,
@@ -559,12 +557,6 @@ def _build_fi_prefill_wrappers(self, prefill: FlashInferPrefillMetadata):
559557
prefill.prefill_main = self._fi_prefill_main
560558
prefill.prefill_chunks = self._fi_prefill_chunks
561559

562-
def reorder_batch(self, input_batch: "InputBatch",
563-
scheduler_output: "SchedulerOutput") -> bool:
564-
return reorder_batch_to_split_decodes_and_prefills(input_batch,
565-
scheduler_output,
566-
decode_threshold=1)
567-
568560
def _build_decode(self, block_table_tensor: torch.Tensor,
569561
seq_lens: torch.Tensor):
570562
return MLACommonDecodeMetadata(

vllm/v1/attention/backends/rocm_aiter_fa.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -251,9 +251,6 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
251251
self.aot_sliding_window: Optional[tuple[int, int]] = None
252252
self.total_tokens: int = 0
253253

254-
def reorder_batch(self, input_batch, scheduler_output) -> bool:
255-
return False
256-
257254
def build_for_cudagraph_capture(
258255
self, common_attn_metadata: CommonAttentionMetadata):
259256
self.total_tokens = self.model_config.max_model_len \

vllm/v1/attention/backends/utils.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,10 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]):
167167
# Does this backend/builder support CUDA Graphs for attention.
168168
attn_cudagraph_support: ClassVar[AttentionCGSupport] = \
169169
AttentionCGSupport.NEVER
170+
# Does this backend/builder reorder the batch?
171+
# If not, set this to None. Otherwise set it to the query
172+
# length that will be pulled into the front of the batch.
173+
reorder_batch_threshold: ClassVar[Optional[int]] = None
170174

171175
@abstractmethod
172176
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
@@ -221,14 +225,6 @@ def use_cascade_attention(
221225
) -> bool:
222226
return False
223227

224-
def reorder_batch(self, input_batch: "InputBatch",
225-
scheduler_output: "SchedulerOutput") -> bool:
226-
"""
227-
This method can reorder the batch if desired by the backend.
228-
:return: Has the batch been reordered (default False).
229-
"""
230-
return False
231-
232228

233229
@functools.lru_cache
234230
def get_kv_cache_layout():

vllm/v1/worker/cpu_model_runner.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,20 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33
from contextlib import contextmanager
4-
from typing import Any
4+
from typing import TYPE_CHECKING, Any
55

66
import torch
77
import torch.nn as nn
88

99
from vllm.config import VllmConfig
1010
from vllm.logger import init_logger
1111
from vllm.model_executor.model_loader import get_model
12+
from vllm.v1.attention.backends.cpu_attn import TorchSDPAMetadataBuilderV1
1213
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
1314

15+
if TYPE_CHECKING:
16+
from vllm.v1.core.sched.output import SchedulerOutput
17+
1418
logger = init_logger(__name__)
1519

1620

@@ -27,6 +31,34 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
2731

2832
self._postprocess_tenosrs()
2933

34+
def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None:
35+
"""
36+
Update the order of requests in the batch based on the attention
37+
backend's needs. For example, some attention backends (namely MLA) may
38+
want to separate requests based on if the attention computation will be
39+
compute-bound or memory-bound.
40+
41+
Args:
42+
scheduler_output: The scheduler output.
43+
"""
44+
# Attention free models have zero kv_cache_goups, however models
45+
# like Mamba are also attention free but use the kv_cache for
46+
# keeping its internal state. This is why we check the number
47+
# of kv_cache groups instead of solely checking
48+
# for self.model_config.is_attention_free.
49+
if len(self.kv_cache_config.kv_cache_groups) == 0:
50+
return
51+
52+
if len(self.kv_cache_config.kv_cache_groups) > 1:
53+
raise ValueError("Multiple KVCacheGroups is not"
54+
"currently supported with CPU model runner.")
55+
56+
assert type(
57+
self.attn_metadata_builders[0]) is TorchSDPAMetadataBuilderV1
58+
59+
self.attn_metadata_builders[0].reorder_batch(self.input_batch,
60+
scheduler_output)
61+
3062
def _postprocess_tenosrs(self) -> None:
3163
# Note: replace device tensors with cpu tensors
3264
def replace_tensor(obj: Any, cpu_attr_name: str,

vllm/v1/worker/gpu_model_runner.py

Lines changed: 34 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,8 @@
4949
from vllm.v1.attention.backends.utils import (
5050
AttentionCGSupport, AttentionMetadataBuilder, CommonAttentionMetadata,
5151
make_kv_sharing_fast_prefill_attention_metadata,
52-
make_local_attention_virtual_batches)
52+
make_local_attention_virtual_batches,
53+
reorder_batch_to_split_decodes_and_prefills)
5354
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
5455
from vllm.v1.kv_cache_interface import (AttentionSpec,
5556
ChunkedLocalAttentionSpec,
@@ -329,6 +330,8 @@ def __init__(
329330
self.kv_sharing_fast_prefill_logits_indices = torch.zeros(
330331
self.max_num_tokens, dtype=torch.int32, device=self.device)
331332

333+
self.reorder_batch_threshold: Optional[int] = None
334+
332335
def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None:
333336
"""
334337
Update the order of requests in the batch based on the attention
@@ -347,20 +350,11 @@ def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None:
347350
if len(self.kv_cache_config.kv_cache_groups) == 0:
348351
return
349352

350-
self.attn_metadata_builders[0].reorder_batch(self.input_batch,
351-
scheduler_output)
352-
353-
# For models with multiple KV cache groups, the groups should agree on
354-
# the same order of requests. We ensure this by only allowing the first
355-
# group to reorder the batch and asserting that all other groups do not
356-
# reorder the batch.
357-
# TODO(tdoublep): make this more flexible so that any group can
358-
# re-order the batch (not only the first).
359-
# TODO(tdoublep): verify this during engine init instead of at runtime
360-
for i in range(1, len(self.kv_cache_config.kv_cache_groups)):
361-
batch_reordered = self.attn_metadata_builders[i].reorder_batch(
362-
self.input_batch, scheduler_output)
363-
assert not batch_reordered
353+
if self.reorder_batch_threshold is not None:
354+
reorder_batch_to_split_decodes_and_prefills(
355+
self.input_batch,
356+
scheduler_output,
357+
decode_threshold=self.reorder_batch_threshold)
364358

365359
# Note: used for model runner override.
366360
def _init_device_properties(self) -> None:
@@ -2654,6 +2648,9 @@ def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None:
26542648
self.attn_backends.append(attn_backend_i)
26552649
self.attn_metadata_builders.append(attn_metadata_builder_i)
26562650

2651+
# Calculate reorder batch threshold (if neeeded)
2652+
self.calculate_reorder_batch_threshold()
2653+
26572654
if len(self.attn_backends) > 0:
26582655
return
26592656

@@ -2688,6 +2685,28 @@ def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None:
26882685
self.attn_metadata_builders.append(attn_metadata_builder)
26892686
self.is_encoder_only_model = True
26902687

2688+
def calculate_reorder_batch_threshold(self) -> None:
2689+
"""
2690+
Check that if any backends reorder batches; that the reordering
2691+
is compatible (e.g., decode threshold is the same)
2692+
"""
2693+
for attn_metadata_builder_i in self.attn_metadata_builders:
2694+
# check that if any backends reorder batches; that the reordering
2695+
# is compatible (e.g., decode threshold is the same)
2696+
reorder_batch_threshold_i = (
2697+
attn_metadata_builder_i.reorder_batch_threshold)
2698+
if reorder_batch_threshold_i is not None:
2699+
if self.reorder_batch_threshold is not None:
2700+
if reorder_batch_threshold_i != \
2701+
self.reorder_batch_threshold:
2702+
raise ValueError(
2703+
f"Attention backend reorders decodes with "
2704+
f"threshold {reorder_batch_threshold_i} but other "
2705+
f"backend uses threshold "
2706+
f"{self.reorder_batch_threshold}")
2707+
else:
2708+
self.reorder_batch_threshold = reorder_batch_threshold_i
2709+
26912710
def may_reinitialize_input_batch(self,
26922711
kv_cache_config: KVCacheConfig) -> None:
26932712
"""

0 commit comments

Comments
 (0)