Skip to content

Commit c39dadb

Browse files
anshulvermafacebook-github-bot
authored andcommitted
call on_checkpoint_save on the unit object after saving a checkpoint (#880)
Summary: Pull Request resolved: #880 In this diff we are adding a `on_checkpoint_save` method in `TrainUnit` and `EvalUnit`. This method would be invoked by the checkpointer when it saves a checkpoint successfully. Reviewed By: richardwang-at-fb Differential Revision: D61256660 fbshipit-source-id: e9934d5f5e19d92fecd67eefae6600771c81c697
1 parent e9edb28 commit c39dadb

File tree

3 files changed

+41
-15
lines changed

3 files changed

+41
-15
lines changed

tests/framework/callbacks/test_dcp_saver.py

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import unittest
1111

1212
from torchtnt.framework.callbacks.dcp_saver import _LATEST_DCP_AVAIL
13+
from torchtnt.framework.state import State
1314

1415
if not _LATEST_DCP_AVAIL:
1516
raise unittest.SkipTest("Latest Pytorch is required to run DCP tests")
@@ -51,35 +52,37 @@ def test_save_restore(self) -> None:
5152
dataset_len = 10
5253
batch_size = 2
5354
max_epochs = 2
54-
expected_steps_per_epoch = math.ceil(dataset_len / batch_size)
5555
save_every_n_train_steps = 2
5656

5757
my_unit = DummyTrainUnit(input_dim=input_dim)
5858
dataloader = generate_random_dataloader(dataset_len, input_dim, batch_size)
5959
expected_paths: List[str] = []
6060
with tempfile.TemporaryDirectory() as temp_dir:
61-
cumulative_steps = 0
62-
for epoch in range(max_epochs):
63-
for _ in range(
64-
save_every_n_train_steps,
65-
expected_steps_per_epoch + 1,
66-
save_every_n_train_steps,
67-
):
68-
cumulative_steps += save_every_n_train_steps
69-
expected_paths.append(
70-
os.path.join(
71-
temp_dir, f"epoch_{epoch}_train_step_{cumulative_steps}"
72-
)
73-
)
61+
expected_paths = [
62+
f"{temp_dir}/epoch_0_train_step_2",
63+
f"{temp_dir}/epoch_0_train_step_4",
64+
f"{temp_dir}/epoch_1_train_step_6",
65+
f"{temp_dir}/epoch_1_train_step_8",
66+
f"{temp_dir}/epoch_1_train_step_10",
67+
f"{temp_dir}/epoch_2_train_step_10", # extra checkpoint on_train_end
68+
]
7469
dcp_cb = DistributedCheckpointSaver(
7570
temp_dir,
7671
save_every_n_train_steps=save_every_n_train_steps,
7772
knob_options=KnobOptions(1),
7873
)
74+
75+
saved_checkpoint_paths: List[str] = []
76+
77+
def _checkpoint_save_callback(state: State, checkpoint_id: str) -> None:
78+
saved_checkpoint_paths.append(checkpoint_id)
79+
80+
my_unit.on_checkpoint_save = _checkpoint_save_callback # pyre-ignore
81+
7982
train(my_unit, dataloader, max_epochs=max_epochs, callbacks=[dcp_cb])
8083

8184
end_num_steps_completed = my_unit.train_progress.num_steps_completed
82-
self.assertGreater(len(expected_paths), 0)
85+
self.assertEqual(saved_checkpoint_paths, expected_paths)
8386
dcp_cb.restore(expected_paths[0], my_unit)
8487
restored_num_steps_completed = my_unit.train_progress.num_steps_completed
8588
# A snapshot is saved every n steps

torchtnt/framework/callbacks/base_checkpointer.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,9 @@ def _generate_checkpoint_and_upkeep(
215215
# 4) track checkpoint and clean up surplus if needed
216216
self._checkpoint_manager.append_checkpoint(checkpoint_path)
217217

218+
# 5) invoke on_checkpoint_save callback on the unit since checkpoint was saved successfully
219+
unit.on_checkpoint_save(state, checkpoint_id=checkpoint_path.path)
220+
218221
return True
219222

220223
def _does_checkpoint_exist(

torchtnt/framework/unit.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -341,6 +341,16 @@ def on_train_end(self, state: State) -> None:
341341
"""
342342
pass
343343

344+
def on_checkpoint_save(self, state: State, checkpoint_id: str) -> None:
345+
"""Hook called after successfully saving a checkpoint.
346+
347+
Args:
348+
state: a :class:`~torchtnt.framework.state.State` object containing metadata about the training run.
349+
checkpoint_id: the ID of the checkpoint that was saved. Depending on the storage type, this may be
350+
a path, a URL or a unique identifier.
351+
"""
352+
pass
353+
344354
def get_next_train_batch(
345355
self,
346356
state: State,
@@ -447,6 +457,16 @@ def on_eval_end(self, state: State) -> None:
447457
"""
448458
pass
449459

460+
def on_checkpoint_save(self, state: State, checkpoint_id: str) -> None:
461+
"""Hook called after successfully saving a checkpoint.
462+
463+
Args:
464+
state: a :class:`~torchtnt.framework.state.State` object containing metadata about the training run.
465+
checkpoint_id: the ID of the checkpoint that was saved. Depending on the storage type, this may be
466+
a path, a URL or a unique identifier.
467+
"""
468+
pass
469+
450470
def get_next_eval_batch(
451471
self,
452472
state: State,

0 commit comments

Comments
 (0)