Skip to content

Commit f3b9c52

Browse files
saumishrfacebook-github-bot
authored andcommitted
Configurable load planner and storage reader in the dcp_saver restore API (#824)
Summary: Pull Request resolved: #824 Configurable load planner and storage reader in the dcp_saver restore API # This Stack DCP saver is the TorchTNT callback which allows checkpointing via the Distributed Checkpointing APIs. Current implementation for the Restore API doesn't expose the Load Planner and Storage Reader in the API for clients to plug in their implementations. It enforces the default planner and FsspecReader. # This diff DCP saver Restore API now supports planner and storage reader allowing clients to plug in their implementations. Reviewed By: anshulverma Differential Revision: D56922769 fbshipit-source-id: 9965c3e78463f494b92196d48c506acc39fcb04d
1 parent ec6d9ee commit f3b9c52

File tree

2 files changed

+81
-7
lines changed

2 files changed

+81
-7
lines changed

tests/framework/callbacks/test_dcp_saver.py

Lines changed: 63 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,12 @@
1818

1919
import torch
2020
from torch import nn
21-
from torch.distributed.checkpoint import FileSystemWriter
22-
from torch.distributed.checkpoint.default_planner import DefaultSavePlanner
23-
from torch.distributed.checkpoint.metadata import STATE_DICT_TYPE
21+
from torch.distributed.checkpoint import FileSystemReader, FileSystemWriter
22+
from torch.distributed.checkpoint.default_planner import (
23+
DefaultLoadPlanner,
24+
DefaultSavePlanner,
25+
)
26+
from torch.distributed.checkpoint.metadata import Metadata, STATE_DICT_TYPE
2427
from torch.utils.data import DataLoader
2528
from torchsnapshot.test_utils import assert_state_dict_eq, check_state_dict_eq
2629
from torchtnt.framework._test_utils import (
@@ -348,6 +351,42 @@ def test_save_planner_storage_components(self, mock_dist_cp: MagicMock) -> None:
348351
self.assertIsInstance(planner, DummySavePlanner)
349352
self.assertIsInstance(storage_writer, DummyStorageWriter)
350353

354+
@patch("torchtnt.framework.callbacks.dcp_saver.dcp")
355+
def test_restore_default_planner_storage_components(
356+
self, mock_dist_cp: MagicMock
357+
) -> None:
358+
from torch.distributed.checkpoint._fsspec_filesystem import FsspecReader
359+
360+
my_unit = DummyTrainUnit(input_dim=2)
361+
restore_options = RestoreOptions(restore_optimizers=False)
362+
DistributedCheckpointSaver.restore(
363+
path="path/to/snapshot",
364+
unit=my_unit,
365+
restore_options=restore_options,
366+
)
367+
planner = mock_dist_cp.load.call_args[1]["planner"]
368+
storage_reader = mock_dist_cp.load.call_args[1]["storage_reader"]
369+
370+
self.assertIsInstance(planner, DefaultLoadPlanner)
371+
self.assertIsInstance(storage_reader, FsspecReader)
372+
373+
@patch("torchtnt.framework.callbacks.dcp_saver.dcp")
374+
def test_restore_planner_storage_components(self, mock_dist_cp: MagicMock) -> None:
375+
my_unit = DummyTrainUnit(input_dim=2)
376+
restore_options = RestoreOptions(restore_optimizers=False)
377+
DistributedCheckpointSaver.restore(
378+
path="path/to/snapshot",
379+
unit=my_unit,
380+
restore_options=restore_options,
381+
planner=DummyLoadPlanner(),
382+
storage_reader=DummyStorageReader(path="path/to/snapshot"),
383+
)
384+
planner = mock_dist_cp.load.call_args[1]["planner"]
385+
storage_reader = mock_dist_cp.load.call_args[1]["storage_reader"]
386+
387+
self.assertIsInstance(planner, DummyLoadPlanner)
388+
self.assertIsInstance(storage_reader, DummyStorageReader)
389+
351390

352391
class DummyStatefulDataLoader:
353392
def __init__(self, dataloader: DataLoader) -> None:
@@ -375,9 +414,30 @@ def set_up_planner(self, state_dict: STATE_DICT_TYPE, is_coordinator: bool) -> N
375414
super().set_up_planner(state_dict, is_coordinator)
376415

377416

417+
class DummyLoadPlanner(DefaultLoadPlanner):
418+
def __init__(self) -> None:
419+
super().__init__()
420+
421+
def set_up_planner(
422+
self,
423+
state_dict: STATE_DICT_TYPE,
424+
metadata: Metadata,
425+
is_coordinator: bool,
426+
) -> None:
427+
super().set_up_planner(state_dict, metadata, is_coordinator)
428+
429+
378430
class DummyStorageWriter(FileSystemWriter):
379431
def __init__(self, path: str) -> None:
380432
super().__init__(path)
381433

382434
def set_up_storage_writer(self, is_coordinator: bool) -> None:
383435
pass
436+
437+
438+
class DummyStorageReader(FileSystemReader):
439+
def __init__(self, path: str) -> None:
440+
super().__init__(path)
441+
442+
def set_up_storage_writer(self, is_coordinator: bool) -> None:
443+
pass

torchtnt/framework/callbacks/dcp_saver.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,12 @@
1414
import torch
1515
import torch.distributed as dist
1616
from torch.distributed import checkpoint as dcp
17-
from torch.distributed.checkpoint.default_planner import DefaultSavePlanner
18-
from torch.distributed.checkpoint.planner import SavePlanner
19-
from torch.distributed.checkpoint.storage import StorageWriter
17+
from torch.distributed.checkpoint.default_planner import (
18+
DefaultLoadPlanner,
19+
DefaultSavePlanner,
20+
)
21+
from torch.distributed.checkpoint.planner import LoadPlanner, SavePlanner
22+
from torch.distributed.checkpoint.storage import StorageReader, StorageWriter
2023

2124
from torchtnt.framework.callbacks._checkpoint_utils import (
2225
_prepare_app_state_for_checkpoint,
@@ -255,6 +258,8 @@ def restore(
255258
process_group: Optional[dist.ProcessGroup] = None,
256259
restore_options: Optional[RestoreOptions] = None,
257260
knob_options: Optional[KnobOptions] = None,
261+
planner: Optional[LoadPlanner] = None,
262+
storage_reader: Optional[StorageReader] = None,
258263
) -> None:
259264
"""Utility method to restore dcp checkpoint from a path.
260265
@@ -269,8 +274,15 @@ def restore(
269274
If torch.distributed is available and a process group is initialized, dcp assumes the intention is to save/load checkpoints in distributed fashion.
270275
restore_options: Controls what to filter when restoring the state.
271276
knob_options: Additional keyword options for StorageWriter and StorageReader
277+
planner: Instance of LoadPlanner. If this is not specificed, the default planner will be used. (Default: ``None``)
278+
storage_reader: Instance of StorageReader used to perform reads. If this is not specified, it will automatically infer
279+
the reader based on the checkpoint_id. If checkpoint_id is also None, an exception will be raised. (Default: ``None``)
272280
"""
273-
storage_reader = Reader(path)
281+
if planner is None:
282+
planner = DefaultLoadPlanner()
283+
284+
if storage_reader is None:
285+
storage_reader = Reader(path)
274286

275287
restore_options = restore_options or RestoreOptions()
276288
app_state = _prepare_app_state_for_restore(unit, restore_options)
@@ -309,13 +321,15 @@ def restore(
309321
{"app_state": MultiStateful(app_state)},
310322
checkpoint_id=path,
311323
storage_reader=storage_reader,
324+
planner=planner,
312325
process_group=process_group,
313326
)
314327
except AttributeError:
315328
dcp.load_state_dict(
316329
{"app_state": MultiStateful(app_state)},
317330
storage_reader=storage_reader,
318331
process_group=process_group,
332+
planner=planner,
319333
)
320334
rank_zero_info(f"Restored snapshot from path: {path}", logger=logger)
321335

0 commit comments

Comments
 (0)