Skip to content

Commit aa34ffb

Browse files
saumishrfacebook-github-bot
authored andcommitted
Config to allow partial loading of the checkpoint (#855)
Summary: Pull Request resolved: #855 Enable TNT restore option to allow partial loading of the checkpoints for DCP Reviewed By: anshulverma, JKSenthil Differential Revision: D59074582 fbshipit-source-id: 6275e582665611f84b0ac594789dd4bc408bd77c
1 parent 12c5637 commit aa34ffb

File tree

2 files changed

+33
-4
lines changed

2 files changed

+33
-4
lines changed

tests/framework/callbacks/test_dcp_saver.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -392,6 +392,33 @@ def test_restore_planner_storage_components(self, mock_dist_cp: MagicMock) -> No
392392
self.assertIsInstance(planner, DummyLoadPlanner)
393393
self.assertIsInstance(storage_reader, DummyStorageReader)
394394

395+
@patch("torchtnt.framework.callbacks.dcp_saver.dcp")
396+
def test_restore_allow_partial_loading(self, mock_dist_cp: MagicMock) -> None:
397+
my_unit = DummyTrainUnit(input_dim=2)
398+
restore_options = RestoreOptions(strict=False)
399+
DistributedCheckpointSaver.restore(
400+
path="path/to/snapshot",
401+
unit=my_unit,
402+
restore_options=restore_options,
403+
)
404+
405+
allow_partial_load = mock_dist_cp.load.call_args[1][
406+
"planner"
407+
].allow_partial_load
408+
self.assertTrue(allow_partial_load)
409+
410+
restore_options = RestoreOptions(strict=True)
411+
DistributedCheckpointSaver.restore(
412+
path="path/to/snapshot",
413+
unit=my_unit,
414+
restore_options=restore_options,
415+
)
416+
417+
allow_partial_load = mock_dist_cp.load.call_args[1][
418+
"planner"
419+
].allow_partial_load
420+
self.assertFalse(allow_partial_load)
421+
395422

396423
class DummyStatefulDataLoader:
397424
def __init__(self, dataloader: DataLoader) -> None:

torchtnt/framework/callbacks/dcp_saver.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -283,14 +283,16 @@ def restore(
283283
storage_reader: Instance of StorageReader used to perform reads. If this is not specified, it will automatically infer
284284
the reader based on the checkpoint_id. If checkpoint_id is also None, an exception will be raised. (Default: ``None``)
285285
"""
286-
if planner is None:
287-
planner = DefaultLoadPlanner()
286+
287+
restore_options = restore_options or RestoreOptions()
288+
app_state = _prepare_app_state_for_restore(unit, restore_options)
288289

289290
if storage_reader is None:
290291
storage_reader = Reader(path)
291292

292-
restore_options = restore_options or RestoreOptions()
293-
app_state = _prepare_app_state_for_restore(unit, restore_options)
293+
if planner is None:
294+
allow_partial_load = not restore_options.strict
295+
planner = DefaultLoadPlanner(allow_partial_load=allow_partial_load)
294296

295297
if train_dataloader is not None:
296298
if not isinstance(train_dataloader, Stateful):

0 commit comments

Comments
 (0)