|
10 | 10 | import shutil
|
11 | 11 | import tempfile
|
12 | 12 | import unittest
|
13 |
| -from unittest.mock import patch |
| 13 | +from unittest.mock import MagicMock, patch |
14 | 14 |
|
15 | 15 | import torch
|
16 | 16 |
|
@@ -787,6 +787,43 @@ def test_remove_worst_checkpoint(self) -> None:
|
787 | 787 | self.assertTrue(os.path.exists(os.path.join(temp_dir, "epoch_0_step_1")))
|
788 | 788 | self.assertEqual(ckpt_manager._ckpt_paths, [CheckpointPath(temp_dir, 0, 1)])
|
789 | 789 |
|
| 790 | + @patch( |
| 791 | + "fsspec.implementations.local.LocalFileSystem.rm", |
| 792 | + side_effect=Exception("OSError: [Errno 2] No such file or directory"), |
| 793 | + ) |
| 794 | + def test_remove_worst_checkpoint_exception(self, mock_url_to_fs: MagicMock) -> None: |
| 795 | + with tempfile.TemporaryDirectory() as temp_dir: |
| 796 | + os.mkdir(os.path.join(temp_dir, "epoch_0_train_step_0")) |
| 797 | + os.mkdir(os.path.join(temp_dir, "epoch_0_train_step_1")) |
| 798 | + |
| 799 | + ckpt_manager = CheckpointManager(temp_dir, keep_last_n_checkpoints=2) |
| 800 | + |
| 801 | + log_container = [] |
| 802 | + with patch( |
| 803 | + "torchtnt.utils.checkpoint.logging.Logger.error", log_container.append |
| 804 | + ): |
| 805 | + ckpt_manager.append_checkpoint( |
| 806 | + CheckpointPath(temp_dir, 0, {Phase.TRAIN: 2}) |
| 807 | + ) |
| 808 | + |
| 809 | + self.assertEqual( |
| 810 | + log_container, |
| 811 | + [ |
| 812 | + ( |
| 813 | + f"Failed to remove checkpoint '{temp_dir}/epoch_0_train_step_0' for bookkeeping purposes. " |
| 814 | + "Do not use it to restore since it may be corrupted! Exception: OSError: [Errno 2] No such file or directory" |
| 815 | + ) |
| 816 | + ], |
| 817 | + ) |
| 818 | + # Make sure we are not tracking the oldest one anymore, even if it was not deleted |
| 819 | + self.assertEqual( |
| 820 | + ckpt_manager._ckpt_paths, |
| 821 | + [ |
| 822 | + CheckpointPath(temp_dir, 0, {Phase.TRAIN: 1}), |
| 823 | + CheckpointPath(temp_dir, 0, {Phase.TRAIN: 2}), |
| 824 | + ], |
| 825 | + ) |
| 826 | + |
790 | 827 |
|
791 | 828 | class CheckpointUtilsTest(unittest.TestCase):
|
792 | 829 | @staticmethod
|
|
0 commit comments