Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 7c4e42b

Browse files
committed
add xpu device
1 parent 778efd6 commit 7c4e42b

File tree

5 files changed

+65
-33
lines changed

5 files changed

+65
-33
lines changed

install/install_requirements.sh

Lines changed: 34 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -35,14 +35,9 @@ fi
3535
# newer version of torch nightly installed later in this script.
3636
#
3737

38-
#(
39-
# set -x
40-
# $PIP_EXECUTABLE install -r install/requirements.txt --extra-index-url https://download.pytorch.org/whl/nightly/cu121
41-
#)
42-
4338
(
4439
set -x
45-
$PIP_EXECUTABLE install -r install/requirements.txt --extra-index-url https://download.pytorch.org/whl/nightly/xpu
40+
$PIP_EXECUTABLE install -r install/requirements.txt --extra-index-url https://download.pytorch.org/whl/nightly/cu121
4641
)
4742

4843
# Since torchchat often uses main-branch features of pytorch, only the nightly
@@ -52,7 +47,12 @@ fi
5247
# NOTE: If a newly-fetched version of the executorch repo changes the value of
5348
# PYTORCH_NIGHTLY_VERSION, you should re-run this script to install the necessary
5449
# package versions.
55-
PYTORCH_NIGHTLY_VERSION=dev20241001
50+
if [[ -x "$(command -v xpu-smi)" ]];
51+
then
52+
PYTORCH_NIGHTLY_VERSION=dev20241001
53+
else
54+
PYTORCH_NIGHTLY_VERSION=dev20241002
55+
fi
5656

5757
# Nightly version for torchvision
5858
VISION_NIGHTLY_VERSION=dev20241002
@@ -69,22 +69,34 @@ TUNE_NIGHTLY_VERSION=dev20241010
6969
# The pip repository that hosts nightly torch packages. cpu by default.
7070
# If cuda is available, based on presence of nvidia-smi, install the pytorch nightly
7171
# with cuda for faster execution on cuda GPUs.
72-
#if [[ -x "$(command -v nvidia-smi)" ]];
73-
#then
74-
# TORCH_NIGHTLY_URL="https://download.pytorch.org/whl/nightly/cu121"
75-
#elif [[ -x "$(command -v rocminfo)" ]];
76-
#then
77-
# TORCH_NIGHTLY_URL="https://download.pytorch.org/whl/nightly/rocm6.2"
78-
#else
79-
# TORCH_NIGHTLY_URL="https://download.pytorch.org/whl/nightly/cpu"
80-
#fi
81-
TORCH_NIGHTLY_URL="https://download.pytorch.org/whl/nightly/xpu"
72+
if [[ -x "$(command -v nvidia-smi)" ]];
73+
then
74+
TORCH_NIGHTLY_URL="https://download.pytorch.org/whl/nightly/cu121"
75+
elif [[ -x "$(command -v rocminfo)" ]];
76+
then
77+
TORCH_NIGHTLY_URL="https://download.pytorch.org/whl/nightly/rocm6.2"
78+
elif [[ -x "$(command -v xpu-smi)" ]];
79+
then
80+
TORCH_NIGHTLY_URL="https://download.pytorch.org/whl/nightly/xpu"
81+
else
82+
TORCH_NIGHTLY_URL="https://download.pytorch.org/whl/nightly/cpu"
83+
fi
84+
8285
# pip packages needed by exir.
83-
REQUIREMENTS_TO_INSTALL=(
84-
torch=="2.6.0.${PYTORCH_NIGHTLY_VERSION}"
85-
torchvision=="0.20.0.${VISION_NIGHTLY_VERSION}"
86-
#torchtune=="0.3.0.${TUNE_NIGHTLY_VERSION}"
87-
)
86+
if [[ -x "$(command -v xpu-smi)" ]];
87+
then
88+
REQUIREMENTS_TO_INSTALL=(
89+
torch=="2.6.0.${PYTORCH_NIGHTLY_VERSION}"
90+
torchvision=="0.20.0.${VISION_NIGHTLY_VERSION}"
91+
torchtune=="0.3.1"
92+
)
93+
else
94+
REQUIREMENTS_TO_INSTALL=(
95+
torch=="2.6.0.${PYTORCH_NIGHTLY_VERSION}"
96+
torchvision=="0.20.0.${VISION_NIGHTLY_VERSION}"
97+
torchtune=="0.4.0.${TUNE_NIGHTLY_VERSION}"
98+
)
99+
fi
88100

89101
# Install the requirements. --extra-index-url tells pip to look for package
90102
# versions on the provided URL if they aren't available on the default URL.
@@ -94,12 +106,6 @@ REQUIREMENTS_TO_INSTALL=(
94106
"${REQUIREMENTS_TO_INSTALL[@]}"
95107
)
96108

97-
(
98-
set -x
99-
$PIP_EXECUTABLE install --extra-index-url "${TORCH_NIGHTLY_URL}" \
100-
torchtune=="0.3.1"
101-
)
102-
103109
(
104110
set -x
105111
$PIP_EXECUTABLE install torchao=="0.5.0"

torchchat/cli/builder.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,13 @@ class BuilderArgs:
6868

6969
def __post_init__(self):
7070
if self.device is None:
71-
self.device = "cuda" if torch.cuda.is_available() else "cpu"
71+
# self.device = "cuda" if torch.cuda.is_available() else "cpu"
72+
if torch.cuda.is_available():
73+
self.device = "cuda"
74+
elif hasattr(torch, 'xpu') and torch.xpu.is_available():
75+
self.device = "xpu"
76+
else:
77+
self.device = "cpu"
7278

7379
if not (
7480
(self.checkpoint_path and self.checkpoint_path.is_file())

torchchat/cli/cli.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ def _add_model_config_args(parser, verb: str) -> None:
168168
type=str,
169169
default=default_device,
170170
choices=["fast", "cpu", "cuda", "mps", "xpu"],
171-
help="Hardware device to use. Options: cpu, cuda, mps",
171+
help="Hardware device to use. Options: cpu, cuda, mps, xpu",
172172
)
173173

174174

torchchat/utils/build_utils.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,8 @@ def find_multiple(n: int, k: int) -> int:
203203
def device_sync(device="cpu"):
204204
if "cuda" in device:
205205
torch.cuda.synchronize(device)
206+
elif "xpu" in device:
207+
torch.xpu.synchronize(device)
206208
elif ("cpu" in device) or ("mps" in device):
207209
pass
208210
else:
@@ -251,7 +253,8 @@ def get_device_str(device) -> str:
251253
device = (
252254
"cuda"
253255
if torch.cuda.is_available()
254-
else "mps" if is_mps_available() else "cpu"
256+
else "mps" if is_mps_available()
257+
else "xpu" if torch.xpu.is_available() else "cpu"
255258
)
256259
return device
257260
else:
@@ -263,7 +266,8 @@ def get_device(device) -> str:
263266
device = (
264267
"cuda"
265268
if torch.cuda.is_available()
266-
else "mps" if is_mps_available() else "cpu"
269+
else "mps" if is_mps_available()
270+
else "xpu" if torch.xpu.is_available() else "cpu"
267271
)
268272
return torch.device(device)
269273

torchchat/utils/device_info.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ def get_device_info(device: str) -> str:
1414
"""Returns a human-readable description of the hardware based on a torch.device.type
1515
1616
Args:
17-
device: A torch.device.type string: one of {"cpu", "cuda"}.
17+
device: A torch.device.type string: one of {"cpu", "cuda", "xpu"}.
1818
Returns:
1919
str: A human-readable description of the hardware or an empty string if the device type is unhandled.
2020
@@ -37,4 +37,20 @@ def get_device_info(device: str) -> str:
3737
)
3838
if device == "cuda":
3939
return torch.cuda.get_device_name(0)
40+
if device == "xpu":
41+
# return (
42+
# check_output(
43+
# ["sycl-ls | grep gpu"], shell=True
44+
# )
45+
# .decode("utf-8")
46+
# .split("\n")[0]
47+
# )
48+
return (
49+
check_output(
50+
["xpu-smi discovery |grep 'Device Name:'"], shell=True
51+
)
52+
.decode("utf-8")
53+
.split("\n")[0]
54+
.split("Device Name:")[1]
55+
)
4056
return ""

0 commit comments

Comments
 (0)