Skip to content

Commit 843835c

Browse files
JKSenthilfacebook-github-bot
authored andcommitted
remove init_optim_state in dcp checkpointer (#901)
Summary: Pull Request resolved: #901 Reviewed By: diego-urgell Differential Revision: D59661542 fbshipit-source-id: a25d5fb991da45187f6a56f4423cb0adad8afe7b
1 parent 57a4279 commit 843835c

File tree

3 files changed

+56
-12
lines changed

3 files changed

+56
-12
lines changed

tests/framework/callbacks/test_dcp_saver.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from torchsnapshot.test_utils import assert_state_dict_eq, check_state_dict_eq
2929
from torchtnt.framework._test_utils import (
3030
DummyAutoUnit,
31+
DummyMultiOptimUnit,
3132
DummyTrainUnit,
3233
generate_random_dataloader,
3334
get_dummy_train_state,
@@ -471,6 +472,24 @@ def test_gloo_pg_restore(
471472
self.assertEqual(process_group, None)
472473
mock_destroy_process_group.assert_not_called()
473474

475+
def test_save_restore_multi_optimizers(self) -> None:
476+
input_dim = 2
477+
dataset_len = 10
478+
batch_size = 2
479+
max_epochs = 1
480+
481+
my_unit = DummyMultiOptimUnit(input_dim=input_dim)
482+
dataloader = generate_random_dataloader(dataset_len, input_dim, batch_size)
483+
with tempfile.TemporaryDirectory() as temp_dir:
484+
dcp_cb = DistributedCheckpointSaver(
485+
temp_dir,
486+
knob_options=KnobOptions(1),
487+
)
488+
train(my_unit, dataloader, max_epochs=max_epochs, callbacks=[dcp_cb])
489+
490+
my_unit_clone = DummyMultiOptimUnit(input_dim=input_dim)
491+
dcp_cb.restore_from_latest(temp_dir, my_unit_clone)
492+
474493

475494
class DummyStatefulDataLoader:
476495
def __init__(self, dataloader: DataLoader) -> None:

torchtnt/framework/_test_utils.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
# pyre-strict
99

10-
from typing import Iterable, Iterator, Optional, Tuple
10+
from typing import Iterable, Iterator, List, Optional, Tuple
1111

1212
import torch
1313
from torch import nn, Tensor
@@ -116,6 +116,42 @@ def train_step(
116116
return loss, outputs
117117

118118

119+
class DummyMultiOptimUnit(TrainUnit[Batch]):
120+
def __init__(self, input_dim: int) -> None:
121+
super().__init__()
122+
# initialize module, loss_fn, & optimizer
123+
124+
self.modules: List[nn.Module] = [nn.Linear(input_dim, 2) for _ in range(6)]
125+
self.loss_fn = nn.CrossEntropyLoss()
126+
self.optims = [
127+
torch.optim.SGD,
128+
torch.optim.Adam,
129+
torch.optim.AdamW,
130+
torch.optim.Adadelta,
131+
torch.optim.NAdam,
132+
torch.optim.RMSprop,
133+
]
134+
self.applied_optims: List[torch.optim.Optimizer] = []
135+
for module, optim in zip(self.modules, self.optims):
136+
self.applied_optims.append(optim(module.parameters(), lr=0.1))
137+
138+
def train_step(
139+
self, state: State, data: Batch
140+
) -> Tuple[torch.Tensor, torch.Tensor]:
141+
inputs, targets = data
142+
143+
outputs = [module(inputs) for module in self.modules]
144+
losses = [self.loss_fn(output, targets) for output in outputs]
145+
loss = torch.stack(losses).sum()
146+
loss.backward()
147+
148+
for optim in self.applied_optims:
149+
optim.step()
150+
optim.zero_grad()
151+
152+
return loss, outputs[0]
153+
154+
119155
class DummyFitUnit(TrainUnit[Batch], EvalUnit[Batch]):
120156
def __init__(self, input_dim: int) -> None:
121157
super().__init__()

torchtnt/framework/callbacks/dcp_saver.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
from datetime import timedelta
1313
from typing import Any, Dict, Iterable, List, Optional, Union
1414

15-
import torch
1615
import torch.distributed as dist
1716
from pyre_extensions import none_throws
1817
from torch.distributed import checkpoint as dcp
@@ -46,7 +45,6 @@
4645
)
4746
from torchtnt.framework.utils import get_timing_context
4847
from torchtnt.utils.checkpoint import BestCheckpointConfig
49-
from torchtnt.utils.optimizer import init_optim_state
5048
from torchtnt.utils.rank_zero_log import rank_zero_info, rank_zero_warn
5149
from torchtnt.utils.stateful import MultiStateful, Stateful
5250

@@ -323,15 +321,6 @@ def restore_with_id(
323321
"train_dataloader was passed to `restore` but no train dataloader exists in the Snapshot"
324322
)
325323

326-
# necessary for loading optimizers since states are initialized lazy
327-
for obj in app_state.values():
328-
# sometimes optimizers are actually held in a wrapper which handles calling
329-
# state_dict and load_state_dict, sa is the case for
330-
# `torchtnt.utils.prepare_module.FSDPOptimizerWrapper`, this handles that case.
331-
optimizer = getattr(obj, "optimizer", obj)
332-
if isinstance(optimizer, torch.optim.Optimizer):
333-
init_optim_state(optimizer)
334-
335324
dcp.load(
336325
{"app_state": MultiStateful(app_state)},
337326
checkpoint_id=checkpoint_id,

0 commit comments

Comments
 (0)