|
16 | 16 | from torch.distributed import checkpoint as dcp
|
17 | 17 |
|
18 | 18 | from torch.distributed.checkpoint._fsspec_filesystem import FsspecReader, FsspecWriter
|
19 |
| -from torch.distributed.checkpoint.state_dict import _init_optim_state |
20 |
| -from torch.distributed.checkpoint.stateful import Stateful |
21 | 19 | from torchtnt.framework.callbacks._checkpoint_utils import (
|
22 | 20 | _prepare_app_state_for_checkpoint,
|
23 | 21 | _prepare_app_state_for_restore,
|
|
39 | 37 | TTrainUnit,
|
40 | 38 | )
|
41 | 39 | from torchtnt.framework.utils import get_timing_context
|
| 40 | +from torchtnt.utils.optimizer import init_optim_state |
42 | 41 | from torchtnt.utils.rank_zero_log import rank_zero_info, rank_zero_warn
|
43 |
| -from torchtnt.utils.stateful import MultiStateful |
| 42 | +from torchtnt.utils.stateful import MultiStateful, Stateful |
44 | 43 |
|
45 | 44 |
|
46 | 45 | logger: logging.Logger = logging.getLogger(__name__)
|
@@ -249,7 +248,7 @@ def restore(
|
249 | 248 | # `torchtnt.utils.prepare_module.FSDPOptimizerWrapper`, this handles that case.
|
250 | 249 | optimizer = getattr(obj, "optimizer", obj)
|
251 | 250 | if isinstance(optimizer, torch.optim.Optimizer):
|
252 |
| - _init_optim_state(optimizer) |
| 251 | + init_optim_state(optimizer) |
253 | 252 |
|
254 | 253 | dcp.load(
|
255 | 254 | {"app_state": MultiStateful(app_state)},
|
|
0 commit comments