4
4
Usage:
5
5
6
6
scripts/update_deps.py
7
+ scripts/update_deps.py --use_latest
7
8
8
- updates the versions of OpenXLA, libtpu, and JAX as used in
9
- PyTorch/XLA. In particular, it:
9
+ By default, updates to the latest stable JAX release and its corresponding
10
+ OpenXLA and libtpu versions.
10
11
11
- - updates OpenXLA to the latest commit,
12
- - updates libtpu to the latest nightly build, and
13
- - updates JAX to the latest nightly build.
12
+ With --use_latest, updates to the latest nightly builds of OpenXLA,
13
+ libtpu, and JAX.
14
14
"""
15
15
16
+ import argparse
17
+ import json
16
18
import logging
17
19
import os
18
20
import platform
19
21
import re
20
22
import sys
21
- from typing import Optional
22
- from html .parser import HTMLParser
23
23
import urllib .request
24
+ from html .parser import HTMLParser
24
25
25
26
logger = logging .getLogger (__name__ )
26
27
@@ -107,34 +108,85 @@ def clean_tmp_dir() -> None:
107
108
os .system (f'mkdir -p { _TMP_DIR } ' )
108
109
109
110
110
- def get_last_xla_commit_and_date () -> tuple [str , str ]:
111
- """Finds the latest commit in the master branch of https://github.com/openxla/xla.
111
+ def get_xla_commit_and_date (commit : str | None = None ) -> tuple [str , str ]:
112
+ """Find a date, commit pair from https://github.com/openxla/xla.
113
+ If commit is specified, use that commit.
114
+ If no commit is specified take the latest commit from main.
112
115
113
116
Returns:
114
- A tuple of the latest commit SHA and its date (YYYY-MM-DD).
117
+ A tuple of the commit SHA and its date (YYYY-MM-DD).
115
118
"""
116
119
117
- # Get the latest commit in the master branch of openxla.
118
120
clean_tmp_dir ()
119
- os .system (
120
- f'git clone --depth=1 https://github.com/openxla/xla { _TMP_DIR } /xla' )
121
- commit = os .popen (f'cd { _TMP_DIR } /xla && git rev-parse HEAD' ).read ().strip ()
122
-
123
- # Get the date of the commit, in the format of YYYY-MM-DD.
124
- date = os .popen (
125
- f'cd { _TMP_DIR } /xla && git show -s --format=%cd --date=short { commit } '
126
- ).read ().strip ()
121
+ if commit is None :
122
+ # Clone the repository to a depth of 1 (just the main branch).
123
+ os .system (
124
+ f'git clone --depth=1 https://github.com/openxla/xla { _TMP_DIR } /xla' )
125
+ commit = os .popen (f'cd { _TMP_DIR } /xla && git rev-parse HEAD' ).read ().strip ()
126
+ date = os .popen (
127
+ f'cd { _TMP_DIR } /xla && git show -s --format=%cd --date=short { commit } '
128
+ ).read ().strip ()
129
+ logger .info (f'Found latest XLA commit { commit } on date { date } ' )
130
+ else :
131
+ # Clone the repository history, but no blobs to save space.
132
+ os .system (
133
+ f'git clone --bare --filter=blob:none https://github.com/openxla/xla.git { _TMP_DIR } /xla.git'
134
+ )
135
+ date = os .popen (
136
+ f'git --git-dir={ _TMP_DIR } /xla.git show -s --format=%cd --date=short { commit } '
137
+ ).read ().strip ()
138
+ if not date :
139
+ logging .error (f"Unable to local XLA commit { commit } " )
140
+ logger .info (f'Given XLA commit { commit } , determined date { date } ' )
141
+
127
142
return commit , date
128
143
129
144
130
- def update_openxla () -> bool :
131
- """Updates the OpenXLA version in the WORKSPACE file to the latest commit .
145
+ def get_latest_stable_jax_info () -> tuple [ str , str , str ] | None :
146
+ """Gets info about the latest stable JAX release from GitHub .
132
147
133
148
Returns:
134
- True if the WORKSPACE file was updated, False otherwise .
149
+ A tuple of (JAX version, JAX release date, XLA commit hash) .
135
150
"""
151
+ url = 'https://api.github.com/repos/google/jax/releases/latest'
152
+ try :
153
+ with urllib .request .urlopen (url ) as response :
154
+ data = json .loads (response .read ().decode ())
155
+ except Exception as e :
156
+ logger .error (f'Failed to fetch { url } : { e } ' )
157
+ return None
158
+
159
+ tag_name = data ['tag_name' ] # e.g., "jax-v0.4.28"
160
+ jax_version = tag_name .replace ('jax-v' , '' ) # e.g., "0.4.28"
161
+
162
+ published_at = data ['published_at' ] # e.g., "2024-04-26T22:58:34Z"
163
+ release_date = published_at .split ('T' )[0 ] # e.g., "2024-04-26"
164
+
165
+ # The XLA commit is in third_party/xla/workspace.bzl in the JAX repo.
166
+ workspace_bzl_url = f'https://raw.githubusercontent.com/google/jax/{ tag_name } /third_party/xla/workspace.bzl'
167
+ try :
168
+ with urllib .request .urlopen (workspace_bzl_url ) as response :
169
+ workspace_content = response .read ().decode ()
170
+ except Exception as e :
171
+ logger .error (f'Failed to fetch { workspace_bzl_url } : { e } ' )
172
+ return None
136
173
137
- commit , date = get_last_xla_commit_and_date ()
174
+ match = re .search (r'XLA_COMMIT = "([a-f0-9]{40})"' , workspace_content )
175
+ if not match :
176
+ logger .error (f'Could not find XLA_COMMIT in { workspace_bzl_url } .' )
177
+ return None
178
+ xla_commit = match .group (1 )
179
+
180
+ return jax_version , release_date , xla_commit
181
+
182
+
183
+ def update_openxla (commit : str | None = None ) -> bool :
184
+ """Updates the OpenXLA version in the WORKSPACE file.
185
+
186
+ Returns:
187
+ True if the WORKSPACE file was updated, False otherwise.
188
+ """
189
+ commit , date = get_xla_commit_and_date (commit = commit )
138
190
139
191
with open (_WORKSPACE_PATH , 'r' ) as f :
140
192
ws_lines = f .readlines ()
@@ -158,14 +210,18 @@ def update_openxla() -> bool:
158
210
return True
159
211
160
212
161
- def find_latest_nightly (html_lines : list [str ],
162
- build_re : str ) -> Optional [tuple [str , str , str ]]:
213
+ def find_latest_nightly (
214
+ html_lines : list [str ],
215
+ build_re : str ,
216
+ target_date : str | None = None ) -> tuple [str , str , str ] | None :
163
217
"""Finds the latest nightly build from the list of HTML lines.
164
218
165
219
Args:
166
220
html_lines: A list of HTML lines to search for the nightly build.
167
221
build_re: A regular expression for matching the nightly build line.
168
222
It must have 3 capture groups: the version, the date, and the name suffix.
223
+ target_date: If specified, find the latest build on or before this date
224
+ (YYYYMMDD). Otherwise, find the latest build overall.
169
225
170
226
Returns:
171
227
A tuple of the version, date, and suffix of the latest nightly build,
@@ -180,7 +236,10 @@ def find_latest_nightly(html_lines: list[str],
180
236
if m :
181
237
found_build = True
182
238
version , date , suffix = m .groups ()
183
- if date > latest_date :
239
+ if target_date is None :
240
+ if date > latest_date :
241
+ latest_version , latest_date , latest_suffix = version , date , suffix
242
+ elif date <= target_date and date > latest_date :
184
243
latest_version , latest_date , latest_suffix = version , date , suffix
185
244
186
245
if found_build :
@@ -189,8 +248,12 @@ def find_latest_nightly(html_lines: list[str],
189
248
return None
190
249
191
250
192
- def find_latest_libtpu_nightly () -> Optional [tuple [str , str , str ]]:
193
- """Finds the latest libtpu nightly build for the current platform.
251
+ def find_libtpu_build (
252
+ target_date : str | None = None ) -> tuple [str , str , str ] | None :
253
+ """Finds a libtpu nightly build for the current platform.
254
+
255
+ Args:
256
+ target_date: If specified, find build for this date. Otherwise, find latest.
194
257
195
258
Returns:
196
259
A tuple of the version, date, and suffix of the latest libtpu nightly build,
@@ -209,7 +272,7 @@ def find_latest_libtpu_nightly() -> Optional[tuple[str, str, str]]:
209
272
return find_latest_nightly (
210
273
html_lines ,
211
274
r'.*<a href=.*?>libtpu/libtpu-(.*?)\.dev(\d{8})\+nightly-(.*?)_' +
212
- _PLATFORM + r'\.whl</a>' )
275
+ _PLATFORM + r'\.whl</a>' , target_date )
213
276
214
277
215
278
def fetch_pep503_page (url : str ) -> list [tuple [str , str ]]:
@@ -233,7 +296,7 @@ def fetch_pep503_page(url: str) -> list[tuple[str, str]]:
233
296
return []
234
297
235
298
236
- def find_latest_jax_nightly () -> Optional [ tuple [str , str , str ]] :
299
+ def find_latest_jax_nightly () -> tuple [str , str , str ] | None :
237
300
"""Finds the latest JAX nightly build using the new package index.
238
301
239
302
Returns:
@@ -309,15 +372,19 @@ def parse_version_date(url: str, pattern: str) -> list[tuple[str, str]]:
309
372
return latest_jax_version , latest_jaxlib_version , latest_jax_date
310
373
311
374
312
- def update_libtpu () -> bool :
313
- """Updates the libtpu version in setup.py to the latest nightly build .
375
+ def update_libtpu (target_date : str | None = None ) -> bool :
376
+ """Updates the libtpu version in setup.py.
314
377
315
378
Returns:
316
379
True if the setup.py file was updated, False otherwise.
317
380
"""
318
381
319
- result = find_latest_libtpu_nightly ( )
382
+ result = find_libtpu_build ( target_date )
320
383
if not result :
384
+ if target_date :
385
+ logger .error (f'Could not find libtpu build for date { target_date } .' )
386
+ else :
387
+ logger .error ('Could not find latest libtpu nightly build.' )
321
388
return False
322
389
323
390
version , date , suffix = result
@@ -360,18 +427,24 @@ def update_libtpu() -> bool:
360
427
return success
361
428
362
429
363
- def update_jax () -> bool :
364
- """Updates the jax/jaxlib versions in setup.py to the latest nightly build .
430
+ def update_jax (use_latest : bool ) -> bool :
431
+ """Updates the jax/jaxlib versions in setup.py.
365
432
366
433
Returns:
367
434
True if the setup.py file was updated, False otherwise.
368
435
"""
369
-
370
- result = find_latest_jax_nightly ()
371
- if not result :
372
- return False
373
-
374
- jax_version , jaxlib_version , date = result
436
+ if use_latest :
437
+ result = find_latest_jax_nightly ()
438
+ if not result :
439
+ return False
440
+ jax_version , jaxlib_version , date = result
441
+ else :
442
+ jax_info = get_latest_stable_jax_info ()
443
+ if not jax_info :
444
+ return False
445
+ jax_version , release_date , _ = jax_info
446
+ jaxlib_version = jax_version
447
+ date = release_date .replace ('-' , '' )
375
448
376
449
with open (_SETUP_PATH , 'r' ) as f :
377
450
setup_lines = f .readlines ()
@@ -408,12 +481,40 @@ def update_jax() -> bool:
408
481
def main () -> None :
409
482
logging .basicConfig (level = logging .INFO )
410
483
411
- openxla_updated = update_openxla ()
412
- libtpu_updated = update_libtpu ()
413
- jax_updated = update_jax ()
484
+ parser = argparse .ArgumentParser (
485
+ description = "Updates third party dependencies." )
486
+ parser .add_argument (
487
+ '--use_latest' ,
488
+ action = 'store_true' ,
489
+ default = False ,
490
+ help = 'Update to latest nightly versions instead of latest stable versions.'
491
+ )
492
+ args = parser .parse_args ()
493
+
494
+ if args .use_latest :
495
+ logger .info ('Updating to latest nightly versions...' )
496
+ openxla_updated = update_openxla ()
497
+ libtpu_updated = update_libtpu ()
498
+ jax_updated = update_jax (use_latest = True )
499
+ if not (openxla_updated and libtpu_updated and jax_updated ):
500
+ sys .exit (1 )
501
+ else :
502
+ logger .info ('Updating to latest stable versions...' )
503
+ jax_info = get_latest_stable_jax_info ()
504
+ if not jax_info :
505
+ sys .exit (1 )
506
+
507
+ jax_version , jax_release_date , xla_commit = jax_info
508
+ logger .info (
509
+ f'Found latest stable JAX release { jax_version } from { jax_release_date } , with XLA commit { xla_commit } '
510
+ )
414
511
415
- if not (openxla_updated and libtpu_updated and jax_updated ):
416
- sys .exit (1 )
512
+ openxla_updated = update_openxla (xla_commit )
513
+ libtpu_updated = update_libtpu (
514
+ target_date = jax_release_date .replace ('-' , '' ))
515
+ jax_updated = update_jax (use_latest = False )
516
+ if not (openxla_updated and libtpu_updated and jax_updated ):
517
+ sys .exit (1 )
417
518
418
519
419
520
if __name__ == '__main__' :
0 commit comments