Skip to content

Commit 57a4279

Browse files
saumishrfacebook-github-bot
authored andcommitted
DCP checkpoint restore backward compatibility in TorchTNT (#898)
Summary: Pull Request resolved: #898 DCP checkpoint restore backward compatibility. Current implementation assumes the checkpoint to be a ModelStore Checkpoint. There are customers who are using DCP saver with path based checkpoint id. Reviewed By: JKSenthil Differential Revision: D62542227 fbshipit-source-id: 6981cad0e4195308c0fc2e8a5da621e2fd9c024a
1 parent 97b68cc commit 57a4279

File tree

2 files changed

+32
-0
lines changed

2 files changed

+32
-0
lines changed

tests/utils/test_checkpoint.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1339,6 +1339,26 @@ def test_metadata_exists(self) -> None:
13391339
os.remove(os.path.join(dirpath, SNAPSHOT_METADATA_FNAME))
13401340
self.assertFalse(_metadata_exists(fs, dirpath, SNAPSHOT_METADATA_FNAME))
13411341

1342+
def test_does_checkpoint_metadata_exist(self) -> None:
1343+
app_state = {"module": nn.Linear(2, 2)}
1344+
1345+
with tempfile.TemporaryDirectory() as temp_dir:
1346+
dirpath = os.path.join(temp_dir, "checkpoint")
1347+
Snapshot.take(dirpath, app_state=app_state)
1348+
1349+
self.assertTrue(
1350+
CheckpointManager.does_checkpoint_metadata_exist(
1351+
dirpath, SNAPSHOT_METADATA_FNAME
1352+
)
1353+
)
1354+
1355+
os.remove(os.path.join(dirpath, SNAPSHOT_METADATA_FNAME))
1356+
self.assertFalse(
1357+
CheckpointManager.does_checkpoint_metadata_exist(
1358+
dirpath, SNAPSHOT_METADATA_FNAME
1359+
)
1360+
)
1361+
13421362

13431363
class MyValLossUnit(TrainUnit[Batch]):
13441364
def __init__(self) -> None:

torchtnt/utils/checkpoint.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -539,6 +539,18 @@ def does_checkpoint_exist(
539539
_metadata_exists(fs, ckpt.path, fname) for fname in self._metadata_fnames
540540
)
541541

542+
@staticmethod
543+
def does_checkpoint_metadata_exist(
544+
checkpoint_path: str,
545+
metadata_fname: str,
546+
) -> bool:
547+
"""
548+
Checking whether a checkpoint metadata file exists in the directory.
549+
If the checkpointer has that metadata file, this function will returns True. Returns False otherwise.
550+
"""
551+
fs, _ = url_to_fs(checkpoint_path)
552+
return _metadata_exists(fs, checkpoint_path, metadata_fname)
553+
542554
@staticmethod
543555
@rank_zero_read_and_broadcast
544556
def _sync_dirpath_to_all_ranks(

0 commit comments

Comments
 (0)