Skip to content

Commit 30b44a1

Browse files
authored
GPU Model Runner V2 (#25266)
Signed-off-by: Woosuk Kwon <[email protected]>
1 parent 1f400c5 commit 30b44a1

File tree

18 files changed

+2639
-12
lines changed

18 files changed

+2639
-12
lines changed

.github/CODEOWNERS

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,9 @@ CMakeLists.txt @tlrmchlsmth @LucasWilkinson
3535
/vllm/v1/kv_cache_interface.py @heheda12345
3636
/vllm/v1/offloading @ApostaC
3737

38+
# Model runner V2
39+
/vllm/v1/worker/gpu @WoosukKwon
40+
3841
# Test ownership
3942
/.buildkite/lm-eval-harness @mgoin
4043
/tests/distributed/test_multi_node_assignment.py @youkaichao

vllm/envs.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,7 @@
231231
VLLM_DISABLE_SHARED_EXPERTS_STREAM: bool = False
232232
VLLM_SHARED_EXPERTS_STREAM_TOKEN_THRESHOLD: int = 256
233233
VLLM_COMPILE_CACHE_SAVE_FORMAT: Literal["binary", "unpacked"] = "binary"
234+
VLLM_USE_V2_MODEL_RUNNER: bool = False
234235

235236

236237
def get_default_cache_root():
@@ -1522,6 +1523,10 @@ def get_vllm_port() -> int | None:
15221523
"VLLM_COMPILE_CACHE_SAVE_FORMAT": env_with_choices(
15231524
"VLLM_COMPILE_CACHE_SAVE_FORMAT", "binary", ["binary", "unpacked"]
15241525
),
1526+
# Flag to enable v2 model runner.
1527+
"VLLM_USE_V2_MODEL_RUNNER": lambda: bool(
1528+
int(os.getenv("VLLM_USE_V2_MODEL_RUNNER", "0"))
1529+
),
15251530
}
15261531

15271532
# --8<-- [end:env-vars-definition]

vllm/v1/attention/backends/flashinfer.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -593,6 +593,9 @@ def _get_workspace_buffer(self):
593593
)
594594
return self._workspace_buffer
595595

596+
def set_workspace_buffer(self, workspace_buffer: torch.Tensor):
597+
self._workspace_buffer = workspace_buffer
598+
596599
def _get_prefill_wrapper(
597600
self,
598601
) -> BatchPrefillWithPagedKVCacheWrapper | BatchDCPPrefillWrapper:

vllm/v1/core/sched/output.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,11 +44,15 @@ class NewRequestData:
4444
lora_request: LoRARequest | None
4545
prompt_embeds: "torch.Tensor | None" = None
4646

47+
# Only used for v2 model runner.
48+
prefill_token_ids: list[int] | None = None
49+
4750
@classmethod
4851
def from_request(
4952
cls,
5053
request: Request,
5154
block_ids: tuple[list[int], ...],
55+
prefill_token_ids: list[int] | None = None,
5256
) -> "NewRequestData":
5357
return cls(
5458
req_id=request.request_id,
@@ -60,6 +64,7 @@ def from_request(
6064
num_computed_tokens=request.num_computed_tokens,
6165
lora_request=request.lora_request,
6266
prompt_embeds=request.prompt_embeds,
67+
prefill_token_ids=prefill_token_ids,
6368
)
6469

6570
def __repr__(self) -> str:
@@ -68,6 +73,7 @@ def __repr__(self) -> str:
6873
f"NewRequestData("
6974
f"req_id={self.req_id},"
7075
f"prompt_token_ids={self.prompt_token_ids},"
76+
f"prefill_token_ids={self.prefill_token_ids},"
7177
f"mm_features={self.mm_features},"
7278
f"sampling_params={self.sampling_params},"
7379
f"block_ids={self.block_ids},"
@@ -183,6 +189,10 @@ class SchedulerOutput:
183189
# freed from the encoder cache.
184190
free_encoder_mm_hashes: list[str]
185191

192+
# Request IDs that are preempted in this step.
193+
# Only used for v2 model runner.
194+
preempted_req_ids: set[str] | None = None
195+
186196
# Whether the scheduled requests have all the output tokens they
187197
# need to perform grammar bitmask computation.
188198
pending_structured_output_tokens: bool = False
@@ -193,6 +203,20 @@ class SchedulerOutput:
193203
# EC Cache Connector metadata
194204
ec_connector_metadata: ECConnectorMetadata | None = None
195205

206+
@classmethod
207+
def make_empty(cls) -> "SchedulerOutput":
208+
return cls(
209+
scheduled_new_reqs=[],
210+
scheduled_cached_reqs=CachedRequestData.make_empty(),
211+
num_scheduled_tokens={},
212+
total_num_scheduled_tokens=0,
213+
scheduled_spec_decode_tokens={},
214+
scheduled_encoder_inputs={},
215+
num_common_prefix_blocks=[],
216+
finished_req_ids=set(),
217+
free_encoder_mm_hashes=[],
218+
)
219+
196220

197221
@dataclass
198222
class GrammarOutput:

vllm/v1/core/sched/scheduler.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from collections.abc import Iterable
77
from typing import Any
88

9+
from vllm import envs
910
from vllm.config import VllmConfig
1011
from vllm.distributed.ec_transfer.ec_connector.base import (
1112
ECConnectorMetadata,
@@ -187,6 +188,7 @@ def __init__(
187188
pcp_world_size=self.pcp_world_size,
188189
)
189190
self.use_pp = self.parallel_config.pipeline_parallel_size > 1
191+
self.use_v2_model_runner = envs.VLLM_USE_V2_MODEL_RUNNER
190192

191193
def schedule(self) -> SchedulerOutput:
192194
# NOTE(woosuk) on the scheduling algorithm:
@@ -658,12 +660,25 @@ def schedule(self) -> SchedulerOutput:
658660
)
659661

660662
# Construct the scheduler output.
661-
new_reqs_data = [
662-
NewRequestData.from_request(
663-
req, req_to_new_blocks[req.request_id].get_block_ids()
664-
)
665-
for req in scheduled_new_reqs
666-
]
663+
if self.use_v2_model_runner:
664+
scheduled_new_reqs = scheduled_new_reqs + scheduled_resumed_reqs
665+
scheduled_resumed_reqs = []
666+
new_reqs_data = [
667+
NewRequestData.from_request(
668+
req,
669+
req_to_new_blocks[req.request_id].get_block_ids(),
670+
req._all_token_ids,
671+
)
672+
for req in scheduled_new_reqs
673+
]
674+
else:
675+
new_reqs_data = [
676+
NewRequestData.from_request(
677+
req, req_to_new_blocks[req.request_id].get_block_ids()
678+
)
679+
for req in scheduled_new_reqs
680+
]
681+
667682
with record_function_or_nullcontext("schedule: make_cached_request_data"):
668683
cached_reqs_data = self._make_cached_request_data(
669684
scheduled_running_reqs,
@@ -685,6 +700,7 @@ def schedule(self) -> SchedulerOutput:
685700
scheduled_spec_decode_tokens=scheduled_spec_decode_tokens,
686701
scheduled_encoder_inputs=scheduled_encoder_inputs,
687702
num_common_prefix_blocks=num_common_prefix_blocks,
703+
preempted_req_ids={req.request_id for req in preempted_reqs},
688704
# finished_req_ids is an existing state in the scheduler,
689705
# instead of being newly scheduled in this step.
690706
# It contains the request IDs that are finished in between

vllm/v1/worker/gpu/README.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
# [Experimental] Model Runner V2
2+
3+
This directory contains the new model runner which is under active development.
4+
Ping [Woosuk Kwon](https://github.com/WoosukKwon) for any changes.

vllm/v1/worker/gpu/__init__.py

Whitespace-only changes.

vllm/v1/worker/gpu/async_utils.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
from contextlib import contextmanager
4+
5+
import numpy as np
6+
import torch
7+
8+
from vllm.v1.outputs import (
9+
AsyncModelRunnerOutput,
10+
ModelRunnerOutput,
11+
SamplerOutput,
12+
)
13+
14+
15+
class AsyncOutput(AsyncModelRunnerOutput):
16+
def __init__(
17+
self,
18+
model_runner_output: ModelRunnerOutput,
19+
sampler_output: SamplerOutput,
20+
num_sampled_tokens: np.ndarray,
21+
copy_stream: torch.cuda.Stream,
22+
copy_event: torch.cuda.Event,
23+
):
24+
self.model_runner_output = model_runner_output
25+
self.sampler_output = sampler_output
26+
self.num_sampled_tokens = num_sampled_tokens
27+
self.copy_stream = copy_stream
28+
self.copy_event = copy_event
29+
30+
default_stream = torch.cuda.current_stream()
31+
with torch.cuda.stream(self.copy_stream):
32+
self.copy_stream.wait_stream(default_stream)
33+
34+
# NOTE(woosuk): We must ensure that CPU tensors are not freed
35+
# before the device-to-host copy is fully completed. For instance,
36+
# operations like
37+
# self.sampled_token_np = ...to("cpu", non_blocking=True).numpy()
38+
# are unsafe because the underlying CPU tensor can be prematurely freed and
39+
# reused by other tensors before the asynchronous copy finishes, potentially
40+
# causing race conditions. To prevent this, we delay freeing by holding
41+
# references until the copy event signals completion.
42+
# Likewise, we also need to keep the reference to the GPU tensors.
43+
# This is done by keeping the reference to sampler_output and
44+
# model_runner_output.
45+
self.sampled_token_ids = sampler_output.sampled_token_ids.to(
46+
"cpu", non_blocking=True
47+
)
48+
if sampler_output.logprobs_tensors is not None:
49+
self.logprobs_tensors = (
50+
sampler_output.logprobs_tensors.to_cpu_nonblocking()
51+
)
52+
else:
53+
self.logprobs_tensors = None
54+
self.prompt_logprobs_dict = {}
55+
if self.model_runner_output.prompt_logprobs_dict:
56+
for k, v in self.model_runner_output.prompt_logprobs_dict.items():
57+
self.prompt_logprobs_dict[k] = v.to_cpu_nonblocking()
58+
self.copy_event.record(self.copy_stream)
59+
60+
def get_output(self) -> ModelRunnerOutput:
61+
self.copy_event.synchronize()
62+
63+
# NOTE(woosuk): The following code is to ensure compatibility with
64+
# the existing model runner.
65+
# Going forward, we should keep the data structures as NumPy arrays
66+
# rather than Python lists.
67+
sampled_token_ids_np = self.sampled_token_ids.numpy()
68+
num_reqs = sampled_token_ids_np.shape[0]
69+
sampled_token_ids: list[np.ndarray] = [
70+
sampled_token_ids_np[i, : self.num_sampled_tokens[i]]
71+
for i in range(num_reqs)
72+
]
73+
self.model_runner_output.sampled_token_ids = sampled_token_ids
74+
75+
if self.logprobs_tensors is not None:
76+
self.model_runner_output.logprobs = self.logprobs_tensors.tolists()
77+
self.model_runner_output.prompt_logprobs_dict = self.prompt_logprobs_dict
78+
return self.model_runner_output
79+
80+
81+
@contextmanager
82+
def async_barrier(event: torch.cuda.Event | None):
83+
if event is not None:
84+
event.synchronize()
85+
try:
86+
yield
87+
finally:
88+
if event is not None:
89+
event.record()

0 commit comments

Comments
 (0)