Skip to content

Commit 041ebe1

Browse files
diego-urgellfacebook-github-bot
authored andcommitted
Don't fail job if CheckpointManager fails deleting older checkpoints (#882)
Summary: Pull Request resolved: #882 Reviewed By: anshulverma, JKSenthil Differential Revision: D61307267 fbshipit-source-id: 4dc97353c34b6dcfbf04b374784107bc636d7f0f
1 parent 123453e commit 041ebe1

File tree

2 files changed

+48
-3
lines changed

2 files changed

+48
-3
lines changed

tests/utils/test_checkpoint.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import shutil
1111
import tempfile
1212
import unittest
13-
from unittest.mock import patch
13+
from unittest.mock import MagicMock, patch
1414

1515
import torch
1616

@@ -787,6 +787,43 @@ def test_remove_worst_checkpoint(self) -> None:
787787
self.assertTrue(os.path.exists(os.path.join(temp_dir, "epoch_0_step_1")))
788788
self.assertEqual(ckpt_manager._ckpt_paths, [CheckpointPath(temp_dir, 0, 1)])
789789

790+
@patch(
791+
"fsspec.implementations.local.LocalFileSystem.rm",
792+
side_effect=Exception("OSError: [Errno 2] No such file or directory"),
793+
)
794+
def test_remove_worst_checkpoint_exception(self, mock_url_to_fs: MagicMock) -> None:
795+
with tempfile.TemporaryDirectory() as temp_dir:
796+
os.mkdir(os.path.join(temp_dir, "epoch_0_train_step_0"))
797+
os.mkdir(os.path.join(temp_dir, "epoch_0_train_step_1"))
798+
799+
ckpt_manager = CheckpointManager(temp_dir, keep_last_n_checkpoints=2)
800+
801+
log_container = []
802+
with patch(
803+
"torchtnt.utils.checkpoint.logging.Logger.error", log_container.append
804+
):
805+
ckpt_manager.append_checkpoint(
806+
CheckpointPath(temp_dir, 0, {Phase.TRAIN: 2})
807+
)
808+
809+
self.assertEqual(
810+
log_container,
811+
[
812+
(
813+
f"Failed to remove checkpoint '{temp_dir}/epoch_0_train_step_0' for bookkeeping purposes. "
814+
"Do not use it to restore since it may be corrupted! Exception: OSError: [Errno 2] No such file or directory"
815+
)
816+
],
817+
)
818+
# Make sure we are not tracking the oldest one anymore, even if it was not deleted
819+
self.assertEqual(
820+
ckpt_manager._ckpt_paths,
821+
[
822+
CheckpointPath(temp_dir, 0, {Phase.TRAIN: 1}),
823+
CheckpointPath(temp_dir, 0, {Phase.TRAIN: 2}),
824+
],
825+
)
826+
790827

791828
class CheckpointUtilsTest(unittest.TestCase):
792829
@staticmethod

torchtnt/utils/checkpoint.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -487,7 +487,7 @@ def append_checkpoint(self, ckpt: CheckpointPath) -> None:
487487
ckpt: The checkpoint to save.
488488
state: The training state.
489489
"""
490-
# Remove oldest checkpoint if needed
490+
# Remove oldest/worst checkpoint if needed
491491
max_ckpts = self._keep_last_n_checkpoints
492492
if max_ckpts and len(self._ckpt_paths) >= max_ckpts:
493493
self.remove_checkpoint()
@@ -542,7 +542,15 @@ def remove_checkpoint(self) -> None:
542542
worst_ckpt_path = self._ckpt_paths.pop(0)
543543
if self._pg_wrapper.get_rank() == 0:
544544
fs, _ = url_to_fs(self.dirpath)
545-
fs.rm(worst_ckpt_path.path, recursive=True)
545+
try:
546+
fs.rm(worst_ckpt_path.path, recursive=True)
547+
except Exception as exc:
548+
logger.error(
549+
(
550+
f"Failed to remove checkpoint '{worst_ckpt_path}' for bookkeeping purposes. "
551+
f"Do not use it to restore since it may be corrupted! Exception: {exc}"
552+
)
553+
)
546554

547555

548556
@rank_zero_read_and_broadcast

0 commit comments

Comments
 (0)