Skip to content

Commit d3e85dc

Browse files
diego-urgellfacebook-github-bot
authored andcommitted
Remove support for deprecated DCP APIs in DCPSaver callback (#890)
Summary: Pull Request resolved: #890 Reviewed By: anshulverma, JKSenthil Differential Revision: D61887203 fbshipit-source-id: 17bd899a9b88033feb0285f3a395be6edbf82d5a
1 parent 3345ed9 commit d3e85dc

File tree

2 files changed

+23
-58
lines changed

2 files changed

+23
-58
lines changed

tests/framework/callbacks/test_dcp_saver.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,18 +7,11 @@
77

88
# pyre-strict
99

10-
import unittest
11-
12-
from torchtnt.framework.callbacks.dcp_saver import _LATEST_DCP_AVAIL
13-
from torchtnt.framework.state import State
14-
15-
if not _LATEST_DCP_AVAIL:
16-
raise unittest.SkipTest("Latest Pytorch is required to run DCP tests")
17-
1810
import math
1911
import os
2012
import shutil
2113
import tempfile
14+
import unittest
2215
from typing import Any, Dict, Iterator, List, Optional
2316
from unittest import mock
2417
from unittest.mock import MagicMock, patch
@@ -40,6 +33,8 @@
4033
)
4134
from torchtnt.framework.callbacks.checkpointer_types import KnobOptions, RestoreOptions
4235
from torchtnt.framework.callbacks.dcp_saver import DistributedCheckpointSaver
36+
37+
from torchtnt.framework.state import State
4338
from torchtnt.framework.train import train
4439
from torchtnt.utils.distributed import get_global_rank, spawn_multi_process
4540
from torchtnt.utils.env import seed

torchtnt/framework/callbacks/dcp_saver.py

Lines changed: 20 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,11 @@
1616
import torch.distributed as dist
1717
from pyre_extensions import none_throws
1818
from torch.distributed import checkpoint as dcp
19+
20+
from torch.distributed.checkpoint._fsspec_filesystem import (
21+
FsspecReader as Reader,
22+
FsspecWriter as Writer,
23+
)
1924
from torch.distributed.checkpoint.default_planner import (
2025
DefaultLoadPlanner,
2126
DefaultSavePlanner,
@@ -45,25 +50,8 @@
4550
from torchtnt.utils.rank_zero_log import rank_zero_info, rank_zero_warn
4651
from torchtnt.utils.stateful import MultiStateful, Stateful
4752

48-
4953
logger: logging.Logger = logging.getLogger(__name__)
5054

51-
_LATEST_DCP_AVAIL: bool = True
52-
try:
53-
from torch.distributed.checkpoint._fsspec_filesystem import (
54-
FsspecReader as Reader,
55-
FsspecWriter as Writer,
56-
)
57-
except ModuleNotFoundError:
58-
logger.warn(
59-
"To use FsspecReader / FsspecWriter, please install latest pytorch version"
60-
)
61-
_LATEST_DCP_AVAIL = False
62-
from torch.distributed.checkpoint import (
63-
FileSystemReader as Reader,
64-
FileSystemWriter as Writer,
65-
)
66-
6755

6856
class DistributedCheckpointSaver(BaseCheckpointer):
6957
"""
@@ -248,24 +236,13 @@ def _save(
248236
if storage_writer is None:
249237
storage_writer = Writer(checkpoint_id, **self.default_writer_options)
250238

251-
try:
252-
dcp.save(
253-
state_dict={"app_state": MultiStateful(app_state)},
254-
checkpoint_id=checkpoint_id,
255-
process_group=self._process_group,
256-
storage_writer=storage_writer,
257-
planner=planner,
258-
)
259-
except AttributeError as ex:
260-
logger.warning(
261-
f"Unable to save checkpoint (will retry saving using deprecated API). Error: {ex}"
262-
)
263-
dcp.save_state_dict(
264-
state_dict={"app_state": MultiStateful(app_state)},
265-
process_group=self._process_group,
266-
storage_writer=storage_writer,
267-
planner=planner,
268-
)
239+
dcp.save(
240+
state_dict={"app_state": MultiStateful(app_state)},
241+
checkpoint_id=checkpoint_id,
242+
process_group=self._process_group,
243+
storage_writer=storage_writer,
244+
planner=planner,
245+
)
269246

270247
return True
271248

@@ -397,21 +374,14 @@ def restore_with_id(
397374
if isinstance(optimizer, torch.optim.Optimizer):
398375
init_optim_state(optimizer)
399376

400-
try:
401-
dcp.load(
402-
{"app_state": MultiStateful(app_state)},
403-
checkpoint_id=checkpoint_id,
404-
storage_reader=storage_reader,
405-
planner=planner,
406-
process_group=process_group,
407-
)
408-
except AttributeError:
409-
dcp.load_state_dict(
410-
{"app_state": MultiStateful(app_state)},
411-
storage_reader=storage_reader,
412-
process_group=process_group,
413-
planner=planner,
414-
)
377+
dcp.load(
378+
{"app_state": MultiStateful(app_state)},
379+
checkpoint_id=checkpoint_id,
380+
storage_reader=storage_reader,
381+
planner=planner,
382+
process_group=process_group,
383+
)
384+
415385
rank_zero_info(
416386
f"Restored snapshot for checkpoint_id: {checkpoint_id}", logger=logger
417387
)

0 commit comments

Comments
 (0)