Skip to content

Commit 1fe0a5d

Browse files
JKSenthilfacebook-github-bot
authored andcommitted
include dataloader in eval epoch end during FIT (#957)
Summary: Pull Request resolved: #957 Reviewed By: diego-urgell Differential Revision: D67813439 fbshipit-source-id: f1fdbbbb0b9784d0e4bdf6dffd9f52eec1915bab
1 parent de119c5 commit 1fe0a5d

File tree

2 files changed

+99
-1
lines changed

2 files changed

+99
-1
lines changed

tests/framework/callbacks/test_dcp_saver.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -748,6 +748,98 @@ def test_save_restore_fit_eval_every_n_epochs(self) -> None:
748748
expected_keys_with_dl,
749749
)
750750

751+
def test_save_restore_fit_save_every_n_eval_epochs(self) -> None:
752+
input_dim = 2
753+
dataset_len = 10
754+
batch_size = 2
755+
756+
my_unit = DummyAutoUnit(module=nn.Linear(input_dim, 2))
757+
my_unit.output_mean = DummyMeanMetric()
758+
my_unit.loss = 0.1
759+
760+
train_dataloader = generate_dummy_stateful_dataloader(
761+
dataset_len, input_dim, batch_size
762+
)
763+
764+
eval_dataloader = generate_dummy_stateful_dataloader(
765+
dataset_len, input_dim, batch_size
766+
)
767+
768+
with tempfile.TemporaryDirectory() as temp_dir:
769+
dcp_cb = DistributedCheckpointSaver(
770+
temp_dir,
771+
knob_options=KnobOptions(1),
772+
save_every_n_eval_epochs=1,
773+
best_checkpoint_config=BestCheckpointConfig(monitored_metric="loss"),
774+
)
775+
776+
fit(
777+
my_unit,
778+
max_epochs=1,
779+
evaluate_every_n_steps=1,
780+
train_dataloader=train_dataloader,
781+
eval_dataloader=eval_dataloader,
782+
callbacks=[dcp_cb],
783+
)
784+
785+
generated_ckpts = os.listdir(temp_dir)
786+
# fbvscode.set_trace()
787+
# Since we are using FIT, the metric value should be included
788+
expected_ckpts_to_dl_mapping = {
789+
"epoch_0_train_step_1_eval_step_5_loss=0.1",
790+
"epoch_0_train_step_2_eval_step_10_loss=0.1",
791+
"epoch_0_train_step_3_eval_step_15_loss=0.1",
792+
"epoch_0_train_step_4_eval_step_20_loss=0.1",
793+
"epoch_0_train_step_5_eval_step_25_loss=0.1",
794+
"epoch_1_train_step_5_eval_step_30_loss=0.1",
795+
}
796+
self.assertCountEqual(generated_ckpts, [*expected_ckpts_to_dl_mapping])
797+
798+
expected_keys = [
799+
"module", # Both train and eval checkpoints save full app_state in fit
800+
"optimizer",
801+
"lr_scheduler",
802+
"train_progress",
803+
"eval_progress",
804+
"predict_progress", # included because of AutoUnit
805+
"output_mean",
806+
"eval_dataloader",
807+
"train_dataloader",
808+
]
809+
810+
for ckpt_path in expected_ckpts_to_dl_mapping:
811+
full_ckpt_path = os.path.join(temp_dir, ckpt_path)
812+
expected_keys_with_dl = list(expected_keys)
813+
storage_reader = FsspecReader(full_ckpt_path)
814+
metadata = storage_reader.read_metadata()
815+
if ckpt_path == "epoch_1_train_step_5_eval_step_30_loss=0.1":
816+
# remove dataloader keys as final checkpoint wont have them
817+
expected_keys_with_dl = expected_keys_with_dl[:-1]
818+
appstate_keys = {
819+
key.split(".")[1] for key in metadata.state_dict_metadata.keys()
820+
}
821+
self.assertCountEqual(
822+
# Get base keys after the app_state wrapper
823+
appstate_keys,
824+
expected_keys_with_dl,
825+
msg=f"key: {ckpt_path}, {expected_keys_with_dl=}, {appstate_keys=},",
826+
)
827+
828+
# Now make sure that the same exact keys are used when restoring
829+
with patch(
830+
"torchtnt.framework.callbacks.dcp_saver.dcp.load"
831+
) as mock_load:
832+
DistributedCheckpointSaver.restore(
833+
full_ckpt_path,
834+
my_unit,
835+
train_dataloader=train_dataloader,
836+
eval_dataloader=eval_dataloader,
837+
)
838+
self.assertCountEqual(
839+
[*mock_load.call_args[0][0]["app_state"].state_dict().keys()],
840+
expected_keys_with_dl,
841+
)
842+
751843
def test_save_fit_eval_every_n_steps(self) -> None:
752844
input_dim = 2
753845

torchtnt/framework/callbacks/dcp_saver.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
DefaultSavePlanner,
2727
)
2828
from torch.distributed.checkpoint.planner import LoadPlanner, SavePlanner
29+
from torchtnt.framework.state import EntryPoint
2930

3031
try:
3132
from torch.distributed.checkpoint.state_dict import _init_optim_state
@@ -171,7 +172,12 @@ def _checkpoint_impl(
171172
]:
172173
raise RuntimeError(f"Unexpected hook encountered '{hook}'")
173174

174-
intra_epoch = "step_end" in hook
175+
# intra epoch when checkpointing during "_step_end" hook OR
176+
# when checkpointing during "on_eval_epoch_end" hook and the entry point is fit
177+
# since it is still intra epoch with respect to the train epoch
178+
intra_epoch = "step_end" in hook or (
179+
"on_eval_epoch_end" == hook and state.entry_point == EntryPoint.FIT
180+
)
175181
curr_snapshot_wait = hook == "on_train_end"
176182

177183
if planner is None:

0 commit comments

Comments
 (0)