Skip to content

Commit 5c73bd5

Browse files
diego-urgellfacebook-github-bot
authored andcommitted
Allow mutliple metadata file names in checkpointers (#872)
Summary: Pull Request resolved: #872 Reviewed By: galrotem Differential Revision: D60246320 fbshipit-source-id: 24be55bcf6917a9c0b2eb5c539d707c843c9fbbb
1 parent e125ae9 commit 5c73bd5

File tree

3 files changed

+12
-10
lines changed

3 files changed

+12
-10
lines changed

torchtnt/framework/callbacks/base_checkpointer.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import abc
1010
import logging
1111
from datetime import timedelta
12-
from typing import Any, cast, Dict, Iterable, Literal, Optional, Union
12+
from typing import Any, cast, Dict, Iterable, List, Literal, Optional, Union
1313

1414
import torch.distributed as dist
1515
from pyre_extensions import none_throws
@@ -43,7 +43,8 @@ class BaseCheckpointer(Callback, metaclass=abc.ABCMeta):
4343
2) ``restore`` which implements restoring the checkpoint given the relevant checkpoint path.
4444
4545
The subclass may override the ``metadata_fname`` attribute to specify the filename of the metadata file that will be written within the checkpoint directory.
46-
This will be used by this base class to ensure the integrity of the checkpoint.
46+
This will be used by this base class to ensure the integrity of the checkpoint. This is a list because some checkpointers may allow more than one valid
47+
``metadata_fnames``, depending on storage or optimization configurations.
4748
4849
Args:
4950
dirpath: Parent directory to save checkpoints to.
@@ -67,7 +68,8 @@ class BaseCheckpointer(Callback, metaclass=abc.ABCMeta):
6768
checkpoint will be saved, without the metric value in the checkpoint name
6869
"""
6970

70-
metadata_fname: Optional[str] = None
71+
# No metadata file is checked by default. This can be overridden by subclasses.
72+
metadata_fnames: List[str] = []
7173

7274
def __init__(
7375
self,
@@ -112,7 +114,7 @@ def __init__(
112114
dirpath,
113115
best_checkpoint_config,
114116
keep_last_n_checkpoints,
115-
metadata_fnames=[self.metadata_fname] if self.metadata_fname else None,
117+
metadata_fnames=self.metadata_fnames,
116118
process_group=self._process_group,
117119
)
118120

@@ -385,11 +387,11 @@ def restore_from_latest(
385387
True if the latest checkpoint directory was found and successfully restored, otherwise False.
386388
"""
387389
path = get_latest_checkpoint_path(
388-
dirpath, metadata_fname=cls.metadata_fname, process_group=process_group
390+
dirpath, metadata_fname=cls.metadata_fnames, process_group=process_group
389391
)
390392
if path is None:
391393
logger.info(
392-
f"Attempted to restore from the following path but no checkpoint was found: {dirpath=}, {cls.metadata_fname}"
394+
f"Attempted to restore from the following path but no checkpoint was found: {dirpath=}, {cls.metadata_fnames}"
393395
)
394396
return False
395397
logger.info(f"Restoring from path: {path}")
@@ -438,7 +440,7 @@ def restore_from_best(
438440
dirpath,
439441
metric_name=metric_name,
440442
mode=mode,
441-
metadata_fname=cls.metadata_fname,
443+
metadata_fname=cls.metadata_fnames,
442444
process_group=process_group,
443445
)
444446

torchtnt/framework/callbacks/dcp_saver.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import logging
1010
import time
1111
from concurrent.futures import Future
12-
from typing import Any, Dict, Iterable, Optional, Union
12+
from typing import Any, Dict, Iterable, List, Optional, Union
1313

1414
import torch
1515
import torch.distributed as dist
@@ -102,7 +102,7 @@ class DistributedCheckpointSaver(BaseCheckpointer):
102102
appropriately. For example, if logging validation accuracy, the unit must be responsible for maintaining the value and resetting it when the epoch ends.
103103
"""
104104

105-
metadata_fname: Optional[str] = ".metadata"
105+
metadata_fnames: List[str] = [".metadata"]
106106

107107
def __init__(
108108
self,

torchtnt/framework/callbacks/torchsnapshot_saver.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ class TorchSnapshotSaver(BaseCheckpointer):
9696
appropriately. For example, if logging validation accuracy, the unit must be responsible for maintaining the value and resetting it when the epoch ends.
9797
"""
9898

99-
metadata_fname: Optional[str] = ".snapshot_metadata"
99+
metadata_fnames: List[str] = [".snapshot_metadata"]
100100

101101
def __init__(
102102
self,

0 commit comments

Comments
 (0)