@@ -59,12 +59,6 @@ VISION_NIGHTLY_VERSION=dev20241218
5959# Nightly version for torchtune
6060TUNE_NIGHTLY_VERSION=dev20241218
6161
62- # Uninstall triton, as nightly will depend on pytorch-triton, which is one and the same
63- (
64- set -x
65- $PIP_EXECUTABLE uninstall -y triton
66- )
67-
6862# The pip repository that hosts nightly torch packages. cpu by default.
6963# If cuda is available, based on presence of nvidia-smi, install the pytorch nightly
7064# with cuda for faster execution on cuda GPUs.
7468elif [[ -x " $( command -v rocminfo) " ]];
7569then
7670 TORCH_NIGHTLY_URL=" https://download.pytorch.org/whl/nightly/rocm6.2"
71+ elif [[ -x " $( command -v xpu-smi) " ]];
72+ then
73+ TORCH_NIGHTLY_URL=" https://download.pytorch.org/whl/nightly/xpu"
7774else
7875 TORCH_NIGHTLY_URL=" https://download.pytorch.org/whl/nightly/cpu"
7976fi
8077
8178# pip packages needed by exir.
82- REQUIREMENTS_TO_INSTALL=(
83- torch==" 2.6.0.${PYTORCH_NIGHTLY_VERSION} "
84- torchvision==" 0.22.0.${VISION_NIGHTLY_VERSION} "
85- torchtune==" 0.5.0.${TUNE_NIGHTLY_VERSION} "
86- )
79+ if [[ -x " $( command -v xpu-smi) " ]];
80+ then
81+ REQUIREMENTS_TO_INSTALL=(
82+ torch==" 2.6.0.${PYTORCH_NIGHTLY_VERSION} "
83+ torchvision==" 0.22.0.${VISION_NIGHTLY_VERSION} "
84+ torchtune==" 0.5.0"
85+ )
86+ else
87+ REQUIREMENTS_TO_INSTALL=(
88+ torch==" 2.6.0.${PYTORCH_NIGHTLY_VERSION} "
89+ torchvision==" 0.22.0.${VISION_NIGHTLY_VERSION} "
90+ torchtune==" 0.5.0.${TUNE_NIGHTLY_VERSION} "
91+ )
92+ fi
8793
8894#
8995# First install requirements in install/requirements.txt. Older torch may be
@@ -95,6 +101,12 @@ REQUIREMENTS_TO_INSTALL=(
95101 $PIP_EXECUTABLE install -r install/requirements.txt --extra-index-url " ${TORCH_NIGHTLY_URL} "
96102)
97103
104+ # Uninstall triton, as nightly will depend on pytorch-triton, which is one and the same
105+ (
106+ set -x
107+ $PIP_EXECUTABLE uninstall -y triton
108+ )
109+
98110# Install the requirements. --extra-index-url tells pip to look for package
99111# versions on the provided URL if they aren't available on the default URL.
100112(
@@ -116,8 +128,6 @@ if [[ -x "$(command -v nvidia-smi)" ]]; then
116128 $PYTHON_EXECUTABLE torchchat/utils/scripts/patch_triton.py
117129 )
118130fi
119-
120-
121131(
122132 set -x
123133 $PIP_EXECUTABLE install evaluate==" 0.4.3" lm-eval==" 0.4.2" psutil==" 6.0.0"
0 commit comments