Skip to content

Commit 1f06115

Browse files
diego-urgellfacebook-github-bot
authored andcommitted
Extract does_checkpoint_exist into util (#906)
Summary: Pull Request resolved: #906 Reviewed By: anshulverma, schwarzmx Differential Revision: D63036738 fbshipit-source-id: 3aa053e4c788e8f42a0f03a2fa2995510836172a
1 parent 6d99aae commit 1f06115

File tree

2 files changed

+54
-8
lines changed

2 files changed

+54
-8
lines changed

tests/utils/test_checkpoint.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
BestCheckpointConfig,
3030
CheckpointManager,
3131
CheckpointPath,
32+
does_checkpoint_exist,
3233
get_best_checkpoint_path,
3334
get_checkpoint_dirpaths,
3435
get_latest_checkpoint_path,
@@ -1419,6 +1420,27 @@ def test_does_checkpoint_metadata_exist(self) -> None:
14191420
)
14201421
)
14211422

1423+
def test_does_checkpoint_exist(self) -> None:
1424+
with tempfile.TemporaryDirectory() as temp_dir:
1425+
ckpt_1 = os.path.join(temp_dir, "checkpoint_1")
1426+
os.mkdir(ckpt_1)
1427+
1428+
self.assertFalse(does_checkpoint_exist(ckpt_1, metadata_fname=None))
1429+
1430+
with open(os.path.join(ckpt_1, ".metadata"), "w"):
1431+
pass
1432+
1433+
self.assertFalse(does_checkpoint_exist(ckpt_1, metadata_fname="manifest"))
1434+
self.assertTrue(does_checkpoint_exist(ckpt_1, metadata_fname=".metadata"))
1435+
self.assertTrue(
1436+
does_checkpoint_exist(ckpt_1, metadata_fname=["manifest", ".metadata"])
1437+
)
1438+
self.assertFalse(
1439+
does_checkpoint_exist(
1440+
ckpt_1, metadata_fname=["manifest", ".state_dict_info"]
1441+
)
1442+
)
1443+
14221444

14231445
class MyValLossUnit(TrainUnit[Batch]):
14241446
def __init__(self) -> None:

torchtnt/utils/checkpoint.py

Lines changed: 32 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -539,21 +539,18 @@ def append_checkpoint(self, ckpt: CheckpointPath) -> None:
539539
# No metric tracked, most recents goes last
540540
self._ckpt_paths.append(ckpt)
541541

542-
@rank_zero_read_and_broadcast
543542
def does_checkpoint_exist(
544-
self, ckpt: CheckpointPath, process_group: Optional[dist.ProcessGroup] = None
543+
self,
544+
ckpt: CheckpointPath,
545+
process_group: Optional[dist.ProcessGroup] = None,
545546
) -> bool:
546547
"""
547548
Checking whether a checkpoint already exists by verifying whether the optional metadata file is present in the directory.
548549
If the checkpointer doesn't have a metadata file, this function will always return False. Check is executed in rank 0, but
549550
result is broadcasted to all ranks.
550551
"""
551-
if not self._metadata_fnames:
552-
return False
553-
554-
fs, _ = url_to_fs(self.dirpath)
555-
return any(
556-
_metadata_exists(fs, ckpt.path, fname) for fname in self._metadata_fnames
552+
return does_checkpoint_exist(
553+
ckpt.path, self._metadata_fnames, process_group=process_group
557554
)
558555

559556
@staticmethod
@@ -596,6 +593,33 @@ def remove_checkpoint(self) -> None:
596593
)
597594

598595

596+
@rank_zero_read_and_broadcast
597+
def does_checkpoint_exist(
598+
ckpt_path: str,
599+
metadata_fname: Union[str, List[str]],
600+
process_group: Optional[dist.ProcessGroup] = None,
601+
) -> bool:
602+
"""
603+
Checking whether a checkpoint already exists by verifying whether the optional metadata file is present in the directory.
604+
Will return False if the metadata_fname is None. Check is executed in rank 0, but
605+
result is broadcasted to all ranks.
606+
607+
Args:
608+
ckpt: The checkpoint to check.
609+
metadata_fname: File to check for existence. If a list is provided, it will check that at least one of the files is present.
610+
process_group: Optional process group on which the ranks will communicate on. By default, the entire world is used.
611+
"""
612+
if not metadata_fname:
613+
return False
614+
else:
615+
metadata_fnames = (
616+
[metadata_fname] if isinstance(metadata_fname, str) else metadata_fname
617+
)
618+
619+
fs, _ = url_to_fs(ckpt_path)
620+
return any(_metadata_exists(fs, ckpt_path, fname) for fname in metadata_fnames)
621+
622+
599623
@rank_zero_read_and_broadcast
600624
def get_latest_checkpoint_path(
601625
dirpath: str,

0 commit comments

Comments
 (0)