Skip to content

Commit b029de9

Browse files
authored
[Optimization] Make new_block_ids None if empty (#23262)
Signed-off-by: Woosuk Kwon <[email protected]>
1 parent bbea1ce commit b029de9

File tree

5 files changed

+57
-27
lines changed

5 files changed

+57
-27
lines changed

vllm/v1/core/kv_cache_manager.py

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

44
from dataclasses import dataclass
5-
from typing import Optional
5+
from typing import Literal, Optional, overload
66

77
from vllm.distributed.kv_events import KVCacheEvent
88
from vllm.logger import init_logger
@@ -37,7 +37,24 @@ def __add__(self, other: "KVCacheBlocks") -> "KVCacheBlocks":
3737
tuple(blk1 + blk2
3838
for blk1, blk2 in zip(self.blocks, other.blocks)))
3939

40-
def get_block_ids(self) -> tuple[list[int], ...]:
40+
@overload
41+
def get_block_ids(
42+
self,
43+
allow_none: Literal[False] = False,
44+
) -> tuple[list[int], ...]:
45+
...
46+
47+
@overload
48+
def get_block_ids(
49+
self,
50+
allow_none: Literal[True] = True,
51+
) -> Optional[tuple[list[int], ...]]:
52+
...
53+
54+
def get_block_ids(
55+
self,
56+
allow_none: bool = False,
57+
):
4158
"""
4259
Converts the KVCacheBlocks instance to block_ids.
4360
@@ -46,6 +63,8 @@ def get_block_ids(self) -> tuple[list[int], ...]:
4663
* the outer tuple corresponds to KV cache groups
4764
* each inner list contains the block_ids of the blocks in that group
4865
"""
66+
if allow_none and all(len(group) == 0 for group in self.blocks):
67+
return None
4968
return tuple([blk.block_id for blk in group] for group in self.blocks)
5069

5170
def get_unhashed_block_ids(self) -> list[int]:
@@ -348,10 +367,13 @@ def take_events(self) -> list[KVCacheEvent]:
348367
"""
349368
return self.block_pool.take_events()
350369

370+
def get_blocks(self, request_id: str) -> KVCacheBlocks:
371+
"""Get the blocks of a request."""
372+
return KVCacheBlocks(self.coordinator.get_blocks(request_id))
373+
351374
def get_block_ids(self, request_id: str) -> tuple[list[int], ...]:
352375
"""Get the block ids of a request."""
353-
return KVCacheBlocks(
354-
self.coordinator.get_blocks(request_id)).get_block_ids()
376+
return self.get_blocks(request_id).get_block_ids()
355377

356378
def cache_blocks(self, request: Request, num_computed_tokens: int) -> None:
357379
"""Cache the blocks for the request, if enabled."""

vllm/v1/core/sched/output.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ class CachedRequestData:
9191
# NOTE(woosuk): new_token_ids is only used for pipeline parallelism.
9292
# When PP is not used, new_token_ids will be empty.
9393
new_token_ids: list[list[int]]
94-
new_block_ids: list[tuple[list[int], ...]]
94+
new_block_ids: list[Optional[tuple[list[int], ...]]]
9595
num_computed_tokens: list[int]
9696

9797
@property

vllm/v1/core/sched/scheduler.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
2020
from vllm.v1.core.encoder_cache_manager import (EncoderCacheManager,
2121
compute_encoder_budget)
22-
from vllm.v1.core.kv_cache_manager import KVCacheManager
22+
from vllm.v1.core.kv_cache_manager import KVCacheBlocks, KVCacheManager
2323
from vllm.v1.core.sched.interface import SchedulerInterface
2424
from vllm.v1.core.sched.output import (CachedRequestData, NewRequestData,
2525
SchedulerOutput)
@@ -185,7 +185,7 @@ def schedule(self) -> SchedulerOutput:
185185
# uses structured decoding.
186186
structured_output_request_ids: dict[str, int] = {}
187187

188-
req_to_new_block_ids: dict[str, tuple[list[int], ...]] = {}
188+
req_to_new_blocks: dict[str, KVCacheBlocks] = {}
189189
num_scheduled_tokens: dict[str, int] = {}
190190
token_budget = self.max_num_scheduled_tokens
191191
# Encoder-related.
@@ -288,8 +288,7 @@ def schedule(self) -> SchedulerOutput:
288288
# Therefore, we might introduce some additional
289289
# cycle to fill in the bitmask, which could be a big no-op.
290290
structured_output_request_ids[request.request_id] = req_index
291-
req_to_new_block_ids[request.request_id] = (
292-
new_blocks.get_block_ids())
291+
req_to_new_blocks[request.request_id] = new_blocks
293292
num_scheduled_tokens[request.request_id] = num_new_tokens
294293
token_budget -= num_new_tokens
295294
req_index += 1
@@ -496,8 +495,8 @@ def schedule(self) -> SchedulerOutput:
496495

497496
if self.lora_config and request.lora_request:
498497
scheduled_loras.add(request.lora_request.lora_int_id)
499-
req_to_new_block_ids[request.request_id] = (
500-
self.kv_cache_manager.get_block_ids(request.request_id))
498+
req_to_new_blocks[request.request_id] = (
499+
self.kv_cache_manager.get_blocks(request.request_id))
501500
num_scheduled_tokens[request.request_id] = num_new_tokens
502501
token_budget -= num_new_tokens
503502
request.status = RequestStatus.RUNNING
@@ -546,16 +545,16 @@ def schedule(self) -> SchedulerOutput:
546545
)
547546
# Construct the scheduler output.
548547
new_reqs_data = [
549-
NewRequestData.from_request(req,
550-
req_to_new_block_ids[req.request_id])
548+
NewRequestData.from_request(
549+
req, req_to_new_blocks[req.request_id].get_block_ids())
551550
for req in scheduled_new_reqs
552551
]
553552
cached_reqs_data = self._make_cached_request_data(
554553
scheduled_running_reqs,
555554
scheduled_resumed_reqs,
556555
num_scheduled_tokens,
557556
scheduled_spec_decode_tokens,
558-
req_to_new_block_ids,
557+
req_to_new_blocks,
559558
)
560559
scheduler_output = SchedulerOutput(
561560
scheduled_new_reqs=new_reqs_data,
@@ -628,11 +627,11 @@ def _make_cached_request_data(
628627
resumed_reqs: list[Request],
629628
num_scheduled_tokens: dict[str, int],
630629
spec_decode_tokens: dict[str, list[int]],
631-
req_to_new_block_ids: dict[str, tuple[list[int], ...]],
630+
req_to_new_blocks: dict[str, KVCacheBlocks],
632631
) -> CachedRequestData:
633632
req_ids: list[str] = []
634633
new_token_ids: list[list[int]] = []
635-
new_block_ids: list[tuple[list[int], ...]] = []
634+
new_block_ids: list[Optional[tuple[list[int], ...]]] = []
636635
num_computed_tokens: list[int] = []
637636

638637
use_connector = self.connector is not None
@@ -655,7 +654,8 @@ def _make_cached_request_data(
655654
# out of bounds errors. TODO: Remove this once the KVConnector
656655
# is updated to handle token IDs properly.
657656
new_token_ids.append([])
658-
new_block_ids.append(req_to_new_block_ids[req_id])
657+
new_block_ids.append(
658+
req_to_new_blocks[req_id].get_block_ids(allow_none=True))
659659
num_computed_tokens.append(req.num_computed_tokens)
660660
# Because resumed_reqs is usually empty, it is more efficient to do
661661
# in-place appending so that we don't need to allocate a new list.

vllm/v1/worker/gpu_model_runner.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -574,11 +574,13 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
574574

575575
# Update the block IDs.
576576
if not resumed_from_preemption:
577-
# Append the new blocks to the existing block IDs.
578-
for block_ids, new_ids in zip(req_state.block_ids,
579-
new_block_ids):
580-
block_ids.extend(new_ids)
577+
if new_block_ids is not None:
578+
# Append the new blocks to the existing block IDs.
579+
for block_ids, new_ids in zip(req_state.block_ids,
580+
new_block_ids):
581+
block_ids.extend(new_ids)
581582
else:
583+
assert new_block_ids is not None
582584
# The request is resumed from preemption.
583585
# Replace the existing block IDs with the new ones.
584586
req_state.block_ids = new_block_ids
@@ -594,7 +596,9 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
594596
# Update the persistent batch.
595597
self.input_batch.num_computed_tokens_cpu[req_index] = (
596598
num_computed_tokens)
597-
self.input_batch.block_table.append_row(new_block_ids, req_index)
599+
if new_block_ids is not None:
600+
self.input_batch.block_table.append_row(
601+
new_block_ids, req_index)
598602

599603
# For the last rank, we don't need to update the token_ids_cpu
600604
# because the sampled tokens are already cached.

vllm/v1/worker/tpu_model_runner.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -418,11 +418,13 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> bool:
418418
# Update the cached states.
419419
req_state.num_computed_tokens = num_computed_tokens
420420
if not resumed_from_preemption:
421-
# Append the new blocks to the existing block IDs.
422-
for block_ids, new_ids in zip(req_state.block_ids,
423-
new_block_ids):
424-
block_ids.extend(new_ids)
421+
if new_block_ids is not None:
422+
# Append the new blocks to the existing block IDs.
423+
for block_ids, new_ids in zip(req_state.block_ids,
424+
new_block_ids):
425+
block_ids.extend(new_ids)
425426
else:
427+
assert new_block_ids is not None
426428
# The request is resumed from preemption.
427429
# Replace the existing block IDs with the new ones.
428430
req_state.block_ids = new_block_ids
@@ -438,7 +440,9 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> bool:
438440
# Update the persistent batch.
439441
self.input_batch.num_computed_tokens_cpu[req_index] = (
440442
num_computed_tokens)
441-
self.input_batch.block_table.append_row(new_block_ids, req_index)
443+
if new_block_ids is not None:
444+
self.input_batch.block_table.append_row(
445+
new_block_ids, req_index)
442446

443447
# Add the new or resumed requests to the persistent batch.
444448
# The smaller empty indices are filled first.

0 commit comments

Comments
 (0)