Skip to content

Commit e125ae9

Browse files
diego-urgellfacebook-github-bot
authored andcommitted
Enable checkpoint utils handling more than one metadata file (#871)
Summary: Pull Request resolved: #871 Reviewed By: galrotem Differential Revision: D60246322 fbshipit-source-id: 12d33b2d8cd2ea55d2cd493075585f7c53d45af3
1 parent cab6afc commit e125ae9

File tree

4 files changed

+68
-19
lines changed

4 files changed

+68
-19
lines changed

tests/framework/callbacks/test_base_checkpointer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -481,7 +481,7 @@ def test_save_on_train_end(self) -> None:
481481
self.assertTrue(os.path.exists(os.path.join(temp_dir, expected_path)))
482482

483483
with self.assertLogs(level="WARNING") as log:
484-
checkpoint_cb._checkpoint_manager._metadata_fname = ".metadata"
484+
checkpoint_cb._checkpoint_manager._metadata_fnames = [".metadata"]
485485
# create metadata file
486486
with open(os.path.join(temp_dir, expected_path, ".metadata"), "w"):
487487
pass

tests/utils/test_checkpoint.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1225,6 +1225,40 @@ def test_get_checkpoint_dirpaths(self) -> None:
12251225
[],
12261226
)
12271227

1228+
def test_get_checkpoint_dirpaths_with_multiple_metadata_fnames(self) -> None:
1229+
with tempfile.TemporaryDirectory() as temp_dir:
1230+
path1 = os.path.join(temp_dir, "epoch_1_step_20")
1231+
os.mkdir(path1)
1232+
1233+
path2 = os.path.join(temp_dir, "epoch_4_eval_step_130")
1234+
os.mkdir(path2)
1235+
1236+
with open(os.path.join(path1, ".metadata"), "w"):
1237+
pass
1238+
1239+
with open(os.path.join(path2, ".manifest"), "w"):
1240+
pass
1241+
1242+
self.assertEqual(
1243+
[
1244+
str(x)
1245+
for x in get_checkpoint_dirpaths(
1246+
temp_dir, metadata_fname=[".metadata"]
1247+
)
1248+
],
1249+
[path1],
1250+
)
1251+
1252+
self.assertEqual(
1253+
{
1254+
str(x)
1255+
for x in get_checkpoint_dirpaths(
1256+
temp_dir, metadata_fname=[".manifest", ".metadata"]
1257+
)
1258+
},
1259+
{path1, path2},
1260+
)
1261+
12281262
def test_metadata_exists(self) -> None:
12291263
app_state = {"module": nn.Linear(2, 2)}
12301264
with tempfile.TemporaryDirectory() as temp_dir:

torchtnt/framework/callbacks/base_checkpointer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ def __init__(
112112
dirpath,
113113
best_checkpoint_config,
114114
keep_last_n_checkpoints,
115-
metadata_fname=self.metadata_fname,
115+
metadata_fnames=[self.metadata_fname] if self.metadata_fname else None,
116116
process_group=self._process_group,
117117
)
118118

torchtnt/utils/checkpoint.py

Lines changed: 32 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -326,7 +326,7 @@ def __init__(
326326
dirpath: str,
327327
best_checkpoint_config: Optional[BestCheckpointConfig] = None,
328328
keep_last_n_checkpoints: Optional[int] = None,
329-
metadata_fname: Optional[str] = None,
329+
metadata_fnames: Optional[List[str]] = None,
330330
process_group: Optional[dist.ProcessGroup] = None,
331331
) -> None:
332332
"""
@@ -338,7 +338,8 @@ def __init__(
338338
dirpath: The base directory path that checkpoints are saved in. This is synced from rank 0 to every other rank upon initialization.
339339
best_checkpoint_config: Optional configuration for the best checkpoint.
340340
keep_last_n_checkpoints: Optional number of checkpoints to keep.
341-
metadata_fname: Optional name of the metadata file.
341+
metadata_fname: Optional names of the metadata files. These are used to verify checkpoint integrity. If more than one is provided, a
342+
checkpoint is considered if at least one of them exists.
342343
process_group: Optional process group to use for distributed training. gloo process groups are known
343344
to perform better.
344345
"""
@@ -348,9 +349,13 @@ def __init__(
348349

349350
self._best_checkpoint_config = best_checkpoint_config
350351
self._keep_last_n_checkpoints = keep_last_n_checkpoints
351-
self._metadata_fname = metadata_fname
352352
self._pg_wrapper = PGWrapper(process_group)
353353

354+
if metadata_fnames is None:
355+
self._metadata_fnames: List[str] = []
356+
else:
357+
self._metadata_fnames = metadata_fnames
358+
354359
self._ckpt_paths: List[CheckpointPath] = []
355360
if not self._keep_last_n_checkpoints:
356361
return
@@ -361,7 +366,7 @@ def __init__(
361366
)
362367
self._ckpt_paths = get_checkpoint_dirpaths(
363368
dirpath=dirpath,
364-
metadata_fname=self._metadata_fname,
369+
metadata_fname=self._metadata_fnames,
365370
metric_name=metric_name,
366371
process_group=process_group,
367372
)
@@ -512,12 +517,13 @@ def does_checkpoint_exist(
512517
If the checkpointer doesn't have a metadata file, this function will always return False. Check is executed in rank 0, but
513518
result is broadcasted to all ranks.
514519
"""
515-
metadata_fname = self._metadata_fname
516-
if not metadata_fname:
520+
if not self._metadata_fnames:
517521
return False
518522

519523
fs, _ = url_to_fs(self.dirpath)
520-
return _metadata_exists(fs, ckpt.path, metadata_fname)
524+
return any(
525+
_metadata_exists(fs, ckpt.path, fname) for fname in self._metadata_fnames
526+
)
521527

522528
@staticmethod
523529
@rank_zero_read_and_broadcast
@@ -542,7 +548,7 @@ def remove_checkpoint(self) -> None:
542548
@rank_zero_read_and_broadcast
543549
def get_latest_checkpoint_path(
544550
dirpath: str,
545-
metadata_fname: Optional[str] = None,
551+
metadata_fname: Optional[Union[str, List[str]]] = None,
546552
process_group: Optional[dist.ProcessGroup] = None,
547553
) -> Optional[str]:
548554
"""
@@ -551,6 +557,7 @@ def get_latest_checkpoint_path(
551557
Args:
552558
dirpath: parent directory where checkpoints are saved.
553559
metadata_fname: Checks if metadata file is present in checkpoint, disregards if it does not exist.
560+
If a list is provided, it will check that at least one of the files is present.
554561
process_group: the process group on which the ranks will communicate on. default: ``None`` (the entire world)
555562
556563
Raises:
@@ -578,7 +585,7 @@ def get_best_checkpoint_path(
578585
dirpath: str,
579586
metric_name: str,
580587
mode: Literal["min", "max"],
581-
metadata_fname: Optional[str] = None,
588+
metadata_fname: Optional[Union[str, List[str]]] = None,
582589
process_group: Optional[dist.ProcessGroup] = None,
583590
) -> Optional[str]:
584591
"""
@@ -592,6 +599,7 @@ def get_best_checkpoint_path(
592599
metric_name: Name of the metric to use to find the best checkpoint.
593600
mode: Either 'min' or 'max'. If 'min', finds and loads the lowest value metric checkpoint. If 'max', finds and loads the largest.
594601
metadata_fname: Checks if metadata file is present in checkpoint, disregards if it does not exist.
602+
If a list is provided, it will check that at least one of the files is present.
595603
process_group: the process group on which the ranks will communicate on. default: ``None`` (the entire world)
596604
597605
Note:
@@ -614,7 +622,7 @@ def get_best_checkpoint_path(
614622
@rank_zero_read_and_broadcast
615623
def get_checkpoint_dirpaths(
616624
dirpath: str,
617-
metadata_fname: Optional[str] = None,
625+
metadata_fname: Optional[Union[str, List[str]]] = None,
618626
metric_name: Optional[str] = None,
619627
process_group: Optional[dist.ProcessGroup] = None,
620628
) -> List[CheckpointPath]:
@@ -629,6 +637,7 @@ def get_checkpoint_dirpaths(
629637
Args:
630638
dirpath: parent directory where checkpoints are saved.
631639
metadata_fname: Checks if metadata file is present in checkpoint, disregards if it does not exist.
640+
If a list is provided, it will check that at least one of the files is present.
632641
metric_name: fetches all the checkpoint directories containing the metric name only.
633642
process_group: the process group on which the ranks will communicate on. default: ``None`` (the entire world)
634643
@@ -642,7 +651,7 @@ def get_checkpoint_dirpaths(
642651

643652
def _retrieve_checkpoint_dirpaths(
644653
dirpath: str,
645-
metadata_fname: Optional[str],
654+
metadata_fname: Optional[Union[str, List[str]]],
646655
metric_name: Optional[str] = None,
647656
) -> List[CheckpointPath]:
648657
"""
@@ -651,6 +660,7 @@ def _retrieve_checkpoint_dirpaths(
651660
Args:
652661
dirpath: parent directory where checkpoints are saved.
653662
metadata_fname: Checks if metadata file is present in checkpoint, disregards if it does not exist.
663+
If a list is provided, it will check that at least one of the files is present.
654664
metric_name: Name of the metric that must exist in checkpoint name.
655665
"""
656666

@@ -687,16 +697,21 @@ def _retrieve_checkpoint_dirpaths(
687697
return candidate_checkpoints
688698

689699
# Iterate through all files and directories in the specified directory
690-
# and check if metedata is present or not
700+
# and check if metadata is present or not
701+
metadata_fnames = (
702+
[metadata_fname] if isinstance(metadata_fname, str) else metadata_fname
703+
)
691704
valid_ckpt_dirpaths: List[CheckpointPath] = []
692705
for candidate in candidate_checkpoints:
693-
if not _metadata_exists(fs, candidate.path, metadata_fname):
694-
logger.warning(
695-
f"Snapshot metadata is missing from {candidate}! Skipping this path"
696-
)
706+
if any(
707+
_metadata_exists(fs, candidate.path, fname) for fname in metadata_fnames
708+
):
709+
valid_ckpt_dirpaths.append(candidate)
697710
continue
698711

699-
valid_ckpt_dirpaths.append(candidate)
712+
logger.warning(
713+
f"Snapshot metadata ({metadata_fnames}) missing from {candidate}! Skipping this path"
714+
)
700715

701716
return valid_ckpt_dirpaths
702717

0 commit comments

Comments
 (0)