Skip to content

Commit c5b9adb

Browse files
JKSenthilfacebook-github-bot
authored andcommitted
replace _init_optim_state w/ tnt's util
Summary: The `_init_optim_state()` util from DCP is identical to the implementation in TorchTNT. However, since the DCP is only in pytorch >=2.3 / nightlies, it won't be compatible with pytorch stable users This diff swaps the implementations. Reviewed By: galrotem Differential Revision: D56446429 fbshipit-source-id: 7ba85410f80994fff73e7dc053968a8b59f44990
1 parent fcd8b22 commit c5b9adb

File tree

1 file changed

+3
-4
lines changed

1 file changed

+3
-4
lines changed

torchtnt/framework/callbacks/dcp_saver.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,6 @@
1616
from torch.distributed import checkpoint as dcp
1717

1818
from torch.distributed.checkpoint._fsspec_filesystem import FsspecReader, FsspecWriter
19-
from torch.distributed.checkpoint.state_dict import _init_optim_state
20-
from torch.distributed.checkpoint.stateful import Stateful
2119
from torchtnt.framework.callbacks._checkpoint_utils import (
2220
_prepare_app_state_for_checkpoint,
2321
_prepare_app_state_for_restore,
@@ -39,8 +37,9 @@
3937
TTrainUnit,
4038
)
4139
from torchtnt.framework.utils import get_timing_context
40+
from torchtnt.utils.optimizer import init_optim_state
4241
from torchtnt.utils.rank_zero_log import rank_zero_info, rank_zero_warn
43-
from torchtnt.utils.stateful import MultiStateful
42+
from torchtnt.utils.stateful import MultiStateful, Stateful
4443

4544

4645
logger: logging.Logger = logging.getLogger(__name__)
@@ -249,7 +248,7 @@ def restore(
249248
# `torchtnt.utils.prepare_module.FSDPOptimizerWrapper`, this handles that case.
250249
optimizer = getattr(obj, "optimizer", obj)
251250
if isinstance(optimizer, torch.optim.Optimizer):
252-
_init_optim_state(optimizer)
251+
init_optim_state(optimizer)
253252

254253
dcp.load(
255254
{"app_state": MultiStateful(app_state)},

0 commit comments

Comments
 (0)