Skip to content

Commit 4727a8a

Browse files
[Attention] Remove unused reorder_batch method (#24463)
Signed-off-by: Matthew Bonanni <[email protected]>
1 parent b8f603c commit 4727a8a

File tree

6 files changed

+8
-64
lines changed

6 files changed

+8
-64
lines changed

tests/v1/logits_processors/test_correctness.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -581,7 +581,7 @@ def _generate_fake_step_update(
581581
persistent_batch[:] = persistent_batch[0:condensed_batch_size]
582582

583583
if condensed_batch_size > 1:
584-
# Simulate arbitrary reorder_batch() in the kernel backend
584+
# Simulate arbitrary batch ordering in the kernel backend
585585
# Generate a random number k of non-overlapping swap tuples
586586
k = random.randint(0, condensed_batch_size // 2)
587587
idxs = list(range(condensed_batch_size))

vllm/v1/attention/backends/flashinfer.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -602,8 +602,7 @@ def build(
602602
)
603603
else:
604604
# Regular attention (common case).
605-
# Decodes are at the front and prefills are at the back,
606-
# according to reorder_batch()
605+
# Decodes are at the front and prefills are at the back.
607606
num_prefills = attn_metadata.num_prefills
608607
num_decodes = attn_metadata.num_decodes
609608
if num_prefills > 0:
@@ -925,8 +924,7 @@ def forward(
925924
stride_order = FlashInferBackend.get_kv_cache_stride_order()
926925
kv_cache_permute = kv_cache.permute(*stride_order)
927926
# Regular attention (common case).
928-
# Decodes are at the front and prefills are at the back,
929-
# according to reorder_batch()
927+
# Decodes are at the front and prefills are at the back.
930928
if num_prefill_tokens > 0:
931929
prefill_wrapper = attn_metadata.prefill_wrapper
932930
prefill_query = query[num_decode_tokens:]

vllm/v1/attention/backends/flex_attention.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
"""Attention layer with FlexAttention."""
44

55
from dataclasses import dataclass
6-
from typing import TYPE_CHECKING, Optional, Union
6+
from typing import Optional, Union
77

88
import torch
99
import torch._dynamo.decorators
@@ -38,10 +38,6 @@
3838

3939
logger = init_logger(__name__)
4040

41-
if TYPE_CHECKING:
42-
from vllm.v1.core.sched.output import SchedulerOutput
43-
from vllm.v1.worker.gpu_input_batch import InputBatch
44-
4541
create_block_mask_compiled = torch.compile(
4642
create_block_mask, fullgraph=True, mode="reduce-overhead"
4743
)
@@ -600,11 +596,6 @@ def __init__(
600596
self.q_block_size: int = 16 if is_torch_equal_or_newer("2.9.0.dev0") else 128
601597
self.kv_block_size: int = 16 if is_torch_equal_or_newer("2.9.0.dev0") else 128
602598

603-
def reorder_batch(
604-
self, input_batch: "InputBatch", scheduler_output: "SchedulerOutput"
605-
) -> bool:
606-
return False
607-
608599
def build(
609600
self,
610601
common_prefix_len: int,

vllm/v1/attention/backends/tree_attn.py

Lines changed: 3 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,11 @@
44

55
import ast
66
from dataclasses import dataclass
7-
from typing import TYPE_CHECKING, Optional
7+
from typing import Optional
88

99
import torch
1010

11+
from vllm import _custom_ops as ops
1112
from vllm.attention.backends.abstract import (
1213
AttentionBackend,
1314
AttentionImpl,
@@ -20,17 +21,10 @@
2021
from vllm.v1.attention.backends.utils import (
2122
AttentionMetadataBuilder,
2223
CommonAttentionMetadata,
23-
reorder_batch_to_split_decodes_and_prefills,
2424
split_decodes_and_prefills,
2525
)
2626
from vllm.v1.kv_cache_interface import AttentionSpec
2727

28-
if TYPE_CHECKING:
29-
from vllm.v1.core.sched.output import SchedulerOutput
30-
from vllm.v1.worker.gpu_input_batch import InputBatch
31-
32-
from vllm import _custom_ops as ops
33-
3428
logger = init_logger(__name__)
3529

3630

@@ -189,12 +183,7 @@ def __init__(
189183
device=device,
190184
)
191185

192-
def reorder_batch(
193-
self, input_batch: "InputBatch", scheduler_output: "SchedulerOutput"
194-
) -> bool:
195-
return reorder_batch_to_split_decodes_and_prefills(
196-
input_batch, scheduler_output, decode_threshold=self.tree_attn_bias.shape[0]
197-
)
186+
self.reorder_batch_threshold = self.tree_attn_bias.shape[0]
198187

199188
def build(
200189
self,

vllm/v1/attention/backends/utils.py

Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -299,24 +299,6 @@ def build(
299299
"""
300300
raise NotImplementedError
301301

302-
def reorder_batch(
303-
self, input_batch: "InputBatch", scheduler_output: "SchedulerOutput"
304-
) -> bool:
305-
"""
306-
Update the order of requests in the batch based on the attention
307-
backend's needs. For example, some attention backends (namely MLA) may
308-
want to separate requests based on if the attention computation will be
309-
compute-bound or memory-bound.
310-
311-
Args:
312-
input_batch: input batch
313-
scheduler_output: scheduler output.
314-
315-
Returns:
316-
True if the batch was modified, False otherwise.
317-
"""
318-
raise NotImplementedError
319-
320302
def build_for_cudagraph_capture(
321303
self, common_attn_metadata: CommonAttentionMetadata
322304
) -> M:
@@ -828,10 +810,6 @@ def reorder_batch_to_split_decodes_and_prefills(
828810

829811
for i, req_id in enumerate(input_batch.req_ids):
830812
num_tokens = scheduler_output.num_scheduled_tokens[req_id]
831-
# for now treat 1 scheduled token as "decode" even if it's not,
832-
# we should update this to something like < 8 in the future but
833-
# currently the TritonMLA._forward_decode only supports
834-
# num_tokens = 1
835813
if num_tokens <= decode_threshold:
836814
decodes.append(i)
837815
num_decode_tokens += num_tokens

vllm/v1/attention/backends/xformers.py

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
"""Attention layer with XFormersAttention."""
44

55
from dataclasses import dataclass
6-
from typing import TYPE_CHECKING, Optional
6+
from typing import Optional
77

88
import torch
99

@@ -19,7 +19,6 @@
1919
from vllm.v1.attention.backends.utils import (
2020
AttentionMetadataBuilder,
2121
CommonAttentionMetadata,
22-
reorder_batch_to_split_decodes_and_prefills,
2322
split_decodes_and_prefills,
2423
)
2524
from vllm.v1.kv_cache_interface import AttentionSpec
@@ -35,10 +34,6 @@
3534
except ImportError:
3635
XFORMERS_AVAILABLE = False
3736

38-
if TYPE_CHECKING:
39-
from vllm.v1.core.sched.output import SchedulerOutput
40-
from vllm.v1.worker.gpu_input_batch import InputBatch
41-
4237
from vllm import _custom_ops as ops
4338

4439
logger = init_logger(__name__)
@@ -223,13 +218,6 @@ def __init__(
223218
self._num_decodes = 0
224219
self._num_decode_tokens = 0
225220

226-
def reorder_batch(
227-
self, input_batch: "InputBatch", scheduler_output: "SchedulerOutput"
228-
) -> bool:
229-
return reorder_batch_to_split_decodes_and_prefills(
230-
input_batch, scheduler_output, decode_threshold=self.reorder_batch_threshold
231-
)
232-
233221
def build(
234222
self,
235223
common_prefix_len: int,

0 commit comments

Comments
 (0)