Skip to content

Commit a425bd9

Browse files
authored
[Setup] Enable TORCH_CUDA_ARCH_LIST for selecting target GPUs (#1074)
1 parent bbbf865 commit a425bd9

File tree

1 file changed

+73
-38
lines changed

1 file changed

+73
-38
lines changed

setup.py

Lines changed: 73 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import re
44
import subprocess
55
from typing import List, Set
6+
import warnings
67

78
from packaging.version import parse, Version
89
import setuptools
@@ -11,6 +12,9 @@
1112

1213
ROOT_DIR = os.path.dirname(__file__)
1314

15+
# Supported NVIDIA GPU architectures.
16+
SUPPORTED_ARCHS = ["7.0", "7.5", "8.0", "8.6", "8.9", "9.0"]
17+
1418
# Compiler flags.
1519
CXX_FLAGS = ["-g", "-O2", "-std=c++17"]
1620
# TODO(woosuk): Should we use -O3?
@@ -38,51 +42,82 @@ def get_nvcc_cuda_version(cuda_dir: str) -> Version:
3842
return nvcc_cuda_version
3943

4044

41-
# Collect the compute capabilities of all available GPUs.
42-
device_count = torch.cuda.device_count()
43-
compute_capabilities: Set[int] = set()
44-
for i in range(device_count):
45-
major, minor = torch.cuda.get_device_capability(i)
46-
if major < 7:
47-
raise RuntimeError(
48-
"GPUs with compute capability less than 7.0 are not supported.")
49-
compute_capabilities.add(major * 10 + minor)
45+
def get_torch_arch_list() -> Set[str]:
46+
# TORCH_CUDA_ARCH_LIST can have one or more architectures,
47+
# e.g. "8.0" or "7.5,8.0,8.6+PTX". Here, the "8.6+PTX" option asks the
48+
# compiler to additionally include PTX code that can be runtime-compiled
49+
# and executed on the 8.6 or newer architectures. While the PTX code will
50+
# not give the best performance on the newer architectures, it provides
51+
# forward compatibility.
52+
valid_arch_strs = SUPPORTED_ARCHS + [s + "+PTX" for s in SUPPORTED_ARCHS]
53+
arch_list = os.environ.get("TORCH_CUDA_ARCH_LIST", None)
54+
if arch_list is None:
55+
return set()
56+
57+
# List are separated by ; or space.
58+
arch_list = arch_list.replace(" ", ";").split(";")
59+
for arch in arch_list:
60+
if arch not in valid_arch_strs:
61+
raise ValueError(
62+
f"Unsupported CUDA arch ({arch}). "
63+
f"Valid CUDA arch strings are: {valid_arch_strs}.")
64+
return set(arch_list)
65+
66+
67+
# First, check the TORCH_CUDA_ARCH_LIST environment variable.
68+
compute_capabilities = get_torch_arch_list()
69+
if not compute_capabilities:
70+
# If TORCH_CUDA_ARCH_LIST is not defined or empty, target all available
71+
# GPUs on the current machine.
72+
device_count = torch.cuda.device_count()
73+
for i in range(device_count):
74+
major, minor = torch.cuda.get_device_capability(i)
75+
if major < 7:
76+
raise RuntimeError(
77+
"GPUs with compute capability below 7.0 are not supported.")
78+
compute_capabilities.add(f"{major}.{minor}")
5079

51-
# Validate the NVCC CUDA version.
5280
nvcc_cuda_version = get_nvcc_cuda_version(CUDA_HOME)
81+
if not compute_capabilities:
82+
# If no GPU is specified nor available, add all supported architectures
83+
# based on the NVCC CUDA version.
84+
compute_capabilities = set(SUPPORTED_ARCHS)
85+
if nvcc_cuda_version < Version("11.1"):
86+
compute_capabilities.remove("8.6")
87+
if nvcc_cuda_version < Version("11.8"):
88+
compute_capabilities.remove("8.9")
89+
compute_capabilities.remove("9.0")
90+
91+
# Validate the NVCC CUDA version.
5392
if nvcc_cuda_version < Version("11.0"):
5493
raise RuntimeError("CUDA 11.0 or higher is required to build the package.")
55-
if 86 in compute_capabilities and nvcc_cuda_version < Version("11.1"):
56-
raise RuntimeError(
57-
"CUDA 11.1 or higher is required for GPUs with compute capability 8.6."
58-
)
59-
if 89 in compute_capabilities and nvcc_cuda_version < Version("11.8"):
60-
# CUDA 11.8 is required to generate the code targeting compute capability 8.9.
61-
# However, GPUs with compute capability 8.9 can also run the code generated by
62-
# the previous versions of CUDA 11 and targeting compute capability 8.0.
63-
# Therefore, if CUDA 11.8 is not available, we target compute capability 8.0
64-
# instead of 8.9.
65-
compute_capabilities.remove(89)
66-
compute_capabilities.add(80)
67-
if 90 in compute_capabilities and nvcc_cuda_version < Version("11.8"):
68-
raise RuntimeError(
69-
"CUDA 11.8 or higher is required for GPUs with compute capability 9.0."
70-
)
71-
72-
# If no GPU is available, add all supported compute capabilities.
73-
if not compute_capabilities:
74-
compute_capabilities = {70, 75, 80}
75-
if nvcc_cuda_version >= Version("11.1"):
76-
compute_capabilities.add(86)
77-
if nvcc_cuda_version >= Version("11.8"):
78-
compute_capabilities.add(89)
79-
compute_capabilities.add(90)
94+
if nvcc_cuda_version < Version("11.1"):
95+
if any(cc.startswith("8.6") for cc in compute_capabilities):
96+
raise RuntimeError(
97+
"CUDA 11.1 or higher is required for compute capability 8.6.")
98+
if nvcc_cuda_version < Version("11.8"):
99+
if any(cc.startswith("8.9") for cc in compute_capabilities):
100+
# CUDA 11.8 is required to generate the code targeting compute capability 8.9.
101+
# However, GPUs with compute capability 8.9 can also run the code generated by
102+
# the previous versions of CUDA 11 and targeting compute capability 8.0.
103+
# Therefore, if CUDA 11.8 is not available, we target compute capability 8.0
104+
# instead of 8.9.
105+
warnings.warn(
106+
"CUDA 11.8 or higher is required for compute capability 8.9. "
107+
"Targeting compute capability 8.0 instead.")
108+
compute_capabilities = set(cc for cc in compute_capabilities
109+
if not cc.startswith("8.9"))
110+
compute_capabilities.add("8.0+PTX")
111+
if any(cc.startswith("9.0") for cc in compute_capabilities):
112+
raise RuntimeError(
113+
"CUDA 11.8 or higher is required for compute capability 9.0.")
80114

81115
# Add target compute capabilities to NVCC flags.
82116
for capability in compute_capabilities:
83-
NVCC_FLAGS += [
84-
"-gencode", f"arch=compute_{capability},code=sm_{capability}"
85-
]
117+
num = capability[0] + capability[2]
118+
NVCC_FLAGS += ["-gencode", f"arch=compute_{num},code=sm_{num}"]
119+
if capability.endswith("+PTX"):
120+
NVCC_FLAGS += ["-gencode", f"arch=compute_{num},code=compute_{num}"]
86121

87122
# Use NVCC threads to parallelize the build.
88123
if nvcc_cuda_version >= Version("11.2"):

0 commit comments

Comments
 (0)