Skip to content

Commit d6770d1

Browse files
authored
Update setup.py (#1006)
1 parent b9cecc2 commit d6770d1

File tree

1 file changed

+35
-14
lines changed

1 file changed

+35
-14
lines changed

setup.py

Lines changed: 35 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222

2323
if CUDA_HOME is None:
2424
raise RuntimeError(
25-
f"Cannot find CUDA_HOME. CUDA must be available to build the package.")
25+
"Cannot find CUDA_HOME. CUDA must be available to build the package.")
2626

2727

2828
def get_nvcc_cuda_version(cuda_dir: str) -> Version:
@@ -54,7 +54,8 @@ def get_nvcc_cuda_version(cuda_dir: str) -> Version:
5454
raise RuntimeError("CUDA 11.0 or higher is required to build the package.")
5555
if 86 in compute_capabilities and nvcc_cuda_version < Version("11.1"):
5656
raise RuntimeError(
57-
"CUDA 11.1 or higher is required for GPUs with compute capability 8.6.")
57+
"CUDA 11.1 or higher is required for GPUs with compute capability 8.6."
58+
)
5859
if 89 in compute_capabilities and nvcc_cuda_version < Version("11.8"):
5960
# CUDA 11.8 is required to generate the code targeting compute capability 8.9.
6061
# However, GPUs with compute capability 8.9 can also run the code generated by
@@ -65,7 +66,8 @@ def get_nvcc_cuda_version(cuda_dir: str) -> Version:
6566
compute_capabilities.add(80)
6667
if 90 in compute_capabilities and nvcc_cuda_version < Version("11.8"):
6768
raise RuntimeError(
68-
"CUDA 11.8 or higher is required for GPUs with compute capability 9.0.")
69+
"CUDA 11.8 or higher is required for GPUs with compute capability 9.0."
70+
)
6971

7072
# If no GPU is available, add all supported compute capabilities.
7173
if not compute_capabilities:
@@ -78,7 +80,9 @@ def get_nvcc_cuda_version(cuda_dir: str) -> Version:
7880

7981
# Add target compute capabilities to NVCC flags.
8082
for capability in compute_capabilities:
81-
NVCC_FLAGS += ["-gencode", f"arch=compute_{capability},code=sm_{capability}"]
83+
NVCC_FLAGS += [
84+
"-gencode", f"arch=compute_{capability},code=sm_{capability}"
85+
]
8286

8387
# Use NVCC threads to parallelize the build.
8488
if nvcc_cuda_version >= Version("11.2"):
@@ -91,39 +95,54 @@ def get_nvcc_cuda_version(cuda_dir: str) -> Version:
9195
cache_extension = CUDAExtension(
9296
name="vllm.cache_ops",
9397
sources=["csrc/cache.cpp", "csrc/cache_kernels.cu"],
94-
extra_compile_args={"cxx": CXX_FLAGS, "nvcc": NVCC_FLAGS},
98+
extra_compile_args={
99+
"cxx": CXX_FLAGS,
100+
"nvcc": NVCC_FLAGS,
101+
},
95102
)
96103
ext_modules.append(cache_extension)
97104

98105
# Attention kernels.
99106
attention_extension = CUDAExtension(
100107
name="vllm.attention_ops",
101108
sources=["csrc/attention.cpp", "csrc/attention/attention_kernels.cu"],
102-
extra_compile_args={"cxx": CXX_FLAGS, "nvcc": NVCC_FLAGS},
109+
extra_compile_args={
110+
"cxx": CXX_FLAGS,
111+
"nvcc": NVCC_FLAGS,
112+
},
103113
)
104114
ext_modules.append(attention_extension)
105115

106116
# Positional encoding kernels.
107117
positional_encoding_extension = CUDAExtension(
108118
name="vllm.pos_encoding_ops",
109119
sources=["csrc/pos_encoding.cpp", "csrc/pos_encoding_kernels.cu"],
110-
extra_compile_args={"cxx": CXX_FLAGS, "nvcc": NVCC_FLAGS},
120+
extra_compile_args={
121+
"cxx": CXX_FLAGS,
122+
"nvcc": NVCC_FLAGS,
123+
},
111124
)
112125
ext_modules.append(positional_encoding_extension)
113126

114127
# Layer normalization kernels.
115128
layernorm_extension = CUDAExtension(
116129
name="vllm.layernorm_ops",
117130
sources=["csrc/layernorm.cpp", "csrc/layernorm_kernels.cu"],
118-
extra_compile_args={"cxx": CXX_FLAGS, "nvcc": NVCC_FLAGS},
131+
extra_compile_args={
132+
"cxx": CXX_FLAGS,
133+
"nvcc": NVCC_FLAGS,
134+
},
119135
)
120136
ext_modules.append(layernorm_extension)
121137

122138
# Activation kernels.
123139
activation_extension = CUDAExtension(
124140
name="vllm.activation_ops",
125141
sources=["csrc/activation.cpp", "csrc/activation_kernels.cu"],
126-
extra_compile_args={"cxx": CXX_FLAGS, "nvcc": NVCC_FLAGS},
142+
extra_compile_args={
143+
"cxx": CXX_FLAGS,
144+
"nvcc": NVCC_FLAGS,
145+
},
127146
)
128147
ext_modules.append(activation_extension)
129148

@@ -138,8 +157,8 @@ def find_version(filepath: str):
138157
Adapted from https://github.com/ray-project/ray/blob/0b190ee1160eeca9796bc091e07eaebf4c85b511/python/setup.py
139158
"""
140159
with open(filepath) as fp:
141-
version_match = re.search(
142-
r"^__version__ = ['\"]([^'\"]*)['\"]", fp.read(), re.M)
160+
version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]",
161+
fp.read(), re.M)
143162
if version_match:
144163
return version_match.group(1)
145164
raise RuntimeError("Unable to find version string.")
@@ -162,7 +181,8 @@ def get_requirements() -> List[str]:
162181
version=find_version(get_path("vllm", "__init__.py")),
163182
author="vLLM Team",
164183
license="Apache 2.0",
165-
description="A high-throughput and memory-efficient inference and serving engine for LLMs",
184+
description=("A high-throughput and memory-efficient inference and "
185+
"serving engine for LLMs"),
166186
long_description=read_readme(),
167187
long_description_content_type="text/markdown",
168188
url="https://github.com/vllm-project/vllm",
@@ -174,11 +194,12 @@ def get_requirements() -> List[str]:
174194
"Programming Language :: Python :: 3.8",
175195
"Programming Language :: Python :: 3.9",
176196
"Programming Language :: Python :: 3.10",
197+
"Programming Language :: Python :: 3.11",
177198
"License :: OSI Approved :: Apache Software License",
178199
"Topic :: Scientific/Engineering :: Artificial Intelligence",
179200
],
180-
packages=setuptools.find_packages(
181-
exclude=("assets", "benchmarks", "csrc", "docs", "examples", "tests")),
201+
packages=setuptools.find_packages(exclude=("benchmarks", "csrc", "docs",
202+
"examples", "tests")),
182203
python_requires=">=3.8",
183204
install_requires=get_requirements(),
184205
ext_modules=ext_modules,

0 commit comments

Comments
 (0)