Skip to content

Commit 42c7c32

Browse files
authored
feat: add PyTorch setup script for modular installation (#117)
* feat: add PyTorch setup script for modular installation Add scripts/setup_torch.sh to support PyTorch installation and configuration within containers. This script: - Downloads PyTorch source from GitHub when not mounted as a volume - Installs build dependencies for PyTorch compilation - Supports installing PyTorch wheels from PyPI (release, nightly, test) - Provides flexible configuration via INSTALL_TORCH environment variable The script supports multiple installation modes: - source: Build from source (with auto-download if not mounted) - release/nightly/test: Install wheels from PyPI - skip: Skip PyTorch installation This is part of the modular script architecture introduced in PR #115. Signed-off-by: Craig Magina <cmagina@redhat.com>
1 parent 45d870d commit 42c7c32

File tree

10 files changed

+256
-26
lines changed

10 files changed

+256
-26
lines changed

.github/workflows/amd-image.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ on: # yamllint disable-line rule:truthy
1111
- scripts/devinstall_software.sh
1212
- scripts/ldpretend.sh
1313
- scripts/devinstall_llvm.sh
14+
- scripts/devinstall_torch.sh
1415
- scripts/devinstall_triton.sh
1516
- scripts/devcreate_user.sh
1617
- scripts/devsetup.sh
@@ -22,6 +23,7 @@ on: # yamllint disable-line rule:truthy
2223
- scripts/devinstall_software.sh
2324
- scripts/ldpretend.sh
2425
- scripts/devinstall_llvm.sh
26+
- scripts/devinstall_torch.sh
2527
- scripts/devinstall_triton.sh
2628
- scripts/devcreate_user.sh
2729
- scripts/devsetup.sh

.github/workflows/cpu-image.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ on: # yamllint disable-line rule:truthy
1111
- scripts/devinstall_software.sh
1212
- scripts/ldpretend.sh
1313
- scripts/devinstall_llvm.sh
14+
- scripts/devinstall_torch.sh
1415
- scripts/devinstall_triton.sh
1516
- scripts/devcreate_user.sh
1617
- scripts/devsetup.sh
@@ -22,6 +23,7 @@ on: # yamllint disable-line rule:truthy
2223
- scripts/devinstall_software.sh
2324
- scripts/ldpretend.sh
2425
- scripts/devinstall_llvm.sh
26+
- scripts/devinstall_torch.sh
2527
- scripts/devinstall_triton.sh
2628
- scripts/devcreate_user.sh
2729
- scripts/devsetup.sh

.github/workflows/nvidia-image.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ on: # yamllint disable-line rule:truthy
1111
- scripts/devinstall_software.sh
1212
- scripts/ldpretend.sh
1313
- scripts/devinstall_llvm.sh
14+
- scripts/devinstall_torch.sh
1415
- scripts/devinstall_triton.sh
1516
- scripts/devcreate_user.sh
1617
- scripts/devsetup.sh
@@ -22,6 +23,7 @@ on: # yamllint disable-line rule:truthy
2223
- scripts/devinstall_software.sh
2324
- scripts/ldpretend.sh
2425
- scripts/devinstall_llvm.sh
26+
- scripts/devinstall_torch.sh
2527
- scripts/devinstall_triton.sh
2628
- scripts/devcreate_user.sh
2729
- scripts/devsetup.sh

Makefile

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,12 +36,14 @@ TRITON_CPU_BACKEND ?=0
3636
TRITON_TAG ?= latest
3737
triton_path ?=$(source_dir)
3838
llvm_path ?=
39+
torch_path ?=
3940
user_path ?=
4041
gitconfig_path ?="$(HOME)/.gitconfig"
4142
USERNAME ?=triton
4243
# NOTE: Requires host build system to have a valid Red Hat Subscription if true
4344
INSTALL_NSIGHT ?=false
4445
INSTALL_LLVM ?= skip # Options: source, skip
46+
INSTALL_TORCH ?= skip # Options: nightly, release, source, skip, test
4547
INSTALL_TRITON ?= source # Options: release, source, skip
4648
INSTALL_JUPYTER ?= true
4749
USE_CCACHE ?= 0
@@ -97,6 +99,9 @@ define run_container
9799
if [ -n "$(llvm_path)" ]; then \
98100
volume_arg+=" -v $(llvm_path):/workspace/llvm-project$(SELINUXFLAG)"; \
99101
fi; \
102+
if [ -n "$(torch_path)" ]; then \
103+
volume_arg+=" -v $(torch_path):/workspace/torch$(SELINUXFLAG)"; \
104+
fi; \
100105
if [ -n "$(user_path)" ]; then \
101106
volume_arg+=" -v $(user_path):/workspace/user$(SELINUXFLAG)"; \
102107
fi; \
@@ -135,7 +140,7 @@ define run_container
135140
else \
136141
port_arg=""; \
137142
fi; \
138-
env_vars="-e USERNAME=$(USER) -e USER_UID=`id -u $(USER)` -e USER_GID=`id -g $(USER)` -e TORCH_VERSION=$(torch_version) -e INSTALL_LLVM=$(INSTALL_LLVM) -e INSTALL_TOOLS=$(DEMO_TOOLS) -e INSTALL_JUPYTER=$(INSTALL_JUPYTER) -e NOTEBOOK_PORT=$(NOTEBOOK_PORT) -e INSTALL_TRITON=$(INSTALL_TRITON) -e USE_CCACHE=$(USE_CCACHE) -e MAX_JOBS=$(MAX_JOBS)"; \
143+
env_vars="-e USERNAME=$(USER) -e USER_UID=`id -u $(USER)` -e USER_GID=`id -g $(USER)` -e TORCH_VERSION=$(torch_version) -e INSTALL_LLVM=$(INSTALL_LLVM) -e INSTALL_TOOLS=$(DEMO_TOOLS) -e INSTALL_JUPYTER=$(INSTALL_JUPYTER) -e NOTEBOOK_PORT=$(NOTEBOOK_PORT) -e INSTALL_TORCH=$(INSTALL_TORCH) -e INSTALL_TRITON=$(INSTALL_TRITON) -e USE_CCACHE=$(USE_CCACHE) -e MAX_JOBS=$(MAX_JOBS)"; \
139144
if [ "$(STRIPPED_CMD)" = "docker" ]; then \
140145
$(CTR_CMD) run $$env_vars $$gpu_args $$profiling_args $$port_arg \
141146
-ti $$volume_arg $$gitconfig_arg $(IMAGE_REPO)/$(strip $(1)):$(TRITON_TAG) bash; \

dockerfiles/Dockerfile.triton

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ ENV BASH_ENV=/workspace/bin/activate \
5252
COPY --from=quay.io/triton-dev-containers/gosu /usr/local/bin/gosu /usr/local/bin/gosu
5353

5454
COPY scripts/devinstall_llvm.sh /workspace/bin/devinstall_llvm
55+
COPY scripts/devinstall_torch.sh /workspace/bin/devinstall_torch
5556
COPY scripts/devinstall_triton.sh /workspace/bin/devinstall_triton
5657
COPY scripts/devcreate_user.sh /workspace/bin/devcreate_user
5758
COPY scripts/devsetup.sh /workspace/bin/devsetup

dockerfiles/Dockerfile.triton-amd

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ WORKDIR /workspace
7070
COPY --from=quay.io/triton-dev-containers/gosu /usr/local/bin/gosu /usr/local/bin/gosu
7171

7272
COPY scripts/devinstall_llvm.sh /workspace/bin/devinstall_llvm
73+
COPY scripts/devinstall_torch.sh /workspace/bin/devinstall_torch
7374
COPY scripts/devinstall_triton.sh /workspace/bin/devinstall_triton
7475
COPY scripts/devcreate_user.sh /workspace/bin/devcreate_user
7576
COPY scripts/devsetup.sh /workspace/bin/devsetup

dockerfiles/Dockerfile.triton-cpu

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ WORKDIR /workspace
5050
COPY --from=quay.io/triton-dev-containers/gosu /usr/local/bin/gosu /usr/local/bin/gosu
5151

5252
COPY scripts/devinstall_llvm.sh /workspace/bin/devinstall_llvm
53+
COPY scripts/devinstall_torch.sh /workspace/bin/devinstall_torch
5354
COPY scripts/devinstall_triton.sh /workspace/bin/devinstall_triton
5455
COPY scripts/devcreate_user.sh /workspace/bin/devcreate_user
5556
COPY scripts/devsetup.sh /workspace/bin/devsetup

scripts/devinstall_torch.sh

Lines changed: 219 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,219 @@
1+
#! /bin/bash -e
2+
3+
trap "echo -e '\nScript interrupted. Exiting gracefully.'; exit 1" SIGINT
4+
5+
# Copyright (C) 2024-2025 Red Hat, Inc.
6+
#
7+
# Licensed under the Apache License, Version 2.0 (the "License");
8+
# you may not use this file except in compliance with the License.
9+
# You may obtain a copy of the License at
10+
#
11+
# http://www.apache.org/licenses/LICENSE-2.0
12+
#
13+
# Unless required by applicable law or agreed to in writing, software
14+
# distributed under the License is distributed on an "AS IS" BASIS,
15+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16+
# See the License for the specific language governing permissions and
17+
# limitations under the License.
18+
#
19+
# SPDX-License-Identifier: Apache-2.0
20+
set -euo pipefail
21+
22+
WORKSPACE=${WORKSPACE:-${HOME}}
23+
24+
TORCH_DIR=${WORKSPACE}/torch
25+
TORCH_REPO=https://github.com/pytorch/pytorch.git
26+
27+
SUDO=''
28+
if ((EUID != 0)) && command -v sudo &>/dev/null; then
29+
SUDO="sudo"
30+
fi
31+
32+
pip_install() {
33+
if command -v uv &>/dev/null; then
34+
uv pip install "$@"
35+
else
36+
pip install "$@"
37+
fi
38+
}
39+
40+
# Remove the dashes or periods from the CUDA version, e.g. 128 from 12-8
41+
get_cuda_version() {
42+
echo "${CUDA_VERSION//[.-]/}"
43+
}
44+
45+
# Extract the major.minor version from ROCM_VERSION, e.g. 6.4 from 6.4.4
46+
get_rocm_version() {
47+
[[ "$ROCM_VERSION" =~ ^([0-9]+\.[0-9]+) ]] && echo "${BASH_REMATCH[1]}" ||
48+
echo "$ROCM_VERSION"
49+
}
50+
51+
setup_src() {
52+
echo "Downloading Torch source code and setting up the environment for building from source..."
53+
54+
if [ ! -d "$TORCH_DIR" ]; then
55+
echo "Cloning the Torch repo $TORCH_REPO to $TORCH_DIR ..."
56+
git clone "$TORCH_REPO" "$TORCH_DIR"
57+
if [ ! -d "$TORCH_DIR" ]; then
58+
echo "$TORCH_DIR not found. ERROR Cloning repository..."
59+
exit 1
60+
else
61+
pushd "$TORCH_DIR" 1>/dev/null || exit 1
62+
63+
if [ -n "${TORCH_GITREF:-}" ]; then
64+
git checkout "$TORCH_GITREF"
65+
fi
66+
67+
git submodule sync
68+
git submodule update --init --recursive
69+
70+
echo "Install pre-commit hooks into your local Torch git repo (one-time)"
71+
pip_install pre-commit
72+
pre-commit install
73+
popd 1>/dev/null
74+
fi
75+
else
76+
echo "Torch repo already present, not cloning ..."
77+
fi
78+
}
79+
80+
install_build_deps() {
81+
echo "Installing Torch build dependencies ..."
82+
83+
pushd "$TORCH_DIR" 1>/dev/null || exit 1
84+
85+
if [ -f requirements.txt ]; then
86+
pip_install --group dev
87+
pip_install mkl-static mkl-include
88+
fi
89+
90+
if ((EUID == 0)) || [ -n "${SUDO:-}" ]; then
91+
$SUDO dnf -y install numactl-devel
92+
else
93+
echo "ERROR: Can't install some build deps without root or sudo permissions." >&2
94+
exit 1
95+
fi
96+
97+
if [ -n "${ROCM_VERSION:-}" ]; then
98+
python tools/amd_build/build_amd.py
99+
fi
100+
101+
popd 1>/dev/null
102+
}
103+
104+
install_deps() {
105+
echo "Installing Torch dependencies ..."
106+
pip_install numpy
107+
}
108+
109+
install_whl() {
110+
local pip_build="$1"
111+
112+
local compute_platform
113+
local pip_torch_index_url_base
114+
local -a pip_install_args
115+
116+
pip_torch_index_url_base="https://download.pytorch.org/whl"
117+
118+
case "$pip_build" in
119+
release) ;;
120+
nightly | test)
121+
pip_torch_index_url_base="${pip_torch_index_url_base}/${pip_build}"
122+
;;
123+
esac
124+
125+
echo "Installing Torch $pip_build from PyPI ..."
126+
127+
if [ -n "${PIP_TORCH_VERSION:-}" ]; then
128+
echo "Using the specified version $PIP_TORCH_VERSION of torch"
129+
PIP_TORCH_VERSION="==$PIP_TORCH_VERSION"
130+
fi
131+
132+
if [ -n "${PIP_TORCHVISION_VERSION:-}" ]; then
133+
echo "Installing the specified version $PIP_TORCHVISION_VERSION of torchvision"
134+
PIP_TORCHVISION_VERSION="==$PIP_TORCHVISION_VERSION"
135+
fi
136+
137+
if [ -n "${PIP_TORCHAUDIO_VERSION:-}" ]; then
138+
echo "Installing the specified version $PIP_TORCHAUDIO_VERSION of torchaudio"
139+
PIP_TORCHAUDIO_VERSION="==$PIP_TORCHAUDIO_VERSION"
140+
fi
141+
142+
declare -a TORCH_PACKAGES=(
143+
"torch${PIP_TORCH_VERSION:-}"
144+
"torchvision${PIP_TORCHVISION_VERSION:-}"
145+
"torchaudio${PIP_TORCHAUDIO_VERSION:-}"
146+
)
147+
148+
if [ -n "${PIP_TORCH_INDEX_URL:-}" ]; then
149+
echo "Using the specified index, $PIP_TORCH_INDEX_URL"
150+
pip_install_args+=("--index-url" "$PIP_TORCH_INDEX_URL")
151+
elif command -v uv &>/dev/null && [ -n "${UV_TORCH_BACKEND:-}" ]; then
152+
echo "Using the specified uv backend, $UV_TORCH_BACKEND"
153+
pip_install_args+=("--torch-backend" "$UV_TORCH_BACKEND")
154+
elif ! command -v uv &>/dev/null && [ -n "${UV_TORCH_BACKEND:-}" ]; then
155+
echo "Error: UV_TORCH_BACKEND is set to $UV_TORCH_BACKEND but uv is not available."
156+
exit 1
157+
else
158+
# Set compute platform for torch wheel installation
159+
if [ -n "${ROCM_VERSION:-}" ]; then
160+
echo "Using the ROCm version $ROCM_VERSION backend"
161+
compute_platform="rocm$(get_rocm_version)"
162+
elif ((${TRITON_CPU_BACKEND:-0} == 1)); then
163+
echo "Using the CPU backend"
164+
compute_platform="cpu"
165+
elif [ -n "${CUDA_VERSION:-}" ]; then
166+
echo "Using the CUDA version $CUDA_VERSION backend"
167+
compute_platform="cu$(get_cuda_version)"
168+
fi
169+
170+
if [ -n "${compute_platform:-}" ]; then
171+
PIP_TORCH_INDEX_URL="${pip_torch_index_url_base}/${compute_platform}"
172+
pip_install_args+=("--index-url" "$PIP_TORCH_INDEX_URL")
173+
else
174+
PIP_TORCH_INDEX_URL="${pip_torch_index_url_base}"
175+
pip_install_args+=("--index-url" "$PIP_TORCH_INDEX_URL")
176+
fi
177+
fi
178+
179+
pip_install -U --force-reinstall "${pip_install_args[@]}" "${TORCH_PACKAGES[@]}"
180+
181+
# Fix up LD_LIBRARY_PATH for CUDA
182+
ldpretend
183+
}
184+
185+
usage() {
186+
cat >&2 <<EOF
187+
Usage: $(basename "$0") [COMMAND]
188+
source Download Torch's source (if needed) and install the build deps
189+
release Install Torch
190+
nightly Install the Torch nightly wheel
191+
test Install the Torch test wheel
192+
EOF
193+
}
194+
195+
##
196+
## Main
197+
##
198+
if [ $# -ne 1 ]; then
199+
usage
200+
exit 1
201+
fi
202+
203+
COMMAND=${1,,}
204+
205+
case $COMMAND in
206+
source)
207+
setup_src
208+
install_build_deps
209+
install_deps
210+
;;
211+
nightly | release | test)
212+
install_deps
213+
install_whl "$COMMAND"
214+
;;
215+
*)
216+
usage
217+
exit 1
218+
;;
219+
esac

scripts/devinstall_triton.sh

Lines changed: 12 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,6 @@ trap "echo -e '\nScript interrupted. Exiting gracefully.'; exit 1" SIGINT
1919
# SPDX-License-Identifier: Apache-2.0
2020
set -euo pipefail
2121

22-
declare -a PIP_INSTALL_ARGS
23-
PIP_TORCH_INDEX_URL_BASE=https://download.pytorch.org/whl
24-
2522
WORKSPACE=${WORKSPACE:-${HOME}}
2623

2724
TRITON_DIR=${WORKSPACE}/triton
@@ -108,30 +105,20 @@ install_deps() {
108105
pip_install cmake ctypeslib2 matplotlib ninja \
109106
numpy pandas pybind11 pytest pyyaml scipy tabulate wheel
110107

111-
if [ -n "${TORCH_VERSION:-}" ]; then
112-
echo "Installing the specified version $TORCH_VERSION of torch"
113-
PIP_TORCH_VERSION="==$TORCH_VERSION"
114-
fi
115-
116-
if [ -n "${ROCM_VERSION:-}" ]; then
117-
echo "Installing torch for ROCm version $ROCM_VERSION"
118-
pip_install "torch${PIP_TORCH_VERSION:-}" \
119-
--index-url "${PIP_TORCH_INDEX_URL_BASE}/rocm$(get_rocm_version)"
120-
elif ((${TRITON_CPU_BACKEND:-0} == 1)); then
121-
echo "Installing torch for CPU"
122-
pip_install "torch${PIP_TORCH_VERSION:-}" \
123-
--index-url "${PIP_TORCH_INDEX_URL_BASE}/cpu"
124-
elif [ -n "${CUDA_VERSION:-}" ]; then
125-
echo "Installing torch for CUDA version $CUDA_VERSION"
126-
pip_install "torch${PIP_TORCH_VERSION:-}" \
127-
--index-url "${PIP_TORCH_INDEX_URL_BASE}/cu$(get_cuda_version)"
128-
else
129-
echo "Installing torch ..."
130-
pip_install "torch${PIP_TORCH_VERSION:-}"
108+
if [ "${INSTALL_TORCH:-}" != "source" ]; then
109+
if [ -n "${INSTALL_TORCH:-}" ] && [ "${INSTALL_TORCH}" != "skip" ]; then
110+
echo "Installing Torch $INSTALL_TORCH as a dependency ..."
111+
devinstall_torch "${INSTALL_TORCH}"
112+
else
113+
echo "Installing Torch as a dependency ..."
114+
devinstall_torch release
115+
fi
131116
fi
132117
}
133118

134119
install_whl() {
120+
local -a pip_install_args
121+
135122
echo "Installing Triton from PyPI ..."
136123

137124
if command -v uv &>/dev/null; then
@@ -151,7 +138,7 @@ install_whl() {
151138
UV_TORCH_BACKEND=auto
152139
fi
153140

154-
PIP_INSTALL_ARGS+=("--torch-backend" "$UV_TORCH_BACKEND")
141+
pip_install_args+=("--torch-backend" "$UV_TORCH_BACKEND")
155142
elif ! command -v uv &>/dev/null && [ -n "${UV_TORCH_BACKEND:-}" ]; then
156143
echo "Error: UV_TORCH_BACKEND is set to $UV_TORCH_BACKEND but uv is not available."
157144
exit 1
@@ -162,7 +149,7 @@ install_whl() {
162149
PIP_TRITON_VERSION="==$PIP_TRITON_VERSION"
163150
fi
164151

165-
pip_install -U --force-reinstall "${PIP_INSTALL_ARGS[@]}" "triton${PIP_TRITON_VERSION:-}"
152+
pip_install -U --force-reinstall "${pip_install_args[@]}" "triton${PIP_TRITON_VERSION:-}"
166153

167154
# Fix up LD_LIBRARY_PATH for CUDA
168155
ldpretend

0 commit comments

Comments
 (0)