You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Summary:
`strict` flag added to TorchSnapshotSaver.restore for when unit contains new stateful items that do not exist in previous snapshots
by default, strict is True to keep default behavior. To enable this feature, can set strict to false
Reviewed By: jiaxuzhu92
Differential Revision: D55501641
fbshipit-source-id: c475f82701f6b28490224b87f45a9de345de7023
Copy file name to clipboardExpand all lines: torchtnt/framework/callbacks/torchsnapshot_saver.py
+19-1Lines changed: 19 additions & 1 deletion
Original file line number
Diff line number
Diff line change
@@ -267,6 +267,7 @@ def restore(
267
267
restore_options: Optional[RestoreOptions] =None,
268
268
storage_options: Optional[Dict[str, Any]] =None,
269
269
knob_options: Optional[KnobOptions] =None,
270
+
strict: bool=True,
270
271
) ->None:
271
272
"""Utility method to restore snapshot state from a path.
272
273
@@ -281,6 +282,7 @@ def restore(
281
282
restore_options: Controls what to filter when restoring the state.
282
283
storage_options: Additional keyword options for the storage plugin to use, to be passed to `torchsnapshot.Snapshot <https://pytorch.org/torchsnapshot/stable/api_reference.html#torchsnapshot.Snapshot>`_. See each storage plugin's documentation for customizations.
283
284
knob_options: Additional keyword options for the snapshot knobs
285
+
strict: If ``False``, allows loading a snapshot even if not all keys exist in the unit's app_state.
284
286
"""
285
287
286
288
_validate_snapshot_available()
@@ -319,9 +321,25 @@ def restore(
319
321
"train_dataloader was passed to `restore` but no train dataloader exists in the Snapshot"
320
322
)
321
323
324
+
ifnotstrict:
325
+
# if app_state keys not in torchsnapshot checkpoint,
326
+
# remove them from app_state prior to checkpoint load
0 commit comments