diff --git a/torchtnt/framework/callbacks/torchsnapshot_saver.py b/torchtnt/framework/callbacks/torchsnapshot_saver.py index ab2a0515d6..3bc65796ca 100644 --- a/torchtnt/framework/callbacks/torchsnapshot_saver.py +++ b/torchtnt/framework/callbacks/torchsnapshot_saver.py @@ -157,7 +157,12 @@ def _checkpoint_impl( """ Checkpoint the current state of the application. """ - if hook not in ["on_train_step_end", "on_train_epoch_end", "on_train_end"]: + if hook not in [ + "on_train_step_end", + "on_train_epoch_end", + "on_train_end", + "on_eval_epoch_end", + ]: raise RuntimeError(f"Unexpected hook encountered '{hook}'") intra_epoch = False