|
14 | 14 | from typing import Any, cast, Iterable, List, Literal, Optional, Union
|
15 | 15 |
|
16 | 16 | import torch.distributed as dist
|
17 |
| - |
| 17 | +from pyre_extensions import none_throws |
18 | 18 | from torchtnt.framework.callback import Callback
|
19 | 19 | from torchtnt.framework.callbacks._checkpoint_utils import (
|
20 | 20 | _delete_checkpoint,
|
@@ -197,85 +197,96 @@ def _generate_checkpoint_and_upkeep(
|
197 | 197 | Returns:
|
198 | 198 | True if checkpoint was successfully saved. False otherwise.
|
199 | 199 | """
|
200 |
| - unit = cast(TTrainUnit, unit) |
201 |
| - |
202 | 200 | # 1) generate checkpoint name
|
| 201 | + unit = cast(TTrainUnit, unit) |
203 | 202 | num_steps_completed = unit.train_progress.num_steps_completed
|
204 | 203 | if state.entry_point == EntryPoint.FIT:
|
205 |
| - num_steps_completed += cast( |
206 |
| - TEvalUnit, unit |
207 |
| - ).eval_progress.num_steps_completed |
| 204 | + eval_unit = cast(TEvalUnit, unit) |
| 205 | + num_steps_completed += eval_unit.eval_progress.num_steps_completed |
208 | 206 | epoch = unit.train_progress.num_epochs_completed
|
209 | 207 | checkpoint_path = _get_save_path(self._dirpath, epoch, num_steps_completed)
|
210 | 208 |
|
211 |
| - # 1.5) Ensure the need to checkpoint again at the end of training |
| 209 | + # 1.1) Make sure that last checkpoint does not already exist |
212 | 210 | if hook == "on_train_end" and self._does_checkpoint_exist(
|
213 | 211 | checkpoint_path, process_group=self._process_group
|
214 | 212 | ):
|
215 | 213 | rank_zero_warn("Final checkpoint already exists, skipping.", logger=logger)
|
216 | 214 | return False
|
217 | 215 |
|
218 |
| - # 2) handle best checkpoint config on all hooks except `on_train_end` |
219 |
| - # TODO: isolate this logic into its own function |
220 |
| - metric_value_f: Optional[float] = None |
221 |
| - best_checkpoint_config = self._best_checkpoint_config |
222 |
| - if best_checkpoint_config: |
223 |
| - if not hasattr(unit, best_checkpoint_config.monitored_metric): |
224 |
| - raise RuntimeError( |
225 |
| - f"Unit does not have attribute {best_checkpoint_config.monitored_metric}, unable to retrieve metric to checkpoint." |
226 |
| - ) |
| 216 | + # 1.2) If there is a tracked metric, add to the checkpoint path |
| 217 | + metric_value = self._get_tracked_metric_value(unit) |
| 218 | + if metric_value is not None: |
| 219 | + metric_name = none_throws(self._best_checkpoint_config).monitored_metric |
| 220 | + checkpoint_path += f"_{metric_name}={metric_value}" |
227 | 221 |
|
228 |
| - metric_value = getattr(unit, best_checkpoint_config.monitored_metric) |
229 |
| - if metric_value is not None: |
230 |
| - try: |
231 |
| - metric_value_f = float(metric_value) |
232 |
| - except Exception as e: |
233 |
| - raise RuntimeError( |
234 |
| - f"Unable to convert monitored metric {best_checkpoint_config.monitored_metric} to a float. Please ensure the value can be converted to float and is not a multi-element tensor value." |
235 |
| - ) from e |
236 |
| - |
237 |
| - # update checkpoint path to include the metric value info |
238 |
| - checkpoint_path += ( |
239 |
| - f"_{best_checkpoint_config.monitored_metric}={metric_value_f}" |
240 |
| - ) |
241 |
| - |
242 |
| - should_checkpoint = self._should_save_checkpoint(metric_value_f) |
243 |
| - if not should_checkpoint: |
| 222 | + # 2) Determine if checkpoint should be saved |
| 223 | + if not self._should_save_checkpoint(metric_value): |
244 | 224 | return False
|
245 | 225 |
|
246 | 226 | # 3) try to save checkpoint
|
247 |
| - success = self._checkpoint_impl( |
248 |
| - state, |
249 |
| - unit, |
250 |
| - checkpoint_path=checkpoint_path, |
251 |
| - hook=hook, |
252 |
| - ) |
| 227 | + if not self._checkpoint_impl( |
| 228 | + state, unit, checkpoint_path=checkpoint_path, hook=hook |
| 229 | + ): |
| 230 | + return False |
253 | 231 |
|
254 |
| - if success: |
255 |
| - # remove the checkpoint if applicable |
256 |
| - # and update the tracked list of checkpoint paths |
| 232 | + # 4) remove the oldest/worst checkpoint if applicable |
| 233 | + if self._should_remove_checkpoint(): |
| 234 | + self._remove_checkpoint(state) |
| 235 | + |
| 236 | + # 5) update the tracked list of checkpoint paths |
| 237 | + if self._best_checkpoint_config and (metric_value is not None): |
| 238 | + metric_mode = none_throws(self._best_checkpoint_config).mode |
| 239 | + # insert the checkpoint path at the correct index to preserve ordering |
| 240 | + keys = [ |
| 241 | + float(os.path.basename(x).split("=")[-1]) for x in self._ckpt_dirpaths |
| 242 | + ] |
| 243 | + if metric_mode == "min": |
| 244 | + keys.reverse() |
| 245 | + # Use bisect.bisect() to find the insertion point |
| 246 | + idx = bisect.bisect(keys, metric_value) |
| 247 | + if metric_mode == "min": |
| 248 | + idx = len(self._ckpt_dirpaths) - idx |
| 249 | + self._ckpt_dirpaths.insert(idx, checkpoint_path) |
| 250 | + |
| 251 | + elif not self._best_checkpoint_config: # no metric to track |
| 252 | + self._ckpt_dirpaths.append(checkpoint_path) |
257 | 253 |
|
258 |
| - if self._should_remove_checkpoint(): |
259 |
| - self._remove_checkpoint(state) |
| 254 | + return True |
260 | 255 |
|
261 |
| - if best_checkpoint_config: |
262 |
| - if metric_value_f: |
263 |
| - # insert the checkpoint path at the right index to preserve ordering |
264 |
| - keys = [ |
265 |
| - float(os.path.basename(x).split("=")[-1]) |
266 |
| - for x in self._ckpt_dirpaths |
267 |
| - ] |
268 |
| - if best_checkpoint_config.mode == "min": |
269 |
| - keys.reverse() |
270 |
| - # Use bisect.bisect() to find the insertion point |
271 |
| - idx = bisect.bisect(keys, metric_value_f) |
272 |
| - if best_checkpoint_config.mode == "min": |
273 |
| - idx = len(self._ckpt_dirpaths) - idx |
274 |
| - self._ckpt_dirpaths.insert(idx, checkpoint_path) |
275 |
| - else: |
276 |
| - self._ckpt_dirpaths.append(checkpoint_path) |
| 256 | + def _get_tracked_metric_value(self, unit: TTrainUnit) -> Optional[float]: |
| 257 | + """ |
| 258 | + If the checkpointer has a tracked metric, look the value in the unit using reflection, and cast to float. |
| 259 | +
|
| 260 | + Args: |
| 261 | + unit: The training unit to look for the tracked metric in. |
| 262 | +
|
| 263 | + Returns: |
| 264 | + The value of the tracked metric, or None if there is no best_checkpoint config defined. |
| 265 | +
|
| 266 | + Raises: |
| 267 | + RuntimeError: If the unit does not have the attribute specified in the best_checkpoint config, |
| 268 | + or if the value cannot be cast to a float. |
| 269 | + """ |
| 270 | + if not self._best_checkpoint_config: |
| 271 | + return None |
| 272 | + |
| 273 | + monitored_metric_name = self._best_checkpoint_config.monitored_metric |
| 274 | + if not hasattr(unit, monitored_metric_name): |
| 275 | + raise RuntimeError( |
| 276 | + f"Unit does not have attribute {monitored_metric_name}, unable to retrieve metric to checkpoint." |
| 277 | + ) |
| 278 | + |
| 279 | + metric_value_f = None |
| 280 | + if (metric_value := getattr(unit, monitored_metric_name)) is not None: |
| 281 | + try: |
| 282 | + metric_value_f = float(metric_value) |
| 283 | + except ValueError as e: |
| 284 | + raise RuntimeError( |
| 285 | + f"Unable to convert monitored metric {monitored_metric_name} to a float. Please ensure the value " |
| 286 | + "can be converted to float and is not a multi-element tensor value." |
| 287 | + ) from e |
277 | 288 |
|
278 |
| - return success |
| 289 | + return metric_value_f |
279 | 290 |
|
280 | 291 | def on_train_start(self, state: State, unit: TTrainUnit) -> None:
|
281 | 292 | # clean up the difference if surplus of checkpoints exist
|
|
0 commit comments