|
11 | 11 | from concurrent.futures import Future
|
12 | 12 | from typing import Any, cast, Dict, Iterable, List, Optional, Union
|
13 | 13 |
|
| 14 | +import torch |
| 15 | + |
14 | 16 | import torch.distributed as dist
|
15 | 17 | from pyre_extensions import none_throws
|
16 | 18 | from torch.distributed import checkpoint as dcp
|
|
24 | 26 | DefaultSavePlanner,
|
25 | 27 | )
|
26 | 28 | from torch.distributed.checkpoint.planner import LoadPlanner, SavePlanner
|
| 29 | + |
| 30 | +try: |
| 31 | + from torch.distributed.checkpoint.state_dict import _init_optim_state |
| 32 | +except ImportError: |
| 33 | + |
| 34 | + def noop(_: Any) -> None: |
| 35 | + return None |
| 36 | + |
| 37 | + _init_optim_state = noop |
| 38 | + |
27 | 39 | from torch.distributed.checkpoint.storage import StorageReader, StorageWriter
|
28 | 40 |
|
29 | 41 | from torchtnt.framework.callbacks._checkpoint_utils import (
|
@@ -351,6 +363,15 @@ def restore_with_id(
|
351 | 363 | predict_dataloader,
|
352 | 364 | )
|
353 | 365 |
|
| 366 | + # necessary for loading optimizers since states are initialized lazy |
| 367 | + for obj in app_state.values(): |
| 368 | + # sometimes optimizers are actually held in a wrapper which handles calling |
| 369 | + # state_dict and load_state_dict, sa is the case for |
| 370 | + # `torchtnt.utils.prepare_module.FSDPOptimizerWrapper`, this handles that case. |
| 371 | + optimizer = getattr(obj, "optimizer", obj) |
| 372 | + if isinstance(optimizer, torch.optim.Optimizer): |
| 373 | + _init_optim_state(optimizer) |
| 374 | + |
354 | 375 | with get_or_create_gloo_pg(candidate_pg=process_group) as pg:
|
355 | 376 | dcp.load(
|
356 | 377 | {"app_state": MultiStateful(app_state)},
|
|
0 commit comments