Skip to content

Commit 19f2911

Browse files
LucasLLCfacebook-github-bot
authored andcommitted
Additional integrations with DCP save, including checkpoint based restoring, async_save, and MAST pre-emption (#768)
Summary: Pull Request resolved: #768 Adds additional integrations with DCP save. - DCP supports async checkpointing, and all knob/storage options from TSS Saver - Implements a dcp_saver under meta, for manifold + mast pre-emption support ghstack-source-id: 220492964 Reviewed By: anshulverma Differential Revision: D55030939 fbshipit-source-id: de0bfb8d036e6b01db35971bcec0f5fa8dc3d18e
1 parent 614526a commit 19f2911

File tree

2 files changed

+147
-37
lines changed

2 files changed

+147
-37
lines changed

tests/framework/callbacks/test_dcp_saver.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
DummyTrainUnit,
2626
generate_random_dataloader,
2727
)
28-
from torchtnt.framework.callbacks.checkpointer_types import RestoreOptions
28+
from torchtnt.framework.callbacks.checkpointer_types import KnobOptions, RestoreOptions
2929
from torchtnt.framework.callbacks.dcp_saver import DistributedCheckpointSaver
3030
from torchtnt.framework.train import train
3131
from torchtnt.utils.distributed import get_global_rank, spawn_multi_process
@@ -60,6 +60,7 @@ def test_save_restore(self) -> None:
6060
dcp_cb = DistributedCheckpointSaver(
6161
temp_dir,
6262
save_every_n_train_steps=save_every_n_train_steps,
63+
knob_options=KnobOptions(1),
6364
)
6465
train(my_unit, dataloader, max_epochs=max_epochs, callbacks=[dcp_cb])
6566

@@ -87,6 +88,7 @@ def test_save_restore_dataloader_state(self) -> None:
8788
dcp_cb = DistributedCheckpointSaver(
8889
temp_dir,
8990
save_every_n_train_steps=save_every_n_train_steps,
91+
knob_options=KnobOptions(1),
9092
)
9193
train(
9294
my_unit,
@@ -138,6 +140,7 @@ def test_restore_from_latest(self) -> None:
138140
dcp_cb = DistributedCheckpointSaver(
139141
temp_dir,
140142
save_every_n_train_steps=save_every_n_train_steps,
143+
knob_options=KnobOptions(1),
141144
)
142145
train(my_unit, dataloader, max_epochs=max_epochs, callbacks=[dcp_cb])
143146

@@ -177,6 +180,7 @@ def test_save_restore_no_train_progress(self) -> None:
177180
dcp_cb = DistributedCheckpointSaver(
178181
temp_dir,
179182
save_every_n_train_steps=save_every_n_train_steps,
183+
knob_options=KnobOptions(1),
180184
)
181185
train(my_unit, dataloader, max_epochs=max_epochs, callbacks=[dcp_cb])
182186

@@ -191,7 +195,7 @@ def test_save_restore_no_train_progress(self) -> None:
191195
# no train progress was restored so the progress after restoration should be the same as the progress before restoration
192196
self.assertEqual(restored_num_steps_completed, end_num_steps_completed)
193197

194-
@patch("torchtnt.framework.callbacks.dcp_saver.dist_cp")
198+
@patch("torchtnt.framework.callbacks.dcp_saver.dcp")
195199
def test_save_restore_no_optimizer_restore(self, mock_dist_cp: MagicMock) -> None:
196200
my_unit = DummyTrainUnit(input_dim=2)
197201
restore_options = RestoreOptions(restore_optimizers=False)
@@ -200,13 +204,13 @@ def test_save_restore_no_optimizer_restore(self, mock_dist_cp: MagicMock) -> Non
200204
unit=my_unit,
201205
restore_options=restore_options,
202206
)
203-
app_state = mock_dist_cp.load_state_dict.call_args.args[0]["app_state"]
207+
app_state = mock_dist_cp.load.call_args.args[0]["app_state"].state_dict()
204208
self.assertNotIn("optimizer", app_state)
205209
DistributedCheckpointSaver.restore(path="path/to/snapshot", unit=my_unit)
206-
app_state = mock_dist_cp.load_state_dict.call_args.args[0]["app_state"]
210+
app_state = mock_dist_cp.load.call_args.args[0]["app_state"].state_dict()
207211
self.assertIn("optimizer", app_state)
208212

209-
@patch("torchtnt.framework.callbacks.dcp_saver.dist_cp")
213+
@patch("torchtnt.framework.callbacks.dcp_saver.dcp")
210214
def test_save_restore_no_lr_scheduler_restore(
211215
self, mock_dist_cp: MagicMock
212216
) -> None:
@@ -215,17 +219,17 @@ def test_save_restore_no_lr_scheduler_restore(
215219
DistributedCheckpointSaver.restore(
216220
path="path/to/snapshot", unit=my_unit, restore_options=restore_options
217221
)
218-
app_state = mock_dist_cp.load_state_dict.call_args.args[0]["app_state"]
222+
app_state = mock_dist_cp.load.call_args.args[0]["app_state"].state_dict()
219223
self.assertNotIn("lr_scheduler", app_state)
220224
DistributedCheckpointSaver.restore(path="path/to/snapshot", unit=my_unit)
221-
app_state = mock_dist_cp.load_state_dict.call_args.args[0]["app_state"]
225+
app_state = mock_dist_cp.load.call_args.args[0]["app_state"].state_dict()
222226
self.assertIn("lr_scheduler", app_state)
223227

224228
@skip_if_not_distributed
225229
def test_save_restore_ddp(self) -> None:
226230
spawn_multi_process(
227231
2,
228-
"gloo",
232+
"cpu:gloo,cuda:gloo",
229233
self._save_restore_ddp,
230234
)
231235

@@ -248,6 +252,7 @@ def _save_restore_ddp() -> None:
248252
dcp_cb = DistributedCheckpointSaver(
249253
temp_dir,
250254
save_every_n_epochs=save_every_n_epochs,
255+
knob_options=KnobOptions(1),
251256
)
252257
temp_dir = dcp_cb.dirpath
253258
train(my_unit, dataloader, max_epochs=max_epochs, callbacks=[dcp_cb])

torchtnt/framework/callbacks/dcp_saver.py

Lines changed: 134 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,17 @@
77
# pyre-strict
88

99
import logging
10-
from typing import Iterable, Optional
10+
import time
11+
from concurrent.futures import Future
12+
from typing import Any, Dict, Iterable, Optional, Union
1113

14+
import torch
1215
import torch.distributed as dist
16+
from torch.distributed import checkpoint as dcp
1317

14-
from torch.distributed import checkpoint as dist_cp
1518
from torch.distributed.checkpoint._fsspec_filesystem import FsspecReader, FsspecWriter
16-
19+
from torch.distributed.checkpoint.state_dict import _init_optim_state
20+
from torch.distributed.checkpoint.stateful import Stateful
1721
from torchtnt.framework.callbacks._checkpoint_utils import (
1822
_prepare_app_state_for_checkpoint,
1923
_prepare_app_state_for_restore,
@@ -23,13 +27,20 @@
2327
from torchtnt.framework.callbacks.base_checkpointer import BaseCheckpointer
2428
from torchtnt.framework.callbacks.checkpointer_types import (
2529
BestCheckpointConfig,
30+
KnobOptions,
2631
RestoreOptions,
2732
)
2833
from torchtnt.framework.state import State
29-
from torchtnt.framework.unit import AppStateMixin, TTrainData
34+
from torchtnt.framework.unit import (
35+
AppStateMixin,
36+
TEvalUnit,
37+
TPredictUnit,
38+
TTrainData,
39+
TTrainUnit,
40+
)
3041
from torchtnt.framework.utils import get_timing_context
3142
from torchtnt.utils.rank_zero_log import rank_zero_info, rank_zero_warn
32-
from torchtnt.utils.stateful import MultiStateful, Stateful
43+
from torchtnt.utils.stateful import MultiStateful
3344

3445

3546
logger: logging.Logger = logging.getLogger(__name__)
@@ -54,10 +65,12 @@ class DistributedCheckpointSaver(BaseCheckpointer):
5465
keep_last_n_checkpoints: Number of most recent checkpoints to keep. If None, all checkpoints are kept. If an excess of existing checkpoints are present, the oldest ones will be deleted to clean the difference. If best checkpoint config is enabled, this param will manage the top n checkpoints instead.
5566
best_checkpoint_config: Configuration for saving the best checkpoint based on a monitored metric. The metric is read off the attribute of the unit prior to checkpoint.
5667
process_group: The process group on which the ranks will communicate on. default: ``None`` (the entire world)
68+
async_checkpoint: Whether to perform asynchronous checkpointing. Default: ``True``.
69+
knob_options: Additional keyword options for StorageWriter. <https://pytorch.org/docs/stable/distributed.checkpoint.html#torch.distributed.checkpoint.StorageWriter/>
5770
5871
Note:
59-
If torch.distributed is available and default process group is initialized, dcp's `no_dist <https://pytorch.org/docs/stable/distributed.checkpoint.html#torch.distributed.checkpoint.load_state_dict/>_`
60-
argument is automatically set to False. Otherwise it's set to True.
72+
If torch.distributed is available and a process group is initialized, dcp assumes the intention is to save/load checkpoints in distributed fashion.
73+
Additionally, a gloo process group must be initialized for async_checkpoint. For workloads that require nccl, the recommended initialization is 'cpu:gloo,cuda:nccl'
6174
6275
Note:
6376
If checkpointing FSDP model, you can set state_dict type calling `set_state_dict_type <https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.FullyShardedDataParallel.set_state_dict_type>`_ prior to starting training.
@@ -67,6 +80,8 @@ class DistributedCheckpointSaver(BaseCheckpointer):
6780
appropriately. For example, if logging validation accuracy, the unit must be responsible for maintaining the value and resetting it when the epoch ends.
6881
"""
6982

83+
metadata_fname: Optional[str] = ".metadata"
84+
7085
def __init__(
7186
self,
7287
dirpath: str,
@@ -77,6 +92,8 @@ def __init__(
7792
keep_last_n_checkpoints: Optional[int] = None,
7893
best_checkpoint_config: Optional[BestCheckpointConfig] = None,
7994
process_group: Optional[dist.ProcessGroup] = None,
95+
async_checkpoint: bool = False,
96+
knob_options: Optional[KnobOptions] = None,
8097
) -> None:
8198
super().__init__(
8299
dirpath=dirpath,
@@ -87,6 +104,10 @@ def __init__(
87104
best_checkpoint_config=best_checkpoint_config,
88105
process_group=process_group,
89106
)
107+
self._async_checkpoint = async_checkpoint
108+
109+
self._knob_options: KnobOptions = knob_options or KnobOptions()
110+
self._prev_snapshot: Optional[Future] = None
90111

91112
def _checkpoint_impl(
92113
self,
@@ -96,25 +117,78 @@ def _checkpoint_impl(
96117
checkpoint_path: str,
97118
hook: str,
98119
) -> bool:
99-
intra_epoch = False
100-
if hook == "on_train_step_end":
101-
intra_epoch = True
120+
if hook not in ["on_train_step_end", "on_train_epoch_end", "on_train_end"]:
121+
raise RuntimeError(f"Unexpected hook encountered '{hook}'")
102122

103-
storage_writer = FsspecWriter(checkpoint_path)
123+
intra_epoch = hook == "on_train_step_end"
124+
curr_snapshot_wait = hook == "on_train_end"
104125

105126
app_state = _prepare_app_state_for_checkpoint(state, unit, intra_epoch)
106-
# flag to indicate whether distributed is available
107-
# determines what to set ``no_dist`` arg in DCP apis
108-
pg_available: bool = dist.is_initialized()
109-
with get_timing_context(state, f"{self.__class__.__name__}.save_state_dict"):
110-
dist_cp.save_state_dict(
111-
{"app_state": MultiStateful(app_state).state_dict()},
112-
storage_writer=storage_writer,
113-
process_group=self._process_group,
114-
no_dist=not pg_available,
115-
)
127+
# TODO: evaluate whether we need to implement the equivalent of torchsnapshot.RNGState()
128+
if self._async_checkpoint:
129+
with get_timing_context(state, f"{self.__class__.__name__}.async_save"):
130+
# TODO checkpoint is not truly successful
131+
# since this is async checkpointed, so in
132+
# future, add logic to set successful flag
133+
# only when checkpoint is fully written
134+
checkpoint_success = self._async_save(checkpoint_path, app_state)
135+
if curr_snapshot_wait:
136+
self._wait()
137+
else:
138+
with get_timing_context(state, f"{self.__class__.__name__}.save"):
139+
checkpoint_success = self._save(checkpoint_path, app_state)
140+
141+
return checkpoint_success
142+
143+
def _wait(self) -> None:
144+
if self._prev_snapshot is not None:
145+
self._prev_snapshot.result()
146+
147+
def _async_save(self, checkpoint_id: str, app_state: Dict[str, Stateful]) -> bool:
148+
149+
if self._prev_snapshot is not None:
150+
if not self._prev_snapshot.done():
151+
rank_zero_warn(
152+
(
153+
"Waiting on previous checkpoint to finish... Consider modifying checkpointing "
154+
f"frequency if this is an issue. Current value (current {self._save_every_n_train_steps})"
155+
),
156+
logger=logger,
157+
)
158+
t0 = time.monotonic()
159+
self._wait()
160+
rank_zero_warn(
161+
f"Waiting on previous checkpoint for {time.monotonic()-t0:.3f} seconds",
162+
logger=logger,
163+
)
164+
else:
165+
self._wait()
166+
167+
self._prev_snapshot = dcp.async_save(
168+
state_dict={"app_state": MultiStateful(app_state)},
169+
process_group=self._process_group,
170+
storage_writer=FsspecWriter(checkpoint_id, **self.default_writer_options),
171+
)
172+
116173
return True
117174

175+
def _save(self, checkpoint_id: str, app_state: Dict[str, Stateful]) -> bool:
176+
dcp.save(
177+
state_dict={"app_state": MultiStateful(app_state)},
178+
process_group=self._process_group,
179+
storage_writer=FsspecWriter(checkpoint_id, **self.default_writer_options),
180+
)
181+
182+
return True
183+
184+
def on_exception(
185+
self,
186+
state: State,
187+
unit: Union[TTrainUnit, TEvalUnit, TPredictUnit],
188+
exc: BaseException,
189+
) -> None:
190+
self._wait()
191+
118192
@staticmethod
119193
def restore(
120194
path: str,
@@ -123,6 +197,7 @@ def restore(
123197
train_dataloader: Optional[Iterable[TTrainData]] = None,
124198
process_group: Optional[dist.ProcessGroup] = None,
125199
restore_options: Optional[RestoreOptions] = None,
200+
knob_options: Optional[KnobOptions] = None,
126201
) -> None:
127202
"""Utility method to restore dcp checkpoint from a path.
128203
@@ -133,10 +208,16 @@ def restore(
133208
path: Path of the snapshot to restore.
134209
unit: An instance of :class:`~torchtnt.framework.unit.TrainUnit`, :class:`~torchtnt.framework.unit.EvalUnit`, or :class:`~torchtnt.framework.unit.PredictUnit` containing states to restore.
135210
train_dataloader: An optional train dataloader to restore.
136-
process_group: The process group on which the ranks will communicate on. default: ``None`` (the entire world)
211+
process_group: The process group on which the ranks will communicate on. default: ``None`` (the entire world) Note:
212+
If torch.distributed is available and a process group is initialized, dcp assumes the intention is to save/load checkpoints in distributed fashion.
137213
restore_options: Controls what to filter when restoring the state.
138-
no_dist: Set to true if loading in non-distributed setting
214+
knob_options: Option is kept for legacy reasons but ignored in DCP
139215
"""
216+
if knob_options is not None:
217+
rank_zero_warn(
218+
"Ignoring `knob_options` which was passed to DistributedCheckpointSaver.restore, but is not supported."
219+
)
220+
140221
storage_reader = FsspecReader(path)
141222

142223
restore_options = restore_options or RestoreOptions()
@@ -161,13 +242,37 @@ def restore(
161242
"train_dataloader was passed to `restore` but no train dataloader exists in the Snapshot"
162243
)
163244

164-
state_dict = {"app_state": MultiStateful(app_state).state_dict()}
165-
no_dist = not dist.is_initialized()
166-
dist_cp.load_state_dict(
167-
state_dict,
245+
# necessary for loading optimizers since states are initialized lazy
246+
for obj in app_state.values():
247+
if isinstance(obj, torch.optim.Optimizer):
248+
_init_optim_state(obj)
249+
250+
dcp.load(
251+
{"app_state": MultiStateful(app_state)},
168252
storage_reader=storage_reader,
169253
process_group=process_group,
170-
no_dist=no_dist,
171254
)
172-
MultiStateful(app_state).load_state_dict(state_dict["app_state"])
173255
rank_zero_info(f"Restored snapshot from path: {path}", logger=logger)
256+
257+
def _does_checkpoint_exist(
258+
self, checkpoint_path: str, process_group: Optional[dist.ProcessGroup] = None
259+
) -> bool:
260+
# if we are still checkpointing, this might cause a collective hang.
261+
# so wait here instead
262+
self._wait()
263+
264+
return super()._does_checkpoint_exist(
265+
checkpoint_path=checkpoint_path, process_group=process_group
266+
)
267+
268+
@property
269+
def default_writer_options(self) -> Dict[str, Any]:
270+
# defaults are picked to to match TSS defaults
271+
# TODO: expose these options in KnobOptions
272+
dcp_options = {
273+
"thread_count": self._knob_options.max_per_rank_io_concurrency or 16,
274+
"sync_files": False,
275+
"single_file_per_rank": False,
276+
}
277+
278+
return dcp_options

0 commit comments

Comments
 (0)