|
| 1 | +import unittest |
| 2 | +import torch |
| 3 | +import torch.nn as nn |
| 4 | +import torchax |
| 5 | +from torchax.checkpoint import _to_torch, _to_jax |
| 6 | +import optax |
| 7 | +import tempfile |
| 8 | +import os |
| 9 | +import jax |
| 10 | +import jax.numpy as jnp |
| 11 | +import shutil |
| 12 | + |
| 13 | + |
| 14 | +class CheckpointTest(unittest.TestCase): |
| 15 | + |
| 16 | + def test_save_and_load_jax_style_checkpoint(self): |
| 17 | + model = torch.nn.Linear(10, 20) |
| 18 | + optimizer = optax.adam(1e-3) |
| 19 | + |
| 20 | + torchax.enable_globally() |
| 21 | + params_jax, _ = torchax.extract_jax(model) |
| 22 | + opt_state = optimizer.init(params_jax) |
| 23 | + torchax.disable_globally() |
| 24 | + |
| 25 | + epoch = 1 |
| 26 | + state = { |
| 27 | + 'model': model.state_dict(), |
| 28 | + 'opt_state': opt_state, |
| 29 | + 'epoch': epoch, |
| 30 | + } |
| 31 | + |
| 32 | + with tempfile.TemporaryDirectory() as tmpdir: |
| 33 | + path = os.path.join(tmpdir, 'checkpoint') |
| 34 | + torchax.save_checkpoint(state, path, step=epoch) |
| 35 | + loaded_state_jax = torchax.load_checkpoint(path) |
| 36 | + loaded_state = _to_torch(loaded_state_jax) |
| 37 | + |
| 38 | + self.assertEqual(state['epoch'], loaded_state['epoch']) |
| 39 | + |
| 40 | + # Compare model state_dict |
| 41 | + for key in state['model']: |
| 42 | + self.assertTrue( |
| 43 | + torch.allclose(state['model'][key], loaded_state['model'][key])) |
| 44 | + |
| 45 | + # Compare optimizer state |
| 46 | + original_leaves = jax.tree_util.tree_leaves(state['opt_state']) |
| 47 | + loaded_leaves = jax.tree_util.tree_leaves(loaded_state['opt_state']) |
| 48 | + for original_leaf, loaded_leaf in zip(original_leaves, loaded_leaves): |
| 49 | + if isinstance(original_leaf, (jnp.ndarray, jax.Array)): |
| 50 | + # Convert loaded leaf to numpy array for comparison if it is a DeviceArray |
| 51 | + self.assertTrue(jnp.allclose(original_leaf, jnp.asarray(loaded_leaf))) |
| 52 | + else: |
| 53 | + self.assertEqual(original_leaf, loaded_leaf) |
| 54 | + |
| 55 | + def test_load_pytorch_style_checkpoint(self): |
| 56 | + model = torch.nn.Linear(10, 20) |
| 57 | + optimizer = optax.adam(1e-3) |
| 58 | + |
| 59 | + torchax.enable_globally() |
| 60 | + params_jax, _ = torchax.extract_jax(model) |
| 61 | + opt_state = optimizer.init(params_jax) |
| 62 | + torchax.disable_globally() |
| 63 | + |
| 64 | + epoch = 1 |
| 65 | + state = { |
| 66 | + 'model': model.state_dict(), |
| 67 | + 'opt_state': opt_state, |
| 68 | + 'epoch': epoch, |
| 69 | + } |
| 70 | + |
| 71 | + with tempfile.TemporaryDirectory() as tmpdir: |
| 72 | + path = os.path.join(tmpdir, 'checkpoint.pt') |
| 73 | + torch.save(state, path) |
| 74 | + loaded_state_jax = torchax.load_checkpoint(path) |
| 75 | + |
| 76 | + # convert original state to jax for comparison |
| 77 | + state_jax = _to_jax(state) |
| 78 | + |
| 79 | + self.assertEqual(state_jax['epoch'], loaded_state_jax['epoch']) |
| 80 | + |
| 81 | + # Compare model state_dict |
| 82 | + for key in state_jax['model']: |
| 83 | + self.assertTrue( |
| 84 | + jnp.allclose(state_jax['model'][key], |
| 85 | + loaded_state_jax['model'][key])) |
| 86 | + |
| 87 | + # Compare optimizer state |
| 88 | + original_leaves = jax.tree_util.tree_leaves(state_jax['opt_state']) |
| 89 | + loaded_leaves = jax.tree_util.tree_leaves(loaded_state_jax['opt_state']) |
| 90 | + for original_leaf, loaded_leaf in zip(original_leaves, loaded_leaves): |
| 91 | + if isinstance(original_leaf, (jnp.ndarray, jax.Array)): |
| 92 | + self.assertTrue(jnp.allclose(original_leaf, loaded_leaf)) |
| 93 | + else: |
| 94 | + self.assertEqual(original_leaf, loaded_leaf) |
| 95 | + |
| 96 | + def test_load_non_existent_checkpoint(self): |
| 97 | + with self.assertRaises(FileNotFoundError): |
| 98 | + torchax.load_checkpoint('/path/to/non_existent_checkpoint') |
| 99 | + |
| 100 | + |
| 101 | +if __name__ == '__main__': |
| 102 | + unittest.main() |
0 commit comments