Skip to content

Commit ec6d9ee

Browse files
saumishrfacebook-github-bot
authored andcommitted
Configurable planner and storage writer in the dcp saver save API (#821)
Summary: Pull Request resolved: #821 Configurable planner and storage writer in the DCP saver save API. # This Stack DCP saver is the TorchTNT callback which allows checkpointing via the Distributed Checkpointing APIs. Current implementation doesn't expose the Save Planner and Storage Writer in the API for clients to plug in their implementations. It enforces the default planner and FsspecWriter. # This diff - DCP save and async save APIs now support planner and storage writer allowing clients to plug in their implementations. - Introduces a knob option to plug in storage writer component with storage efficiency optimizations Reviewed By: JKSenthil Differential Revision: D56921724 fbshipit-source-id: b60c34c6df38e02c0af9a1db4ffd243d382fd621
1 parent e14d0cf commit ec6d9ee

File tree

3 files changed

+133
-15
lines changed

3 files changed

+133
-15
lines changed

tests/framework/callbacks/test_dcp_saver.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@
1818

1919
import torch
2020
from torch import nn
21+
from torch.distributed.checkpoint import FileSystemWriter
22+
from torch.distributed.checkpoint.default_planner import DefaultSavePlanner
23+
from torch.distributed.checkpoint.metadata import STATE_DICT_TYPE
2124
from torch.utils.data import DataLoader
2225
from torchsnapshot.test_utils import assert_state_dict_eq, check_state_dict_eq
2326
from torchtnt.framework._test_utils import (
@@ -289,6 +292,62 @@ def _save_restore_ddp() -> None:
289292
if get_global_rank() == 0:
290293
shutil.rmtree(temp_dir) # delete temp directory
291294

295+
@patch("torchtnt.framework.callbacks.dcp_saver.dcp")
296+
def test_save_default_planner_storage_components(
297+
self, mock_dist_cp: MagicMock
298+
) -> None:
299+
from torch.distributed.checkpoint._fsspec_filesystem import FsspecWriter
300+
301+
input_dim = 2
302+
save_every_n_train_steps = 1
303+
304+
my_unit = DummyTrainUnit(input_dim=input_dim)
305+
306+
with tempfile.TemporaryDirectory() as temp_dir:
307+
dcp_cb = DistributedCheckpointSaver(
308+
temp_dir,
309+
save_every_n_train_steps=save_every_n_train_steps,
310+
knob_options=KnobOptions(1),
311+
)
312+
313+
dcp_cb._save(
314+
checkpoint_id=temp_dir,
315+
app_state=my_unit.module.state_dict(),
316+
)
317+
318+
planner = mock_dist_cp.save.call_args_list[0][1]["planner"]
319+
storage_writer = mock_dist_cp.save.call_args_list[0][1]["storage_writer"]
320+
321+
self.assertIsInstance(planner, DefaultSavePlanner)
322+
self.assertIsInstance(storage_writer, FsspecWriter)
323+
324+
@patch("torchtnt.framework.callbacks.dcp_saver.dcp")
325+
def test_save_planner_storage_components(self, mock_dist_cp: MagicMock) -> None:
326+
input_dim = 2
327+
save_every_n_train_steps = 1
328+
329+
my_unit = DummyTrainUnit(input_dim=input_dim)
330+
331+
with tempfile.TemporaryDirectory() as temp_dir:
332+
dcp_cb = DistributedCheckpointSaver(
333+
temp_dir,
334+
save_every_n_train_steps=save_every_n_train_steps,
335+
knob_options=KnobOptions(1),
336+
)
337+
338+
dcp_cb._save(
339+
checkpoint_id=temp_dir,
340+
app_state=my_unit.module.state_dict(),
341+
planner=DummySavePlanner(),
342+
storage_writer=DummyStorageWriter(path=temp_dir),
343+
)
344+
345+
planner = mock_dist_cp.save.call_args_list[0][1]["planner"]
346+
storage_writer = mock_dist_cp.save.call_args_list[0][1]["storage_writer"]
347+
348+
self.assertIsInstance(planner, DummySavePlanner)
349+
self.assertIsInstance(storage_writer, DummyStorageWriter)
350+
292351

293352
class DummyStatefulDataLoader:
294353
def __init__(self, dataloader: DataLoader) -> None:
@@ -306,3 +365,19 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
306365

307366
def __iter__(self) -> Iterator[object]:
308367
return iter(self.dataloader)
368+
369+
370+
class DummySavePlanner(DefaultSavePlanner):
371+
def __init__(self) -> None:
372+
super().__init__()
373+
374+
def set_up_planner(self, state_dict: STATE_DICT_TYPE, is_coordinator: bool) -> None:
375+
super().set_up_planner(state_dict, is_coordinator)
376+
377+
378+
class DummyStorageWriter(FileSystemWriter):
379+
def __init__(self, path: str) -> None:
380+
super().__init__(path)
381+
382+
def set_up_storage_writer(self, is_coordinator: bool) -> None:
383+
pass

torchtnt/framework/callbacks/checkpointer_types.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,20 @@
1414
@dataclass
1515
class KnobOptions:
1616
"""
17-
Controls the knobs in TorchSnapshot.
17+
Controls the knobs for Checkpoints.
1818
1919
Args:
20-
max_per_rank_io_concurrency: Maximum number of concurrent IO operations per rank. Defaults to 16.
20+
max_per_rank_io_concurrency: Maximum number of concurrent IO operations per rank in checkpointing.
21+
Defaults to 16.
22+
enable_storage_optimization: Enable storage efficiency optimizations for Distributed Checkpointing.
2123
"""
2224

25+
# use a more conservative number of concurrent IO operations per rank in Checkpointing
26+
# the default value of 16 is too bandwidth hungry for most users
2327
max_per_rank_io_concurrency: Optional[int] = None
28+
# This is a no-op and for future use. This would enable storage efficiency optimizations:
29+
# e.g. Compression, Batching, Quantization etc.
30+
enable_storage_optimization: bool = False
2431

2532

2633
@dataclass

torchtnt/framework/callbacks/dcp_saver.py

Lines changed: 49 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@
1414
import torch
1515
import torch.distributed as dist
1616
from torch.distributed import checkpoint as dcp
17+
from torch.distributed.checkpoint.default_planner import DefaultSavePlanner
18+
from torch.distributed.checkpoint.planner import SavePlanner
19+
from torch.distributed.checkpoint.storage import StorageWriter
1720

1821
from torchtnt.framework.callbacks._checkpoint_utils import (
1922
_prepare_app_state_for_checkpoint,
@@ -127,6 +130,8 @@ def _checkpoint_impl(
127130
*,
128131
checkpoint_path: str,
129132
hook: str,
133+
planner: Optional[SavePlanner] = None,
134+
storage_writer: Optional[StorageWriter] = None,
130135
) -> bool:
131136
if hook not in ["on_train_step_end", "on_train_epoch_end", "on_train_end"]:
132137
raise RuntimeError(f"Unexpected hook encountered '{hook}'")
@@ -142,20 +147,36 @@ def _checkpoint_impl(
142147
# since this is async checkpointed, so in
143148
# future, add logic to set successful flag
144149
# only when checkpoint is fully written
145-
checkpoint_success = self._async_save(checkpoint_path, app_state)
150+
checkpoint_success = self._async_save(
151+
checkpoint_path, app_state, planner, storage_writer
152+
)
146153
if curr_snapshot_wait:
147154
self._wait()
148155
else:
149156
with get_timing_context(state, f"{self.__class__.__name__}.save"):
150-
checkpoint_success = self._save(checkpoint_path, app_state)
157+
checkpoint_success = self._save(
158+
checkpoint_path, app_state, planner, storage_writer
159+
)
151160

152161
return checkpoint_success
153162

154163
def _wait(self) -> None:
155164
if self._prev_snapshot is not None:
156165
self._prev_snapshot.result()
157166

158-
def _async_save(self, checkpoint_id: str, app_state: Dict[str, Stateful]) -> bool:
167+
def _async_save(
168+
self,
169+
checkpoint_id: str,
170+
app_state: Dict[str, Stateful],
171+
planner: Optional[SavePlanner] = None,
172+
storage_writer: Optional[StorageWriter] = None,
173+
) -> bool:
174+
175+
if planner is None:
176+
planner = DefaultSavePlanner()
177+
178+
if storage_writer is None:
179+
storage_writer = Writer(checkpoint_id, **self.default_writer_options)
159180

160181
if self._prev_snapshot is not None:
161182
if not self._prev_snapshot.done():
@@ -177,24 +198,42 @@ def _async_save(self, checkpoint_id: str, app_state: Dict[str, Stateful]) -> boo
177198

178199
self._prev_snapshot = dcp.async_save(
179200
state_dict={"app_state": MultiStateful(app_state)},
201+
checkpoint_id=checkpoint_id,
180202
process_group=self._process_group,
181-
storage_writer=Writer(checkpoint_id, **self.default_writer_options),
203+
storage_writer=storage_writer,
204+
planner=planner,
182205
)
183206

184207
return True
185208

186-
def _save(self, checkpoint_id: str, app_state: Dict[str, Stateful]) -> bool:
209+
def _save(
210+
self,
211+
checkpoint_id: str,
212+
app_state: Dict[str, Stateful],
213+
planner: Optional[SavePlanner] = None,
214+
storage_writer: Optional[StorageWriter] = None,
215+
) -> bool:
216+
# Initialize DefaultSavePlanner and FsspecWriter if not provided
217+
if planner is None:
218+
planner = DefaultSavePlanner()
219+
220+
if storage_writer is None:
221+
storage_writer = Writer(checkpoint_id, **self.default_writer_options)
222+
187223
try:
188224
dcp.save(
189225
state_dict={"app_state": MultiStateful(app_state)},
226+
checkpoint_id=checkpoint_id,
190227
process_group=self._process_group,
191-
storage_writer=Writer(checkpoint_id, **self.default_writer_options),
228+
storage_writer=storage_writer,
229+
planner=planner,
192230
)
193231
except AttributeError:
194232
dcp.save_state_dict(
195233
state_dict={"app_state": MultiStateful(app_state)},
196234
process_group=self._process_group,
197-
storage_writer=Writer(checkpoint_id, **self.default_writer_options),
235+
storage_writer=storage_writer,
236+
planner=planner,
198237
)
199238

200239
return True
@@ -229,13 +268,8 @@ def restore(
229268
process_group: The process group on which the ranks will communicate on. default: ``None`` (the entire world) Note:
230269
If torch.distributed is available and a process group is initialized, dcp assumes the intention is to save/load checkpoints in distributed fashion.
231270
restore_options: Controls what to filter when restoring the state.
232-
knob_options: Option is kept for legacy reasons but ignored in DCP
271+
knob_options: Additional keyword options for StorageWriter and StorageReader
233272
"""
234-
if knob_options is not None:
235-
rank_zero_warn(
236-
"Ignoring `knob_options` which was passed to DistributedCheckpointSaver.restore, but is not supported."
237-
)
238-
239273
storage_reader = Reader(path)
240274

241275
restore_options = restore_options or RestoreOptions()
@@ -250,6 +284,7 @@ def restore(
250284
# request to restore the dataloader state only if
251285
# the persisted snapshot state includes the dataloader entry
252286
metadata = storage_reader.read_metadata()
287+
253288
for key in metadata.state_dict_metadata.keys():
254289
if _TRAIN_DL_STATE_KEY in key:
255290
app_state[_TRAIN_DL_STATE_KEY] = train_dataloader
@@ -272,6 +307,7 @@ def restore(
272307
try:
273308
dcp.load(
274309
{"app_state": MultiStateful(app_state)},
310+
checkpoint_id=path,
275311
storage_reader=storage_reader,
276312
process_group=process_group,
277313
)

0 commit comments

Comments
 (0)