Skip to content

Commit 1adbe97

Browse files
authored
Update jax dependency to 0.7.1 to align with tt front ends (#8)
1 parent d9f7d39 commit 1adbe97

File tree

2 files changed

+9
-22
lines changed

2 files changed

+9
-22
lines changed

.github/workflows/_build_torch_xla_3.11.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ jobs:
3838
eval "$(pyenv init -)"
3939
pyenv install 3.11
4040
pyenv global 3.11
41-
ln -sf $HOME/.pyenv/versions/3.11/bin/python3.11 /usr/local/bin/python3.11
41+
ln -sf $(pyenv which python3.11) /usr/local/bin/python3.11
4242
4343
# Install essential packages for Python 3.11
4444
python3.11 -m pip install --upgrade pip

setup.py

Lines changed: 8 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -113,13 +113,13 @@
113113
# 4. After the local build succeeds, create a PR and wait for the CI result. Fix
114114
# CI errors as needed until all required checks pass.
115115

116-
USE_NIGHTLY = True # Whether to use nightly or stable libtpu and JAX.
116+
USE_NIGHTLY = False # Whether to use nightly or stable libtpu and JAX.
117117

118118
_libtpu_version = '0.0.18'
119119
_libtpu_date = '20250617'
120120

121-
_jax_version = '0.6.2'
122-
_jaxlib_version = '0.6.2'
121+
_jax_version = '0.7.1'
122+
_jaxlib_version = '0.7.1'
123123
_jax_date = '20250617' # Date for jax and jaxlib.
124124

125125
if USE_NIGHTLY:
@@ -135,8 +135,7 @@
135135
_libtpu_wheel_name = f'libtpu-{_libtpu_version}-py3-none-manylinux_2_31_{platform_machine}'
136136
_libtpu_storage_directory = 'libtpu-lts-releases'
137137

138-
_libtpu_storage_path = f'https://storage.googleapis.com/{_libtpu_storage_directory}/wheels/libtpu/{_libtpu_wheel_name}.whl'
139-
138+
_libtpu_storage_path = f'https://us-python.pkg.dev/ml-oss-artifacts-published/jax/libtpu/{_libtpu_wheel_name}.whl'
140139

141140
def _get_build_mode():
142141
for i in range(1, len(sys.argv)):
@@ -423,22 +422,10 @@ def link_packages(self):
423422

424423

425424
def _get_jax_install_requirements():
426-
if not USE_NIGHTLY:
427-
# Stable versions of JAX can be directly installed from PyPI.
428-
return [
429-
f'jaxlib=={_jaxlib_version}',
430-
f'jax=={_jax_version}',
431-
]
432-
433-
# Install nightly JAX libraries from the JAX package registries.
434-
jax = f'jax @ https://us-python.pkg.dev/ml-oss-artifacts-published/jax-public-nightly-artifacts-registry/jax/jax-{_jax_version}-py3-none-any.whl'
435-
436-
jaxlib = []
437-
for python_minor_version in [9, 10, 11, 12]:
438-
jaxlib.append(
439-
f'jaxlib @ https://us-python.pkg.dev/ml-oss-artifacts-published/jax-public-nightly-artifacts-registry/jaxlib/jaxlib-{_jaxlib_version}-cp3{python_minor_version}-cp3{python_minor_version}-manylinux2014_x86_64.whl ; python_version == "3.{python_minor_version}"'
440-
)
441-
return [jax] + jaxlib
425+
return [
426+
f'jaxlib=={_jaxlib_version}',
427+
f'jax=={_jax_version}',
428+
]
442429

443430

444431
setup(

0 commit comments

Comments
 (0)