Skip to content

Commit be33668

Browse files
authored
Add torchax maximum version. (#9706)
This fixes the `torchax` CI failures (e.g. [#9705 action][1]) caused by the [`flax` dependency][2], which pulls `jax` 0.8.0 (or newer) for some Python versions. **Key Changes:** - Set upper bound `<0.8.0` in `pyproject.toml` optional dependency section - Install `torchax` before the test dependencies [1]: https://github.com/pytorch/xla/actions/runs/19229289257/job/54963835669?pr=9705 [2]: https://github.com/pytorch/xla/blob/611a5cc2675133e6b159a2be8b07b65c44656f29/torchax/dev-requirements.txt#L5
1 parent 95446d3 commit be33668

File tree

2 files changed

+5
-5
lines changed

2 files changed

+5
-5
lines changed

.github/workflows/torchax.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,8 @@ jobs:
4646
shell: bash
4747
working-directory: torchax
4848
run: |
49-
pip install -r test-requirements.txt
5049
pip install -e .[cpu]
50+
pip install -r test-requirements.txt
5151
- name: Run tests
5252
if: needs.check_code_changes.outputs.has_code_changes == 'true'
5353
working-directory: torchax

torchax/pyproject.toml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,11 +40,11 @@ classifiers = [
4040
path = "torchax/__init__.py"
4141

4242
[project.optional-dependencies]
43-
cpu = ["jax[cpu]>=0.6.2", "jax[cpu]"]
43+
cpu = ["jax[cpu]>=0.6.2, <0.8.0"]
4444
# Add libtpu index `-f https://storage.googleapis.com/libtpu-wheels/index.html -f https://storage.googleapis.com/libtpu-releases/index.html`
45-
tpu = ["jax[cpu]>=0.6.2", "jax[tpu]"]
46-
cuda = ["jax[cpu]>=0.6.2", "jax[cuda12]"]
47-
odml = ["jax[cpu]>=0.6.2", "jax[cpu]"]
45+
tpu = ["jax[cpu,tpu]>=0.6.2, <0.8.0"]
46+
cuda = ["jax[cpu,cuda12]>=0.6.2, <0.8.0"]
47+
odml = ["jax[cpu]>=0.6.2, <0.8.0"]
4848

4949
[tool.hatch.build.targets.wheel]
5050
packages = ["torchax"]

0 commit comments

Comments
 (0)