|
7 | 7 | # Inputs:
|
8 | 8 | # CU_VERSION (cpu, cu92, cu100)
|
9 | 9 | # NO_CUDA_PACKAGE (bool)
|
| 10 | +# BUILD_TYPE (conda, wheel) |
10 | 11 | #
|
11 | 12 | # Outputs:
|
12 | 13 | # VERSION_SUFFIX (e.g., "")
|
|
27 | 28 | # version of a Python package. But that doesn't apply if you're on OS X,
|
28 | 29 | # since the default CU_VERSION on OS X is cpu.
|
29 | 30 | setup_cuda() {
|
30 |
| - if [[ "$(uname)" == Darwin ]] || [[ -n "$NO_CUDA_PACKAGE" ]]; then |
31 |
| - if [[ "$CU_VERSION" != "cpu" ]]; then |
32 |
| - echo "CU_VERSION on OS X / package with no CUDA must be cpu" |
33 |
| - exit 1 |
| 31 | + |
| 32 | + # First, compute version suffixes. By default, assume no version suffixes |
| 33 | + export VERSION_SUFFIX="" |
| 34 | + export PYTORCH_VERSION_SUFFIX="" |
| 35 | + export WHEEL_DIR="" |
| 36 | + # Wheel builds need suffixes (but not if they're on OS X, which never has suffix) |
| 37 | + if [[ "$BUILD_TYPE" == "wheel" ]] && [[ "$(uname)" != Darwin ]]; then |
| 38 | + # The default CUDA has no suffix |
| 39 | + if [[ "$CU_VERSION" != "cu100" ]]; then |
| 40 | + export PYTORCH_VERSION_SUFFIX="+$CU_VERSION" |
34 | 41 | fi
|
35 |
| - if [[ "$(uname)" == Darwin ]]; then |
36 |
| - export PYTORCH_VERSION_SUFFIX="" |
37 |
| - else |
38 |
| - export PYTORCH_VERSION_SUFFIX="+cpu" |
| 42 | + # Match the suffix scheme of pytorch, unless this package does not have |
| 43 | + # CUDA builds (in which case, use default) |
| 44 | + if [[ -z "$NO_CUDA_PACKAGE" ]]; then |
| 45 | + export VERSION_SUFFIX="$PYTORCH_VERSION_SUFFIX" |
| 46 | + # If the suffix is non-empty, we will use a wheel subdirectory |
| 47 | + if [[ -n "$PYTORCH_VERSION_SUFFIX" ]]; then |
| 48 | + export WHEEL_DIR="$PYTORCH_VERSION_SUFFIX/" |
| 49 | + fi |
39 | 50 | fi
|
40 |
| - export VERSION_SUFFIX="" |
41 |
| - # NB: When there is no CUDA package available, we put these |
42 |
| - # packages in the top-level directory, so they are eligible |
43 |
| - # for selection even if you are otherwise trying to install |
44 |
| - # a cu100 stack. This differs from when there ARE CUDA packages |
45 |
| - # available; then we don't want the cpu package; we want |
46 |
| - # to give you as much goodies as possible. |
47 |
| - export WHEEL_DIR="" |
48 |
| - else |
49 |
| - case "$CU_VERSION" in |
50 |
| - cu100) |
51 |
| - export PYTORCH_VERSION_SUFFIX="" |
52 |
| - export CUDA_HOME=/usr/local/cuda-10.0/ |
53 |
| - export FORCE_CUDA=1 |
54 |
| - # Hard-coding gencode flags is temporary situation until |
55 |
| - # https://github.com/pytorch/pytorch/pull/23408 lands |
56 |
| - export NVCC_FLAGS="-gencode=arch=compute_35,code=sm_35 -gencode=arch=compute_50,code=sm_50 -gencode=arch=compute_60,code=sm_60 -gencode=arch=compute_70,code=sm_70 -gencode=arch=compute_75,code=sm_75 -gencode=arch=compute_50,code=compute_50" |
57 |
| - ;; |
58 |
| - cu92) |
59 |
| - export CUDA_HOME=/usr/local/cuda-9.2/ |
60 |
| - export PYTORCH_VERSION_SUFFIX="+cu92" |
61 |
| - export FORCE_CUDA=1 |
62 |
| - export NVCC_FLAGS="-gencode=arch=compute_35,code=sm_35 -gencode=arch=compute_50,code=sm_50 -gencode=arch=compute_60,code=sm_60 -gencode=arch=compute_70,code=sm_70 -gencode=arch=compute_50,code=compute_50" |
63 |
| - ;; |
64 |
| - cpu) |
65 |
| - export PYTORCH_VERSION_SUFFIX="+cpu" |
66 |
| - ;; |
67 |
| - *) |
68 |
| - echo "Unrecognized CU_VERSION=$CU_VERSION" |
69 |
| - esac |
70 |
| - export VERSION_SUFFIX="$PYTORCH_VERSION_SUFFIX" |
71 |
| - export WHEEL_DIR="$CU_VERSION/" |
72 | 51 | fi
|
| 52 | + |
| 53 | + # Now work out the CUDA settings |
| 54 | + case "$CU_VERSION" in |
| 55 | + cu100) |
| 56 | + export CUDA_HOME=/usr/local/cuda-10.0/ |
| 57 | + export FORCE_CUDA=1 |
| 58 | + # Hard-coding gencode flags is temporary situation until |
| 59 | + # https://github.com/pytorch/pytorch/pull/23408 lands |
| 60 | + export NVCC_FLAGS="-gencode=arch=compute_35,code=sm_35 -gencode=arch=compute_50,code=sm_50 -gencode=arch=compute_60,code=sm_60 -gencode=arch=compute_70,code=sm_70 -gencode=arch=compute_75,code=sm_75 -gencode=arch=compute_50,code=compute_50" |
| 61 | + ;; |
| 62 | + cu92) |
| 63 | + export CUDA_HOME=/usr/local/cuda-9.2/ |
| 64 | + export FORCE_CUDA=1 |
| 65 | + export NVCC_FLAGS="-gencode=arch=compute_35,code=sm_35 -gencode=arch=compute_50,code=sm_50 -gencode=arch=compute_60,code=sm_60 -gencode=arch=compute_70,code=sm_70 -gencode=arch=compute_50,code=compute_50" |
| 66 | + ;; |
| 67 | + cpu) |
| 68 | + ;; |
| 69 | + *) |
| 70 | + echo "Unrecognized CU_VERSION=$CU_VERSION" |
| 71 | + exit 1 |
| 72 | + ;; |
| 73 | + esac |
73 | 74 | }
|
74 | 75 |
|
75 | 76 | # Populate build version if necessary, and add version suffix
|
76 | 77 | #
|
77 | 78 | # Inputs:
|
78 | 79 | # BUILD_VERSION (e.g., 0.2.0 or empty)
|
| 80 | +# VERSION_SUFFIX (e.g., +cpu) |
79 | 81 | #
|
80 | 82 | # Outputs:
|
81 | 83 | # BUILD_VERSION (e.g., 0.2.0.dev20190807+cpu)
|
|
0 commit comments