22
22
23
23
if CUDA_HOME is None :
24
24
raise RuntimeError (
25
- f "Cannot find CUDA_HOME. CUDA must be available to build the package." )
25
+ "Cannot find CUDA_HOME. CUDA must be available to build the package." )
26
26
27
27
28
28
def get_nvcc_cuda_version (cuda_dir : str ) -> Version :
@@ -54,7 +54,8 @@ def get_nvcc_cuda_version(cuda_dir: str) -> Version:
54
54
raise RuntimeError ("CUDA 11.0 or higher is required to build the package." )
55
55
if 86 in compute_capabilities and nvcc_cuda_version < Version ("11.1" ):
56
56
raise RuntimeError (
57
- "CUDA 11.1 or higher is required for GPUs with compute capability 8.6." )
57
+ "CUDA 11.1 or higher is required for GPUs with compute capability 8.6."
58
+ )
58
59
if 89 in compute_capabilities and nvcc_cuda_version < Version ("11.8" ):
59
60
# CUDA 11.8 is required to generate the code targeting compute capability 8.9.
60
61
# However, GPUs with compute capability 8.9 can also run the code generated by
@@ -65,7 +66,8 @@ def get_nvcc_cuda_version(cuda_dir: str) -> Version:
65
66
compute_capabilities .add (80 )
66
67
if 90 in compute_capabilities and nvcc_cuda_version < Version ("11.8" ):
67
68
raise RuntimeError (
68
- "CUDA 11.8 or higher is required for GPUs with compute capability 9.0." )
69
+ "CUDA 11.8 or higher is required for GPUs with compute capability 9.0."
70
+ )
69
71
70
72
# If no GPU is available, add all supported compute capabilities.
71
73
if not compute_capabilities :
@@ -78,7 +80,9 @@ def get_nvcc_cuda_version(cuda_dir: str) -> Version:
78
80
79
81
# Add target compute capabilities to NVCC flags.
80
82
for capability in compute_capabilities :
81
- NVCC_FLAGS += ["-gencode" , f"arch=compute_{ capability } ,code=sm_{ capability } " ]
83
+ NVCC_FLAGS += [
84
+ "-gencode" , f"arch=compute_{ capability } ,code=sm_{ capability } "
85
+ ]
82
86
83
87
# Use NVCC threads to parallelize the build.
84
88
if nvcc_cuda_version >= Version ("11.2" ):
@@ -91,39 +95,54 @@ def get_nvcc_cuda_version(cuda_dir: str) -> Version:
91
95
cache_extension = CUDAExtension (
92
96
name = "vllm.cache_ops" ,
93
97
sources = ["csrc/cache.cpp" , "csrc/cache_kernels.cu" ],
94
- extra_compile_args = {"cxx" : CXX_FLAGS , "nvcc" : NVCC_FLAGS },
98
+ extra_compile_args = {
99
+ "cxx" : CXX_FLAGS ,
100
+ "nvcc" : NVCC_FLAGS ,
101
+ },
95
102
)
96
103
ext_modules .append (cache_extension )
97
104
98
105
# Attention kernels.
99
106
attention_extension = CUDAExtension (
100
107
name = "vllm.attention_ops" ,
101
108
sources = ["csrc/attention.cpp" , "csrc/attention/attention_kernels.cu" ],
102
- extra_compile_args = {"cxx" : CXX_FLAGS , "nvcc" : NVCC_FLAGS },
109
+ extra_compile_args = {
110
+ "cxx" : CXX_FLAGS ,
111
+ "nvcc" : NVCC_FLAGS ,
112
+ },
103
113
)
104
114
ext_modules .append (attention_extension )
105
115
106
116
# Positional encoding kernels.
107
117
positional_encoding_extension = CUDAExtension (
108
118
name = "vllm.pos_encoding_ops" ,
109
119
sources = ["csrc/pos_encoding.cpp" , "csrc/pos_encoding_kernels.cu" ],
110
- extra_compile_args = {"cxx" : CXX_FLAGS , "nvcc" : NVCC_FLAGS },
120
+ extra_compile_args = {
121
+ "cxx" : CXX_FLAGS ,
122
+ "nvcc" : NVCC_FLAGS ,
123
+ },
111
124
)
112
125
ext_modules .append (positional_encoding_extension )
113
126
114
127
# Layer normalization kernels.
115
128
layernorm_extension = CUDAExtension (
116
129
name = "vllm.layernorm_ops" ,
117
130
sources = ["csrc/layernorm.cpp" , "csrc/layernorm_kernels.cu" ],
118
- extra_compile_args = {"cxx" : CXX_FLAGS , "nvcc" : NVCC_FLAGS },
131
+ extra_compile_args = {
132
+ "cxx" : CXX_FLAGS ,
133
+ "nvcc" : NVCC_FLAGS ,
134
+ },
119
135
)
120
136
ext_modules .append (layernorm_extension )
121
137
122
138
# Activation kernels.
123
139
activation_extension = CUDAExtension (
124
140
name = "vllm.activation_ops" ,
125
141
sources = ["csrc/activation.cpp" , "csrc/activation_kernels.cu" ],
126
- extra_compile_args = {"cxx" : CXX_FLAGS , "nvcc" : NVCC_FLAGS },
142
+ extra_compile_args = {
143
+ "cxx" : CXX_FLAGS ,
144
+ "nvcc" : NVCC_FLAGS ,
145
+ },
127
146
)
128
147
ext_modules .append (activation_extension )
129
148
@@ -138,8 +157,8 @@ def find_version(filepath: str):
138
157
Adapted from https://github.com/ray-project/ray/blob/0b190ee1160eeca9796bc091e07eaebf4c85b511/python/setup.py
139
158
"""
140
159
with open (filepath ) as fp :
141
- version_match = re .search (
142
- r"^__version__ = ['\"]([^'\"]*)['\"]" , fp .read (), re .M )
160
+ version_match = re .search (r"^__version__ = ['\"]([^'\"]*)['\"]" ,
161
+ fp .read (), re .M )
143
162
if version_match :
144
163
return version_match .group (1 )
145
164
raise RuntimeError ("Unable to find version string." )
@@ -162,7 +181,8 @@ def get_requirements() -> List[str]:
162
181
version = find_version (get_path ("vllm" , "__init__.py" )),
163
182
author = "vLLM Team" ,
164
183
license = "Apache 2.0" ,
165
- description = "A high-throughput and memory-efficient inference and serving engine for LLMs" ,
184
+ description = ("A high-throughput and memory-efficient inference and "
185
+ "serving engine for LLMs" ),
166
186
long_description = read_readme (),
167
187
long_description_content_type = "text/markdown" ,
168
188
url = "https://github.com/vllm-project/vllm" ,
@@ -174,11 +194,12 @@ def get_requirements() -> List[str]:
174
194
"Programming Language :: Python :: 3.8" ,
175
195
"Programming Language :: Python :: 3.9" ,
176
196
"Programming Language :: Python :: 3.10" ,
197
+ "Programming Language :: Python :: 3.11" ,
177
198
"License :: OSI Approved :: Apache Software License" ,
178
199
"Topic :: Scientific/Engineering :: Artificial Intelligence" ,
179
200
],
180
- packages = setuptools .find_packages (
181
- exclude = ( "assets" , "benchmarks" , "csrc" , "docs" , "examples" , "tests" )),
201
+ packages = setuptools .find_packages (exclude = ( "benchmarks" , "csrc" , "docs" ,
202
+ "examples" , "tests" )),
182
203
python_requires = ">=3.8" ,
183
204
install_requires = get_requirements (),
184
205
ext_modules = ext_modules ,
0 commit comments