@@ -145,7 +145,7 @@ def test_save_restore_dataloader_state(self) -> None:
145
145
self .assertEqual (stateful_dataloader .load_state_dict_call_count , 1 )
146
146
self .assertEqual (
147
147
log .output [0 ],
148
- "WARNING:torchtnt.utils.rank_zero_log:train_dataloader was passed to `restore` but no train dataloader exists in the Snapshot " ,
148
+ "WARNING:torchtnt.framework.callbacks.dcp_saver:dataloader (train) was passed to `restore` but no dataloader exists in checkpoint metadata. " ,
149
149
)
150
150
151
151
def test_restore_from_latest (self ) -> None :
@@ -500,7 +500,7 @@ def test_save_restore_multi_optimizers(self) -> None:
500
500
my_unit_clone = DummyMultiOptimUnit (input_dim = input_dim )
501
501
dcp_cb .restore_from_latest (temp_dir , my_unit_clone )
502
502
503
- def test_save_predict (self ) -> None :
503
+ def test_save_restore_predict (self ) -> None :
504
504
input_dim = 2
505
505
dataset_len = 10
506
506
batch_size = 2
@@ -537,19 +537,51 @@ def test_save_predict(self) -> None:
537
537
ckpt_path = none_throws (get_latest_checkpoint_path (temp_dir ))
538
538
self .assertEqual (ckpt_path , os .path .join (temp_dir , expected_ckpts [- 1 ]))
539
539
540
+ expected_keys = [
541
+ "predict_progress" ,
542
+ "predict_dataloader" ,
543
+ "output_mean" ,
544
+ ]
545
+
540
546
storage_reader = FsspecReader (ckpt_path )
541
547
metadata = storage_reader .read_metadata ()
542
548
self .assertCountEqual (
543
549
# Get base keys after the app_state wrapper
544
550
{key .split ("." )[1 ] for key in metadata .state_dict_metadata .keys ()},
545
- [
546
- "predict_progress" ,
547
- "predict_dataloader" ,
548
- "output_mean" ,
549
- ],
551
+ expected_keys ,
552
+ )
553
+
554
+ # Now make sure that the same exact keys are used when restoring
555
+ with patch ("torchtnt.framework.callbacks.dcp_saver.dcp.load" ) as mock_load :
556
+ DistributedCheckpointSaver .restore (
557
+ ckpt_path , my_unit , predict_dataloader = dataloader
558
+ )
559
+ self .assertCountEqual (
560
+ [* mock_load .call_args [0 ][0 ]["app_state" ].state_dict ().keys ()],
561
+ expected_keys ,
562
+ )
563
+
564
+ # Double check that the module parameters are not overwritten when loading cktp
565
+ my_unit = DummyPredictUnit (input_dim = input_dim )
566
+ my_unit .module .weight .data .fill_ (0.0 )
567
+ my_unit .module .bias .data .fill_ (1.0 )
568
+
569
+ DistributedCheckpointSaver .restore (
570
+ ckpt_path , my_unit , predict_dataloader = dataloader
571
+ )
572
+
573
+ self .assertTrue (
574
+ torch .allclose (
575
+ my_unit .module .weight .data , torch .zeros (input_dim , input_dim )
576
+ )
577
+ )
578
+ self .assertTrue (
579
+ torch .allclose (
580
+ my_unit .module .bias .data , torch .ones (input_dim , input_dim )
581
+ )
550
582
)
551
583
552
- def test_save_evaluate (self ) -> None :
584
+ def test_save_restore_evaluate (self ) -> None :
553
585
input_dim = 2
554
586
dataset_len = 10
555
587
batch_size = 2
@@ -580,18 +612,49 @@ def test_save_evaluate(self) -> None:
580
612
ckpt_path = none_throws (get_latest_checkpoint_path (temp_dir ))
581
613
self .assertEqual (ckpt_path , os .path .join (temp_dir , expected_ckpts [- 1 ]))
582
614
615
+ expected_keys = [
616
+ "eval_progress" ,
617
+ "eval_dataloader" ,
618
+ ]
583
619
storage_reader = FsspecReader (ckpt_path )
584
620
metadata = storage_reader .read_metadata ()
585
621
self .assertCountEqual (
586
622
# Get base keys after the app_state wrapper
587
623
{key .split ("." )[1 ] for key in metadata .state_dict_metadata .keys ()},
588
- [
589
- "eval_progress" ,
590
- "eval_dataloader" ,
591
- ],
624
+ expected_keys ,
592
625
)
593
626
594
- def test_save_fit_eval_every_n_epochs (self ) -> None :
627
+ # Now make sure that the same exact keys are used when restoring
628
+ with patch ("torchtnt.framework.callbacks.dcp_saver.dcp.load" ) as mock_load :
629
+ DistributedCheckpointSaver .restore (
630
+ ckpt_path , my_unit , eval_dataloader = dataloader
631
+ )
632
+ self .assertCountEqual (
633
+ [* mock_load .call_args [0 ][0 ]["app_state" ].state_dict ().keys ()],
634
+ expected_keys ,
635
+ )
636
+
637
+ # Double check that the module parameters are not overwritten when loading cktp
638
+ my_unit = DummyEvalUnit (input_dim = input_dim )
639
+ my_unit .module .weight .data .fill_ (0.0 )
640
+ my_unit .module .bias .data .fill_ (1.0 )
641
+
642
+ DistributedCheckpointSaver .restore (
643
+ ckpt_path , my_unit , predict_dataloader = dataloader
644
+ )
645
+
646
+ self .assertTrue (
647
+ torch .allclose (
648
+ my_unit .module .weight .data , torch .zeros (input_dim , input_dim )
649
+ )
650
+ )
651
+ self .assertTrue (
652
+ torch .allclose (
653
+ my_unit .module .bias .data , torch .ones (input_dim , input_dim )
654
+ )
655
+ )
656
+
657
+ def test_save_restore_fit_eval_every_n_epochs (self ) -> None :
595
658
input_dim = 2
596
659
dataset_len = 10
597
660
batch_size = 2
@@ -625,33 +688,52 @@ def test_save_fit_eval_every_n_epochs(self) -> None:
625
688
)
626
689
627
690
generated_ckpts = os .listdir (temp_dir )
628
- expected_ckpts = [
629
- "epoch_0_train_step_2_eval_step_0" ,
630
- "epoch_0_train_step_4_eval_step_0" ,
631
- "epoch_1_train_step_5_eval_step_2" ,
632
- "epoch_1_train_step_5_eval_step_4" ,
691
+ expected_ckpts_to_dl_mapping : Dict [str , str ] = {
692
+ "epoch_0_train_step_2_eval_step_0" : "train_dataloader" ,
693
+ "epoch_0_train_step_4_eval_step_0" : "train_dataloader" ,
694
+ "epoch_1_train_step_5_eval_step_2" : "eval_dataloader" ,
695
+ "epoch_1_train_step_5_eval_step_4" : "eval_dataloader" ,
696
+ }
697
+ self .assertCountEqual (
698
+ generated_ckpts , [* expected_ckpts_to_dl_mapping .keys ()]
699
+ )
700
+
701
+ expected_keys = [
702
+ "module" , # Both train and eval checkpoints save full app_state in fit
703
+ "optimizer" ,
704
+ "lr_scheduler" ,
705
+ "train_progress" ,
706
+ "eval_progress" ,
707
+ "predict_progress" , # included because of AutoUnit
708
+ "output_mean" ,
633
709
]
634
- self .assertCountEqual (generated_ckpts , expected_ckpts )
635
710
636
- expected_dataloader = ["train_dataloader" ] * 2 + ["eval_dataloader" ] * 2
637
- for ckpt_path , dl_key in zip (expected_ckpts , expected_dataloader ):
638
- storage_reader = FsspecReader (os .path .join (temp_dir , ckpt_path ))
711
+ for ckpt_path , dl_key in expected_ckpts_to_dl_mapping .items ():
712
+ full_ckpt_path = os .path .join (temp_dir , ckpt_path )
713
+ expected_keys_with_dl = expected_keys + [dl_key ]
714
+ storage_reader = FsspecReader (full_ckpt_path )
639
715
metadata = storage_reader .read_metadata ()
640
716
self .assertCountEqual (
641
717
# Get base keys after the app_state wrapper
642
718
{key .split ("." )[1 ] for key in metadata .state_dict_metadata .keys ()},
643
- [
644
- "module" , # Both train and eval checkpoints save full app_state in fit
645
- "optimizer" ,
646
- "lr_scheduler" ,
647
- "train_progress" ,
648
- "eval_progress" ,
649
- "predict_progress" , # included because of AutoUnit
650
- dl_key ,
651
- "output_mean" ,
652
- ],
719
+ expected_keys_with_dl ,
653
720
)
654
721
722
+ # Now make sure that the same exact keys are used when restoring
723
+ with patch (
724
+ "torchtnt.framework.callbacks.dcp_saver.dcp.load"
725
+ ) as mock_load :
726
+ DistributedCheckpointSaver .restore (
727
+ full_ckpt_path ,
728
+ my_unit ,
729
+ train_dataloader = train_dataloader ,
730
+ eval_dataloader = eval_dataloader ,
731
+ )
732
+ self .assertCountEqual (
733
+ [* mock_load .call_args [0 ][0 ]["app_state" ].state_dict ().keys ()],
734
+ expected_keys_with_dl ,
735
+ )
736
+
655
737
def test_save_fit_eval_every_n_steps (self ) -> None :
656
738
input_dim = 2
657
739
@@ -710,24 +792,42 @@ def test_save_fit_eval_every_n_steps(self) -> None:
710
792
generated_ckpts , [* expected_ckpts_to_dl_mapping .keys ()]
711
793
)
712
794
795
+ expected_keys = [
796
+ "module" , # Both train and eval checkpoints save full app_state in fit
797
+ "optimizer" ,
798
+ "lr_scheduler" ,
799
+ "train_progress" ,
800
+ "eval_progress" ,
801
+ "predict_progress" , # included because of AutoUnit
802
+ "output_mean" ,
803
+ ]
804
+
713
805
for ckpt_path , expected_dls in expected_ckpts_to_dl_mapping .items ():
714
- storage_reader = FsspecReader (os .path .join (temp_dir , ckpt_path ))
806
+ expected_keys_with_dls = [* expected_keys , * expected_dls ]
807
+ full_ckpt_path = os .path .join (temp_dir , ckpt_path )
808
+ storage_reader = FsspecReader (full_ckpt_path )
715
809
metadata = storage_reader .read_metadata ()
716
810
self .assertCountEqual (
717
811
# Get base keys after the app_state wrapper
718
812
{key .split ("." )[1 ] for key in metadata .state_dict_metadata .keys ()},
719
- [
720
- "module" , # Both train and eval checkpoints save full app_state in fit
721
- "optimizer" ,
722
- "lr_scheduler" ,
723
- "train_progress" ,
724
- "eval_progress" ,
725
- "predict_progress" , # included because of AutoUnit
726
- "output_mean" ,
727
- * expected_dls ,
728
- ],
813
+ expected_keys_with_dls ,
729
814
)
730
815
816
+ # Now make sure that the same exact keys are used when restoring
817
+ with patch (
818
+ "torchtnt.framework.callbacks.dcp_saver.dcp.load"
819
+ ) as mock_load :
820
+ DistributedCheckpointSaver .restore (
821
+ full_ckpt_path ,
822
+ my_unit ,
823
+ train_dataloader = train_dataloader ,
824
+ eval_dataloader = eval_dataloader ,
825
+ )
826
+ self .assertCountEqual (
827
+ [* mock_load .call_args [0 ][0 ]["app_state" ].state_dict ().keys ()],
828
+ expected_keys_with_dls ,
829
+ )
830
+
731
831
732
832
class DummyStatefulDataLoader :
733
833
def __init__ (self , dataloader : DataLoader ) -> None :
0 commit comments