|
113 | 113 | # 4. After the local build succeeds, create a PR and wait for the CI result. Fix
|
114 | 114 | # CI errors as needed until all required checks pass.
|
115 | 115 |
|
116 |
| -USE_NIGHTLY = True # Whether to use nightly or stable libtpu and JAX. |
| 116 | +USE_NIGHTLY = False # Whether to use nightly or stable libtpu and JAX. |
117 | 117 |
|
118 | 118 | _libtpu_version = '0.0.18'
|
119 | 119 | _libtpu_date = '20250617'
|
120 | 120 |
|
121 |
| -_jax_version = '0.6.2' |
122 |
| -_jaxlib_version = '0.6.2' |
| 121 | +_jax_version = '0.7.1' |
| 122 | +_jaxlib_version = '0.7.1' |
123 | 123 | _jax_date = '20250617' # Date for jax and jaxlib.
|
124 | 124 |
|
125 | 125 | if USE_NIGHTLY:
|
|
135 | 135 | _libtpu_wheel_name = f'libtpu-{_libtpu_version}-py3-none-manylinux_2_31_{platform_machine}'
|
136 | 136 | _libtpu_storage_directory = 'libtpu-lts-releases'
|
137 | 137 |
|
138 |
| -_libtpu_storage_path = f'https://storage.googleapis.com/{_libtpu_storage_directory}/wheels/libtpu/{_libtpu_wheel_name}.whl' |
139 |
| - |
| 138 | +_libtpu_storage_path = f'https://us-python.pkg.dev/ml-oss-artifacts-published/jax/libtpu/{_libtpu_wheel_name}.whl' |
140 | 139 |
|
141 | 140 | def _get_build_mode():
|
142 | 141 | for i in range(1, len(sys.argv)):
|
@@ -423,22 +422,10 @@ def link_packages(self):
|
423 | 422 |
|
424 | 423 |
|
425 | 424 | def _get_jax_install_requirements():
|
426 |
| - if not USE_NIGHTLY: |
427 |
| - # Stable versions of JAX can be directly installed from PyPI. |
428 |
| - return [ |
429 |
| - f'jaxlib=={_jaxlib_version}', |
430 |
| - f'jax=={_jax_version}', |
431 |
| - ] |
432 |
| - |
433 |
| - # Install nightly JAX libraries from the JAX package registries. |
434 |
| - jax = f'jax @ https://us-python.pkg.dev/ml-oss-artifacts-published/jax-public-nightly-artifacts-registry/jax/jax-{_jax_version}-py3-none-any.whl' |
435 |
| - |
436 |
| - jaxlib = [] |
437 |
| - for python_minor_version in [9, 10, 11, 12]: |
438 |
| - jaxlib.append( |
439 |
| - f'jaxlib @ https://us-python.pkg.dev/ml-oss-artifacts-published/jax-public-nightly-artifacts-registry/jaxlib/jaxlib-{_jaxlib_version}-cp3{python_minor_version}-cp3{python_minor_version}-manylinux2014_x86_64.whl ; python_version == "3.{python_minor_version}"' |
440 |
| - ) |
441 |
| - return [jax] + jaxlib |
| 425 | + return [ |
| 426 | + f'jaxlib=={_jaxlib_version}', |
| 427 | + f'jax=={_jax_version}', |
| 428 | + ] |
442 | 429 |
|
443 | 430 |
|
444 | 431 | setup(
|
|
0 commit comments