Skip to content

Commit a1ed206

Browse files
authored
Refactor version suffix so conda packages don't get suffixes. (#1218) (#1219)
Signed-off-by: Edward Z. Yang <[email protected]>
1 parent 66bc6f9 commit a1ed206

File tree

3 files changed

+44
-40
lines changed

3 files changed

+44
-40
lines changed

packaging/build_conda.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ set -ex
44
script_dir="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"
55
. "$script_dir/pkg_helpers.bash"
66

7+
export BUILD_TYPE=conda
78
setup_env 0.4.0
89
export SOURCE_ROOT_DIR="$PWD"
910
setup_conda_pytorch_constraint

packaging/build_wheel.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ set -ex
44
script_dir="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"
55
. "$script_dir/pkg_helpers.bash"
66

7+
export BUILD_TYPE=wheel
78
setup_env 0.4.0
89
setup_wheel_python
910
pip_install numpy pyyaml future ninja

packaging/pkg_helpers.bash

Lines changed: 42 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
# Inputs:
88
# CU_VERSION (cpu, cu92, cu100)
99
# NO_CUDA_PACKAGE (bool)
10+
# BUILD_TYPE (conda, wheel)
1011
#
1112
# Outputs:
1213
# VERSION_SUFFIX (e.g., "")
@@ -27,55 +28,56 @@
2728
# version of a Python package. But that doesn't apply if you're on OS X,
2829
# since the default CU_VERSION on OS X is cpu.
2930
setup_cuda() {
30-
if [[ "$(uname)" == Darwin ]] || [[ -n "$NO_CUDA_PACKAGE" ]]; then
31-
if [[ "$CU_VERSION" != "cpu" ]]; then
32-
echo "CU_VERSION on OS X / package with no CUDA must be cpu"
33-
exit 1
31+
32+
# First, compute version suffixes. By default, assume no version suffixes
33+
export VERSION_SUFFIX=""
34+
export PYTORCH_VERSION_SUFFIX=""
35+
export WHEEL_DIR=""
36+
# Wheel builds need suffixes (but not if they're on OS X, which never has suffix)
37+
if [[ "$BUILD_TYPE" == "wheel" ]] && [[ "$(uname)" != Darwin ]]; then
38+
# The default CUDA has no suffix
39+
if [[ "$CU_VERSION" != "cu100" ]]; then
40+
export PYTORCH_VERSION_SUFFIX="+$CU_VERSION"
3441
fi
35-
if [[ "$(uname)" == Darwin ]]; then
36-
export PYTORCH_VERSION_SUFFIX=""
37-
else
38-
export PYTORCH_VERSION_SUFFIX="+cpu"
42+
# Match the suffix scheme of pytorch, unless this package does not have
43+
# CUDA builds (in which case, use default)
44+
if [[ -z "$NO_CUDA_PACKAGE" ]]; then
45+
export VERSION_SUFFIX="$PYTORCH_VERSION_SUFFIX"
46+
# If the suffix is non-empty, we will use a wheel subdirectory
47+
if [[ -n "$PYTORCH_VERSION_SUFFIX" ]]; then
48+
export WHEEL_DIR="$PYTORCH_VERSION_SUFFIX/"
49+
fi
3950
fi
40-
export VERSION_SUFFIX=""
41-
# NB: When there is no CUDA package available, we put these
42-
# packages in the top-level directory, so they are eligible
43-
# for selection even if you are otherwise trying to install
44-
# a cu100 stack. This differs from when there ARE CUDA packages
45-
# available; then we don't want the cpu package; we want
46-
# to give you as much goodies as possible.
47-
export WHEEL_DIR=""
48-
else
49-
case "$CU_VERSION" in
50-
cu100)
51-
export PYTORCH_VERSION_SUFFIX=""
52-
export CUDA_HOME=/usr/local/cuda-10.0/
53-
export FORCE_CUDA=1
54-
# Hard-coding gencode flags is temporary situation until
55-
# https://github.com/pytorch/pytorch/pull/23408 lands
56-
export NVCC_FLAGS="-gencode=arch=compute_35,code=sm_35 -gencode=arch=compute_50,code=sm_50 -gencode=arch=compute_60,code=sm_60 -gencode=arch=compute_70,code=sm_70 -gencode=arch=compute_75,code=sm_75 -gencode=arch=compute_50,code=compute_50"
57-
;;
58-
cu92)
59-
export CUDA_HOME=/usr/local/cuda-9.2/
60-
export PYTORCH_VERSION_SUFFIX="+cu92"
61-
export FORCE_CUDA=1
62-
export NVCC_FLAGS="-gencode=arch=compute_35,code=sm_35 -gencode=arch=compute_50,code=sm_50 -gencode=arch=compute_60,code=sm_60 -gencode=arch=compute_70,code=sm_70 -gencode=arch=compute_50,code=compute_50"
63-
;;
64-
cpu)
65-
export PYTORCH_VERSION_SUFFIX="+cpu"
66-
;;
67-
*)
68-
echo "Unrecognized CU_VERSION=$CU_VERSION"
69-
esac
70-
export VERSION_SUFFIX="$PYTORCH_VERSION_SUFFIX"
71-
export WHEEL_DIR="$CU_VERSION/"
7251
fi
52+
53+
# Now work out the CUDA settings
54+
case "$CU_VERSION" in
55+
cu100)
56+
export CUDA_HOME=/usr/local/cuda-10.0/
57+
export FORCE_CUDA=1
58+
# Hard-coding gencode flags is temporary situation until
59+
# https://github.com/pytorch/pytorch/pull/23408 lands
60+
export NVCC_FLAGS="-gencode=arch=compute_35,code=sm_35 -gencode=arch=compute_50,code=sm_50 -gencode=arch=compute_60,code=sm_60 -gencode=arch=compute_70,code=sm_70 -gencode=arch=compute_75,code=sm_75 -gencode=arch=compute_50,code=compute_50"
61+
;;
62+
cu92)
63+
export CUDA_HOME=/usr/local/cuda-9.2/
64+
export FORCE_CUDA=1
65+
export NVCC_FLAGS="-gencode=arch=compute_35,code=sm_35 -gencode=arch=compute_50,code=sm_50 -gencode=arch=compute_60,code=sm_60 -gencode=arch=compute_70,code=sm_70 -gencode=arch=compute_50,code=compute_50"
66+
;;
67+
cpu)
68+
;;
69+
*)
70+
echo "Unrecognized CU_VERSION=$CU_VERSION"
71+
exit 1
72+
;;
73+
esac
7374
}
7475

7576
# Populate build version if necessary, and add version suffix
7677
#
7778
# Inputs:
7879
# BUILD_VERSION (e.g., 0.2.0 or empty)
80+
# VERSION_SUFFIX (e.g., +cpu)
7981
#
8082
# Outputs:
8183
# BUILD_VERSION (e.g., 0.2.0.dev20190807+cpu)

0 commit comments

Comments
 (0)