|
10 | 10 | import unittest
|
11 | 11 |
|
12 | 12 | from torchtnt.framework.callbacks.dcp_saver import _LATEST_DCP_AVAIL
|
| 13 | +from torchtnt.framework.state import State |
13 | 14 |
|
14 | 15 | if not _LATEST_DCP_AVAIL:
|
15 | 16 | raise unittest.SkipTest("Latest Pytorch is required to run DCP tests")
|
@@ -51,35 +52,37 @@ def test_save_restore(self) -> None:
|
51 | 52 | dataset_len = 10
|
52 | 53 | batch_size = 2
|
53 | 54 | max_epochs = 2
|
54 |
| - expected_steps_per_epoch = math.ceil(dataset_len / batch_size) |
55 | 55 | save_every_n_train_steps = 2
|
56 | 56 |
|
57 | 57 | my_unit = DummyTrainUnit(input_dim=input_dim)
|
58 | 58 | dataloader = generate_random_dataloader(dataset_len, input_dim, batch_size)
|
59 | 59 | expected_paths: List[str] = []
|
60 | 60 | with tempfile.TemporaryDirectory() as temp_dir:
|
61 |
| - cumulative_steps = 0 |
62 |
| - for epoch in range(max_epochs): |
63 |
| - for _ in range( |
64 |
| - save_every_n_train_steps, |
65 |
| - expected_steps_per_epoch + 1, |
66 |
| - save_every_n_train_steps, |
67 |
| - ): |
68 |
| - cumulative_steps += save_every_n_train_steps |
69 |
| - expected_paths.append( |
70 |
| - os.path.join( |
71 |
| - temp_dir, f"epoch_{epoch}_train_step_{cumulative_steps}" |
72 |
| - ) |
73 |
| - ) |
| 61 | + expected_paths = [ |
| 62 | + f"{temp_dir}/epoch_0_train_step_2", |
| 63 | + f"{temp_dir}/epoch_0_train_step_4", |
| 64 | + f"{temp_dir}/epoch_1_train_step_6", |
| 65 | + f"{temp_dir}/epoch_1_train_step_8", |
| 66 | + f"{temp_dir}/epoch_1_train_step_10", |
| 67 | + f"{temp_dir}/epoch_2_train_step_10", # extra checkpoint on_train_end |
| 68 | + ] |
74 | 69 | dcp_cb = DistributedCheckpointSaver(
|
75 | 70 | temp_dir,
|
76 | 71 | save_every_n_train_steps=save_every_n_train_steps,
|
77 | 72 | knob_options=KnobOptions(1),
|
78 | 73 | )
|
| 74 | + |
| 75 | + saved_checkpoint_paths: List[str] = [] |
| 76 | + |
| 77 | + def _checkpoint_save_callback(state: State, checkpoint_id: str) -> None: |
| 78 | + saved_checkpoint_paths.append(checkpoint_id) |
| 79 | + |
| 80 | + my_unit.on_checkpoint_save = _checkpoint_save_callback # pyre-ignore |
| 81 | + |
79 | 82 | train(my_unit, dataloader, max_epochs=max_epochs, callbacks=[dcp_cb])
|
80 | 83 |
|
81 | 84 | end_num_steps_completed = my_unit.train_progress.num_steps_completed
|
82 |
| - self.assertGreater(len(expected_paths), 0) |
| 85 | + self.assertEqual(saved_checkpoint_paths, expected_paths) |
83 | 86 | dcp_cb.restore(expected_paths[0], my_unit)
|
84 | 87 | restored_num_steps_completed = my_unit.train_progress.num_steps_completed
|
85 | 88 | # A snapshot is saved every n steps
|
|
0 commit comments