|
3 | 3 | import re
|
4 | 4 | import subprocess
|
5 | 5 | from typing import List, Set
|
| 6 | +import warnings |
6 | 7 |
|
7 | 8 | from packaging.version import parse, Version
|
8 | 9 | import setuptools
|
|
11 | 12 |
|
12 | 13 | ROOT_DIR = os.path.dirname(__file__)
|
13 | 14 |
|
| 15 | +# Supported NVIDIA GPU architectures. |
| 16 | +SUPPORTED_ARCHS = ["7.0", "7.5", "8.0", "8.6", "8.9", "9.0"] |
| 17 | + |
14 | 18 | # Compiler flags.
|
15 | 19 | CXX_FLAGS = ["-g", "-O2", "-std=c++17"]
|
16 | 20 | # TODO(woosuk): Should we use -O3?
|
@@ -38,51 +42,82 @@ def get_nvcc_cuda_version(cuda_dir: str) -> Version:
|
38 | 42 | return nvcc_cuda_version
|
39 | 43 |
|
40 | 44 |
|
41 |
| -# Collect the compute capabilities of all available GPUs. |
42 |
| -device_count = torch.cuda.device_count() |
43 |
| -compute_capabilities: Set[int] = set() |
44 |
| -for i in range(device_count): |
45 |
| - major, minor = torch.cuda.get_device_capability(i) |
46 |
| - if major < 7: |
47 |
| - raise RuntimeError( |
48 |
| - "GPUs with compute capability less than 7.0 are not supported.") |
49 |
| - compute_capabilities.add(major * 10 + minor) |
| 45 | +def get_torch_arch_list() -> Set[str]: |
| 46 | + # TORCH_CUDA_ARCH_LIST can have one or more architectures, |
| 47 | + # e.g. "8.0" or "7.5,8.0,8.6+PTX". Here, the "8.6+PTX" option asks the |
| 48 | + # compiler to additionally include PTX code that can be runtime-compiled |
| 49 | + # and executed on the 8.6 or newer architectures. While the PTX code will |
| 50 | + # not give the best performance on the newer architectures, it provides |
| 51 | + # forward compatibility. |
| 52 | + valid_arch_strs = SUPPORTED_ARCHS + [s + "+PTX" for s in SUPPORTED_ARCHS] |
| 53 | + arch_list = os.environ.get("TORCH_CUDA_ARCH_LIST", None) |
| 54 | + if arch_list is None: |
| 55 | + return set() |
| 56 | + |
| 57 | + # List are separated by ; or space. |
| 58 | + arch_list = arch_list.replace(" ", ";").split(";") |
| 59 | + for arch in arch_list: |
| 60 | + if arch not in valid_arch_strs: |
| 61 | + raise ValueError( |
| 62 | + f"Unsupported CUDA arch ({arch}). " |
| 63 | + f"Valid CUDA arch strings are: {valid_arch_strs}.") |
| 64 | + return set(arch_list) |
| 65 | + |
| 66 | + |
| 67 | +# First, check the TORCH_CUDA_ARCH_LIST environment variable. |
| 68 | +compute_capabilities = get_torch_arch_list() |
| 69 | +if not compute_capabilities: |
| 70 | + # If TORCH_CUDA_ARCH_LIST is not defined or empty, target all available |
| 71 | + # GPUs on the current machine. |
| 72 | + device_count = torch.cuda.device_count() |
| 73 | + for i in range(device_count): |
| 74 | + major, minor = torch.cuda.get_device_capability(i) |
| 75 | + if major < 7: |
| 76 | + raise RuntimeError( |
| 77 | + "GPUs with compute capability below 7.0 are not supported.") |
| 78 | + compute_capabilities.add(f"{major}.{minor}") |
50 | 79 |
|
51 |
| -# Validate the NVCC CUDA version. |
52 | 80 | nvcc_cuda_version = get_nvcc_cuda_version(CUDA_HOME)
|
| 81 | +if not compute_capabilities: |
| 82 | + # If no GPU is specified nor available, add all supported architectures |
| 83 | + # based on the NVCC CUDA version. |
| 84 | + compute_capabilities = set(SUPPORTED_ARCHS) |
| 85 | + if nvcc_cuda_version < Version("11.1"): |
| 86 | + compute_capabilities.remove("8.6") |
| 87 | + if nvcc_cuda_version < Version("11.8"): |
| 88 | + compute_capabilities.remove("8.9") |
| 89 | + compute_capabilities.remove("9.0") |
| 90 | + |
| 91 | +# Validate the NVCC CUDA version. |
53 | 92 | if nvcc_cuda_version < Version("11.0"):
|
54 | 93 | raise RuntimeError("CUDA 11.0 or higher is required to build the package.")
|
55 |
| -if 86 in compute_capabilities and nvcc_cuda_version < Version("11.1"): |
56 |
| - raise RuntimeError( |
57 |
| - "CUDA 11.1 or higher is required for GPUs with compute capability 8.6." |
58 |
| - ) |
59 |
| -if 89 in compute_capabilities and nvcc_cuda_version < Version("11.8"): |
60 |
| - # CUDA 11.8 is required to generate the code targeting compute capability 8.9. |
61 |
| - # However, GPUs with compute capability 8.9 can also run the code generated by |
62 |
| - # the previous versions of CUDA 11 and targeting compute capability 8.0. |
63 |
| - # Therefore, if CUDA 11.8 is not available, we target compute capability 8.0 |
64 |
| - # instead of 8.9. |
65 |
| - compute_capabilities.remove(89) |
66 |
| - compute_capabilities.add(80) |
67 |
| -if 90 in compute_capabilities and nvcc_cuda_version < Version("11.8"): |
68 |
| - raise RuntimeError( |
69 |
| - "CUDA 11.8 or higher is required for GPUs with compute capability 9.0." |
70 |
| - ) |
71 |
| - |
72 |
| -# If no GPU is available, add all supported compute capabilities. |
73 |
| -if not compute_capabilities: |
74 |
| - compute_capabilities = {70, 75, 80} |
75 |
| - if nvcc_cuda_version >= Version("11.1"): |
76 |
| - compute_capabilities.add(86) |
77 |
| - if nvcc_cuda_version >= Version("11.8"): |
78 |
| - compute_capabilities.add(89) |
79 |
| - compute_capabilities.add(90) |
| 94 | +if nvcc_cuda_version < Version("11.1"): |
| 95 | + if any(cc.startswith("8.6") for cc in compute_capabilities): |
| 96 | + raise RuntimeError( |
| 97 | + "CUDA 11.1 or higher is required for compute capability 8.6.") |
| 98 | +if nvcc_cuda_version < Version("11.8"): |
| 99 | + if any(cc.startswith("8.9") for cc in compute_capabilities): |
| 100 | + # CUDA 11.8 is required to generate the code targeting compute capability 8.9. |
| 101 | + # However, GPUs with compute capability 8.9 can also run the code generated by |
| 102 | + # the previous versions of CUDA 11 and targeting compute capability 8.0. |
| 103 | + # Therefore, if CUDA 11.8 is not available, we target compute capability 8.0 |
| 104 | + # instead of 8.9. |
| 105 | + warnings.warn( |
| 106 | + "CUDA 11.8 or higher is required for compute capability 8.9. " |
| 107 | + "Targeting compute capability 8.0 instead.") |
| 108 | + compute_capabilities = set(cc for cc in compute_capabilities |
| 109 | + if not cc.startswith("8.9")) |
| 110 | + compute_capabilities.add("8.0+PTX") |
| 111 | + if any(cc.startswith("9.0") for cc in compute_capabilities): |
| 112 | + raise RuntimeError( |
| 113 | + "CUDA 11.8 or higher is required for compute capability 9.0.") |
80 | 114 |
|
81 | 115 | # Add target compute capabilities to NVCC flags.
|
82 | 116 | for capability in compute_capabilities:
|
83 |
| - NVCC_FLAGS += [ |
84 |
| - "-gencode", f"arch=compute_{capability},code=sm_{capability}" |
85 |
| - ] |
| 117 | + num = capability[0] + capability[2] |
| 118 | + NVCC_FLAGS += ["-gencode", f"arch=compute_{num},code=sm_{num}"] |
| 119 | + if capability.endswith("+PTX"): |
| 120 | + NVCC_FLAGS += ["-gencode", f"arch=compute_{num},code=compute_{num}"] |
86 | 121 |
|
87 | 122 | # Use NVCC threads to parallelize the build.
|
88 | 123 | if nvcc_cuda_version >= Version("11.2"):
|
|
0 commit comments