From 1ccdf82067f501a397f7214d941cd7969f337327 Mon Sep 17 00:00:00 2001 From: Dorukhan Sergin Date: Tue, 7 May 2024 17:41:29 -0700 Subject: [PATCH] Add on_eval_epoch_end as a valid hook to TorchSnapshotSaver (#826) Summary: When `TorchSnapshotSaver` is used with `save_every_n_eval_epochs > 0` and `best_checkpoint_config`, this hook is invoked. https://www.internalfb.com/code/fbsource/[a8a4a7fba9a8a93af7382fa12e669c066f41024f]/fbcode/torchtnt/framework/callbacks/base_checkpointer.py?lines=273 However, it fails due to not being considered a valid hook. This diff fixes that Differential Revision: D57083777 --- torchtnt/framework/callbacks/torchsnapshot_saver.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/torchtnt/framework/callbacks/torchsnapshot_saver.py b/torchtnt/framework/callbacks/torchsnapshot_saver.py index ab2a0515d6..3bc65796ca 100644 --- a/torchtnt/framework/callbacks/torchsnapshot_saver.py +++ b/torchtnt/framework/callbacks/torchsnapshot_saver.py @@ -157,7 +157,12 @@ def _checkpoint_impl( """ Checkpoint the current state of the application. """ - if hook not in ["on_train_step_end", "on_train_epoch_end", "on_train_end"]: + if hook not in [ + "on_train_step_end", + "on_train_epoch_end", + "on_train_end", + "on_eval_epoch_end", + ]: raise RuntimeError(f"Unexpected hook encountered '{hook}'") intra_epoch = False