Skip to content

Commit bd95382

Browse files
authored
Miscelanous cleanup (#9619)
1 parent 77d85fb commit bd95382

File tree

13 files changed

+57
-422
lines changed

13 files changed

+57
-422
lines changed

torchax/README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,11 +97,14 @@ inputs = torch.randn(3, 3, 28, 28, device='jax')
9797
m = MyModel().to('jax')
9898
res = m(inputs)
9999
print(type(res)) # outputs torchax.tensor.Tensor
100+
print(res.jax()) # print the underlying Jax Array
100101
```
101102

102103
`torchax.tensor.Tensor` is a `torch.Tensor` subclass that holds
103104
a `jax.Array`. You can inspect that JAX array with `res.jax()`.
104105

106+
In other words, despite that the code above looks like PyTorch, it is actually running JAX!
107+
105108
## What is happening behind the scene
106109

107110
We took the approach detailed in the

torchax/dev-requirements.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
-f https://download.pytorch.org/whl/torch
2-
torch==2.7.1 ; sys_platform == 'darwin' # macOS
3-
torch==2.7.1+cpu; sys_platform != 'darwin' # Non-macOS (CPU-only), like on TPU
2+
torch==2.8.0 ; sys_platform == 'darwin' # macOS
3+
torch==2.8.0+cpu; sys_platform != 'darwin' # Non-macOS (CPU-only), like on TPU
44
yapf==0.40.2 # N.B.: keep in sync with `infra/ansible/config/pip.yaml`, `.github/workflows/lintercheck.yml`
55
flax==0.10.6

torchax/examples/_diffusion.py

Lines changed: 0 additions & 106 deletions
This file was deleted.

torchax/examples/_grad_of_attention.py

Lines changed: 0 additions & 76 deletions
This file was deleted.

torchax/examples/torchbench_models/BERT_pytorch.py

Lines changed: 0 additions & 52 deletions
This file was deleted.

torchax/examples/train_gpt/requirements.txt

Lines changed: 0 additions & 4 deletions
This file was deleted.

torchax/pyproject.toml

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,3 @@ odml = ["jax[cpu]>=0.6.2", "jax[cpu]"]
4848

4949
[tool.hatch.build.targets.wheel]
5050
packages = ["torchax"]
51-
52-
[tool.pytest.ini_options]
53-
addopts="-n auto"

torchax/test-requirements.txt

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,8 @@
22
absl-py==2.2.2
33
immutabledict==4.2.1
44
pytest==8.3.5
5-
pytest-xdist==3.6.1
6-
pytest-forked==1.6.0
7-
sentencepiece==0.2.0
5+
sentencepiece
86
expecttest==0.3.0
97
optax==0.2.4
10-
tensorflow==2.19.0
8+
pytest
9+
pytest-xdist

torchax/test/test_misc.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,17 @@ def forward(self, a, b):
3131
jnp.sin(jnp.array([1, 2, 3])) + jnp.cos(jnp.array([3, 4, 5]))))
3232

3333
def test_to_device(self):
34+
env = torchax.default_env()
35+
with env:
36+
step1 = torch.ones(
37+
100,
38+
100,
39+
)
40+
step2 = torch.triu(step1, diagonal=1)
41+
step3 = step2.to(dtype=torch.bool, device='jax')
42+
self.assertEqual(step3.device.type, 'jax')
43+
44+
def test_to_device_twice(self):
3445
env = torchax.default_env()
3546
env.config.debug_print_each_op = True
3647
with env:
@@ -40,6 +51,7 @@ def test_to_device(self):
4051
)
4152
step2 = torch.triu(step1, diagonal=1)
4253
step3 = step2.to(dtype=torch.bool, device='jax')
54+
step3.to('jax')
4355
self.assertEqual(step3.device.type, 'jax')
4456

4557

torchax/test/test_tf_integration.py

Lines changed: 0 additions & 51 deletions
This file was deleted.

0 commit comments

Comments
 (0)