@@ -748,6 +748,98 @@ def test_save_restore_fit_eval_every_n_epochs(self) -> None:
748
748
expected_keys_with_dl ,
749
749
)
750
750
751
+ def test_save_restore_fit_save_every_n_eval_epochs (self ) -> None :
752
+ input_dim = 2
753
+ dataset_len = 10
754
+ batch_size = 2
755
+
756
+ my_unit = DummyAutoUnit (module = nn .Linear (input_dim , 2 ))
757
+ my_unit .output_mean = DummyMeanMetric ()
758
+ my_unit .loss = 0.1
759
+
760
+ train_dataloader = generate_dummy_stateful_dataloader (
761
+ dataset_len , input_dim , batch_size
762
+ )
763
+
764
+ eval_dataloader = generate_dummy_stateful_dataloader (
765
+ dataset_len , input_dim , batch_size
766
+ )
767
+
768
+ with tempfile .TemporaryDirectory () as temp_dir :
769
+ dcp_cb = DistributedCheckpointSaver (
770
+ temp_dir ,
771
+ knob_options = KnobOptions (1 ),
772
+ save_every_n_eval_epochs = 1 ,
773
+ best_checkpoint_config = BestCheckpointConfig (monitored_metric = "loss" ),
774
+ )
775
+
776
+ fit (
777
+ my_unit ,
778
+ max_epochs = 1 ,
779
+ evaluate_every_n_steps = 1 ,
780
+ train_dataloader = train_dataloader ,
781
+ eval_dataloader = eval_dataloader ,
782
+ callbacks = [dcp_cb ],
783
+ )
784
+
785
+ generated_ckpts = os .listdir (temp_dir )
786
+ # fbvscode.set_trace()
787
+ # Since we are using FIT, the metric value should be included
788
+ expected_ckpts_to_dl_mapping = {
789
+ "epoch_0_train_step_1_eval_step_5_loss=0.1" ,
790
+ "epoch_0_train_step_2_eval_step_10_loss=0.1" ,
791
+ "epoch_0_train_step_3_eval_step_15_loss=0.1" ,
792
+ "epoch_0_train_step_4_eval_step_20_loss=0.1" ,
793
+ "epoch_0_train_step_5_eval_step_25_loss=0.1" ,
794
+ "epoch_1_train_step_5_eval_step_30_loss=0.1" ,
795
+ }
796
+ self .assertCountEqual (generated_ckpts , [* expected_ckpts_to_dl_mapping ])
797
+
798
+ expected_keys = [
799
+ "module" , # Both train and eval checkpoints save full app_state in fit
800
+ "optimizer" ,
801
+ "lr_scheduler" ,
802
+ "train_progress" ,
803
+ "eval_progress" ,
804
+ "predict_progress" , # included because of AutoUnit
805
+ "output_mean" ,
806
+ "eval_dataloader" ,
807
+ "train_dataloader" ,
808
+ ]
809
+
810
+ for ckpt_path in expected_ckpts_to_dl_mapping :
811
+ full_ckpt_path = os .path .join (temp_dir , ckpt_path )
812
+ expected_keys_with_dl = list (expected_keys )
813
+ storage_reader = FsspecReader (full_ckpt_path )
814
+ metadata = storage_reader .read_metadata ()
815
+ if ckpt_path == "epoch_1_train_step_5_eval_step_30_loss=0.1" :
816
+ # remove dataloader keys as final checkpoint wont have them
817
+ expected_keys_with_dl = expected_keys_with_dl [:- 1 ]
818
+ appstate_keys = {
819
+ key .split ("." )[1 ] for key in metadata .state_dict_metadata .keys ()
820
+ }
821
+ self .assertCountEqual (
822
+ # Get base keys after the app_state wrapper
823
+ appstate_keys ,
824
+ expected_keys_with_dl ,
825
+ msg = f"key: { ckpt_path } , { expected_keys_with_dl = } , { appstate_keys = } ," ,
826
+ )
827
+
828
+ # Now make sure that the same exact keys are used when restoring
829
+ with patch (
830
+ "torchtnt.framework.callbacks.dcp_saver.dcp.load"
831
+ ) as mock_load :
832
+ DistributedCheckpointSaver .restore (
833
+ full_ckpt_path ,
834
+ my_unit ,
835
+ train_dataloader = train_dataloader ,
836
+ eval_dataloader = eval_dataloader ,
837
+ )
838
+ self .assertCountEqual (
839
+ [* mock_load .call_args [0 ][0 ]["app_state" ].state_dict ().keys ()],
840
+ expected_keys_with_dl ,
841
+ )
842
+
751
843
def test_save_fit_eval_every_n_steps (self ) -> None :
752
844
input_dim = 2
753
845
0 commit comments