Skip to content

Commit 7aba922

Browse files
authored
support load and save checkpoint in torchax (#9616)
This PR supports checkpointing with torchax: 1. load a checkpoint file in torch tensors and convert to Jax arrays; Or load a checkpoint file in Jax arrays 2. save a checkpoint file in Jax arrays. This support single worker now.
1 parent caa809f commit 7aba922

File tree

6 files changed

+203
-1
lines changed

6 files changed

+203
-1
lines changed

.github/workflows/_test.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ jobs:
111111
# TODO: Add these in setup.py
112112
pip install fsspec
113113
pip install rich
114+
pip install flax
114115
- name: Checkout PyTorch Repo
115116
if: inputs.has_code_changes == 'true'
116117
uses: actions/checkout@v4

.github/workflows/_tpu_ci.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ jobs:
5555
pip install --pre 'torch_xla[pallas]' --index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ --find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html'
5656
pip install --pre 'torch_xla[tpu]' --index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ --find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html'
5757
pip install --upgrade protobuf
58+
pip install flax
5859
- name: Run Tests (${{ matrix.test_script }})
5960
if: inputs.has_code_changes == 'true'
6061
env:

torchax/README.md

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,42 @@ The first time `m_jitted` is called, it will trigger `jax.jit` to compile the
182182
compile for the given input shapes. Subsequent calls with the same input shapes
183183
will be fast as the compilation is cached.
184184

185+
## Saving and Loading Checkpoints
186+
187+
You can use `torchax.save_checkpoint` and `torchax.load_checkpoint` to save and load your training state. The state can be a dictionary containing the model's weights, optimizer state, and any other information you want to save.
188+
189+
```python
190+
import torchax
191+
import torch
192+
import optax
193+
194+
# Assume model, optimizer, and other states are defined
195+
model = MyModel()
196+
optimizer = optax.adam(1e-3)
197+
opt_state = optimizer.init(model.parameters())
198+
weights = model.parameters()
199+
buffers = model.buffers()
200+
epoch = 10
201+
202+
state = {
203+
'weights': weights,
204+
'buffers': buffers,
205+
'opt_state': opt_state,
206+
'epoch': epoch,
207+
}
208+
209+
# Save checkpoint
210+
torchax.save_checkpoint(state, '/path/to/checkpoint.pt')
211+
212+
# Load checkpoint
213+
loaded_state = torchax.load_checkpoint('/path/to/checkpoint.pt')
214+
215+
# Restore state
216+
model.load_state_dict(loaded_state['weights'])
217+
opt_state = loaded_state['opt_state']
218+
epoch = loaded_state['epoch']
219+
```
220+
185221
## Citation
186222

187223
```

torchax/test/test_checkpoint.py

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
import unittest
2+
import torch
3+
import torch.nn as nn
4+
import torchax
5+
from torchax.checkpoint import _to_torch, _to_jax
6+
import optax
7+
import tempfile
8+
import os
9+
import jax
10+
import jax.numpy as jnp
11+
import shutil
12+
13+
14+
class CheckpointTest(unittest.TestCase):
15+
16+
def test_save_and_load_jax_style_checkpoint(self):
17+
model = torch.nn.Linear(10, 20)
18+
optimizer = optax.adam(1e-3)
19+
20+
torchax.enable_globally()
21+
params_jax, _ = torchax.extract_jax(model)
22+
opt_state = optimizer.init(params_jax)
23+
torchax.disable_globally()
24+
25+
epoch = 1
26+
state = {
27+
'model': model.state_dict(),
28+
'opt_state': opt_state,
29+
'epoch': epoch,
30+
}
31+
32+
with tempfile.TemporaryDirectory() as tmpdir:
33+
path = os.path.join(tmpdir, 'checkpoint')
34+
torchax.save_checkpoint(state, path, step=epoch)
35+
loaded_state_jax = torchax.load_checkpoint(path)
36+
loaded_state = _to_torch(loaded_state_jax)
37+
38+
self.assertEqual(state['epoch'], loaded_state['epoch'])
39+
40+
# Compare model state_dict
41+
for key in state['model']:
42+
self.assertTrue(
43+
torch.allclose(state['model'][key], loaded_state['model'][key]))
44+
45+
# Compare optimizer state
46+
original_leaves = jax.tree_util.tree_leaves(state['opt_state'])
47+
loaded_leaves = jax.tree_util.tree_leaves(loaded_state['opt_state'])
48+
for original_leaf, loaded_leaf in zip(original_leaves, loaded_leaves):
49+
if isinstance(original_leaf, (jnp.ndarray, jax.Array)):
50+
# Convert loaded leaf to numpy array for comparison if it is a DeviceArray
51+
self.assertTrue(jnp.allclose(original_leaf, jnp.asarray(loaded_leaf)))
52+
else:
53+
self.assertEqual(original_leaf, loaded_leaf)
54+
55+
def test_load_pytorch_style_checkpoint(self):
56+
model = torch.nn.Linear(10, 20)
57+
optimizer = optax.adam(1e-3)
58+
59+
torchax.enable_globally()
60+
params_jax, _ = torchax.extract_jax(model)
61+
opt_state = optimizer.init(params_jax)
62+
torchax.disable_globally()
63+
64+
epoch = 1
65+
state = {
66+
'model': model.state_dict(),
67+
'opt_state': opt_state,
68+
'epoch': epoch,
69+
}
70+
71+
with tempfile.TemporaryDirectory() as tmpdir:
72+
path = os.path.join(tmpdir, 'checkpoint.pt')
73+
torch.save(state, path)
74+
loaded_state_jax = torchax.load_checkpoint(path)
75+
76+
# convert original state to jax for comparison
77+
state_jax = _to_jax(state)
78+
79+
self.assertEqual(state_jax['epoch'], loaded_state_jax['epoch'])
80+
81+
# Compare model state_dict
82+
for key in state_jax['model']:
83+
self.assertTrue(
84+
jnp.allclose(state_jax['model'][key],
85+
loaded_state_jax['model'][key]))
86+
87+
# Compare optimizer state
88+
original_leaves = jax.tree_util.tree_leaves(state_jax['opt_state'])
89+
loaded_leaves = jax.tree_util.tree_leaves(loaded_state_jax['opt_state'])
90+
for original_leaf, loaded_leaf in zip(original_leaves, loaded_leaves):
91+
if isinstance(original_leaf, (jnp.ndarray, jax.Array)):
92+
self.assertTrue(jnp.allclose(original_leaf, loaded_leaf))
93+
else:
94+
self.assertEqual(original_leaf, loaded_leaf)
95+
96+
def test_load_non_existent_checkpoint(self):
97+
with self.assertRaises(FileNotFoundError):
98+
torchax.load_checkpoint('/path/to/non_existent_checkpoint')
99+
100+
101+
if __name__ == '__main__':
102+
unittest.main()

torchax/torchax/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,11 @@
1515
'default_env',
1616
'extract_jax',
1717
'enable_globally',
18+
'save_checkpoint',
19+
'load_checkpoint',
1820
]
1921

20-
from jax._src import xla_bridge
22+
from .checkpoint import save_checkpoint, load_checkpoint
2123

2224
os.environ.setdefault('ENABLE_RUNTIME_UPTIME_TELEMETRY', '1')
2325

torchax/torchax/checkpoint.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
import torch
2+
import os
3+
from typing import Any, Dict
4+
from flax.training import checkpoints
5+
import jax
6+
import jax.numpy as jnp
7+
import numpy as np
8+
9+
10+
def _to_jax(pytree):
11+
return jax.tree_util.tree_map(
12+
lambda x: jnp.asarray(x.cpu().numpy())
13+
if isinstance(x, torch.Tensor) else x, pytree)
14+
15+
16+
def _to_torch(pytree):
17+
return jax.tree_util.tree_map(
18+
lambda x: torch.from_numpy(np.asarray(x))
19+
if isinstance(x, (jnp.ndarray, jax.Array)) else x, pytree)
20+
21+
22+
def save_checkpoint(state: Dict[str, Any], path: str, step: int):
23+
"""Saves a checkpoint to a file in JAX style.
24+
25+
Args:
26+
state: A dictionary containing the state to save. torch.Tensors will be
27+
converted to jax.Array.
28+
path: The path to save the checkpoint to. This is a directory.
29+
step: The training step.
30+
"""
31+
state = _to_jax(state)
32+
checkpoints.save_checkpoint(path, state, step=step, overwrite=True)
33+
34+
35+
def load_checkpoint(path: str) -> Dict[str, Any]:
36+
"""Loads a checkpoint and returns it in JAX format.
37+
38+
This function can load both PyTorch-style (single file) and JAX-style
39+
(directory) checkpoints.
40+
41+
If the checkpoint is in PyTorch format, it will be converted to JAX format.
42+
43+
Args:
44+
path: The path to the checkpoint.
45+
46+
Returns:
47+
The loaded state in JAX format (pytree with jax.Array leaves).
48+
"""
49+
if os.path.isdir(path):
50+
# JAX-style checkpoint
51+
state = checkpoints.restore_checkpoint(path, target=None)
52+
if state is None:
53+
raise FileNotFoundError(f"No checkpoint found at {path}")
54+
return state
55+
elif os.path.isfile(path):
56+
# PyTorch-style checkpoint
57+
state = torch.load(path, weights_only=False)
58+
return _to_jax(state)
59+
else:
60+
raise FileNotFoundError(f"No such file or directory: {path}")

0 commit comments

Comments
 (0)