Skip to content

Commit e53c8eb

Browse files
JKSenthilfacebook-github-bot
authored andcommitted
TorchSnapshotSaver.restore non-strict option
Summary: `strict` flag added to TorchSnapshotSaver.restore for when unit contains new stateful items that do not exist in previous snapshots by default, strict is True to keep default behavior. To enable this feature, can set strict to false Reviewed By: jiaxuzhu92 Differential Revision: D55501641 fbshipit-source-id: c475f82701f6b28490224b87f45a9de345de7023
1 parent c988b86 commit e53c8eb

File tree

2 files changed

+52
-1
lines changed

2 files changed

+52
-1
lines changed

tests/framework/callbacks/test_torchsnapshot_saver.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,39 @@ def test_save_restore_no_lr_scheduler_restore(
227227
app_state = mock_torchsnapshot.Snapshot().restore.call_args.args[0]
228228
self.assertIn("lr_scheduler", app_state)
229229

230+
def test_restore_strict(self) -> None:
231+
my_unit = DummyAutoUnit(module=torch.nn.Linear(2, 2))
232+
with tempfile.TemporaryDirectory() as temp_dir:
233+
state = get_dummy_train_state()
234+
snapshot_cb = TorchSnapshotSaver(
235+
temp_dir, save_every_n_train_steps=1, async_checkpoint=False
236+
)
237+
snapshot_cb.on_train_step_end(state, my_unit)
238+
239+
# add a new parameter to the module
240+
my_unit.module2 = torch.nn.Linear(2, 2)
241+
242+
with self.assertRaisesRegex(
243+
AssertionError,
244+
"module2 is absent in both manifest and flattened.",
245+
):
246+
TorchSnapshotSaver.restore(
247+
path=os.path.join(temp_dir, "epoch_0_step_0"),
248+
unit=my_unit,
249+
strict=True,
250+
)
251+
252+
with self.assertLogs(level="WARNING") as log:
253+
TorchSnapshotSaver.restore(
254+
path=os.path.join(temp_dir, "epoch_0_step_0"),
255+
unit=my_unit,
256+
strict=False,
257+
)
258+
self.assertEqual(
259+
log.output[0],
260+
"WARNING:torchtnt.utils.rank_zero_log:module2 was passed to `restore` but does not exists in the snapshot",
261+
)
262+
230263
@skip_if_not_distributed
231264
def test_save_restore_ddp(self) -> None:
232265
spawn_multi_process(

torchtnt/framework/callbacks/torchsnapshot_saver.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,7 @@ def restore(
267267
restore_options: Optional[RestoreOptions] = None,
268268
storage_options: Optional[Dict[str, Any]] = None,
269269
knob_options: Optional[KnobOptions] = None,
270+
strict: bool = True,
270271
) -> None:
271272
"""Utility method to restore snapshot state from a path.
272273
@@ -281,6 +282,7 @@ def restore(
281282
restore_options: Controls what to filter when restoring the state.
282283
storage_options: Additional keyword options for the storage plugin to use, to be passed to `torchsnapshot.Snapshot <https://pytorch.org/torchsnapshot/stable/api_reference.html#torchsnapshot.Snapshot>`_. See each storage plugin's documentation for customizations.
283284
knob_options: Additional keyword options for the snapshot knobs
285+
strict: If ``False``, allows loading a snapshot even if not all keys exist in the unit's app_state.
284286
"""
285287

286288
_validate_snapshot_available()
@@ -319,9 +321,25 @@ def restore(
319321
"train_dataloader was passed to `restore` but no train dataloader exists in the Snapshot"
320322
)
321323

324+
if not strict:
325+
# if app_state keys not in torchsnapshot checkpoint,
326+
# remove them from app_state prior to checkpoint load
327+
missing_stateful_keys = []
328+
manifest = snapshot.get_manifest()
329+
for stateful_key in app_state:
330+
found = any((f"/{stateful_key}/" in key for key in manifest.keys()))
331+
if not found:
332+
missing_stateful_keys.append(stateful_key)
333+
334+
for key in missing_stateful_keys:
335+
rank_zero_warn(
336+
f"{key} was passed to `restore` but does not exists in the snapshot"
337+
)
338+
app_state.pop(key)
339+
322340
knob_options = knob_options or KnobOptions()
323341
with _override_knobs(knob_options):
324-
snapshot.restore(app_state)
342+
snapshot.restore(app_state, strict=strict)
325343
rank_zero_info(f"Restored snapshot from path: {path}", logger=logger)
326344

327345

0 commit comments

Comments
 (0)