Skip to content

Commit 1bed891

Browse files
authored
[Chore] Fix pre-commit error after #25266 (#29190)
1 parent ceca060 commit 1bed891

File tree

5 files changed

+40
-24
lines changed

5 files changed

+40
-24
lines changed

vllm/v1/worker/gpu/async_utils.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
from vllm.v1.outputs import (
99
AsyncModelRunnerOutput,
10+
LogprobsTensors,
1011
ModelRunnerOutput,
1112
SamplerOutput,
1213
)
@@ -46,15 +47,18 @@ def __init__(
4647
"cpu", non_blocking=True
4748
)
4849
if sampler_output.logprobs_tensors is not None:
49-
self.logprobs_tensors = (
50+
self.logprobs_tensors: LogprobsTensors | None = (
5051
sampler_output.logprobs_tensors.to_cpu_nonblocking()
5152
)
5253
else:
5354
self.logprobs_tensors = None
54-
self.prompt_logprobs_dict = {}
55+
self.prompt_logprobs_dict: dict[str, LogprobsTensors | None] = {}
5556
if self.model_runner_output.prompt_logprobs_dict:
5657
for k, v in self.model_runner_output.prompt_logprobs_dict.items():
57-
self.prompt_logprobs_dict[k] = v.to_cpu_nonblocking()
58+
if v is not None:
59+
self.prompt_logprobs_dict[k] = v.to_cpu_nonblocking()
60+
else:
61+
self.prompt_logprobs_dict[k] = None
5862
self.copy_event.record(self.copy_stream)
5963

6064
def get_output(self) -> ModelRunnerOutput:
@@ -64,12 +68,10 @@ def get_output(self) -> ModelRunnerOutput:
6468
# the existing model runner.
6569
# Going forward, we should keep the data structures as NumPy arrays
6670
# rather than Python lists.
67-
sampled_token_ids_np = self.sampled_token_ids.numpy()
68-
num_reqs = sampled_token_ids_np.shape[0]
69-
sampled_token_ids: list[np.ndarray] = [
70-
sampled_token_ids_np[i, : self.num_sampled_tokens[i]]
71-
for i in range(num_reqs)
72-
]
71+
sampled_token_ids: list[list[int]] = self.sampled_token_ids.tolist()
72+
num_reqs = len(sampled_token_ids)
73+
for i in range(num_reqs):
74+
del sampled_token_ids[i][self.num_sampled_tokens[i] :]
7375
self.model_runner_output.sampled_token_ids = sampled_token_ids
7476

7577
if self.logprobs_tensors is not None:

vllm/v1/worker/gpu/attn_utils.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33
from collections.abc import Sequence
4-
from typing import Any
4+
from typing import Any, cast
55

66
import torch
77

@@ -13,6 +13,7 @@
1313
CommonAttentionMetadata,
1414
)
1515
from vllm.v1.kv_cache_interface import (
16+
AttentionSpec,
1617
KVCacheConfig,
1718
KVCacheSpec,
1819
)
@@ -22,7 +23,8 @@
2223

2324
def get_kv_cache_spec(vllm_config: VllmConfig) -> dict[str, KVCacheSpec]:
2425
kv_cache_spec: dict[str, KVCacheSpec] = {}
25-
attn_layers = get_layers_from_vllm_config(vllm_config, AttentionLayerBase)
26+
layer_type = cast(type[Any], AttentionLayerBase)
27+
attn_layers = get_layers_from_vllm_config(vllm_config, layer_type)
2628
for layer_name, attn_module in attn_layers.items():
2729
# Skip modules that don't need KV cache (eg encoder-only attention)
2830
if spec := attn_module.get_kv_cache_spec(vllm_config):
@@ -35,16 +37,15 @@ def init_attn_backend(
3537
vllm_config: VllmConfig,
3638
device: torch.device,
3739
):
38-
attn_backends: dict[str, AttentionBackend] = {}
40+
attn_backends: dict[str, type[AttentionBackend]] = {}
3941
attn_metadata_builders: list[AttentionMetadataBuilder] = []
4042
flashinfer_workspace: torch.Tensor | None = None
4143
for kv_cache_group_spec in kv_cache_config.kv_cache_groups:
4244
layer_names = kv_cache_group_spec.layer_names
4345
any_layer_name = next(iter(layer_names))
4446

45-
attn_layers = get_layers_from_vllm_config(
46-
vllm_config, AttentionLayerBase, layer_names
47-
)
47+
layer_type = cast(type[Any], AttentionLayerBase)
48+
attn_layers = get_layers_from_vllm_config(vllm_config, layer_type, layer_names)
4849
attn_backend = attn_layers[any_layer_name].get_attn_backend()
4950
for layer_name in layer_names:
5051
attn_backends[layer_name] = attn_backend
@@ -93,6 +94,7 @@ def _reshape_kv_cache(
9394
kv_caches: dict[str, torch.Tensor] = {}
9495
for kv_cache_group_spec in kv_cache_config.kv_cache_groups:
9596
kv_cache_spec = kv_cache_group_spec.kv_cache_spec
97+
assert isinstance(kv_cache_spec, AttentionSpec)
9698
for layer_name in kv_cache_group_spec.layer_names:
9799
raw_tensor = kv_cache_raw_tensors[layer_name]
98100
assert raw_tensor.numel() % kv_cache_spec.page_size_bytes == 0

vllm/v1/worker/gpu/cudagraph_utils.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,16 @@ def __init__(
3434
self.compilation_config = vllm_config.compilation_config
3535
assert self.compilation_config is not None
3636

37-
self.cudagraph_mode = self.compilation_config.cudagraph_mode
38-
self.cudagraph_sizes = sorted(self.compilation_config.cudagraph_capture_sizes)
37+
if self.compilation_config.cudagraph_mode is None:
38+
self.cudagraph_mode = CUDAGraphMode.NONE
39+
else:
40+
self.cudagraph_mode = self.compilation_config.cudagraph_mode
41+
if self.compilation_config.cudagraph_capture_sizes is not None:
42+
self.cudagraph_sizes = sorted(
43+
self.compilation_config.cudagraph_capture_sizes
44+
)
45+
else:
46+
self.cudagraph_sizes = []
3947
self.padded_sizes = self._init_padded_sizes()
4048

4149
self.graphs: dict[int, torch.cuda.CUDAGraph] = {}

vllm/v1/worker/gpu/model_runner.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -329,8 +329,9 @@ def warmup_for_prefill(self) -> None:
329329
torch.cuda.synchronize()
330330

331331
def update_states(self, scheduler_output: SchedulerOutput) -> None:
332-
for req_id in scheduler_output.preempted_req_ids:
333-
self.req_states.remove_request(req_id)
332+
if scheduler_output.preempted_req_ids is not None:
333+
for req_id in scheduler_output.preempted_req_ids:
334+
self.req_states.remove_request(req_id)
334335
for req_id in scheduler_output.finished_req_ids:
335336
self.req_states.remove_request(req_id)
336337

@@ -346,6 +347,9 @@ def update_states(self, scheduler_output: SchedulerOutput) -> None:
346347

347348
# Add new requests.
348349
for new_req_data in scheduler_output.scheduled_new_reqs:
350+
assert new_req_data.prompt_token_ids is not None
351+
assert new_req_data.prefill_token_ids is not None
352+
assert new_req_data.sampling_params is not None
349353
req_id = new_req_data.req_id
350354
self.req_states.add_request(
351355
req_id=req_id,
@@ -398,8 +402,8 @@ def prepare_inputs(
398402
# Decode first, then prefill.
399403
# batch_idx -> req_id
400404
req_ids = sorted(
401-
scheduler_output.num_scheduled_tokens,
402-
key=scheduler_output.num_scheduled_tokens.get,
405+
scheduler_output.num_scheduled_tokens.keys(),
406+
key=lambda k: scheduler_output.num_scheduled_tokens[k],
403407
)
404408
num_scheduled_tokens = np.array(
405409
[scheduler_output.num_scheduled_tokens[i] for i in req_ids], dtype=np.int32
@@ -637,9 +641,9 @@ def postprocess(
637641
model_runner_output = ModelRunnerOutput(
638642
req_ids=input_batch.req_ids,
639643
req_id_to_index={req_id: i for i, req_id in enumerate(input_batch.req_ids)},
640-
sampled_token_ids=None,
644+
sampled_token_ids=None, # type: ignore
641645
logprobs=None,
642-
prompt_logprobs_dict=prompt_logprobs_dict,
646+
prompt_logprobs_dict=prompt_logprobs_dict, # type: ignore
643647
pooler_output=[],
644648
kv_connector_output=None,
645649
num_nans_in_logits=None,

vllm/v1/worker/gpu/sampler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88

99
from vllm.config.model import LogprobsMode
1010
from vllm.v1.outputs import LogprobsTensors, SamplerOutput
11-
from vllm.v1.sample.metadata import SamplingMetadata
1211
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p
12+
from vllm.v1.worker.gpu.states import SamplingMetadata
1313

1414

1515
class Sampler:

0 commit comments

Comments
 (0)