Skip to content

Commit f61a04f

Browse files
committed
feat: add PyTorch setup script for modular installation
Add scripts/setup_torch.sh to support PyTorch installation and configuration within containers. This script: - Downloads PyTorch source from GitHub when not mounted as a volume - Installs build dependencies for PyTorch compilation - Supports installing PyTorch wheels from PyPI (release, nightly, test) - Provides flexible configuration via INSTALL_TORCH environment variable The script supports multiple installation modes: - source: Build from source (with auto-download if not mounted) - release/nightly/test: Install wheels from PyPI - skip: Skip PyTorch installation This is part of the modular script architecture introduced in PR #115. Signed-off-by: Craig Magina <cmagina@redhat.com>
1 parent 5dd5ceb commit f61a04f

File tree

10 files changed

+227
-23
lines changed

10 files changed

+227
-23
lines changed

.github/workflows/amd-image.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ on: # yamllint disable-line rule:truthy
1111
- scripts/devinstall_software.sh
1212
- scripts/ldpretend.sh
1313
- scripts/devinstall_llvm.sh
14+
- scripts/devinstall_torch.sh
1415
- scripts/devinstall_triton.sh
1516
- scripts/devcreate_user.sh
1617
- scripts/devsetup.sh
@@ -22,6 +23,7 @@ on: # yamllint disable-line rule:truthy
2223
- scripts/devinstall_software.sh
2324
- scripts/ldpretend.sh
2425
- scripts/devinstall_llvm.sh
26+
- scripts/devinstall_torch.sh
2527
- scripts/devinstall_triton.sh
2628
- scripts/devcreate_user.sh
2729
- scripts/devsetup.sh

.github/workflows/cpu-image.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ on: # yamllint disable-line rule:truthy
1111
- scripts/devinstall_software.sh
1212
- scripts/ldpretend.sh
1313
- scripts/devinstall_llvm.sh
14+
- scripts/devinstall_torch.sh
1415
- scripts/devinstall_triton.sh
1516
- scripts/devcreate_user.sh
1617
- scripts/devsetup.sh
@@ -22,6 +23,7 @@ on: # yamllint disable-line rule:truthy
2223
- scripts/devinstall_software.sh
2324
- scripts/ldpretend.sh
2425
- scripts/devinstall_llvm.sh
26+
- scripts/devinstall_torch.sh
2527
- scripts/devinstall_triton.sh
2628
- scripts/devcreate_user.sh
2729
- scripts/devsetup.sh

.github/workflows/nvidia-image.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ on: # yamllint disable-line rule:truthy
1111
- scripts/devinstall_software.sh
1212
- scripts/ldpretend.sh
1313
- scripts/devinstall_llvm.sh
14+
- scripts/devinstall_torch.sh
1415
- scripts/devinstall_triton.sh
1516
- scripts/devcreate_user.sh
1617
- scripts/devsetup.sh
@@ -22,6 +23,7 @@ on: # yamllint disable-line rule:truthy
2223
- scripts/devinstall_software.sh
2324
- scripts/ldpretend.sh
2425
- scripts/devinstall_llvm.sh
26+
- scripts/devinstall_torch.sh
2527
- scripts/devinstall_triton.sh
2628
- scripts/devcreate_user.sh
2729
- scripts/devsetup.sh

Makefile

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,10 @@ create_user ?=true
4444
# NOTE: Requires host build system to have a valid Red Hat Subscription if true
4545
INSTALL_NSIGHT ?=false
4646
llvm_path ?=
47+
torch_path ?=
4748
user_path ?=
4849
INSTALL_LLVM ?= skip # Options: source, skip
50+
INSTALL_TORCH ?= skip # Options: nightly, release, source, skip, test
4951
INSTALL_TRITON ?= source # Options: release, source, skip
5052
INSTALL_JUPYTER ?= true
5153
USE_CCACHE ?= 0
@@ -119,6 +121,9 @@ define run_container
119121
if [ -n "$(llvm_path)" ]; then \
120122
volume_arg+=" -v $(llvm_path):/workspace/llvm-project$(SELINUXFLAG)"; \
121123
fi; \
124+
if [ -n "$(torch_path)" ]; then \
125+
volume_arg+=" -v $(torch_path):/workspace/torch$(SELINUXFLAG)"; \
126+
fi; \
122127
if [ -n "$(user_path)" ]; then \
123128
volume_arg+=" -v $(user_path):/workspace/user$(SELINUXFLAG)"; \
124129
fi; \
@@ -160,7 +165,7 @@ define run_container
160165
if [ "$(CUSTOM_LLVM)" = "false" ]; then \
161166
install_llvm="-e INSTALL_LLVM=$(INSTALL_LLVM)"; \
162167
fi; \
163-
env_vars="-e USERNAME=$(USER) -e TORCH_VERSION=$(torch_version) -e CUSTOM_LLVM=$(CUSTOM_LLVM) -e INSTALL_TOOLS=$(DEMO_TOOLS) -e INSTALL_JUPYTER=$(INSTALL_JUPYTER) -e NOTEBOOK_PORT=$(NOTEBOOK_PORT) -e INSTALL_TRITON=$(INSTALL_TRITON) -e USE_CCACHE=$(USE_CCACHE) -e MAX_JOBS=$(MAX_JOBS)"; \
168+
env_vars="-e USERNAME=$(USER) -e TORCH_VERSION=$(torch_version) -e CUSTOM_LLVM=$(CUSTOM_LLVM) -e INSTALL_TOOLS=$(DEMO_TOOLS) -e INSTALL_JUPYTER=$(INSTALL_JUPYTER) -e NOTEBOOK_PORT=$(NOTEBOOK_PORT) -e INSTALL_TORCH=$(INSTALL_TORCH) -e INSTALL_TRITON=$(INSTALL_TRITON) -e USE_CCACHE=$(USE_CCACHE) -e MAX_JOBS=$(MAX_JOBS)"; \
164169
if [ "$(create_user)" = "true" ]; then \
165170
$(CTR_CMD) run -e CREATE_USER=$(create_user) $$env_vars $$install_llvm $$port_arg \
166171
-e USER_UID=`id -u $(USER)` -e USER_GID=`id -g $(USER)` $$gpu_args $$profiling_args $$keep_ns_arg \

dockerfiles/Dockerfile.triton

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ ENV BASH_ENV=/workspace/bin/activate \
6767
COPY --from=quay.io/triton-dev-containers/gosu /usr/local/bin/gosu /usr/local/bin/gosu
6868

6969
COPY scripts/devinstall_llvm.sh /workspace/bin/devinstall_llvm
70+
COPY scripts/devinstall_torch.sh /workspace/bin/devinstall_torch
7071
COPY scripts/devinstall_triton.sh /workspace/bin/devinstall_triton
7172
COPY scripts/devcreate_user.sh /workspace/bin/devcreate_user
7273
COPY scripts/devsetup.sh /workspace/bin/devsetup

dockerfiles/Dockerfile.triton-amd

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ WORKDIR /workspace
8585
COPY --from=quay.io/triton-dev-containers/gosu /usr/local/bin/gosu /usr/local/bin/gosu
8686

8787
COPY scripts/devinstall_llvm.sh /workspace/bin/devinstall_llvm
88+
COPY scripts/devinstall_torch.sh /workspace/bin/devinstall_torch
8889
COPY scripts/devinstall_triton.sh /workspace/bin/devinstall_triton
8990
COPY scripts/devcreate_user.sh /workspace/bin/devcreate_user
9091
COPY scripts/devsetup.sh /workspace/bin/devsetup

dockerfiles/Dockerfile.triton-cpu

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ WORKDIR /workspace
6464
COPY --from=quay.io/triton-dev-containers/gosu /usr/local/bin/gosu /usr/local/bin/gosu
6565

6666
COPY scripts/devinstall_llvm.sh /workspace/bin/devinstall_llvm
67+
COPY scripts/devinstall_torch.sh /workspace/bin/devinstall_torch
6768
COPY scripts/devinstall_triton.sh /workspace/bin/devinstall_triton
6869
COPY scripts/devcreate_user.sh /workspace/bin/devcreate_user
6970
COPY scripts/devsetup.sh /workspace/bin/devsetup

scripts/devinstall_torch.sh

Lines changed: 205 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,205 @@
1+
#! /bin/bash -e
2+
3+
trap "echo -e '\nScript interrupted. Exiting gracefully.'; exit 1" SIGINT
4+
5+
# Copyright (C) 2024-2025 Red Hat, Inc.
6+
#
7+
# Licensed under the Apache License, Version 2.0 (the "License");
8+
# you may not use this file except in compliance with the License.
9+
# You may obtain a copy of the License at
10+
#
11+
# http://www.apache.org/licenses/LICENSE-2.0
12+
#
13+
# Unless required by applicable law or agreed to in writing, software
14+
# distributed under the License is distributed on an "AS IS" BASIS,
15+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16+
# See the License for the specific language governing permissions and
17+
# limitations under the License.
18+
#
19+
# SPDX-License-Identifier: Apache-2.0
20+
set -euo pipefail
21+
22+
WORKSPACE=${WORKSPACE:-${HOME}}
23+
24+
TORCH_DIR=${WORKSPACE}/torch
25+
TORCH_REPO=https://github.com/pytorch/pytorch.git
26+
27+
declare -a PIP_INSTALL_ARGS
28+
PIP_TORCH_INDEX_URL_BASE=https://download.pytorch.org/whl
29+
30+
SUDO=''
31+
if ((EUID != 0)) && command -v sudo &>/dev/null; then
32+
SUDO="sudo"
33+
elif ((EUID != 0)); then
34+
echo "ERROR: $(basename "$0") requires root privileges or sudo." >&2
35+
exit 1
36+
fi
37+
38+
pip_install() {
39+
if command -v uv &>/dev/null; then
40+
uv pip install "$@"
41+
else
42+
pip install "$@"
43+
fi
44+
}
45+
46+
# Extract the major.minor version from ROCM_VERSION, e.g. 6.4 from 6.4.4
47+
get_rocm_version() {
48+
[[ "$ROCM_VERSION" =~ ^([0-9]+\.[0-9]+) ]] && echo "${BASH_REMATCH[1]}" ||
49+
echo "$ROCM_VERSION"
50+
}
51+
52+
setup_src() {
53+
echo "Downloading Torch source code and setting up the environment for building from source..."
54+
55+
if [ ! -d "$TORCH_DIR" ]; then
56+
echo "Cloning the Torch repo $TORCH_REPO to $TORCH_DIR ..."
57+
git clone "$TORCH_REPO" "$TORCH_DIR"
58+
if [ ! -d "$TORCH_DIR" ]; then
59+
echo "$TORCH_DIR not found. ERROR Cloning repository..."
60+
exit 1
61+
else
62+
pushd "$TORCH_DIR" 1>/dev/null || exit 1
63+
git submodule sync
64+
git submodule update --init --recursive
65+
66+
if [ -n "${TORCH_GITREF:-}" ]; then
67+
git checkout ""
68+
fi
69+
70+
echo "Install pre-commit hooks into your local Torch git repo (one-time)"
71+
pip_install pre-commit
72+
pre-commit install
73+
popd 1>/dev/null
74+
fi
75+
else
76+
echo "Torch repo already present, not cloning ..."
77+
fi
78+
}
79+
80+
install_build_deps() {
81+
echo "Installing Torch build dependencies ..."
82+
83+
pushd "$TORCH_DIR" 1>/dev/null || exit 1
84+
85+
if [ -f requirements.txt ]; then
86+
pip_install --group dev
87+
pip_install mkl-static mkl-include
88+
make triton
89+
fi
90+
91+
$SUDO dnf -y install numactl-devel
92+
93+
if [ -n "${ROCM_VERSION:-}" ]; then
94+
python tools/amd_build/build_amd.py
95+
fi
96+
97+
popd 1>/dev/null
98+
}
99+
100+
install_deps() {
101+
echo "Installing Torch dependencies ..."
102+
pip_install numpy
103+
}
104+
105+
install_whl() {
106+
echo "Installing Torch ${PIP_TORCH_INDEX_URL_BUILD:-release} from PyPI ..."
107+
108+
if [ -n "${PIP_TORCH_VERSION:-}" ]; then
109+
echo "Using the specified version $PIP_TORCH_VERSION of torch"
110+
PIP_TORCH_VERSION="==$PIP_TORCH_VERSION"
111+
fi
112+
113+
if [ -n "${PIP_TORCHVISION_VERSION:-}" ]; then
114+
echo "Installing the specified version $PIP_TORCHVISION_VERSION of torchvision"
115+
PIP_TORCHVISION_VERSION="==$PIP_TORCHVISION_VERSION"
116+
fi
117+
118+
if [ -n "${PIP_TORCHAUDIO_VERSION:-}" ]; then
119+
echo "Installing the specified version $PIP_TORCHAUDIO_VERSION of torchaudio"
120+
PIP_TORCHAUDIO_VERSION="==$PIP_TORCHAUDIO_VERSION"
121+
fi
122+
123+
declare -a TORCH_PACKAGES=(
124+
"torch${PIP_TORCH_VERSION:-}"
125+
"torchvision${PIP_TORCHVISION_VERSION:-}"
126+
"torchaudio${PIP_TORCHAUDIO_VERSION:-}"
127+
)
128+
129+
if [ -n "${PIP_TORCH_INDEX_URL:-}" ]; then
130+
echo "Using the specified index, $PIP_TORCH_INDEX_URL"
131+
PIP_INSTALL_ARGS+=("--index-url" "$PIP_TORCH_INDEX_URL")
132+
elif command -v uv &>/dev/null && [ -n "${UV_TORCH_BACKEND:-}" ]; then
133+
echo "Using the specified uv backend, $UV_TORCH_BACKEND"
134+
PIP_INSTALL_ARGS+=("--torch-backend" "$UV_TORCH_BACKEND")
135+
elif ! command -v uv &>/dev/null && [ -n "${UV_TORCH_BACKEND:-}" ]; then
136+
echo "Error: UV_TORCH_BACKEND is set to $UV_TORCH_BACKEND but uv is not available."
137+
exit 1
138+
else
139+
# Set compute platform for torch wheel installation
140+
if [ -n "${ROCM_VERSION:-}" ]; then
141+
echo "Using the ROCm version $ROCM_VERSION backend"
142+
COMPUTE_PLATFORM="rocm$(get_rocm_version)"
143+
elif ((${TRITON_CPU_BACKEND:-0} == 1)); then
144+
echo "Using the CPU backend"
145+
COMPUTE_PLATFORM="cpu"
146+
elif [ -n "${CUDA_VERSION:-}" ]; then
147+
echo "Using the CUDA version $CUDA_VERSION backend"
148+
COMPUTE_PLATFORM="cu${CUDA_VERSION/[.-]/}"
149+
fi
150+
151+
if [ -n "${COMPUTE_PLATFORM:-}" ]; then
152+
[[ -n "${PIP_TORCH_INDEX_URL_BUILD:-}" ]] && PIP_TORCH_INDEX_URL_BUILD="/${PIP_TORCH_INDEX_URL_BUILD}"
153+
PIP_TORCH_INDEX_URL="${PIP_TORCH_INDEX_URL_BASE}${PIP_TORCH_INDEX_URL_BUILD:-}/${COMPUTE_PLATFORM}"
154+
PIP_INSTALL_ARGS+=("--index-url" "$PIP_TORCH_INDEX_URL")
155+
fi
156+
fi
157+
158+
pip_install -U --force-reinstall "${PIP_INSTALL_ARGS[@]}" "${TORCH_PACKAGES[@]}"
159+
160+
# Fix up LD_LIBRARY_PATH for CUDA
161+
ldpretend
162+
}
163+
164+
usage() {
165+
cat >&2 <<EOF
166+
Usage: $(basename "$0") [COMMAND]
167+
source Download Torch's source (if needed) and install the build deps
168+
release Install Torch
169+
nightly Install the Torch nightly wheel
170+
test Install the Torch test wheel
171+
EOF
172+
}
173+
174+
##
175+
## Main
176+
##
177+
if [ $# -ne 1 ]; then
178+
usage
179+
exit 1
180+
fi
181+
182+
COMMAND=${1,,}
183+
184+
case $COMMAND in
185+
source)
186+
setup_src
187+
install_build_deps
188+
install_deps
189+
;;
190+
release)
191+
install_deps
192+
install_whl
193+
;;
194+
nightly | test)
195+
PIP_TORCH_INDEX_URL_BUILD=$COMMAND
196+
install_deps
197+
install_whl
198+
;;
199+
*)
200+
usage
201+
exit 1
202+
;;
203+
esac
204+
205+

scripts/devinstall_triton.sh

Lines changed: 2 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@ trap "echo -e '\nScript interrupted. Exiting gracefully.'; exit 1" SIGINT
2020
set -euo pipefail
2121

2222
declare -a PIP_INSTALL_ARGS
23-
PIP_TORCH_INDEX_URL_BASE=https://download.pytorch.org/whl
2423

2524
WORKSPACE=${WORKSPACE:-${HOME}}
2625

@@ -123,27 +122,8 @@ install_deps() {
123122
pip_install cmake ctypeslib2 matplotlib ninja \
124123
numpy pandas pybind11 pytest pyyaml scipy tabulate wheel
125124

126-
if [ -n "${TORCH_VERSION:-}" ]; then
127-
echo "Installing the specified version $TORCH_VERSION of torch"
128-
PIP_TORCH_VERSION="==$TORCH_VERSION"
129-
fi
130-
131-
if [ -n "${ROCM_VERSION:-}" ]; then
132-
echo "Installing torch for ROCm version $ROCM_VERSION"
133-
pip_install "torch${PIP_TORCH_VERSION:-}" \
134-
--index-url "${PIP_TORCH_INDEX_URL_BASE}/rocm$(get_rocm_version)"
135-
elif ((${TRITON_CPU_BACKEND:-0} == 1)); then
136-
echo "Installing torch for CPU"
137-
pip_install "torch${PIP_TORCH_VERSION:-}" \
138-
--index-url "${PIP_TORCH_INDEX_URL_BASE}/cpu"
139-
elif [ -n "${CUDA_VERSION:-}" ]; then
140-
echo "Installing torch for CUDA version $CUDA_VERSION"
141-
pip_install "torch${PIP_TORCH_VERSION:-}" \
142-
--index-url "${PIP_TORCH_INDEX_URL_BASE}/cu$(get_cuda_version)"
143-
else
144-
echo "Installing torch ..."
145-
pip_install "torch${PIP_TORCH_VERSION:-}"
146-
fi
125+
echo "Installing Torch as a Triton dependency ..."
126+
devinstall_torch release
147127
}
148128

149129
install_whl() {

scripts/devsetup.sh

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ declare -a SAVE_VARS=(
2727
"INSTALL_JUPYTER"
2828
"INSTALL_LLVM"
2929
"INSTALL_TOOLS"
30+
"INSTALL_TORCH"
3031
"INSTALL_TRITON"
3132
"MAX_JOBS"
3233
"PIP_TRITON_VERSION"
@@ -70,3 +71,7 @@ fi
7071
if [ "${INSTALL_TRITON:-skip}" != "skip" ]; then
7172
run_as_user devinstall_triton "$INSTALL_TRITON"
7273
fi
74+
75+
if [ "${INSTALL_TORCH:-skip}" != "skip" ]; then
76+
run_as_user devinstall_torch "$INSTALL_TORCH"
77+
fi

0 commit comments

Comments
 (0)