Skip to content

Commit 3d2c56b

Browse files
Make mypy behave like a proper pre-commit hook (#25313)
Signed-off-by: Harry Mellor <[email protected]> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
1 parent 64c824c commit 3d2c56b

File tree

9 files changed

+166
-87
lines changed

9 files changed

+166
-87
lines changed

.github/CODEOWNERS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ mkdocs.yaml @hmellor
7272
# Linting
7373
.markdownlint.yaml @hmellor
7474
.pre-commit-config.yaml @hmellor
75+
/tools/pre_commit @hmellor
7576

7677
# CPU
7778
/vllm/v1/worker/cpu* @bigPYJ1151

.pre-commit-config.yaml

Lines changed: 14 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -60,38 +60,32 @@ repos:
6060
files: ^requirements/test\.(in|txt)$
6161
- id: mypy-local
6262
name: Run mypy for local Python installation
63-
entry: tools/mypy.sh 0 "local"
64-
language: python
65-
types: [python]
66-
additional_dependencies: &mypy_deps [mypy==1.11.1, types-cachetools, types-setuptools, types-PyYAML, types-requests, pydantic]
63+
entry: python tools/pre_commit/mypy.py 0 "local"
6764
stages: [pre-commit] # Don't run in CI
65+
<<: &mypy_common
66+
language: python
67+
types_or: [python, pyi]
68+
require_serial: true
69+
additional_dependencies: [mypy==1.11.1, regex, types-cachetools, types-setuptools, types-PyYAML, types-requests, types-torch, pydantic]
6870
- id: mypy-3.9 # TODO: Use https://github.com/pre-commit/mirrors-mypy when mypy setup is less awkward
6971
name: Run mypy for Python 3.9
70-
entry: tools/mypy.sh 1 "3.9"
71-
language: python
72-
types: [python]
73-
additional_dependencies: *mypy_deps
72+
entry: python tools/pre_commit/mypy.py 1 "3.9"
73+
<<: *mypy_common
7474
stages: [manual] # Only run in CI
7575
- id: mypy-3.10 # TODO: Use https://github.com/pre-commit/mirrors-mypy when mypy setup is less awkward
7676
name: Run mypy for Python 3.10
77-
entry: tools/mypy.sh 1 "3.10"
78-
language: python
79-
types: [python]
80-
additional_dependencies: *mypy_deps
77+
entry: python tools/pre_commit/mypy.py 1 "3.10"
78+
<<: *mypy_common
8179
stages: [manual] # Only run in CI
8280
- id: mypy-3.11 # TODO: Use https://github.com/pre-commit/mirrors-mypy when mypy setup is less awkward
8381
name: Run mypy for Python 3.11
84-
entry: tools/mypy.sh 1 "3.11"
85-
language: python
86-
types: [python]
87-
additional_dependencies: *mypy_deps
82+
entry: python tools/pre_commit/mypy.py 1 "3.11"
83+
<<: *mypy_common
8884
stages: [manual] # Only run in CI
8985
- id: mypy-3.12 # TODO: Use https://github.com/pre-commit/mirrors-mypy when mypy setup is less awkward
9086
name: Run mypy for Python 3.12
91-
entry: tools/mypy.sh 1 "3.12"
92-
language: python
93-
types: [python]
94-
additional_dependencies: *mypy_deps
87+
entry: python tools/pre_commit/mypy.py 1 "3.12"
88+
<<: *mypy_common
9589
stages: [manual] # Only run in CI
9690
- id: shellcheck
9791
name: Lint shell scripts

pyproject.toml

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -110,27 +110,6 @@ ignore_missing_imports = true
110110
check_untyped_defs = true
111111
follow_imports = "silent"
112112

113-
# After fixing type errors resulting from follow_imports: "skip" -> "silent",
114-
# move the directory here and remove it from tools/mypy.sh
115-
files = [
116-
"vllm/*.py",
117-
"vllm/assets",
118-
"vllm/entrypoints",
119-
"vllm/inputs",
120-
"vllm/logging_utils",
121-
"vllm/multimodal",
122-
"vllm/platforms",
123-
"vllm/transformers_utils",
124-
"vllm/triton_utils",
125-
"vllm/usage",
126-
]
127-
# TODO(woosuk): Include the code from Megatron and HuggingFace.
128-
exclude = [
129-
"vllm/model_executor/parallel_utils/|vllm/model_executor/models/",
130-
# Ignore triton kernels in ops.
131-
'vllm/attention/ops/.*\.py$'
132-
]
133-
134113
[tool.isort]
135114
skip_glob = [
136115
".buildkite/*",

tools/mypy.sh

Lines changed: 0 additions & 35 deletions
This file was deleted.

tools/pre_commit/mypy.py

Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
"""
4+
Run mypy on changed files.
5+
6+
This script is designed to be used as a pre-commit hook. It runs mypy
7+
on files that have been changed. It groups files into different mypy calls
8+
based on their directory to avoid import following issues.
9+
10+
Usage:
11+
python tools/pre_commit/mypy.py <ci> <python_version> <changed_files...>
12+
13+
Args:
14+
ci: "1" if running in CI, "0" otherwise. In CI, follow_imports is set to
15+
"silent" for the main group of files.
16+
python_version: Python version to use (e.g., "3.10") or "local" to use
17+
the local Python version.
18+
changed_files: List of changed files to check.
19+
"""
20+
21+
import subprocess
22+
import sys
23+
from typing import Optional
24+
25+
import regex as re
26+
27+
FILES = [
28+
"vllm/*.py",
29+
"vllm/assets",
30+
"vllm/entrypoints",
31+
"vllm/inputs",
32+
"vllm/logging_utils",
33+
"vllm/multimodal",
34+
"vllm/platforms",
35+
"vllm/transformers_utils",
36+
"vllm/triton_utils",
37+
"vllm/usage",
38+
]
39+
40+
# After fixing errors resulting from changing follow_imports
41+
# from "skip" to "silent", move the following directories to FILES
42+
SEPARATE_GROUPS = [
43+
"tests",
44+
"vllm/attention",
45+
"vllm/compilation",
46+
"vllm/distributed",
47+
"vllm/engine",
48+
"vllm/executor",
49+
"vllm/inputs",
50+
"vllm/lora",
51+
"vllm/model_executor",
52+
"vllm/plugins",
53+
"vllm/worker",
54+
"vllm/v1",
55+
]
56+
57+
# TODO(woosuk): Include the code from Megatron and HuggingFace.
58+
EXCLUDE = [
59+
"vllm/model_executor/parallel_utils",
60+
"vllm/model_executor/models",
61+
"vllm/model_executor/layers/fla/ops",
62+
# Ignore triton kernels in ops.
63+
"vllm/attention/ops",
64+
]
65+
66+
67+
def group_files(changed_files: list[str]) -> dict[str, list[str]]:
68+
"""
69+
Group changed files into different mypy calls.
70+
71+
Args:
72+
changed_files: List of changed files.
73+
74+
Returns:
75+
A dictionary mapping file group names to lists of changed files.
76+
"""
77+
exclude_pattern = re.compile(f"^{'|'.join(EXCLUDE)}.*")
78+
files_pattern = re.compile(f"^({'|'.join(FILES)}).*")
79+
file_groups = {"": []}
80+
file_groups.update({k: [] for k in SEPARATE_GROUPS})
81+
for changed_file in changed_files:
82+
# Skip files which should be ignored completely
83+
if exclude_pattern.match(changed_file):
84+
continue
85+
# Group files by mypy call
86+
if files_pattern.match(changed_file):
87+
file_groups[""].append(changed_file)
88+
continue
89+
else:
90+
for directory in SEPARATE_GROUPS:
91+
if re.match(f"^{directory}.*", changed_file):
92+
file_groups[directory].append(changed_file)
93+
break
94+
return file_groups
95+
96+
97+
def mypy(targets: list[str], python_version: Optional[str],
98+
follow_imports: Optional[str], file_group: str) -> int:
99+
"""
100+
Run mypy on the given targets.
101+
102+
Args:
103+
targets: List of files or directories to check.
104+
python_version: Python version to use (e.g., "3.10") or None to use
105+
the default mypy version.
106+
follow_imports: Value for the --follow-imports option or None to use
107+
the default mypy behavior.
108+
file_group: The file group name for logging purposes.
109+
110+
Returns:
111+
The return code from mypy.
112+
"""
113+
args = ["mypy"]
114+
if python_version is not None:
115+
args += ["--python-version", python_version]
116+
if follow_imports is not None:
117+
args += ["--follow-imports", follow_imports]
118+
print(f"$ {' '.join(args)} {file_group}")
119+
return subprocess.run(args + targets, check=False).returncode
120+
121+
122+
def main():
123+
ci = sys.argv[1] == "1"
124+
python_version = sys.argv[2]
125+
file_groups = group_files(sys.argv[3:])
126+
127+
if python_version == "local":
128+
python_version = f"{sys.version_info.major}.{sys.version_info.minor}"
129+
130+
returncode = 0
131+
for file_group, changed_files in file_groups.items():
132+
follow_imports = None if ci and file_group == "" else "skip"
133+
if changed_files:
134+
returncode |= mypy(changed_files, python_version, follow_imports,
135+
file_group)
136+
return returncode
137+
138+
139+
if __name__ == "__main__":
140+
sys.exit(main())

vllm/entrypoints/llm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1468,7 +1468,7 @@ def get_metrics(self) -> list["Metric"]:
14681468

14691469
def _validate_and_add_requests(
14701470
self,
1471-
prompts: Union[PromptType, Sequence[PromptType]],
1471+
prompts: Union[PromptType, Sequence[PromptType], DataPrompt],
14721472
params: Union[SamplingParams, Sequence[SamplingParams], PoolingParams,
14731473
Sequence[PoolingParams]],
14741474
*,
@@ -1478,7 +1478,7 @@ def _validate_and_add_requests(
14781478
) -> None:
14791479
if isinstance(prompts, (str, dict)):
14801480
# Convert a single prompt to a list.
1481-
prompts = [prompts]
1481+
prompts = [prompts] # type: ignore[list-item]
14821482

14831483
num_requests = len(prompts)
14841484
if isinstance(params, Sequence) and len(params) != num_requests:

vllm/entrypoints/renderer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -280,7 +280,7 @@ def _validate_and_normalize_truncate_tokens(
280280
if truncate_prompt_tokens < 0:
281281
truncate_prompt_tokens = self.model_config.max_model_len
282282

283-
if max_length is not None and truncate_prompt_tokens > max_length:
283+
if max_length is not None and truncate_prompt_tokens > max_length: # type: ignore[operator]
284284
raise ValueError(
285285
f"truncate_prompt_tokens ({truncate_prompt_tokens}) "
286286
f"cannot be greater than max_length ({max_length}). "

vllm/utils/__init__.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -551,9 +551,10 @@ async def _batch_encode_loop(self, queue: asyncio.Queue, can_batch: bool):
551551
# If every request uses identical kwargs we can run a single
552552
# batched tokenizer call for a big speed-up.
553553
if can_batch and len(prompts) > 1:
554-
encode_fn = partial(self.tokenizer, prompts, **kwargs)
554+
batch_encode_fn = partial(self.tokenizer, prompts,
555+
**kwargs)
555556
results = await self._loop.run_in_executor(
556-
self._executor, encode_fn)
557+
self._executor, batch_encode_fn)
557558

558559
for i, fut in enumerate(result_futures):
559560
if not fut.done():
@@ -889,7 +890,7 @@ def get_open_port() -> int:
889890

890891
def get_open_ports_list(count: int = 5) -> list[int]:
891892
"""Get a list of open ports."""
892-
ports = set()
893+
ports = set[int]()
893894
while len(ports) < count:
894895
ports.add(get_open_port())
895896
return list(ports)
@@ -1279,7 +1280,7 @@ def as_list(maybe_list: Iterable[T]) -> list[T]:
12791280

12801281
def as_iter(obj: Union[T, Iterable[T]]) -> Iterable[T]:
12811282
if isinstance(obj, str) or not isinstance(obj, Iterable):
1282-
obj = [obj]
1283+
return [obj] # type: ignore[list-item]
12831284
return obj
12841285

12851286

vllm/utils/tensor_schema.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,8 @@ def __init__(
2222
self.dims = dims
2323
self.dynamic_dims = dynamic_dims if dynamic_dims else set()
2424

25-
def resolve(self, **bindings: dict[str,
26-
int]) -> tuple[Union[int, str], ...]:
27-
resolved = []
25+
def resolve(self, **bindings: int) -> tuple[Union[int, str], ...]:
26+
resolved = list[Union[int, str]]()
2827
for dim in self.dims:
2928
if isinstance(dim, str) and dim in bindings:
3029
resolved.append(bindings[dim])
@@ -159,7 +158,7 @@ def _validate_tensor_shape_expected(
159158

160159
def validate(self) -> None:
161160
type_hints = get_type_hints(self.__class__, include_extras=True)
162-
shape_env = {}
161+
shape_env = dict[str, int]()
163162

164163
for field_name, field_type in type_hints.items():
165164
# Check if field is missing

0 commit comments

Comments
 (0)