Skip to content

Commit e09eb4a

Browse files
cmaginaclaude
andcommitted
feat: add PyTorch setup script for modular installation
Add scripts/setup_torch.sh to support PyTorch installation and configuration within containers. This script: - Downloads PyTorch source from GitHub when not mounted as a volume - Installs build dependencies for PyTorch compilation - Supports installing PyTorch wheels from PyPI (release, nightly, test) - Provides flexible configuration via INSTALL_TORCH environment variable The script supports multiple installation modes: - source: Build from source (with auto-download if not mounted) - release/nightly/test: Install wheels from PyPI - skip: Skip PyTorch installation This is part of the modular script architecture introduced in PR #115. Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
1 parent 0a65511 commit e09eb4a

File tree

1 file changed

+180
-0
lines changed

1 file changed

+180
-0
lines changed

scripts/setup_torch.sh

Lines changed: 180 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,180 @@
1+
#! /bin/bash -e
2+
3+
trap "echo -e '\nScript interrupted. Exiting gracefully.'; exit 1" SIGINT
4+
5+
# Copyright (C) 2024-2025 Red Hat, Inc.
6+
#
7+
# Licensed under the Apache License, Version 2.0 (the "License");
8+
# you may not use this file except in compliance with the License.
9+
# You may obtain a copy of the License at
10+
#
11+
# http://www.apache.org/licenses/LICENSE-2.0
12+
#
13+
# Unless required by applicable law or agreed to in writing, software
14+
# distributed under the License is distributed on an "AS IS" BASIS,
15+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16+
# See the License for the specific language governing permissions and
17+
# limitations under the License.
18+
#
19+
# SPDX-License-Identifier: Apache-2.0
20+
set -euo pipefail
21+
22+
WORKSPACE=${WORKSPACE:-${HOME}}
23+
24+
TORCH_DIR=${WORKSPACE}/torch
25+
TORCH_REPO=https://github.com/pytorch/pytorch.git
26+
27+
PIP_TORCH_INDEX_URL_BASE=https://download.pytorch.org/whl
28+
29+
setup_src() {
30+
if [ ! -d "$TORCH_DIR" ]; then
31+
echo "Cloning the Torch repo $TORCH_REPO to $TORCH_DIR ..."
32+
git clone "$TORCH_REPO" "$TORCH_DIR"
33+
if [ ! -d "$TORCH_DIR" ]; then
34+
echo "$TORCH_DIR not found. ERROR Cloning repository..."
35+
exit 1
36+
else
37+
pushd "$TORCH_DIR" 1>/dev/null || exit 1
38+
git submodule sync
39+
git submodule update --init --recursive
40+
41+
if [ -n "${TORCH_GITREF:-}" ]; then
42+
git checkout ""
43+
fi
44+
45+
echo "Install pre-commit hooks into your local Torch git repo (one-time)"
46+
uv pip install pre-commit
47+
pre-commit install
48+
popd 1>/dev/null
49+
fi
50+
else
51+
echo "Torch repo already present, not cloning ..."
52+
fi
53+
}
54+
55+
install_build_deps() {
56+
pushd "$TORCH_DIR" 1>/dev/null || exit 1
57+
58+
if [ -f requirements.txt ]; then
59+
echo "Installing Torch dependencies ..."
60+
uv pip install --group dev
61+
uv pip install mkl-static mkl-include
62+
make triton
63+
fi
64+
65+
${SUDO:-} dnf -y install numactl-devel
66+
67+
if [ -n "${ROCM_VERSION:-}" ]; then
68+
python tools/amd_build/build_amd.py
69+
fi
70+
71+
popd 1>/dev/null
72+
}
73+
74+
install_deps() {
75+
echo "Installing torch dependencies ..."
76+
uv pip install numpy
77+
}
78+
79+
usage() {
80+
cat >&2 <<EOF
81+
Usage: $(basename "$0") [COMMAND]
82+
source Download Torch's source (if needed) and install the build deps
83+
release Install Torch
84+
nightly Install the Torch nightly wheel
85+
test Install the Torch test wheel
86+
EOF
87+
}
88+
89+
##
90+
## Main
91+
##
92+
93+
if [ $# -ne 1 ]; then
94+
usage
95+
exit 1
96+
fi
97+
98+
if command -v sudo &>/dev/null; then
99+
SUDO=sudo
100+
export SUDO
101+
fi
102+
103+
COMMAND=${1,,}
104+
105+
case $COMMAND in
106+
source)
107+
echo "Setting up the environment for building Torch ..."
108+
setup_src
109+
install_build_deps
110+
install_deps
111+
exit $?
112+
;;
113+
release)
114+
echo "Installing the Torch release build ..."
115+
if [ -n "${UV_TORCH_BACKEND:-}" ]; then
116+
echo "Using specified uv backend, $UV_TORCH_BACKEND"
117+
UV_TORCH_BACKEND="--torch-backend=$UV_TORCH_BACKEND"
118+
elif [ -n "${ROCM_VERSION:-}" ]; then
119+
TORCH_ROCM_VERSION=$(echo "$ROCM_VERSION" | sed -e 's/\([0-9]\.[0-9]\).*/\1/')
120+
121+
echo "Using the uv ROCm version $TORCH_ROCM_VERSION backend"
122+
UV_TORCH_BACKEND="--torch-backend=rocm${TORCH_ROCM_VERSION}"
123+
elif [ ${TRITON_CPU_BACKEND:-0} -eq 1 ]; then
124+
echo "Using the uv CPU backend"
125+
UV_TORCH_BACKEND="--torch-backend=cpu"
126+
elif [ -n "${CUDA_VERSION:-}" ]; then
127+
TORCH_CUDA_VERSION=$(echo "$CUDA_VERSION" | sed -e 's/\([0-9]*\)[.-]\([0-9]\)/\1\2/')
128+
129+
echo "Using the uv CUDA version $TORCH_CUDA_VERSION backend"
130+
UV_TORCH_BACKEND="--torch-backend=cu${TORCH_CUDA_VERSION}"
131+
else
132+
echo "Using the uv auto backend"
133+
UV_TORCH_BACKEND="--torch-backend=auto"
134+
fi
135+
;;
136+
nightly | test)
137+
echo "Installing Torch ..."
138+
PIP_TORCH_INDEX_URL_BUILD=/$COMMAND
139+
;;
140+
*)
141+
usage
142+
exit 1
143+
;;
144+
esac
145+
146+
if [ -n "${PIP_TORCH_INDEX_URL:-}" ]; then
147+
echo "Using the specified index, $PIP_TORCH_INDEX_URL"
148+
PIP_TORCH_INDEX_URL="--index-url $PIP_TORCH_INDEX_URL"
149+
else
150+
PIP_TORCH_INDEX_URL="--index-url ${PIP_TORCH_INDEX_URL_BASE}${PIP_TORCH_INDEX_URL_BUILD:-}"
151+
fi
152+
153+
if [ -n "${PIP_TORCH_INDEX_URL_BUILD:-}" ]; then
154+
echo "Using the ${PIP_TORCH_INDEX_URL_BUILD///} build ..."
155+
if [ -n "${ROCM_VERSION:-}" ]; then
156+
TORCH_ROCM_VERSION=$(echo "$ROCM_VERSION" | sed -e 's/\([0-9]\.[0-9]\).*/\1/')
157+
158+
echo "Using the ROCm version $TORCH_ROCM_VERSION backend"
159+
PIP_TORCH_INDEX_URL="${PIP_TORCH_INDEX_URL}/rocm${TORCH_ROCM_VERSION}"
160+
elif [ ${TRITON_CPU_BACKEND:-0} -eq 1 ]; then
161+
echo "Using the CPU backend"
162+
PIP_TORCH_INDEX_URL="${PIP_TORCH_INDEX_URL}/cpu"
163+
elif [ -n "${CUDA_VERSION:-}" ]; then
164+
TORCH_CUDA_VERSION=$(echo "$CUDA_VERSION" | sed -e 's/\([0-9]*\)[.-]\([0-9]\)/\1\2/')
165+
166+
echo "Using the CUDA version $TORCH_CUDA_VERSION backend"
167+
PIP_TORCH_INDEX_URL="${PIP_TORCH_INDEX_URL}/cu${TORCH_CUDA_VERSION}"
168+
fi
169+
fi
170+
171+
if [ -n "${PIP_TORCH_VERSION:-}" ]; then
172+
echo "Installing the specified version $PIP_TORCH_VERSION"
173+
PIP_TORCH_VERSION="==$PIP_TORCH_VERSION"
174+
fi
175+
176+
uv pip install torch${PIP_TORCH_VERSION:-} ${UV_TORCH_BACKEND:-} \
177+
$PIP_TORCH_INDEX_URL
178+
179+
# Fix up LD_LIBRARY_PATH for CUDA
180+
"${WORKSPACE}"/ldpretend.sh

0 commit comments

Comments
 (0)