2222
2323if CUDA_HOME is None :
2424 raise RuntimeError (
25- f "Cannot find CUDA_HOME. CUDA must be available to build the package." )
25+ "Cannot find CUDA_HOME. CUDA must be available to build the package." )
2626
2727
2828def get_nvcc_cuda_version (cuda_dir : str ) -> Version :
@@ -54,7 +54,8 @@ def get_nvcc_cuda_version(cuda_dir: str) -> Version:
5454 raise RuntimeError ("CUDA 11.0 or higher is required to build the package." )
5555if 86 in compute_capabilities and nvcc_cuda_version < Version ("11.1" ):
5656 raise RuntimeError (
57- "CUDA 11.1 or higher is required for GPUs with compute capability 8.6." )
57+ "CUDA 11.1 or higher is required for GPUs with compute capability 8.6."
58+ )
5859if 89 in compute_capabilities and nvcc_cuda_version < Version ("11.8" ):
5960 # CUDA 11.8 is required to generate the code targeting compute capability 8.9.
6061 # However, GPUs with compute capability 8.9 can also run the code generated by
@@ -65,7 +66,8 @@ def get_nvcc_cuda_version(cuda_dir: str) -> Version:
6566 compute_capabilities .add (80 )
6667if 90 in compute_capabilities and nvcc_cuda_version < Version ("11.8" ):
6768 raise RuntimeError (
68- "CUDA 11.8 or higher is required for GPUs with compute capability 9.0." )
69+ "CUDA 11.8 or higher is required for GPUs with compute capability 9.0."
70+ )
6971
7072# If no GPU is available, add all supported compute capabilities.
7173if not compute_capabilities :
@@ -78,7 +80,9 @@ def get_nvcc_cuda_version(cuda_dir: str) -> Version:
7880
7981# Add target compute capabilities to NVCC flags.
8082for capability in compute_capabilities :
81- NVCC_FLAGS += ["-gencode" , f"arch=compute_{ capability } ,code=sm_{ capability } " ]
83+ NVCC_FLAGS += [
84+ "-gencode" , f"arch=compute_{ capability } ,code=sm_{ capability } "
85+ ]
8286
8387# Use NVCC threads to parallelize the build.
8488if nvcc_cuda_version >= Version ("11.2" ):
@@ -91,39 +95,54 @@ def get_nvcc_cuda_version(cuda_dir: str) -> Version:
9195cache_extension = CUDAExtension (
9296 name = "vllm.cache_ops" ,
9397 sources = ["csrc/cache.cpp" , "csrc/cache_kernels.cu" ],
94- extra_compile_args = {"cxx" : CXX_FLAGS , "nvcc" : NVCC_FLAGS },
98+ extra_compile_args = {
99+ "cxx" : CXX_FLAGS ,
100+ "nvcc" : NVCC_FLAGS ,
101+ },
95102)
96103ext_modules .append (cache_extension )
97104
98105# Attention kernels.
99106attention_extension = CUDAExtension (
100107 name = "vllm.attention_ops" ,
101108 sources = ["csrc/attention.cpp" , "csrc/attention/attention_kernels.cu" ],
102- extra_compile_args = {"cxx" : CXX_FLAGS , "nvcc" : NVCC_FLAGS },
109+ extra_compile_args = {
110+ "cxx" : CXX_FLAGS ,
111+ "nvcc" : NVCC_FLAGS ,
112+ },
103113)
104114ext_modules .append (attention_extension )
105115
106116# Positional encoding kernels.
107117positional_encoding_extension = CUDAExtension (
108118 name = "vllm.pos_encoding_ops" ,
109119 sources = ["csrc/pos_encoding.cpp" , "csrc/pos_encoding_kernels.cu" ],
110- extra_compile_args = {"cxx" : CXX_FLAGS , "nvcc" : NVCC_FLAGS },
120+ extra_compile_args = {
121+ "cxx" : CXX_FLAGS ,
122+ "nvcc" : NVCC_FLAGS ,
123+ },
111124)
112125ext_modules .append (positional_encoding_extension )
113126
114127# Layer normalization kernels.
115128layernorm_extension = CUDAExtension (
116129 name = "vllm.layernorm_ops" ,
117130 sources = ["csrc/layernorm.cpp" , "csrc/layernorm_kernels.cu" ],
118- extra_compile_args = {"cxx" : CXX_FLAGS , "nvcc" : NVCC_FLAGS },
131+ extra_compile_args = {
132+ "cxx" : CXX_FLAGS ,
133+ "nvcc" : NVCC_FLAGS ,
134+ },
119135)
120136ext_modules .append (layernorm_extension )
121137
122138# Activation kernels.
123139activation_extension = CUDAExtension (
124140 name = "vllm.activation_ops" ,
125141 sources = ["csrc/activation.cpp" , "csrc/activation_kernels.cu" ],
126- extra_compile_args = {"cxx" : CXX_FLAGS , "nvcc" : NVCC_FLAGS },
142+ extra_compile_args = {
143+ "cxx" : CXX_FLAGS ,
144+ "nvcc" : NVCC_FLAGS ,
145+ },
127146)
128147ext_modules .append (activation_extension )
129148
@@ -138,8 +157,8 @@ def find_version(filepath: str):
138157 Adapted from https://github.com/ray-project/ray/blob/0b190ee1160eeca9796bc091e07eaebf4c85b511/python/setup.py
139158 """
140159 with open (filepath ) as fp :
141- version_match = re .search (
142- r"^__version__ = ['\"]([^'\"]*)['\"]" , fp .read (), re .M )
160+ version_match = re .search (r"^__version__ = ['\"]([^'\"]*)['\"]" ,
161+ fp .read (), re .M )
143162 if version_match :
144163 return version_match .group (1 )
145164 raise RuntimeError ("Unable to find version string." )
@@ -162,7 +181,8 @@ def get_requirements() -> List[str]:
162181 version = find_version (get_path ("vllm" , "__init__.py" )),
163182 author = "vLLM Team" ,
164183 license = "Apache 2.0" ,
165- description = "A high-throughput and memory-efficient inference and serving engine for LLMs" ,
184+ description = ("A high-throughput and memory-efficient inference and "
185+ "serving engine for LLMs" ),
166186 long_description = read_readme (),
167187 long_description_content_type = "text/markdown" ,
168188 url = "https://github.com/vllm-project/vllm" ,
@@ -174,11 +194,12 @@ def get_requirements() -> List[str]:
174194 "Programming Language :: Python :: 3.8" ,
175195 "Programming Language :: Python :: 3.9" ,
176196 "Programming Language :: Python :: 3.10" ,
197+ "Programming Language :: Python :: 3.11" ,
177198 "License :: OSI Approved :: Apache Software License" ,
178199 "Topic :: Scientific/Engineering :: Artificial Intelligence" ,
179200 ],
180- packages = setuptools .find_packages (
181- exclude = ( "assets" , "benchmarks" , "csrc" , "docs" , "examples" , "tests" )),
201+ packages = setuptools .find_packages (exclude = ( "benchmarks" , "csrc" , "docs" ,
202+ "examples" , "tests" )),
182203 python_requires = ">=3.8" ,
183204 install_requires = get_requirements (),
184205 ext_modules = ext_modules ,
0 commit comments