Skip to content

Commit fee607e

Browse files
committed
Cleanup based on other comments
- Moved to using local variables inside install_whl cleaning up old globals from when it wasn't a function - Fixed the devinstall_triton torch dep install to respect the user's INSTALL_TORCH Signed-off-by: Craig Magina <cmagina@redhat.com>
1 parent adc8396 commit fee607e

File tree

2 files changed

+36
-23
lines changed

2 files changed

+36
-23
lines changed

scripts/devinstall_torch.sh

Lines changed: 27 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,6 @@ WORKSPACE=${WORKSPACE:-${HOME}}
2424
TORCH_DIR=${WORKSPACE}/torch
2525
TORCH_REPO=https://github.com/pytorch/pytorch.git
2626

27-
declare -a PIP_INSTALL_ARGS
28-
PIP_TORCH_INDEX_URL_BASE=https://download.pytorch.org/whl
29-
3027
SUDO=''
3128
if ((EUID != 0)) && command -v sudo &>/dev/null; then
3229
SUDO="sudo"
@@ -102,7 +99,22 @@ install_deps() {
10299
}
103100

104101
install_whl() {
105-
echo "Installing Torch ${PIP_TORCH_INDEX_URL_BUILD:-release} from PyPI ..."
102+
local pip_build="$1"
103+
local pip_torch_index_url_base
104+
local -a pip_install_args
105+
local compute_platform
106+
107+
108+
pip_torch_index_url_base="https://download.pytorch.org/whl"
109+
110+
case "$pip_build" in
111+
release) ;;
112+
nightly | test)
113+
pip_torch_index_url_base="${pip_torch_index_url_base}/${pip_build}"
114+
;;
115+
esac
116+
117+
echo "Installing Torch $pip_build from PyPI ..."
106118

107119
if [ -n "${PIP_TORCH_VERSION:-}" ]; then
108120
echo "Using the specified version $PIP_TORCH_VERSION of torch"
@@ -127,34 +139,33 @@ install_whl() {
127139

128140
if [ -n "${PIP_TORCH_INDEX_URL:-}" ]; then
129141
echo "Using the specified index, $PIP_TORCH_INDEX_URL"
130-
PIP_INSTALL_ARGS+=("--index-url" "$PIP_TORCH_INDEX_URL")
142+
pip_install_args+=("--index-url" "$PIP_TORCH_INDEX_URL")
131143
elif command -v uv &>/dev/null && [ -n "${UV_TORCH_BACKEND:-}" ]; then
132144
echo "Using the specified uv backend, $UV_TORCH_BACKEND"
133-
PIP_INSTALL_ARGS+=("--torch-backend" "$UV_TORCH_BACKEND")
145+
pip_install_args+=("--torch-backend" "$UV_TORCH_BACKEND")
134146
elif ! command -v uv &>/dev/null && [ -n "${UV_TORCH_BACKEND:-}" ]; then
135147
echo "Error: UV_TORCH_BACKEND is set to $UV_TORCH_BACKEND but uv is not available."
136148
exit 1
137149
else
138150
# Set compute platform for torch wheel installation
139151
if [ -n "${ROCM_VERSION:-}" ]; then
140152
echo "Using the ROCm version $ROCM_VERSION backend"
141-
COMPUTE_PLATFORM="rocm$(get_rocm_version)"
153+
compute_platform="rocm$(get_rocm_version)"
142154
elif ((${TRITON_CPU_BACKEND:-0} == 1)); then
143155
echo "Using the CPU backend"
144-
COMPUTE_PLATFORM="cpu"
156+
compute_platform="cpu"
145157
elif [ -n "${CUDA_VERSION:-}" ]; then
146158
echo "Using the CUDA version $CUDA_VERSION backend"
147-
COMPUTE_PLATFORM="cu${CUDA_VERSION/[.-]/}"
159+
compute_platform="cu${CUDA_VERSION/[.-]/}"
148160
fi
149161

150-
if [ -n "${COMPUTE_PLATFORM:-}" ]; then
151-
[[ -n "${PIP_TORCH_INDEX_URL_BUILD:-}" ]] && PIP_TORCH_INDEX_URL_BUILD="/${PIP_TORCH_INDEX_URL_BUILD}"
152-
PIP_TORCH_INDEX_URL="${PIP_TORCH_INDEX_URL_BASE}${PIP_TORCH_INDEX_URL_BUILD:-}/${COMPUTE_PLATFORM}"
153-
PIP_INSTALL_ARGS+=("--index-url" "$PIP_TORCH_INDEX_URL")
162+
if [ -n "${compute_platform:-}" ]; then
163+
PIP_TORCH_INDEX_URL="${pip_torch_index_url_base}/${compute_platform}"
164+
pip_install_args+=("--index-url" "$PIP_TORCH_INDEX_URL")
154165
fi
155166
fi
156167

157-
pip_install -U --force-reinstall "${PIP_INSTALL_ARGS[@]}" "${TORCH_PACKAGES[@]}"
168+
pip_install -U --force-reinstall "${pip_install_args[@]}" "${TORCH_PACKAGES[@]}"
158169

159170
# Fix up LD_LIBRARY_PATH for CUDA
160171
ldpretend
@@ -186,14 +197,9 @@ source)
186197
install_build_deps
187198
install_deps
188199
;;
189-
release)
190-
install_deps
191-
install_whl
192-
;;
193-
nightly | test)
194-
PIP_TORCH_INDEX_URL_BUILD=$COMMAND
200+
nightly | release | test)
195201
install_deps
196-
install_whl
202+
install_whl "$COMMAND"
197203
;;
198204
*)
199205
usage

scripts/devinstall_triton.sh

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,8 +122,15 @@ install_deps() {
122122
pip_install cmake ctypeslib2 matplotlib ninja \
123123
numpy pandas pybind11 pytest pyyaml scipy tabulate wheel
124124

125-
echo "Installing Torch as a Triton dependency ..."
126-
devinstall_torch release
125+
if [ "${INSTALL_TORCH:-}" != "source" ]; then
126+
if [ -n "${INSTALL_TORCH:-}" ]; then
127+
echo "Installing Torch $INSTALL_TORCH as a dependency ..."
128+
devinstall_torch "${INSTALL_TORCH}"
129+
else
130+
echo "Installing Torch as a dependency ..."
131+
devinstall_torch release
132+
fi
133+
fi
127134
}
128135

129136
install_whl() {

0 commit comments

Comments
 (0)