Skip to content

Commit 0f72333

Browse files
diego-urgellfacebook-github-bot
authored andcommitted
Don't take final ckpt if no more training was done in FIT (#846)
Summary: Pull Request resolved: #846 Reviewed By: JKSenthil Differential Revision: D58397317 fbshipit-source-id: 31b8f7382059f04cd35f26eafbd957bd53e7f3f0
1 parent f9f566b commit 0f72333

File tree

2 files changed

+62
-7
lines changed

2 files changed

+62
-7
lines changed

tests/framework/callbacks/test_base_checkpointer.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -500,6 +500,44 @@ def test_save_on_train_end(self) -> None:
500500
],
501501
)
502502

503+
def test_save_on_train_end_on_fit(self) -> None:
504+
input_dim = 2
505+
dataset_len = 10
506+
batch_size = 2
507+
max_epochs = 6
508+
509+
for save_every_n_eval_epochs, expected_last_ckpt in [
510+
(None, "epoch_6_train_step_30_eval_step_25"),
511+
(2, "epoch_6_train_step_30_eval_step_30"),
512+
]:
513+
my_unit = DummyAutoUnit(module=nn.Linear(input_dim, 2))
514+
train_dataloader = generate_random_dataloader(
515+
dataset_len, input_dim, batch_size
516+
)
517+
eval_dataloader = generate_random_dataloader(
518+
dataset_len, input_dim, batch_size
519+
)
520+
with tempfile.TemporaryDirectory() as temp_dir:
521+
checkpoint_cb = BaseCheckpointSaver(
522+
temp_dir,
523+
save_every_n_epochs=2,
524+
save_every_n_eval_epochs=save_every_n_eval_epochs,
525+
)
526+
fit(
527+
my_unit,
528+
train_dataloader=train_dataloader,
529+
eval_dataloader=eval_dataloader,
530+
max_epochs=max_epochs,
531+
evaluate_every_n_epochs=1,
532+
callbacks=[checkpoint_cb],
533+
)
534+
expected_path = os.path.join(temp_dir, expected_last_ckpt)
535+
self.assertTrue(os.path.exists(expected_path))
536+
self.assertEqual(
537+
checkpoint_cb._checkpoint_manager._ckpt_paths[-1].path,
538+
expected_path,
539+
)
540+
503541
@skip_if_not_distributed
504542
def test_directory_sync_collective(self) -> None:
505543
spawn_multi_process(

torchtnt/framework/callbacks/base_checkpointer.py

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from torchtnt.framework.callback import Callback
1717
from torchtnt.framework.callbacks._checkpoint_utils import _get_step_phase_mapping
1818
from torchtnt.framework.callbacks.checkpointer_types import RestoreOptions
19-
from torchtnt.framework.state import State
19+
from torchtnt.framework.state import EntryPoint, State
2020
from torchtnt.framework.unit import AppStateMixin, TEvalUnit, TTrainData, TTrainUnit
2121
from torchtnt.utils.checkpoint import (
2222
BestCheckpointConfig,
@@ -25,6 +25,7 @@
2525
get_best_checkpoint_path,
2626
get_latest_checkpoint_path,
2727
MetricData,
28+
Phase,
2829
)
2930
from torchtnt.utils.distributed import PGWrapper
3031
from torchtnt.utils.rank_zero_log import rank_zero_info, rank_zero_warn
@@ -177,12 +178,28 @@ def _generate_checkpoint_and_upkeep(
177178
if not self._checkpoint_manager.should_save_checkpoint(checkpoint_path):
178179
return False
179180

180-
# 2.1) Make sure that last checkpoint does not already exist
181-
if hook == "on_train_end" and self._does_checkpoint_exist(
182-
checkpoint_path, self._process_group
183-
):
184-
rank_zero_warn("Final checkpoint already exists, skipping.", logger=logger)
185-
return False
181+
if hook == "on_train_end":
182+
# 2.1) Make sure that last checkpoint does not already exist
183+
if self._does_checkpoint_exist(checkpoint_path, self._process_group):
184+
rank_zero_warn(
185+
"Final checkpoint already exists, skipping.", logger=logger
186+
)
187+
return False
188+
189+
# 2.2) If doing fit without eval checkpointing, only consider training progress when
190+
# checking if last checkpoint exists.
191+
if (
192+
state.entry_point == EntryPoint.FIT
193+
and self._save_every_n_eval_epochs is None
194+
and self._checkpoint_manager._ckpt_paths
195+
and self._checkpoint_manager._ckpt_paths[-1].step[Phase.TRAIN]
196+
== cast(TTrainUnit, unit).train_progress.num_steps_completed
197+
):
198+
rank_zero_info(
199+
"Omitting final checkpoint since train progress is unchanged, and eval checkpointing is not configured.",
200+
logger=logger,
201+
)
202+
return False
186203

187204
# 3) try to save checkpoint
188205
if not self._checkpoint_impl(

0 commit comments

Comments
 (0)