Skip to content

Commit 8279201

Browse files
authored
[Build] Cython compilation support fix (#14296)
Signed-off-by: Gregory Shtrasberg <[email protected]>
1 parent 23fdab0 commit 8279201

File tree

6 files changed

+46
-6
lines changed

6 files changed

+46
-6
lines changed

Dockerfile.rocm

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ ARG USE_CYTHON
4040
RUN cd vllm \
4141
&& python3 -m pip install -r requirements/rocm.txt \
4242
&& python3 setup.py clean --all \
43-
&& if [ ${USE_CYTHON} -eq "1" ]; then python3 setup_cython.py build_ext --inplace; fi \
43+
&& if [ ${USE_CYTHON} -eq "1" ]; then python3 tests/build_cython.py build_ext --inplace; fi \
4444
&& python3 setup.py bdist_wheel --dist-dir=dist
4545
FROM scratch AS export_vllm
4646
ARG COMMON_WORKDIR

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ exclude = [
8686
"vllm/triton_utils/**/*.py" = ["UP006", "UP035"]
8787
"vllm/vllm_flash_attn/**/*.py" = ["UP006", "UP035"]
8888
"vllm/worker/**/*.py" = ["UP006", "UP035"]
89+
"vllm/utils.py" = ["UP006", "UP035"]
8990

9091
[tool.ruff.lint]
9192
select = [

tests/build_cython.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
import Cython.Compiler.Options
3+
from Cython.Build import cythonize
4+
from setuptools import setup
5+
6+
Cython.Compiler.Options.annotate = True
7+
8+
infiles = []
9+
10+
infiles += [
11+
"vllm/engine/llm_engine.py",
12+
"vllm/transformers_utils/detokenizer.py",
13+
"vllm/engine/output_processor/single_step.py",
14+
"vllm/outputs.py",
15+
"vllm/engine/output_processor/stop_checker.py",
16+
]
17+
18+
infiles += [
19+
"vllm/core/scheduler.py",
20+
"vllm/sequence.py",
21+
"vllm/core/block_manager.py",
22+
]
23+
24+
infiles += [
25+
"vllm/model_executor/layers/sampler.py",
26+
"vllm/sampling_params.py",
27+
"vllm/utils.py",
28+
]
29+
30+
setup(ext_modules=cythonize(infiles,
31+
annotate=False,
32+
force=True,
33+
compiler_directives={
34+
'language_level': "3",
35+
'infer_types': True
36+
}))
37+
38+
# example usage: python3 build_cython.py build_ext --inplace

vllm/engine/llm_engine.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1249,7 +1249,7 @@ def _process_model_outputs(self,
12491249
return None
12501250

12511251
def _advance_to_next_step(
1252-
self, output: List[SamplerOutput],
1252+
self, output: SamplerOutput,
12531253
seq_group_metadata_list: List[SequenceGroupMetadata],
12541254
scheduled_seq_groups: List[ScheduledSequenceGroup]) -> None:
12551255
"""Given model output from a single run, append the tokens to the

vllm/model_executor/layers/sampler.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1187,7 +1187,8 @@ def _build_sampler_output(
11871187
deferred_sample_results_args=deferred_sample_results_args)
11881188

11891189

1190-
def _get_next_prompt_tokens(seq_group: SequenceGroupToSample) -> List[int]:
1190+
def _get_next_prompt_tokens(
1191+
seq_group: SequenceGroupToSample) -> tuple[int, ...]:
11911192
"""Get a list of next prompt tokens to compute logprob from a
11921193
given sequence group.
11931194

vllm/utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
from dataclasses import dataclass, field
3838
from functools import cache, lru_cache, partial, wraps
3939
from typing import (TYPE_CHECKING, Any, Callable, Generic, Literal, NamedTuple,
40-
Optional, TypeVar, Union)
40+
Optional, Type, TypeVar, Union)
4141
from uuid import uuid4
4242

4343
import cloudpickle
@@ -1544,9 +1544,9 @@ def __len__(self):
15441544
return len(self._factory)
15451545

15461546

1547-
class ClassRegistry(UserDict[type[T], _V]):
1547+
class ClassRegistry(UserDict[Type[T], _V]):
15481548

1549-
def __getitem__(self, key: type[T]) -> _V:
1549+
def __getitem__(self, key: Type[T]) -> _V:
15501550
for cls in key.mro():
15511551
if cls in self.data:
15521552
return self.data[cls]

0 commit comments

Comments
 (0)