1- # torchax: Running PyTorch on TPU
1+ # torchax: Running PyTorch on TPU via JAX
22
33** torchax** is a backend for PyTorch, allowing users to run
4- PyTorch on Google CloudTPUs . ** torchax** is also a library for providing
5- graph-level interoperability between PyTorch and Jax .
4+ PyTorch on Google Cloud TPUs . ** torchax** is also a library for providing
5+ graph-level interoperability between PyTorch and JAX .
66
77This means, with ** torchax** you can:
8- * Run PyTorch code on TPU with as little as 2 lines of code change.
9- * Call a jax function from a pytorch function, passing in ` jax.Array ` s
10- * Call a pytorch function from a jax function, passing in a ` torch.Tensor ` subclass.
11- * Use jax features such as ` jax.grad ` , ` optax ` and ` GSMPD ` to train a Pytorch model.
12- * Use a Pytorch model as feature extractor and use it with a Jax model.
8+ * Run PyTorch code on TPUs with as little as 2 lines of code change.
9+ * Call a JAX function from a PyTorch function, passing in ` jax.Array ` s.
10+ * Call a PyTorch function from a JAX function, passing in a ` torch.Tensor ` s.
11+ * Use JAX features such as ` jax.grad ` , ` optax ` , and ` GSMPD ` to train a PyTorch
12+ model.
13+ * Use a PyTorch model as feature extractor and use it with a JAX model.
1314etc etc.
1415
1516## Install
1617
17-
18- ### On Google Cloud TPU:
1918First install torch CPU:
2019
2120``` bash
21+ # On Linux.
2222pip install torch --index-url https://download.pytorch.org/whl/cpu
23- ```
24-
25- Then install jax TPU:
26-
27- ``` bash
28- pip install -U jax[tpu]
29- ```
3023
31- Finally install torchax
32-
33- ``` bash
34- pip install torchax
24+ # Or on Mac.
25+ pip install torch
3526```
3627
37- ### On GPU machines:
38- First install torch CPU:
28+ Then install JAX for the accelerator you want to use:
3929
4030``` bash
41- pip install torch --index-url https://download.pytorch.org/whl/cpu
42- ```
43-
44- Then install jax CUDA:
31+ # On Google Cloud TPU.
32+ pip install -U jax[tpu]
4533
46- ``` bash
34+ # Or, on GPU machines.
4735pip install -U jax[cuda12]
48- ```
49-
50- Finally install torchax
51-
52- ``` bash
53- pip install torchax
54- ```
55-
56- ### On CPU machines (mac included)
57- First install torch CPU:
58-
59- ``` bash
60- # Linux
61- pip install torch --index-url https://download.pytorch.org/whl/cpu
62-
63- # OR Mac:
64- pip install torch
65- ```
6636
67- Then install jax CPU:
68-
69- ``` bash
37+ # Or, on Linux CPU machines or Macs (see the note below).
7038pip install -U jax
7139```
7240
73- Finally install torchax
74-
75- ``` bash
76- pip install torchax
77- ```
78-
7941NOTE: if you like metal support for Apple devices then install the
80- metal version of jax : https://developer.apple.com/metal/jax/
42+ metal version of JAX : https://developer.apple.com/metal/jax/
8143
82- ### Installing ` torchax ` from source
83-
84- Still need to install ` torch ` CPU and ` Jax ` of your accelerator (GPU, TPU or None).
44+ Finally install torchax:
8545
8646``` bash
47+ # Install pre-built torchax.
48+ pip install torchax
49+
50+ # Or, install torchax from source.
8751pip install git+https://github.com/pytorch/xla.git#subdirectory=torchax
8852```
8953
9054## Run a model
9155
92- Now let's execute a model under torchax. We'll start with a simple 2-layer model
93- it can be in theory any instance of ` torch.nn.Module ` .
56+ Now let's execute a model under torchax. We'll start with a simple 2-layer model.
57+ In theory, we can use any instance of ` torch.nn.Module ` .
9458
9559``` python
9660import torch
@@ -114,22 +78,19 @@ class MyModel(nn.Module):
11478
11579m = MyModel()
11680
117- # Execute this model using torch
81+ # Execute this model using torch.
11882inputs = torch.randn(3 , 3 , 28 , 28 )
11983print (m(inputs))
12084```
12185
122- This model ` m ` contains 2 parts: the weights that is stored inside of the model
123- and it's submodules (` nn.Linear ` ).
124-
125- To execute this model with ` torchax ` ; we need to enable torchax to capture pytorch ops.
126- To enable this, use:
86+ To execute this model with ` torchax ` , we need to enable torchax to capture PyTorch ops:
12787
12888``` python
12989import torchax
13090torchax.enable_globally()
13191```
132- Then, a ` jax ` device will be available to use
92+
93+ Then, we can use a ` jax ` device:
13394
13495``` python
13596inputs = torch.randn(3 , 3 , 28 , 28 , device = ' jax' )
@@ -139,91 +100,91 @@ print(type(res)) # outputs torchax.tensor.Tensor
139100```
140101
141102` torchax.tensor.Tensor ` is a ` torch.Tensor ` subclass that holds
142- a ` jax.Array ` . You can inspect that jax array with ` res.jax() `
103+ a ` jax.Array ` . You can inspect that JAX array with ` res.jax() ` .
143104
105+ ## What is happening behind the scene
144106
145- ## What is happening behind the scene:
107+ We took the approach detailed in the
108+ [ new device] ( https://github.com/albanD/subclass_zoo/blob/main/new_device.py )
109+ recipe by Alban (@albanD ), using ` jax.Array ` for ` raw_data ` .
146110
147- We took the approach detailed in [ new device] ( https://github.com/albanD/subclass_zoo/blob/main/new_device.py ) recipe by Alban (@albanD ); using ` jax.Array ` for the ` raw_data ` .
148-
149- In other words, When a torch op is executed inside of ` env ` context manager (which is enabled with ` torchax.enable_globally() ` ), we can swap out the
150- implementation of that op written in Jax.
111+ In other words, when a torch op is executed inside an ` env ` context manager,
112+ which is enabled by ` torchax.enable_globally() ` , we will swap out the
113+ implementation of that op with JAX.
151114
152115When a model's constructor runs, it will call some tensor constructor, such as
153- ` torch.rand ` , ` torch.ones ` or ` torch.zeros ` etc to create its weights. The constructor
154- will create an ` torch.Tensor ` subclass that contains a ` jax.Array ` .
155-
156- Then, each subsequent op can unpack the ` jax.Array ` , call the op implementation,
157- and wraps it back into ` torch.Tensor ` subclass.
116+ ` torch.rand ` , ` torch.ones ` , or ` torch.zeros ` to create its weights. When torchax
117+ is enabled, these constructors will create a ` torchax.tensor.Tensor ` , which
118+ contains a ` jax.Array ` .
158119
159- See more at [ how_it_works] ( docs/how_it_works.md ) and [ ops registry] ( docs/ops_registry.md ) .
120+ Then, each subsequent op will extract the ` jax.Array ` , call the op's JAX
121+ implementation, and wrap the result back into a ` torchax.tensor.Tensor ` ,
160122
123+ See more at [ how it works] ( docs/how_it_works.md ) and\
124+ [ ops registry] ( docs/ops_registry.md ) .
161125
162126### Executing with jax.jit
163127
164- The above script will execute the model using eager mode Jax as backend. This
165- does allow executing torch models on TPU , but is often slower than what we can
128+ The above script will execute the model using eager mode JAX as the backend. This
129+ does allow executing torch models on TPUs , but is often slower than what we can
166130achieve with ` jax.jit ` .
167131
168- ` jax.jit ` is a function that takes a Jax function (i.e. a function that takes jax array
169- and returns jax array ) into the same function, but faster .
132+ ` jax.jit ` is a function that takes a JAX function (i.e. a function that takes JAX arrays
133+ and returns JAX arrays ) into a compiled (thus faster) version of the same function.
170134
171- We have made the ` jax_jit ` decorator that would accomplish the same with functions
172- that takes and returns ` torch.Tensor ` . To use this, the first step is to create
135+ We have made a ` jax_jit ` decorator that would accomplish the same with functions
136+ that takes and returns ` torch.Tensor ` s . To use this, the first step is to create
173137a functional version of this model: this means the parameters should be passed in
174- as input instead of being attributes on class:
175-
138+ as input instead of being attributes of the class:
176139
177140``` python
178-
179141def model_func (param , inputs ):
180142 return torch.func.functional_call(m, param, inputs)
181-
182143```
144+
183145Here we use [ torch.func.functional_call] ( https://pytorch.org/docs/stable/generated/torch.func.functional_call.html )
184- from PyTorch to replace the model
185- weights with ` param ` , then call the model. This is roughly equivalent to:
146+ from PyTorch to replace the model weights with ` param ` and then call the
147+ model. This is roughly equivalent to:
186148
187149``` python
188150def model_func (param , inputs ):
189151 m.load_state_dict(param)
190152 return m(* inputs)
191153```
192154
193- Now, we can apply ` jax_jit `
155+ Now, we can apply ` jax_jit ` on ` module_func ` :
194156
195157``` python
196158from torchax.interop import jax_jit
159+
197160model_func_jitted = jax_jit(model_func)
198161print (model_func_jitted(new_state_dict, inputs))
199162```
200163
201- See more examples at [ eager_mode.py] ( examples/eager_mode.py ) and the (examples folder)[ examples/]
202-
203- However, to ease the idiom of creating functional model and calling it with parameters,
204- we also created the ` JittableModule ` helper class.
164+ See more examples at [ eager_mode.py] ( examples/eager_mode.py ) and the
165+ [ examples folder] ( examples/ ) .
205166
206- So the above can be written as:
167+ To ease the idiom of creating functional model and calling it with parameters,
168+ we also created the ` JittableModule ` helper class. It lets us rewrite the
169+ above as:
207170
208171``` python
209-
210172from torchax.interop import JittableModule
211173
212174m_jitted = JittableModule(m)
213175res = m_jitted(... )
214176```
215177
216- The first time that ` m_jitted ` is called , it will trigger ` jax.jit `
217- then the subsequent computation with inputs of same shape will be fast.
218-
219-
178+ The first time ` m_jitted ` is called, it will trigger ` jax.jit ` to compile the
179+ compile for the given input shapes. Subsequent calls with the same input shapes
180+ will be fast as the compilation is cached.
220181
221- # Citation:
182+ ## Citation
222183
223184```
224185@software{torchax,
225186 author = {Han Qi, Chun-nien Chan, Will Cromar, Manfei Bai, Kevin Gleanson},
226- title = {torchax: PyTorch on TPU and Jax interoperability},
187+ title = {torchax: PyTorch on TPU and JAX interoperability},
227188 url = {https://github.com/pytorch/xla/tree/master/torchax}
228189 version = {0.0.4},
229190 date = {2025-02-24},
@@ -234,47 +195,48 @@ then the subsequent computation with inputs of same shape will be fast.
234195
235196This library is created and maintained by the PyTorch/XLA team at Google Cloud.
236197
237- However, it benefitted from many direct and indirect
198+ It benefitted from many direct and indirect
238199contributions outside of the team. Many of them done by
239- fellow Googlers using [ Google's 20% project policy] ( https://ebsedu.org/blog/google-tapping-workplace-actualization-20-time-rule ) , others by partner teams.
200+ fellow Googlers using [ Google's 20% project policy] ( https://ebsedu.org/blog/google-tapping-workplace-actualization-20-time-rule ) .
201+ Others by partner teams at Google and other companies.
240202
241- Here is the full list of contributors by 2025-02-25.
203+ Here is the list of contributors by 2025-02-25.
242204
243205```
244- Han Qi (qihqi), Pytorch / XLA
245- Manfei Bai (manfeibai), Pytorch / XLA
206+ Han Qi (qihqi), PyTorch/ XLA
207+ Manfei Bai (manfeibai), PyTorch/ XLA
246208Will Cromar (will-cromar), Meta
247- Milad Mohammadi (miladm), Pytorch / XLA
248- Siyuan Liu (lsy323), Pytorch / XLA
249- Bhavya Bahl (bhavya01), Pytorch / XLA
250- Pei Zhang (zpcore), Pytorch / XLA
251- Yifei Teng (tengyifei), Pytorch / XLA
209+ Milad Mohammadi (miladm), PyTorch/ XLA
210+ Siyuan Liu (lsy323), PyTorch/ XLA
211+ Bhavya Bahl (bhavya01), PyTorch/ XLA
212+ Pei Zhang (zpcore), PyTorch/ XLA
213+ Yifei Teng (tengyifei), PyTorch/ XLA
252214Chunnien Chan (chunnienc), Google, ODML
253- Alban Desmaison (albanD), Meta, Pytorch
254- Simon Teo (simonteozw), Google(20%)
255- David Huang (dvhg), Google(20%)
256- Barni Seetharaman (barney-s), Google(20%)
257- Anish Karthik (anishfish2) , Google(20%)
258- Yao Gu (guyao) , Google(20%)
259- Yenkai Wang (yenkwang) , Google(20%)
260- Greg Shikhman (commander) , Google(20%)
261- Matin Akhlaghinia (matinehAkhlaghinia), Google(20%)
262- Tracy Chen (tracych477), Google(20%)
263- Matthias Guenther (mrguenther) , Google(20%)
264- WenXin Dong (wenxindongwork), Google(20%)
265- Kevin Gleason (GleasonK) , Google, StableHLO
266- Nupur Baghel (nupurbaghel), Google(20%)
267- Gwen Mittertreiner (gmittert), Google(20%)
215+ Alban Desmaison (albanD), Meta, PyTorch
216+ Simon Teo (simonteozw), Google (20%)
217+ David Huang (dvhg), Google (20%)
218+ Barni Seetharaman (barney-s), Google (20%)
219+ Anish Karthik (anishfish2), Google (20%)
220+ Yao Gu (guyao), Google (20%)
221+ Yenkai Wang (yenkwang), Google (20%)
222+ Greg Shikhman (commander), Google (20%)
223+ Matin Akhlaghinia (matinehAkhlaghinia), Google (20%)
224+ Tracy Chen (tracych477), Google (20%)
225+ Matthias Guenther (mrguenther), Google (20%)
226+ WenXin Dong (wenxindongwork), Google (20%)
227+ Kevin Gleason (GleasonK), Google, StableHLO
228+ Nupur Baghel (nupurbaghel), Google (20%)
229+ Gwen Mittertreiner (gmittert), Google (20%)
268230Zeev Melumian (zmelumian), Lightricks
269- Vyom Sharma (vyom1611), Google(20%)
231+ Vyom Sharma (vyom1611), Google (20%)
270232Shitong Wang (ShitongWang), Adobe
271- Rémi Doreau (ayshiff), Google(20%)
233+ Rémi Doreau (ayshiff), Google (20%)
272234Lance Wang (wang2yn84), Google, CoreML
273- Hossein Sarshar (hosseinsarshar) , Google(20%)
274- Daniel Vega-Myhre (danielvegamyhre) , Google(20%)
275- Tianqi Fan (tqfan28), Google(20%)
276- Jim Lin (jimlinntu), Google(20%)
235+ Hossein Sarshar (hosseinsarshar), Google (20%)
236+ Daniel Vega-Myhre (danielvegamyhre), Google (20%)
237+ Tianqi Fan (tqfan28), Google (20%)
238+ Jim Lin (jimlinntu), Google (20%)
277239Fanhai Lu (FanhaiLu1), Google Cloud
278240DeWitt Clinton (dewitt), Google PyTorch
279- Aman Gupta (aman2930) , Google(20%)
241+ Aman Gupta (aman2930), Google (20%)
280242```
0 commit comments