Skip to content

Commit 05d1458

Browse files
diego-urgellfacebook-github-bot
authored andcommitted
Reduce _generate_checkpoint_and_upkeep code complexity (#783)
Summary: Pull Request resolved: #783 The function _generate_checkpoint_and_upkeep is important for the checkpointing logic. But it has a linter warning indicating that it is too complex, and a there is a TODO to extract some logic into a separate function. Let's do a small refactor to improve readability and reduce function complexity, but avoid breaking changes or regressions. **Potential Future Changes** Note that while doing this change, I found two small bugs that we can fix. They are documented in this Bento notebook: https://fburl.com/anp/gqoezved I did not fix any of them here to avoid having a refactor + logic changes. Additionally, there is this user request that we can handle in this function. Again it was not modified here but we can decide what to do and change later: https://fb.workplace.com/groups/cu.training.framework.users/permalink/1147131179616886/ Reviewed By: galrotem Differential Revision: D55881050 fbshipit-source-id: 46d0777a2dc5208763628fecdac8c12a5573e407
1 parent fd78425 commit 05d1458

File tree

2 files changed

+116
-60
lines changed

2 files changed

+116
-60
lines changed

tests/framework/callbacks/test_base_checkpointer.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -852,6 +852,51 @@ def test_no_assert_error_in_on_train_end(self) -> None:
852852
callbacks=[checkpoint_cb],
853853
)
854854

855+
def test_get_tracked_metric_value(self) -> None:
856+
"""
857+
Tests that _get_tracked_metric_value returns the correct value
858+
"""
859+
val_loss_unit = MyValLossUnit()
860+
861+
val_loss_ckpt_cb = BaseCheckpointSaver(
862+
dirpath="checkpoint",
863+
best_checkpoint_config=BestCheckpointConfig("val_loss", "min"),
864+
)
865+
val_loss = val_loss_ckpt_cb._get_tracked_metric_value(val_loss_unit)
866+
self.assertEqual(0.01, val_loss)
867+
868+
# pyre-ignore
869+
val_loss_unit.val_loss = "0.01" # Test when returned as a string
870+
val_loss_from_s = val_loss_ckpt_cb._get_tracked_metric_value(val_loss_unit)
871+
self.assertEqual(0.01, val_loss_from_s)
872+
873+
# pyre-ignore
874+
val_loss_unit.val_loss = "hola" # Test weird metric value
875+
with self.assertRaisesRegex(
876+
RuntimeError,
877+
(
878+
"Unable to convert monitored metric val_loss to a float. Please ensure the value "
879+
"can be converted to float and is not a multi-element tensor value."
880+
),
881+
):
882+
val_loss = val_loss_ckpt_cb._get_tracked_metric_value(val_loss_unit)
883+
884+
train_loss_ckpt_cb = BaseCheckpointSaver(
885+
dirpath="checkpoint",
886+
best_checkpoint_config=BestCheckpointConfig("train_loss", "max"),
887+
)
888+
with self.assertRaisesRegex(
889+
RuntimeError,
890+
"Unit does not have attribute train_loss, unable to retrieve metric to checkpoint.",
891+
):
892+
val_loss = train_loss_ckpt_cb._get_tracked_metric_value(val_loss_unit)
893+
894+
ckpt_cb = BaseCheckpointSaver(
895+
dirpath="checkpoint",
896+
)
897+
no_metric = ckpt_cb._get_tracked_metric_value(val_loss_unit)
898+
self.assertIsNone(no_metric)
899+
855900

856901
class MyValLossUnit(TrainUnit[Batch]):
857902
def __init__(self) -> None:

torchtnt/framework/callbacks/base_checkpointer.py

Lines changed: 71 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from typing import Any, cast, Iterable, List, Literal, Optional, Union
1515

1616
import torch.distributed as dist
17-
17+
from pyre_extensions import none_throws
1818
from torchtnt.framework.callback import Callback
1919
from torchtnt.framework.callbacks._checkpoint_utils import (
2020
_delete_checkpoint,
@@ -197,85 +197,96 @@ def _generate_checkpoint_and_upkeep(
197197
Returns:
198198
True if checkpoint was successfully saved. False otherwise.
199199
"""
200-
unit = cast(TTrainUnit, unit)
201-
202200
# 1) generate checkpoint name
201+
unit = cast(TTrainUnit, unit)
203202
num_steps_completed = unit.train_progress.num_steps_completed
204203
if state.entry_point == EntryPoint.FIT:
205-
num_steps_completed += cast(
206-
TEvalUnit, unit
207-
).eval_progress.num_steps_completed
204+
eval_unit = cast(TEvalUnit, unit)
205+
num_steps_completed += eval_unit.eval_progress.num_steps_completed
208206
epoch = unit.train_progress.num_epochs_completed
209207
checkpoint_path = _get_save_path(self._dirpath, epoch, num_steps_completed)
210208

211-
# 1.5) Ensure the need to checkpoint again at the end of training
209+
# 1.1) Make sure that last checkpoint does not already exist
212210
if hook == "on_train_end" and self._does_checkpoint_exist(
213211
checkpoint_path, process_group=self._process_group
214212
):
215213
rank_zero_warn("Final checkpoint already exists, skipping.", logger=logger)
216214
return False
217215

218-
# 2) handle best checkpoint config on all hooks except `on_train_end`
219-
# TODO: isolate this logic into its own function
220-
metric_value_f: Optional[float] = None
221-
best_checkpoint_config = self._best_checkpoint_config
222-
if best_checkpoint_config:
223-
if not hasattr(unit, best_checkpoint_config.monitored_metric):
224-
raise RuntimeError(
225-
f"Unit does not have attribute {best_checkpoint_config.monitored_metric}, unable to retrieve metric to checkpoint."
226-
)
216+
# 1.2) If there is a tracked metric, add to the checkpoint path
217+
metric_value = self._get_tracked_metric_value(unit)
218+
if metric_value is not None:
219+
metric_name = none_throws(self._best_checkpoint_config).monitored_metric
220+
checkpoint_path += f"_{metric_name}={metric_value}"
227221

228-
metric_value = getattr(unit, best_checkpoint_config.monitored_metric)
229-
if metric_value is not None:
230-
try:
231-
metric_value_f = float(metric_value)
232-
except Exception as e:
233-
raise RuntimeError(
234-
f"Unable to convert monitored metric {best_checkpoint_config.monitored_metric} to a float. Please ensure the value can be converted to float and is not a multi-element tensor value."
235-
) from e
236-
237-
# update checkpoint path to include the metric value info
238-
checkpoint_path += (
239-
f"_{best_checkpoint_config.monitored_metric}={metric_value_f}"
240-
)
241-
242-
should_checkpoint = self._should_save_checkpoint(metric_value_f)
243-
if not should_checkpoint:
222+
# 2) Determine if checkpoint should be saved
223+
if not self._should_save_checkpoint(metric_value):
244224
return False
245225

246226
# 3) try to save checkpoint
247-
success = self._checkpoint_impl(
248-
state,
249-
unit,
250-
checkpoint_path=checkpoint_path,
251-
hook=hook,
252-
)
227+
if not self._checkpoint_impl(
228+
state, unit, checkpoint_path=checkpoint_path, hook=hook
229+
):
230+
return False
253231

254-
if success:
255-
# remove the checkpoint if applicable
256-
# and update the tracked list of checkpoint paths
232+
# 4) remove the oldest/worst checkpoint if applicable
233+
if self._should_remove_checkpoint():
234+
self._remove_checkpoint(state)
235+
236+
# 5) update the tracked list of checkpoint paths
237+
if self._best_checkpoint_config and (metric_value is not None):
238+
metric_mode = none_throws(self._best_checkpoint_config).mode
239+
# insert the checkpoint path at the correct index to preserve ordering
240+
keys = [
241+
float(os.path.basename(x).split("=")[-1]) for x in self._ckpt_dirpaths
242+
]
243+
if metric_mode == "min":
244+
keys.reverse()
245+
# Use bisect.bisect() to find the insertion point
246+
idx = bisect.bisect(keys, metric_value)
247+
if metric_mode == "min":
248+
idx = len(self._ckpt_dirpaths) - idx
249+
self._ckpt_dirpaths.insert(idx, checkpoint_path)
250+
251+
elif not self._best_checkpoint_config: # no metric to track
252+
self._ckpt_dirpaths.append(checkpoint_path)
257253

258-
if self._should_remove_checkpoint():
259-
self._remove_checkpoint(state)
254+
return True
260255

261-
if best_checkpoint_config:
262-
if metric_value_f:
263-
# insert the checkpoint path at the right index to preserve ordering
264-
keys = [
265-
float(os.path.basename(x).split("=")[-1])
266-
for x in self._ckpt_dirpaths
267-
]
268-
if best_checkpoint_config.mode == "min":
269-
keys.reverse()
270-
# Use bisect.bisect() to find the insertion point
271-
idx = bisect.bisect(keys, metric_value_f)
272-
if best_checkpoint_config.mode == "min":
273-
idx = len(self._ckpt_dirpaths) - idx
274-
self._ckpt_dirpaths.insert(idx, checkpoint_path)
275-
else:
276-
self._ckpt_dirpaths.append(checkpoint_path)
256+
def _get_tracked_metric_value(self, unit: TTrainUnit) -> Optional[float]:
257+
"""
258+
If the checkpointer has a tracked metric, look the value in the unit using reflection, and cast to float.
259+
260+
Args:
261+
unit: The training unit to look for the tracked metric in.
262+
263+
Returns:
264+
The value of the tracked metric, or None if there is no best_checkpoint config defined.
265+
266+
Raises:
267+
RuntimeError: If the unit does not have the attribute specified in the best_checkpoint config,
268+
or if the value cannot be cast to a float.
269+
"""
270+
if not self._best_checkpoint_config:
271+
return None
272+
273+
monitored_metric_name = self._best_checkpoint_config.monitored_metric
274+
if not hasattr(unit, monitored_metric_name):
275+
raise RuntimeError(
276+
f"Unit does not have attribute {monitored_metric_name}, unable to retrieve metric to checkpoint."
277+
)
278+
279+
metric_value_f = None
280+
if (metric_value := getattr(unit, monitored_metric_name)) is not None:
281+
try:
282+
metric_value_f = float(metric_value)
283+
except ValueError as e:
284+
raise RuntimeError(
285+
f"Unable to convert monitored metric {monitored_metric_name} to a float. Please ensure the value "
286+
"can be converted to float and is not a multi-element tensor value."
287+
) from e
277288

278-
return success
289+
return metric_value_f
279290

280291
def on_train_start(self, state: State, unit: TTrainUnit) -> None:
281292
# clean up the difference if surplus of checkpoints exist

0 commit comments

Comments
 (0)