Skip to content

Commit 8b4ab19

Browse files
JKSenthilfacebook-github-bot
authored andcommitted
re-introduce init_optim_state (#946)
Summary: Pull Request resolved: #946 Reviewed By: galrotem Differential Revision: D65922636 fbshipit-source-id: 0965aa03ce000d1c1d544635238e3465f0a5dc5e
1 parent 4f91877 commit 8b4ab19

File tree

2 files changed

+24
-1
lines changed

2 files changed

+24
-1
lines changed

torchtnt/framework/_test_utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,9 @@ def __init__(self, input_dim: int) -> None:
146146
]
147147
self.applied_optims: List[torch.optim.Optimizer] = []
148148
for module, optim in zip(self.modules, self.optims):
149-
self.applied_optims.append(optim(module.parameters(), lr=0.1))
149+
o = optim(module.parameters(), lr=0.1)
150+
self.applied_optims.append(o)
151+
setattr(self, f"optimizer_{optim.__name__}", o)
150152

151153
def train_step(
152154
self, state: State, data: Batch

torchtnt/framework/callbacks/dcp_saver.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
from concurrent.futures import Future
1212
from typing import Any, cast, Dict, Iterable, List, Optional, Union
1313

14+
import torch
15+
1416
import torch.distributed as dist
1517
from pyre_extensions import none_throws
1618
from torch.distributed import checkpoint as dcp
@@ -24,6 +26,16 @@
2426
DefaultSavePlanner,
2527
)
2628
from torch.distributed.checkpoint.planner import LoadPlanner, SavePlanner
29+
30+
try:
31+
from torch.distributed.checkpoint.state_dict import _init_optim_state
32+
except ImportError:
33+
34+
def noop(_: Any) -> None:
35+
return None
36+
37+
_init_optim_state = noop
38+
2739
from torch.distributed.checkpoint.storage import StorageReader, StorageWriter
2840

2941
from torchtnt.framework.callbacks._checkpoint_utils import (
@@ -351,6 +363,15 @@ def restore_with_id(
351363
predict_dataloader,
352364
)
353365

366+
# necessary for loading optimizers since states are initialized lazy
367+
for obj in app_state.values():
368+
# sometimes optimizers are actually held in a wrapper which handles calling
369+
# state_dict and load_state_dict, sa is the case for
370+
# `torchtnt.utils.prepare_module.FSDPOptimizerWrapper`, this handles that case.
371+
optimizer = getattr(obj, "optimizer", obj)
372+
if isinstance(optimizer, torch.optim.Optimizer):
373+
_init_optim_state(optimizer)
374+
354375
with get_or_create_gloo_pg(candidate_pg=process_group) as pg:
355376
dcp.load(
356377
{"app_state": MultiStateful(app_state)},

0 commit comments

Comments
 (0)