Skip to content

Commit 1ccdf82

Browse files
dorukhanserginfacebook-github-bot
authored andcommitted
Add on_eval_epoch_end as a valid hook to TorchSnapshotSaver (#826)
Summary: When `TorchSnapshotSaver` is used with `save_every_n_eval_epochs > 0` and `best_checkpoint_config`, this hook is invoked. https://www.internalfb.com/code/fbsource/[a8a4a7fba9a8a93af7382fa12e669c066f41024f]/fbcode/torchtnt/framework/callbacks/base_checkpointer.py?lines=273 However, it fails due to not being considered a valid hook. This diff fixes that Differential Revision: D57083777
1 parent c62c630 commit 1ccdf82

File tree

1 file changed

+6
-1
lines changed

1 file changed

+6
-1
lines changed

torchtnt/framework/callbacks/torchsnapshot_saver.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,12 @@ def _checkpoint_impl(
157157
"""
158158
Checkpoint the current state of the application.
159159
"""
160-
if hook not in ["on_train_step_end", "on_train_epoch_end", "on_train_end"]:
160+
if hook not in [
161+
"on_train_step_end",
162+
"on_train_epoch_end",
163+
"on_train_end",
164+
"on_eval_epoch_end",
165+
]:
161166
raise RuntimeError(f"Unexpected hook encountered '{hook}'")
162167

163168
intra_epoch = False

0 commit comments

Comments
 (0)