Skip to content

Commit e38d280

Browse files
Clean up torchax readme. (#9321)
Co-authored-by: Zhanyong Wan <[email protected]>
1 parent 9a517a9 commit e38d280

File tree

1 file changed

+99
-137
lines changed

1 file changed

+99
-137
lines changed

torchax/README.md

Lines changed: 99 additions & 137 deletions
Original file line numberDiff line numberDiff line change
@@ -1,96 +1,60 @@
1-
# torchax: Running PyTorch on TPU
1+
# torchax: Running PyTorch on TPU via JAX
22

33
**torchax** is a backend for PyTorch, allowing users to run
4-
PyTorch on Google CloudTPUs. **torchax** is also a library for providing
5-
graph-level interoperability between PyTorch and Jax.
4+
PyTorch on Google Cloud TPUs. **torchax** is also a library for providing
5+
graph-level interoperability between PyTorch and JAX.
66

77
This means, with **torchax** you can:
8-
* Run PyTorch code on TPU with as little as 2 lines of code change.
9-
* Call a jax function from a pytorch function, passing in `jax.Array`s
10-
* Call a pytorch function from a jax function, passing in a `torch.Tensor` subclass.
11-
* Use jax features such as `jax.grad`, `optax` and `GSMPD` to train a Pytorch model.
12-
* Use a Pytorch model as feature extractor and use it with a Jax model.
8+
* Run PyTorch code on TPUs with as little as 2 lines of code change.
9+
* Call a JAX function from a PyTorch function, passing in `jax.Array`s.
10+
* Call a PyTorch function from a JAX function, passing in a `torch.Tensor`s.
11+
* Use JAX features such as `jax.grad`, `optax`, and `GSMPD` to train a PyTorch
12+
model.
13+
* Use a PyTorch model as feature extractor and use it with a JAX model.
1314
etc etc.
1415

1516
## Install
1617

17-
18-
### On Google Cloud TPU:
1918
First install torch CPU:
2019

2120
```bash
21+
# On Linux.
2222
pip install torch --index-url https://download.pytorch.org/whl/cpu
23-
```
24-
25-
Then install jax TPU:
26-
27-
```bash
28-
pip install -U jax[tpu]
29-
```
3023

31-
Finally install torchax
32-
33-
```bash
34-
pip install torchax
24+
# Or on Mac.
25+
pip install torch
3526
```
3627

37-
### On GPU machines:
38-
First install torch CPU:
28+
Then install JAX for the accelerator you want to use:
3929

4030
```bash
41-
pip install torch --index-url https://download.pytorch.org/whl/cpu
42-
```
43-
44-
Then install jax CUDA:
31+
# On Google Cloud TPU.
32+
pip install -U jax[tpu]
4533

46-
```bash
34+
# Or, on GPU machines.
4735
pip install -U jax[cuda12]
48-
```
49-
50-
Finally install torchax
51-
52-
```bash
53-
pip install torchax
54-
```
55-
56-
### On CPU machines (mac included)
57-
First install torch CPU:
58-
59-
```bash
60-
# Linux
61-
pip install torch --index-url https://download.pytorch.org/whl/cpu
62-
63-
# OR Mac:
64-
pip install torch
65-
```
6636

67-
Then install jax CPU:
68-
69-
```bash
37+
# Or, on Linux CPU machines or Macs (see the note below).
7038
pip install -U jax
7139
```
7240

73-
Finally install torchax
74-
75-
```bash
76-
pip install torchax
77-
```
78-
7941
NOTE: if you like metal support for Apple devices then install the
80-
metal version of jax: https://developer.apple.com/metal/jax/
42+
metal version of JAX: https://developer.apple.com/metal/jax/
8143

82-
### Installing `torchax` from source
83-
84-
Still need to install `torch` CPU and `Jax` of your accelerator (GPU, TPU or None).
44+
Finally install torchax:
8545

8646
```bash
47+
# Install pre-built torchax.
48+
pip install torchax
49+
50+
# Or, install torchax from source.
8751
pip install git+https://github.com/pytorch/xla.git#subdirectory=torchax
8852
```
8953

9054
## Run a model
9155

92-
Now let's execute a model under torchax. We'll start with a simple 2-layer model
93-
it can be in theory any instance of `torch.nn.Module`.
56+
Now let's execute a model under torchax. We'll start with a simple 2-layer model.
57+
In theory, we can use any instance of `torch.nn.Module`.
9458

9559
```python
9660
import torch
@@ -114,22 +78,19 @@ class MyModel(nn.Module):
11478

11579
m = MyModel()
11680

117-
# Execute this model using torch
81+
# Execute this model using torch.
11882
inputs = torch.randn(3, 3, 28, 28)
11983
print(m(inputs))
12084
```
12185

122-
This model `m` contains 2 parts: the weights that is stored inside of the model
123-
and it's submodules (`nn.Linear`).
124-
125-
To execute this model with `torchax`; we need to enable torchax to capture pytorch ops.
126-
To enable this, use:
86+
To execute this model with `torchax`, we need to enable torchax to capture PyTorch ops:
12787

12888
```python
12989
import torchax
13090
torchax.enable_globally()
13191
```
132-
Then, a `jax` device will be available to use
92+
93+
Then, we can use a `jax` device:
13394

13495
```python
13596
inputs = torch.randn(3, 3, 28, 28, device='jax')
@@ -139,91 +100,91 @@ print(type(res)) # outputs torchax.tensor.Tensor
139100
```
140101

141102
`torchax.tensor.Tensor` is a `torch.Tensor` subclass that holds
142-
a `jax.Array`. You can inspect that jax array with `res.jax()`
103+
a `jax.Array`. You can inspect that JAX array with `res.jax()`.
143104

105+
## What is happening behind the scene
144106

145-
## What is happening behind the scene:
107+
We took the approach detailed in the
108+
[new device](https://github.com/albanD/subclass_zoo/blob/main/new_device.py)
109+
recipe by Alban (@albanD), using `jax.Array` for `raw_data`.
146110

147-
We took the approach detailed in [new device](https://github.com/albanD/subclass_zoo/blob/main/new_device.py) recipe by Alban (@albanD); using `jax.Array` for the `raw_data`.
148-
149-
In other words, When a torch op is executed inside of `env` context manager (which is enabled with `torchax.enable_globally()`), we can swap out the
150-
implementation of that op written in Jax.
111+
In other words, when a torch op is executed inside an `env` context manager,
112+
which is enabled by `torchax.enable_globally()`, we will swap out the
113+
implementation of that op with JAX.
151114

152115
When a model's constructor runs, it will call some tensor constructor, such as
153-
`torch.rand`, `torch.ones` or `torch.zeros` etc to create its weights. The constructor
154-
will create an `torch.Tensor` subclass that contains a `jax.Array`.
155-
156-
Then, each subsequent op can unpack the `jax.Array`, call the op implementation,
157-
and wraps it back into `torch.Tensor` subclass.
116+
`torch.rand`, `torch.ones`, or `torch.zeros` to create its weights. When torchax
117+
is enabled, these constructors will create a `torchax.tensor.Tensor`, which
118+
contains a `jax.Array`.
158119

159-
See more at [how_it_works](docs/how_it_works.md) and [ops registry](docs/ops_registry.md).
120+
Then, each subsequent op will extract the `jax.Array`, call the op's JAX
121+
implementation, and wrap the result back into a `torchax.tensor.Tensor`,
160122

123+
See more at [how it works](docs/how_it_works.md) and\
124+
[ops registry](docs/ops_registry.md).
161125

162126
### Executing with jax.jit
163127

164-
The above script will execute the model using eager mode Jax as backend. This
165-
does allow executing torch models on TPU, but is often slower than what we can
128+
The above script will execute the model using eager mode JAX as the backend. This
129+
does allow executing torch models on TPUs, but is often slower than what we can
166130
achieve with `jax.jit`.
167131

168-
`jax.jit` is a function that takes a Jax function (i.e. a function that takes jax array
169-
and returns jax array) into the same function, but faster.
132+
`jax.jit` is a function that takes a JAX function (i.e. a function that takes JAX arrays
133+
and returns JAX arrays) into a compiled (thus faster) version of the same function.
170134

171-
We have made the `jax_jit` decorator that would accomplish the same with functions
172-
that takes and returns `torch.Tensor`. To use this, the first step is to create
135+
We have made a `jax_jit` decorator that would accomplish the same with functions
136+
that takes and returns `torch.Tensor`s. To use this, the first step is to create
173137
a functional version of this model: this means the parameters should be passed in
174-
as input instead of being attributes on class:
175-
138+
as input instead of being attributes of the class:
176139

177140
```python
178-
179141
def model_func(param, inputs):
180142
return torch.func.functional_call(m, param, inputs)
181-
182143
```
144+
183145
Here we use [torch.func.functional_call](https://pytorch.org/docs/stable/generated/torch.func.functional_call.html)
184-
from PyTorch to replace the model
185-
weights with `param`, then call the model. This is roughly equivalent to:
146+
from PyTorch to replace the model weights with `param` and then call the
147+
model. This is roughly equivalent to:
186148

187149
```python
188150
def model_func(param, inputs):
189151
m.load_state_dict(param)
190152
return m(*inputs)
191153
```
192154

193-
Now, we can apply `jax_jit`
155+
Now, we can apply `jax_jit` on `module_func`:
194156

195157
```python
196158
from torchax.interop import jax_jit
159+
197160
model_func_jitted = jax_jit(model_func)
198161
print(model_func_jitted(new_state_dict, inputs))
199162
```
200163

201-
See more examples at [eager_mode.py](examples/eager_mode.py) and the (examples folder)[examples/]
202-
203-
However, to ease the idiom of creating functional model and calling it with parameters,
204-
we also created the `JittableModule` helper class.
164+
See more examples at [eager_mode.py](examples/eager_mode.py) and the
165+
[examples folder](examples/).
205166

206-
So the above can be written as:
167+
To ease the idiom of creating functional model and calling it with parameters,
168+
we also created the `JittableModule` helper class. It lets us rewrite the
169+
above as:
207170

208171
```python
209-
210172
from torchax.interop import JittableModule
211173

212174
m_jitted = JittableModule(m)
213175
res = m_jitted(...)
214176
```
215177

216-
The first time that `m_jitted` is called , it will trigger `jax.jit`
217-
then the subsequent computation with inputs of same shape will be fast.
218-
219-
178+
The first time `m_jitted` is called, it will trigger `jax.jit` to compile the
179+
compile for the given input shapes. Subsequent calls with the same input shapes
180+
will be fast as the compilation is cached.
220181

221-
# Citation:
182+
## Citation
222183

223184
```
224185
@software{torchax,
225186
author = {Han Qi, Chun-nien Chan, Will Cromar, Manfei Bai, Kevin Gleanson},
226-
title = {torchax: PyTorch on TPU and Jax interoperability},
187+
title = {torchax: PyTorch on TPU and JAX interoperability},
227188
url = {https://github.com/pytorch/xla/tree/master/torchax}
228189
version = {0.0.4},
229190
date = {2025-02-24},
@@ -234,47 +195,48 @@ then the subsequent computation with inputs of same shape will be fast.
234195

235196
This library is created and maintained by the PyTorch/XLA team at Google Cloud.
236197

237-
However, it benefitted from many direct and indirect
198+
It benefitted from many direct and indirect
238199
contributions outside of the team. Many of them done by
239-
fellow Googlers using [Google's 20% project policy](https://ebsedu.org/blog/google-tapping-workplace-actualization-20-time-rule), others by partner teams.
200+
fellow Googlers using [Google's 20% project policy](https://ebsedu.org/blog/google-tapping-workplace-actualization-20-time-rule).
201+
Others by partner teams at Google and other companies.
240202

241-
Here is the full list of contributors by 2025-02-25.
203+
Here is the list of contributors by 2025-02-25.
242204

243205
```
244-
Han Qi (qihqi), Pytorch / XLA
245-
Manfei Bai (manfeibai), Pytorch / XLA
206+
Han Qi (qihqi), PyTorch/XLA
207+
Manfei Bai (manfeibai), PyTorch/XLA
246208
Will Cromar (will-cromar), Meta
247-
Milad Mohammadi (miladm), Pytorch / XLA
248-
Siyuan Liu (lsy323), Pytorch / XLA
249-
Bhavya Bahl (bhavya01), Pytorch / XLA
250-
Pei Zhang (zpcore), Pytorch / XLA
251-
Yifei Teng (tengyifei), Pytorch / XLA
209+
Milad Mohammadi (miladm), PyTorch/XLA
210+
Siyuan Liu (lsy323), PyTorch/XLA
211+
Bhavya Bahl (bhavya01), PyTorch/XLA
212+
Pei Zhang (zpcore), PyTorch/XLA
213+
Yifei Teng (tengyifei), PyTorch/XLA
252214
Chunnien Chan (chunnienc), Google, ODML
253-
Alban Desmaison (albanD), Meta, Pytorch
254-
Simon Teo (simonteozw), Google(20%)
255-
David Huang (dvhg), Google(20%)
256-
Barni Seetharaman (barney-s), Google(20%)
257-
Anish Karthik (anishfish2) , Google(20%)
258-
Yao Gu (guyao) , Google(20%)
259-
Yenkai Wang (yenkwang) , Google(20%)
260-
Greg Shikhman (commander) , Google(20%)
261-
Matin Akhlaghinia (matinehAkhlaghinia), Google(20%)
262-
Tracy Chen (tracych477), Google(20%)
263-
Matthias Guenther (mrguenther) , Google(20%)
264-
WenXin Dong (wenxindongwork), Google(20%)
265-
Kevin Gleason (GleasonK) , Google, StableHLO
266-
Nupur Baghel (nupurbaghel), Google(20%)
267-
Gwen Mittertreiner (gmittert), Google(20%)
215+
Alban Desmaison (albanD), Meta, PyTorch
216+
Simon Teo (simonteozw), Google (20%)
217+
David Huang (dvhg), Google (20%)
218+
Barni Seetharaman (barney-s), Google (20%)
219+
Anish Karthik (anishfish2), Google (20%)
220+
Yao Gu (guyao), Google (20%)
221+
Yenkai Wang (yenkwang), Google (20%)
222+
Greg Shikhman (commander), Google (20%)
223+
Matin Akhlaghinia (matinehAkhlaghinia), Google (20%)
224+
Tracy Chen (tracych477), Google (20%)
225+
Matthias Guenther (mrguenther), Google (20%)
226+
WenXin Dong (wenxindongwork), Google (20%)
227+
Kevin Gleason (GleasonK), Google, StableHLO
228+
Nupur Baghel (nupurbaghel), Google (20%)
229+
Gwen Mittertreiner (gmittert), Google (20%)
268230
Zeev Melumian (zmelumian), Lightricks
269-
Vyom Sharma (vyom1611), Google(20%)
231+
Vyom Sharma (vyom1611), Google (20%)
270232
Shitong Wang (ShitongWang), Adobe
271-
Rémi Doreau (ayshiff), Google(20%)
233+
Rémi Doreau (ayshiff), Google (20%)
272234
Lance Wang (wang2yn84), Google, CoreML
273-
Hossein Sarshar (hosseinsarshar) , Google(20%)
274-
Daniel Vega-Myhre (danielvegamyhre) , Google(20%)
275-
Tianqi Fan (tqfan28), Google(20%)
276-
Jim Lin (jimlinntu), Google(20%)
235+
Hossein Sarshar (hosseinsarshar), Google (20%)
236+
Daniel Vega-Myhre (danielvegamyhre), Google (20%)
237+
Tianqi Fan (tqfan28), Google (20%)
238+
Jim Lin (jimlinntu), Google (20%)
277239
Fanhai Lu (FanhaiLu1), Google Cloud
278240
DeWitt Clinton (dewitt), Google PyTorch
279-
Aman Gupta (aman2930) , Google(20%)
241+
Aman Gupta (aman2930), Google (20%)
280242
```

0 commit comments

Comments
 (0)