Skip to content

Commit 66ad14f

Browse files
authored
Change update_deps script so that latest stable version can be pulled instead of latest nightly (#9424)
1 parent 1c00dea commit 66ad14f

File tree

2 files changed

+148
-47
lines changed

2 files changed

+148
-47
lines changed

WORKSPACE

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ new_local_repository(
4646

4747
# To build PyTorch/XLA with a new revison of OpenXLA, update the xla_hash to
4848
# the openxla git commit hash and note the date of the commit.
49-
xla_hash = '9084478fa71d7661ad9b4c95f24107b53c8a4709' # Committed on 2025-06-17.
49+
xla_hash = '3d5ece64321630dade7ff733ae1353fc3c83d9cc' # Committed on 2025-06-17.
5050

5151
http_archive(
5252
name = "xla",

scripts/update_deps.py

Lines changed: 147 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -4,23 +4,24 @@
44
Usage:
55
66
scripts/update_deps.py
7+
scripts/update_deps.py --use_latest
78
8-
updates the versions of OpenXLA, libtpu, and JAX as used in
9-
PyTorch/XLA. In particular, it:
9+
By default, updates to the latest stable JAX release and its corresponding
10+
OpenXLA and libtpu versions.
1011
11-
- updates OpenXLA to the latest commit,
12-
- updates libtpu to the latest nightly build, and
13-
- updates JAX to the latest nightly build.
12+
With --use_latest, updates to the latest nightly builds of OpenXLA,
13+
libtpu, and JAX.
1414
"""
1515

16+
import argparse
17+
import json
1618
import logging
1719
import os
1820
import platform
1921
import re
2022
import sys
21-
from typing import Optional
22-
from html.parser import HTMLParser
2323
import urllib.request
24+
from html.parser import HTMLParser
2425

2526
logger = logging.getLogger(__name__)
2627

@@ -107,34 +108,85 @@ def clean_tmp_dir() -> None:
107108
os.system(f'mkdir -p {_TMP_DIR}')
108109

109110

110-
def get_last_xla_commit_and_date() -> tuple[str, str]:
111-
"""Finds the latest commit in the master branch of https://github.com/openxla/xla.
111+
def get_xla_commit_and_date(commit: str | None = None) -> tuple[str, str]:
112+
"""Find a date, commit pair from https://github.com/openxla/xla.
113+
If commit is specified, use that commit.
114+
If no commit is specified take the latest commit from main.
112115
113116
Returns:
114-
A tuple of the latest commit SHA and its date (YYYY-MM-DD).
117+
A tuple of the commit SHA and its date (YYYY-MM-DD).
115118
"""
116119

117-
# Get the latest commit in the master branch of openxla.
118120
clean_tmp_dir()
119-
os.system(
120-
f'git clone --depth=1 https://github.com/openxla/xla {_TMP_DIR}/xla')
121-
commit = os.popen(f'cd {_TMP_DIR}/xla && git rev-parse HEAD').read().strip()
122-
123-
# Get the date of the commit, in the format of YYYY-MM-DD.
124-
date = os.popen(
125-
f'cd {_TMP_DIR}/xla && git show -s --format=%cd --date=short {commit}'
126-
).read().strip()
121+
if commit is None:
122+
# Clone the repository to a depth of 1 (just the main branch).
123+
os.system(
124+
f'git clone --depth=1 https://github.com/openxla/xla {_TMP_DIR}/xla')
125+
commit = os.popen(f'cd {_TMP_DIR}/xla && git rev-parse HEAD').read().strip()
126+
date = os.popen(
127+
f'cd {_TMP_DIR}/xla && git show -s --format=%cd --date=short {commit}'
128+
).read().strip()
129+
logger.info(f'Found latest XLA commit {commit} on date {date}')
130+
else:
131+
# Clone the repository history, but no blobs to save space.
132+
os.system(
133+
f'git clone --bare --filter=blob:none https://github.com/openxla/xla.git {_TMP_DIR}/xla.git'
134+
)
135+
date = os.popen(
136+
f'git --git-dir={_TMP_DIR}/xla.git show -s --format=%cd --date=short {commit}'
137+
).read().strip()
138+
if not date:
139+
logging.error(f"Unable to local XLA commit {commit}")
140+
logger.info(f'Given XLA commit {commit}, determined date {date}')
141+
127142
return commit, date
128143

129144

130-
def update_openxla() -> bool:
131-
"""Updates the OpenXLA version in the WORKSPACE file to the latest commit.
145+
def get_latest_stable_jax_info() -> tuple[str, str, str] | None:
146+
"""Gets info about the latest stable JAX release from GitHub.
132147
133148
Returns:
134-
True if the WORKSPACE file was updated, False otherwise.
149+
A tuple of (JAX version, JAX release date, XLA commit hash).
135150
"""
151+
url = 'https://api.github.com/repos/google/jax/releases/latest'
152+
try:
153+
with urllib.request.urlopen(url) as response:
154+
data = json.loads(response.read().decode())
155+
except Exception as e:
156+
logger.error(f'Failed to fetch {url}: {e}')
157+
return None
158+
159+
tag_name = data['tag_name'] # e.g., "jax-v0.4.28"
160+
jax_version = tag_name.replace('jax-v', '') # e.g., "0.4.28"
161+
162+
published_at = data['published_at'] # e.g., "2024-04-26T22:58:34Z"
163+
release_date = published_at.split('T')[0] # e.g., "2024-04-26"
164+
165+
# The XLA commit is in third_party/xla/workspace.bzl in the JAX repo.
166+
workspace_bzl_url = f'https://raw.githubusercontent.com/google/jax/{tag_name}/third_party/xla/workspace.bzl'
167+
try:
168+
with urllib.request.urlopen(workspace_bzl_url) as response:
169+
workspace_content = response.read().decode()
170+
except Exception as e:
171+
logger.error(f'Failed to fetch {workspace_bzl_url}: {e}')
172+
return None
136173

137-
commit, date = get_last_xla_commit_and_date()
174+
match = re.search(r'XLA_COMMIT = "([a-f0-9]{40})"', workspace_content)
175+
if not match:
176+
logger.error(f'Could not find XLA_COMMIT in {workspace_bzl_url}.')
177+
return None
178+
xla_commit = match.group(1)
179+
180+
return jax_version, release_date, xla_commit
181+
182+
183+
def update_openxla(commit: str | None = None) -> bool:
184+
"""Updates the OpenXLA version in the WORKSPACE file.
185+
186+
Returns:
187+
True if the WORKSPACE file was updated, False otherwise.
188+
"""
189+
commit, date = get_xla_commit_and_date(commit=commit)
138190

139191
with open(_WORKSPACE_PATH, 'r') as f:
140192
ws_lines = f.readlines()
@@ -158,14 +210,18 @@ def update_openxla() -> bool:
158210
return True
159211

160212

161-
def find_latest_nightly(html_lines: list[str],
162-
build_re: str) -> Optional[tuple[str, str, str]]:
213+
def find_latest_nightly(
214+
html_lines: list[str],
215+
build_re: str,
216+
target_date: str | None = None) -> tuple[str, str, str] | None:
163217
"""Finds the latest nightly build from the list of HTML lines.
164218
165219
Args:
166220
html_lines: A list of HTML lines to search for the nightly build.
167221
build_re: A regular expression for matching the nightly build line.
168222
It must have 3 capture groups: the version, the date, and the name suffix.
223+
target_date: If specified, find the latest build on or before this date
224+
(YYYYMMDD). Otherwise, find the latest build overall.
169225
170226
Returns:
171227
A tuple of the version, date, and suffix of the latest nightly build,
@@ -180,7 +236,10 @@ def find_latest_nightly(html_lines: list[str],
180236
if m:
181237
found_build = True
182238
version, date, suffix = m.groups()
183-
if date > latest_date:
239+
if target_date is None:
240+
if date > latest_date:
241+
latest_version, latest_date, latest_suffix = version, date, suffix
242+
elif date <= target_date and date > latest_date:
184243
latest_version, latest_date, latest_suffix = version, date, suffix
185244

186245
if found_build:
@@ -189,8 +248,12 @@ def find_latest_nightly(html_lines: list[str],
189248
return None
190249

191250

192-
def find_latest_libtpu_nightly() -> Optional[tuple[str, str, str]]:
193-
"""Finds the latest libtpu nightly build for the current platform.
251+
def find_libtpu_build(
252+
target_date: str | None = None) -> tuple[str, str, str] | None:
253+
"""Finds a libtpu nightly build for the current platform.
254+
255+
Args:
256+
target_date: If specified, find build for this date. Otherwise, find latest.
194257
195258
Returns:
196259
A tuple of the version, date, and suffix of the latest libtpu nightly build,
@@ -209,7 +272,7 @@ def find_latest_libtpu_nightly() -> Optional[tuple[str, str, str]]:
209272
return find_latest_nightly(
210273
html_lines,
211274
r'.*<a href=.*?>libtpu/libtpu-(.*?)\.dev(\d{8})\+nightly-(.*?)_' +
212-
_PLATFORM + r'\.whl</a>')
275+
_PLATFORM + r'\.whl</a>', target_date)
213276

214277

215278
def fetch_pep503_page(url: str) -> list[tuple[str, str]]:
@@ -233,7 +296,7 @@ def fetch_pep503_page(url: str) -> list[tuple[str, str]]:
233296
return []
234297

235298

236-
def find_latest_jax_nightly() -> Optional[tuple[str, str, str]]:
299+
def find_latest_jax_nightly() -> tuple[str, str, str] | None:
237300
"""Finds the latest JAX nightly build using the new package index.
238301
239302
Returns:
@@ -309,15 +372,19 @@ def parse_version_date(url: str, pattern: str) -> list[tuple[str, str]]:
309372
return latest_jax_version, latest_jaxlib_version, latest_jax_date
310373

311374

312-
def update_libtpu() -> bool:
313-
"""Updates the libtpu version in setup.py to the latest nightly build.
375+
def update_libtpu(target_date: str | None = None) -> bool:
376+
"""Updates the libtpu version in setup.py.
314377
315378
Returns:
316379
True if the setup.py file was updated, False otherwise.
317380
"""
318381

319-
result = find_latest_libtpu_nightly()
382+
result = find_libtpu_build(target_date)
320383
if not result:
384+
if target_date:
385+
logger.error(f'Could not find libtpu build for date {target_date}.')
386+
else:
387+
logger.error('Could not find latest libtpu nightly build.')
321388
return False
322389

323390
version, date, suffix = result
@@ -360,18 +427,24 @@ def update_libtpu() -> bool:
360427
return success
361428

362429

363-
def update_jax() -> bool:
364-
"""Updates the jax/jaxlib versions in setup.py to the latest nightly build.
430+
def update_jax(use_latest: bool) -> bool:
431+
"""Updates the jax/jaxlib versions in setup.py.
365432
366433
Returns:
367434
True if the setup.py file was updated, False otherwise.
368435
"""
369-
370-
result = find_latest_jax_nightly()
371-
if not result:
372-
return False
373-
374-
jax_version, jaxlib_version, date = result
436+
if use_latest:
437+
result = find_latest_jax_nightly()
438+
if not result:
439+
return False
440+
jax_version, jaxlib_version, date = result
441+
else:
442+
jax_info = get_latest_stable_jax_info()
443+
if not jax_info:
444+
return False
445+
jax_version, release_date, _ = jax_info
446+
jaxlib_version = jax_version
447+
date = release_date.replace('-', '')
375448

376449
with open(_SETUP_PATH, 'r') as f:
377450
setup_lines = f.readlines()
@@ -408,12 +481,40 @@ def update_jax() -> bool:
408481
def main() -> None:
409482
logging.basicConfig(level=logging.INFO)
410483

411-
openxla_updated = update_openxla()
412-
libtpu_updated = update_libtpu()
413-
jax_updated = update_jax()
484+
parser = argparse.ArgumentParser(
485+
description="Updates third party dependencies.")
486+
parser.add_argument(
487+
'--use_latest',
488+
action='store_true',
489+
default=False,
490+
help='Update to latest nightly versions instead of latest stable versions.'
491+
)
492+
args = parser.parse_args()
493+
494+
if args.use_latest:
495+
logger.info('Updating to latest nightly versions...')
496+
openxla_updated = update_openxla()
497+
libtpu_updated = update_libtpu()
498+
jax_updated = update_jax(use_latest=True)
499+
if not (openxla_updated and libtpu_updated and jax_updated):
500+
sys.exit(1)
501+
else:
502+
logger.info('Updating to latest stable versions...')
503+
jax_info = get_latest_stable_jax_info()
504+
if not jax_info:
505+
sys.exit(1)
506+
507+
jax_version, jax_release_date, xla_commit = jax_info
508+
logger.info(
509+
f'Found latest stable JAX release {jax_version} from {jax_release_date}, with XLA commit {xla_commit}'
510+
)
414511

415-
if not (openxla_updated and libtpu_updated and jax_updated):
416-
sys.exit(1)
512+
openxla_updated = update_openxla(xla_commit)
513+
libtpu_updated = update_libtpu(
514+
target_date=jax_release_date.replace('-', ''))
515+
jax_updated = update_jax(use_latest=False)
516+
if not (openxla_updated and libtpu_updated and jax_updated):
517+
sys.exit(1)
417518

418519

419520
if __name__ == '__main__':

0 commit comments

Comments
 (0)