Skip to content

Commit 634be4c

Browse files
JKSenthilfacebook-github-bot
authored andcommitted
make DCPSaver OSS compatible
Summary: This diff makes the dcp saver OSS compatible (with any pytorch stable version >= 2.0.0). Reviewed By: diego-urgell Differential Revision: D56537398 fbshipit-source-id: 1c5630b9402135c25e4678c0214f165d5d827e8c
1 parent 82a6b62 commit 634be4c

File tree

2 files changed

+49
-14
lines changed

2 files changed

+49
-14
lines changed

tests/framework/callbacks/test_dcp_saver.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,18 @@
2626
generate_random_dataloader,
2727
)
2828
from torchtnt.framework.callbacks.checkpointer_types import KnobOptions, RestoreOptions
29-
from torchtnt.framework.callbacks.dcp_saver import DistributedCheckpointSaver
29+
from torchtnt.framework.callbacks.dcp_saver import (
30+
_LATEST_DCP_AVAIL,
31+
DistributedCheckpointSaver,
32+
)
3033
from torchtnt.framework.train import train
3134
from torchtnt.utils.distributed import get_global_rank, spawn_multi_process
3235
from torchtnt.utils.env import seed
3336
from torchtnt.utils.test_utils import skip_if_not_distributed
3437

38+
if not _LATEST_DCP_AVAIL:
39+
raise unittest.SkipTest("Latest Pytorch is required to run DCP tests")
40+
3541

3642
class DistributedCheckpointSaverTest(unittest.TestCase):
3743
def test_save_restore(self) -> None:

torchtnt/framework/callbacks/dcp_saver.py

Lines changed: 42 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
import torch.distributed as dist
1616
from torch.distributed import checkpoint as dcp
1717

18-
from torch.distributed.checkpoint._fsspec_filesystem import FsspecReader, FsspecWriter
1918
from torchtnt.framework.callbacks._checkpoint_utils import (
2019
_prepare_app_state_for_checkpoint,
2120
_prepare_app_state_for_restore,
@@ -44,6 +43,22 @@
4443

4544
logger: logging.Logger = logging.getLogger(__name__)
4645

46+
_LATEST_DCP_AVAIL: bool = True
47+
try:
48+
from torch.distributed.checkpoint._fsspec_filesystem import (
49+
FsspecReader as Reader,
50+
FsspecWriter as Writer,
51+
)
52+
except ModuleNotFoundError:
53+
logger.warn(
54+
"To use FsspecReader / FsspecWriter, please install latest pytorch version"
55+
)
56+
_LATEST_DCP_AVAIL = False
57+
from torch.distributed.checkpoint import (
58+
FileSystemReader as Reader,
59+
FileSystemWriter as Writer,
60+
)
61+
4762

4863
class DistributedCheckpointSaver(BaseCheckpointer):
4964
"""
@@ -166,17 +181,24 @@ def _async_save(self, checkpoint_id: str, app_state: Dict[str, Stateful]) -> boo
166181
self._prev_snapshot = dcp.async_save(
167182
state_dict={"app_state": MultiStateful(app_state)},
168183
process_group=self._process_group,
169-
storage_writer=FsspecWriter(checkpoint_id, **self.default_writer_options),
184+
storage_writer=Writer(checkpoint_id, **self.default_writer_options),
170185
)
171186

172187
return True
173188

174189
def _save(self, checkpoint_id: str, app_state: Dict[str, Stateful]) -> bool:
175-
dcp.save(
176-
state_dict={"app_state": MultiStateful(app_state)},
177-
process_group=self._process_group,
178-
storage_writer=FsspecWriter(checkpoint_id, **self.default_writer_options),
179-
)
190+
try:
191+
dcp.save(
192+
state_dict={"app_state": MultiStateful(app_state)},
193+
process_group=self._process_group,
194+
storage_writer=Writer(checkpoint_id, **self.default_writer_options),
195+
)
196+
except AttributeError:
197+
dcp.save_state_dict(
198+
state_dict={"app_state": MultiStateful(app_state)},
199+
process_group=self._process_group,
200+
storage_writer=Writer(checkpoint_id, **self.default_writer_options),
201+
)
180202

181203
return True
182204

@@ -217,7 +239,7 @@ def restore(
217239
"Ignoring `knob_options` which was passed to DistributedCheckpointSaver.restore, but is not supported."
218240
)
219241

220-
storage_reader = FsspecReader(path)
242+
storage_reader = Reader(path)
221243

222244
restore_options = restore_options or RestoreOptions()
223245
app_state = _prepare_app_state_for_restore(unit, restore_options)
@@ -250,11 +272,18 @@ def restore(
250272
if isinstance(optimizer, torch.optim.Optimizer):
251273
init_optim_state(optimizer)
252274

253-
dcp.load(
254-
{"app_state": MultiStateful(app_state)},
255-
storage_reader=storage_reader,
256-
process_group=process_group,
257-
)
275+
try:
276+
dcp.load(
277+
{"app_state": MultiStateful(app_state)},
278+
storage_reader=storage_reader,
279+
process_group=process_group,
280+
)
281+
except AttributeError:
282+
dcp.load_state_dict(
283+
{"app_state": MultiStateful(app_state)},
284+
storage_reader=storage_reader,
285+
process_group=process_group,
286+
)
258287
rank_zero_info(f"Restored snapshot from path: {path}", logger=logger)
259288

260289
def _does_checkpoint_exist(

0 commit comments

Comments
 (0)