|
16 | 16 | from torchtnt.framework.callback import Callback
|
17 | 17 | from torchtnt.framework.callbacks._checkpoint_utils import _get_step_phase_mapping
|
18 | 18 | from torchtnt.framework.callbacks.checkpointer_types import RestoreOptions
|
19 |
| -from torchtnt.framework.state import State |
| 19 | +from torchtnt.framework.state import EntryPoint, State |
20 | 20 | from torchtnt.framework.unit import AppStateMixin, TEvalUnit, TTrainData, TTrainUnit
|
21 | 21 | from torchtnt.utils.checkpoint import (
|
22 | 22 | BestCheckpointConfig,
|
|
25 | 25 | get_best_checkpoint_path,
|
26 | 26 | get_latest_checkpoint_path,
|
27 | 27 | MetricData,
|
| 28 | + Phase, |
28 | 29 | )
|
29 | 30 | from torchtnt.utils.distributed import PGWrapper
|
30 | 31 | from torchtnt.utils.rank_zero_log import rank_zero_info, rank_zero_warn
|
@@ -177,12 +178,28 @@ def _generate_checkpoint_and_upkeep(
|
177 | 178 | if not self._checkpoint_manager.should_save_checkpoint(checkpoint_path):
|
178 | 179 | return False
|
179 | 180 |
|
180 |
| - # 2.1) Make sure that last checkpoint does not already exist |
181 |
| - if hook == "on_train_end" and self._does_checkpoint_exist( |
182 |
| - checkpoint_path, self._process_group |
183 |
| - ): |
184 |
| - rank_zero_warn("Final checkpoint already exists, skipping.", logger=logger) |
185 |
| - return False |
| 181 | + if hook == "on_train_end": |
| 182 | + # 2.1) Make sure that last checkpoint does not already exist |
| 183 | + if self._does_checkpoint_exist(checkpoint_path, self._process_group): |
| 184 | + rank_zero_warn( |
| 185 | + "Final checkpoint already exists, skipping.", logger=logger |
| 186 | + ) |
| 187 | + return False |
| 188 | + |
| 189 | + # 2.2) If doing fit without eval checkpointing, only consider training progress when |
| 190 | + # checking if last checkpoint exists. |
| 191 | + if ( |
| 192 | + state.entry_point == EntryPoint.FIT |
| 193 | + and self._save_every_n_eval_epochs is None |
| 194 | + and self._checkpoint_manager._ckpt_paths |
| 195 | + and self._checkpoint_manager._ckpt_paths[-1].step[Phase.TRAIN] |
| 196 | + == cast(TTrainUnit, unit).train_progress.num_steps_completed |
| 197 | + ): |
| 198 | + rank_zero_info( |
| 199 | + "Omitting final checkpoint since train progress is unchanged, and eval checkpointing is not configured.", |
| 200 | + logger=logger, |
| 201 | + ) |
| 202 | + return False |
186 | 203 |
|
187 | 204 | # 3) try to save checkpoint
|
188 | 205 | if not self._checkpoint_impl(
|
|
0 commit comments