Skip to content

Commit e51719a

Browse files
lucas-tuckerlucast2021
andauthored
mypy type checking for vllm/worker (#11418)
Signed-off-by: lucast2021 <[email protected]> Co-authored-by: lucast2021 <[email protected]>
1 parent f30581c commit e51719a

File tree

3 files changed

+9
-9
lines changed

3 files changed

+9
-9
lines changed

vllm/worker/cpu_worker.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -333,9 +333,8 @@ def execute_worker(
333333
def prepare_worker_input(
334334
self, execute_model_req: ExecuteModelRequest) -> WorkerInput:
335335
assert execute_model_req is not None
336-
virtual_engine = execute_model_req.virtual_engine
336+
virtual_engine: int = execute_model_req.virtual_engine
337337
num_seq_groups: int = len(execute_model_req.seq_group_metadata_list)
338-
blocks_to_copy = execute_model_req.blocks_to_copy
339338
blocks_to_copy = torch.tensor(execute_model_req.blocks_to_copy,
340339
device="cpu",
341340
dtype=torch.int64).view(-1, 2)

vllm/worker/multi_step_model_runner.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -406,8 +406,9 @@ def _async_process_outputs(self, model_input: StatefulModelInput,
406406
if not cont:
407407
break
408408

409-
def _final_process_outputs(self, model_input: StatefulModelInput,
410-
output_proc_callback: Optional[Callable]):
409+
def _final_process_outputs(
410+
self, model_input: StatefulModelInput,
411+
output_proc_callback: Optional[Callable]) -> List[SamplerOutput]:
411412
assert model_input.frozen_model_input is not None
412413

413414
has_async_callback = output_proc_callback is not None
@@ -594,8 +595,8 @@ def execute_model(
594595
# should be [SamplerOutput]
595596
return output
596597

597-
def _update_sampling_metadata(self, sampling_metadata, num_seqs,
598-
num_queries):
598+
def _update_sampling_metadata(self, sampling_metadata: SamplingMetadata,
599+
num_seqs: Optional[int], num_queries: int):
599600

600601
assert sampling_metadata.num_prompts == 0
601602
assert len(sampling_metadata.seq_groups) == num_queries
@@ -850,13 +851,13 @@ def _pythonize_sampler_output(
850851
seq_ids = seq_group.seq_ids
851852
next_token_ids = sample_result
852853
parent_ids = [0]
854+
seq_outputs: List[SequenceOutput]
853855

854856
if cache is not None:
855857
completion_seq_group_output: CompletionSequenceGroupOutput = \
856858
cache.cached_completion_seq_group_output.get_object()
857859
completion_seq_group_output.samples.clear()
858-
seq_outputs: List[
859-
SequenceOutput] = completion_seq_group_output.samples
860+
seq_outputs = completion_seq_group_output.samples
860861
else:
861862
seq_outputs = []
862863

vllm/worker/worker_base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -452,7 +452,7 @@ def init_worker(self, *args, **kwargs):
452452
self.worker = worker_class(*args, **kwargs)
453453
assert self.worker is not None
454454

455-
def execute_method(self, method, *args, **kwargs):
455+
def execute_method(self, method: str, *args, **kwargs):
456456
try:
457457
target = self if self.worker is None else self.worker
458458
executor = getattr(target, method)

0 commit comments

Comments
 (0)