Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 30a38ec

Browse files
committed
update
1 parent 5c3306e commit 30a38ec

File tree

2 files changed

+12
-7
lines changed

2 files changed

+12
-7
lines changed

install/install_requirements.sh

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -59,12 +59,6 @@ VISION_NIGHTLY_VERSION=dev20241218
5959
# Nightly version for torchtune
6060
TUNE_NIGHTLY_VERSION=dev20241218
6161

62-
# Uninstall triton, as nightly will depend on pytorch-triton, which is one and the same
63-
(
64-
set -x
65-
$PIP_EXECUTABLE uninstall -y triton
66-
)
67-
6862
# The pip repository that hosts nightly torch packages. cpu by default.
6963
# If cuda is available, based on presence of nvidia-smi, install the pytorch nightly
7064
# with cuda for faster execution on cuda GPUs.
@@ -87,7 +81,7 @@ then
8781
REQUIREMENTS_TO_INSTALL=(
8882
torch=="2.6.0.${PYTORCH_NIGHTLY_VERSION}"
8983
torchvision=="0.22.0.${VISION_NIGHTLY_VERSION}"
90-
torchtune=="0.4.0"
84+
torchtune=="0.5.0"
9185
)
9286
else
9387
REQUIREMENTS_TO_INSTALL=(
@@ -107,6 +101,12 @@ fi
107101
$PIP_EXECUTABLE install -r install/requirements.txt --extra-index-url "${TORCH_NIGHTLY_URL}"
108102
)
109103

104+
# Uninstall triton, as nightly will depend on pytorch-triton, which is one and the same
105+
(
106+
set -x
107+
$PIP_EXECUTABLE uninstall -y triton
108+
)
109+
110110
# Install the requirements. --extra-index-url tells pip to look for package
111111
# versions on the provided URL if they aren't available on the default URL.
112112
(

torchchat/generate.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1291,6 +1291,9 @@ def callback(x, *, done_generating=False):
12911291
)
12921292
if torch.cuda.is_available():
12931293
print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB")
1294+
if torch.xpu.is_available():
1295+
print(f"Memory used: {torch.xpu.max_memory_reserved() / 1e9:.02f} GB")
1296+
12941297

12951298

12961299
class DistributedGenerator(LocalGenerator):
@@ -1617,6 +1620,8 @@ def run_generator(
16171620
)
16181621
if torch.cuda.is_available():
16191622
torch.cuda.reset_peak_memory_stats()
1623+
if torch.xpu.is_available():
1624+
torch.xpu.reset_peak_memory_stats()
16201625

16211626
for _ in gen.chat(generator_args):
16221627
pass

0 commit comments

Comments
 (0)