Skip to content

Commit 287d2cc

Browse files
jiaxuzhu92facebook-github-bot
authored andcommitted
Add restore_module_strict to RestoreOptions (#833)
Summary: Pull Request resolved: #833 As title, expose `restore_module_strict` in `RestoreOptions`. Reviewed By: JKSenthil Differential Revision: D57062022 fbshipit-source-id: 20f81a02c58ee5cbee0736c6bfe4030f4ae8f056
1 parent f929896 commit 287d2cc

File tree

2 files changed

+4
-1
lines changed

2 files changed

+4
-1
lines changed

torchtnt/framework/callbacks/checkpointer_types.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,11 @@ class RestoreOptions:
4040
restore_eval_progress: Whether to restore the evaluation progress state.
4141
restore_optimizers: Whether to restore the optimizer states.
4242
restore_lr_schedulers: Whether to restore the lr scheduler states.
43+
strict: Whether to strictly restore app state and the module state dict.
4344
"""
4445

4546
restore_train_progress: bool = True
4647
restore_eval_progress: bool = True
4748
restore_optimizers: bool = True
4849
restore_lr_schedulers: bool = True
50+
strict: bool = True

torchtnt/framework/callbacks/torchsnapshot_saver.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -341,7 +341,8 @@ def restore(
341341

342342
knob_options = knob_options or KnobOptions()
343343
with _override_knobs(knob_options):
344-
snapshot.restore(app_state, strict=strict)
344+
strict = strict or restore_options.strict
345+
snapshot.restore(app_state, strict=restore_options.strict)
345346
rank_zero_info(f"Restored snapshot from path: {path}", logger=logger)
346347

347348

0 commit comments

Comments
 (0)