Skip to content

Commit c444003

Browse files
LucasLLCfacebook-github-bot
authored andcommitted
Adds check for FSPDOptimizer wrapper (#788)
Summary: Pull Request resolved: #788 DCP expects the entire optimizer to be initialized in the state dict before calling load. Unfortunately, since optimizers are sometimes lazy loaded, this means dcp_saver has to do a check for optimizer objects and ensure they are properly initialized. This diff adds a check for optimizers which are hidden inside wrappers. An alternative to this would be to implement in the FSDP wrapper in under utils/prepare_module, but I found that change to be higher risk, and then we'd also have the change in both places. ghstack-source-id: 222855240 exported-using-ghexport Reviewed By: JKSenthil Differential Revision: D56075363 fbshipit-source-id: 68a4086ce322d9453bdd7954d56a455163d7189e
1 parent 34b04ae commit c444003

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

torchtnt/framework/callbacks/dcp_saver.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -244,8 +244,12 @@ def restore(
244244

245245
# necessary for loading optimizers since states are initialized lazy
246246
for obj in app_state.values():
247-
if isinstance(obj, torch.optim.Optimizer):
248-
_init_optim_state(obj)
247+
# sometimes optimizers are actually held in a wrapper which handles calling
248+
# state_dict and load_state_dict, sa is the case for
249+
# `torchtnt.utils.prepare_module.FSDPOptimizerWrapper`, this handles that case.
250+
optimizer = getattr(obj, "optimizer", obj)
251+
if isinstance(optimizer, torch.optim.Optimizer):
252+
_init_optim_state(optimizer)
249253

250254
dcp.load(
251255
{"app_state": MultiStateful(app_state)},

0 commit comments

Comments
 (0)