Skip to content

Commit 1beb1f0

Browse files
diego-urgellfacebook-github-bot
authored andcommitted
Custom restore for eval/predict checkpoints based on path (#919)
Summary: Pull Request resolved: #919 Reviewed By: JKSenthil Differential Revision: D63013011 fbshipit-source-id: baaac36f6ce3b5688edd7d98711edfce13277434
1 parent 8bf95f3 commit 1beb1f0

File tree

3 files changed

+260
-69
lines changed

3 files changed

+260
-69
lines changed

tests/framework/callbacks/test_dcp_saver.py

Lines changed: 143 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ def test_save_restore_dataloader_state(self) -> None:
145145
self.assertEqual(stateful_dataloader.load_state_dict_call_count, 1)
146146
self.assertEqual(
147147
log.output[0],
148-
"WARNING:torchtnt.utils.rank_zero_log:train_dataloader was passed to `restore` but no train dataloader exists in the Snapshot",
148+
"WARNING:torchtnt.framework.callbacks.dcp_saver:dataloader (train) was passed to `restore` but no dataloader exists in checkpoint metadata.",
149149
)
150150

151151
def test_restore_from_latest(self) -> None:
@@ -500,7 +500,7 @@ def test_save_restore_multi_optimizers(self) -> None:
500500
my_unit_clone = DummyMultiOptimUnit(input_dim=input_dim)
501501
dcp_cb.restore_from_latest(temp_dir, my_unit_clone)
502502

503-
def test_save_predict(self) -> None:
503+
def test_save_restore_predict(self) -> None:
504504
input_dim = 2
505505
dataset_len = 10
506506
batch_size = 2
@@ -537,19 +537,51 @@ def test_save_predict(self) -> None:
537537
ckpt_path = none_throws(get_latest_checkpoint_path(temp_dir))
538538
self.assertEqual(ckpt_path, os.path.join(temp_dir, expected_ckpts[-1]))
539539

540+
expected_keys = [
541+
"predict_progress",
542+
"predict_dataloader",
543+
"output_mean",
544+
]
545+
540546
storage_reader = FsspecReader(ckpt_path)
541547
metadata = storage_reader.read_metadata()
542548
self.assertCountEqual(
543549
# Get base keys after the app_state wrapper
544550
{key.split(".")[1] for key in metadata.state_dict_metadata.keys()},
545-
[
546-
"predict_progress",
547-
"predict_dataloader",
548-
"output_mean",
549-
],
551+
expected_keys,
552+
)
553+
554+
# Now make sure that the same exact keys are used when restoring
555+
with patch("torchtnt.framework.callbacks.dcp_saver.dcp.load") as mock_load:
556+
DistributedCheckpointSaver.restore(
557+
ckpt_path, my_unit, predict_dataloader=dataloader
558+
)
559+
self.assertCountEqual(
560+
[*mock_load.call_args[0][0]["app_state"].state_dict().keys()],
561+
expected_keys,
562+
)
563+
564+
# Double check that the module parameters are not overwritten when loading cktp
565+
my_unit = DummyPredictUnit(input_dim=input_dim)
566+
my_unit.module.weight.data.fill_(0.0)
567+
my_unit.module.bias.data.fill_(1.0)
568+
569+
DistributedCheckpointSaver.restore(
570+
ckpt_path, my_unit, predict_dataloader=dataloader
571+
)
572+
573+
self.assertTrue(
574+
torch.allclose(
575+
my_unit.module.weight.data, torch.zeros(input_dim, input_dim)
576+
)
577+
)
578+
self.assertTrue(
579+
torch.allclose(
580+
my_unit.module.bias.data, torch.ones(input_dim, input_dim)
581+
)
550582
)
551583

552-
def test_save_evaluate(self) -> None:
584+
def test_save_restore_evaluate(self) -> None:
553585
input_dim = 2
554586
dataset_len = 10
555587
batch_size = 2
@@ -580,18 +612,49 @@ def test_save_evaluate(self) -> None:
580612
ckpt_path = none_throws(get_latest_checkpoint_path(temp_dir))
581613
self.assertEqual(ckpt_path, os.path.join(temp_dir, expected_ckpts[-1]))
582614

615+
expected_keys = [
616+
"eval_progress",
617+
"eval_dataloader",
618+
]
583619
storage_reader = FsspecReader(ckpt_path)
584620
metadata = storage_reader.read_metadata()
585621
self.assertCountEqual(
586622
# Get base keys after the app_state wrapper
587623
{key.split(".")[1] for key in metadata.state_dict_metadata.keys()},
588-
[
589-
"eval_progress",
590-
"eval_dataloader",
591-
],
624+
expected_keys,
592625
)
593626

594-
def test_save_fit_eval_every_n_epochs(self) -> None:
627+
# Now make sure that the same exact keys are used when restoring
628+
with patch("torchtnt.framework.callbacks.dcp_saver.dcp.load") as mock_load:
629+
DistributedCheckpointSaver.restore(
630+
ckpt_path, my_unit, eval_dataloader=dataloader
631+
)
632+
self.assertCountEqual(
633+
[*mock_load.call_args[0][0]["app_state"].state_dict().keys()],
634+
expected_keys,
635+
)
636+
637+
# Double check that the module parameters are not overwritten when loading cktp
638+
my_unit = DummyEvalUnit(input_dim=input_dim)
639+
my_unit.module.weight.data.fill_(0.0)
640+
my_unit.module.bias.data.fill_(1.0)
641+
642+
DistributedCheckpointSaver.restore(
643+
ckpt_path, my_unit, predict_dataloader=dataloader
644+
)
645+
646+
self.assertTrue(
647+
torch.allclose(
648+
my_unit.module.weight.data, torch.zeros(input_dim, input_dim)
649+
)
650+
)
651+
self.assertTrue(
652+
torch.allclose(
653+
my_unit.module.bias.data, torch.ones(input_dim, input_dim)
654+
)
655+
)
656+
657+
def test_save_restore_fit_eval_every_n_epochs(self) -> None:
595658
input_dim = 2
596659
dataset_len = 10
597660
batch_size = 2
@@ -625,33 +688,52 @@ def test_save_fit_eval_every_n_epochs(self) -> None:
625688
)
626689

627690
generated_ckpts = os.listdir(temp_dir)
628-
expected_ckpts = [
629-
"epoch_0_train_step_2_eval_step_0",
630-
"epoch_0_train_step_4_eval_step_0",
631-
"epoch_1_train_step_5_eval_step_2",
632-
"epoch_1_train_step_5_eval_step_4",
691+
expected_ckpts_to_dl_mapping: Dict[str, str] = {
692+
"epoch_0_train_step_2_eval_step_0": "train_dataloader",
693+
"epoch_0_train_step_4_eval_step_0": "train_dataloader",
694+
"epoch_1_train_step_5_eval_step_2": "eval_dataloader",
695+
"epoch_1_train_step_5_eval_step_4": "eval_dataloader",
696+
}
697+
self.assertCountEqual(
698+
generated_ckpts, [*expected_ckpts_to_dl_mapping.keys()]
699+
)
700+
701+
expected_keys = [
702+
"module", # Both train and eval checkpoints save full app_state in fit
703+
"optimizer",
704+
"lr_scheduler",
705+
"train_progress",
706+
"eval_progress",
707+
"predict_progress", # included because of AutoUnit
708+
"output_mean",
633709
]
634-
self.assertCountEqual(generated_ckpts, expected_ckpts)
635710

636-
expected_dataloader = ["train_dataloader"] * 2 + ["eval_dataloader"] * 2
637-
for ckpt_path, dl_key in zip(expected_ckpts, expected_dataloader):
638-
storage_reader = FsspecReader(os.path.join(temp_dir, ckpt_path))
711+
for ckpt_path, dl_key in expected_ckpts_to_dl_mapping.items():
712+
full_ckpt_path = os.path.join(temp_dir, ckpt_path)
713+
expected_keys_with_dl = expected_keys + [dl_key]
714+
storage_reader = FsspecReader(full_ckpt_path)
639715
metadata = storage_reader.read_metadata()
640716
self.assertCountEqual(
641717
# Get base keys after the app_state wrapper
642718
{key.split(".")[1] for key in metadata.state_dict_metadata.keys()},
643-
[
644-
"module", # Both train and eval checkpoints save full app_state in fit
645-
"optimizer",
646-
"lr_scheduler",
647-
"train_progress",
648-
"eval_progress",
649-
"predict_progress", # included because of AutoUnit
650-
dl_key,
651-
"output_mean",
652-
],
719+
expected_keys_with_dl,
653720
)
654721

722+
# Now make sure that the same exact keys are used when restoring
723+
with patch(
724+
"torchtnt.framework.callbacks.dcp_saver.dcp.load"
725+
) as mock_load:
726+
DistributedCheckpointSaver.restore(
727+
full_ckpt_path,
728+
my_unit,
729+
train_dataloader=train_dataloader,
730+
eval_dataloader=eval_dataloader,
731+
)
732+
self.assertCountEqual(
733+
[*mock_load.call_args[0][0]["app_state"].state_dict().keys()],
734+
expected_keys_with_dl,
735+
)
736+
655737
def test_save_fit_eval_every_n_steps(self) -> None:
656738
input_dim = 2
657739

@@ -710,24 +792,42 @@ def test_save_fit_eval_every_n_steps(self) -> None:
710792
generated_ckpts, [*expected_ckpts_to_dl_mapping.keys()]
711793
)
712794

795+
expected_keys = [
796+
"module", # Both train and eval checkpoints save full app_state in fit
797+
"optimizer",
798+
"lr_scheduler",
799+
"train_progress",
800+
"eval_progress",
801+
"predict_progress", # included because of AutoUnit
802+
"output_mean",
803+
]
804+
713805
for ckpt_path, expected_dls in expected_ckpts_to_dl_mapping.items():
714-
storage_reader = FsspecReader(os.path.join(temp_dir, ckpt_path))
806+
expected_keys_with_dls = [*expected_keys, *expected_dls]
807+
full_ckpt_path = os.path.join(temp_dir, ckpt_path)
808+
storage_reader = FsspecReader(full_ckpt_path)
715809
metadata = storage_reader.read_metadata()
716810
self.assertCountEqual(
717811
# Get base keys after the app_state wrapper
718812
{key.split(".")[1] for key in metadata.state_dict_metadata.keys()},
719-
[
720-
"module", # Both train and eval checkpoints save full app_state in fit
721-
"optimizer",
722-
"lr_scheduler",
723-
"train_progress",
724-
"eval_progress",
725-
"predict_progress", # included because of AutoUnit
726-
"output_mean",
727-
*expected_dls,
728-
],
813+
expected_keys_with_dls,
729814
)
730815

816+
# Now make sure that the same exact keys are used when restoring
817+
with patch(
818+
"torchtnt.framework.callbacks.dcp_saver.dcp.load"
819+
) as mock_load:
820+
DistributedCheckpointSaver.restore(
821+
full_ckpt_path,
822+
my_unit,
823+
train_dataloader=train_dataloader,
824+
eval_dataloader=eval_dataloader,
825+
)
826+
self.assertCountEqual(
827+
[*mock_load.call_args[0][0]["app_state"].state_dict().keys()],
828+
expected_keys_with_dls,
829+
)
830+
731831

732832
class DummyStatefulDataLoader:
733833
def __init__(self, dataloader: DataLoader) -> None:

0 commit comments

Comments
 (0)