|
16 | 16 | import torch.distributed as dist
|
17 | 17 | from pyre_extensions import none_throws
|
18 | 18 | from torch.distributed import checkpoint as dcp
|
| 19 | + |
| 20 | +from torch.distributed.checkpoint._fsspec_filesystem import ( |
| 21 | + FsspecReader as Reader, |
| 22 | + FsspecWriter as Writer, |
| 23 | +) |
19 | 24 | from torch.distributed.checkpoint.default_planner import (
|
20 | 25 | DefaultLoadPlanner,
|
21 | 26 | DefaultSavePlanner,
|
|
45 | 50 | from torchtnt.utils.rank_zero_log import rank_zero_info, rank_zero_warn
|
46 | 51 | from torchtnt.utils.stateful import MultiStateful, Stateful
|
47 | 52 |
|
48 |
| - |
49 | 53 | logger: logging.Logger = logging.getLogger(__name__)
|
50 | 54 |
|
51 |
| -_LATEST_DCP_AVAIL: bool = True |
52 |
| -try: |
53 |
| - from torch.distributed.checkpoint._fsspec_filesystem import ( |
54 |
| - FsspecReader as Reader, |
55 |
| - FsspecWriter as Writer, |
56 |
| - ) |
57 |
| -except ModuleNotFoundError: |
58 |
| - logger.warn( |
59 |
| - "To use FsspecReader / FsspecWriter, please install latest pytorch version" |
60 |
| - ) |
61 |
| - _LATEST_DCP_AVAIL = False |
62 |
| - from torch.distributed.checkpoint import ( |
63 |
| - FileSystemReader as Reader, |
64 |
| - FileSystemWriter as Writer, |
65 |
| - ) |
66 |
| - |
67 | 55 |
|
68 | 56 | class DistributedCheckpointSaver(BaseCheckpointer):
|
69 | 57 | """
|
@@ -248,24 +236,13 @@ def _save(
|
248 | 236 | if storage_writer is None:
|
249 | 237 | storage_writer = Writer(checkpoint_id, **self.default_writer_options)
|
250 | 238 |
|
251 |
| - try: |
252 |
| - dcp.save( |
253 |
| - state_dict={"app_state": MultiStateful(app_state)}, |
254 |
| - checkpoint_id=checkpoint_id, |
255 |
| - process_group=self._process_group, |
256 |
| - storage_writer=storage_writer, |
257 |
| - planner=planner, |
258 |
| - ) |
259 |
| - except AttributeError as ex: |
260 |
| - logger.warning( |
261 |
| - f"Unable to save checkpoint (will retry saving using deprecated API). Error: {ex}" |
262 |
| - ) |
263 |
| - dcp.save_state_dict( |
264 |
| - state_dict={"app_state": MultiStateful(app_state)}, |
265 |
| - process_group=self._process_group, |
266 |
| - storage_writer=storage_writer, |
267 |
| - planner=planner, |
268 |
| - ) |
| 239 | + dcp.save( |
| 240 | + state_dict={"app_state": MultiStateful(app_state)}, |
| 241 | + checkpoint_id=checkpoint_id, |
| 242 | + process_group=self._process_group, |
| 243 | + storage_writer=storage_writer, |
| 244 | + planner=planner, |
| 245 | + ) |
269 | 246 |
|
270 | 247 | return True
|
271 | 248 |
|
@@ -397,21 +374,14 @@ def restore_with_id(
|
397 | 374 | if isinstance(optimizer, torch.optim.Optimizer):
|
398 | 375 | init_optim_state(optimizer)
|
399 | 376 |
|
400 |
| - try: |
401 |
| - dcp.load( |
402 |
| - {"app_state": MultiStateful(app_state)}, |
403 |
| - checkpoint_id=checkpoint_id, |
404 |
| - storage_reader=storage_reader, |
405 |
| - planner=planner, |
406 |
| - process_group=process_group, |
407 |
| - ) |
408 |
| - except AttributeError: |
409 |
| - dcp.load_state_dict( |
410 |
| - {"app_state": MultiStateful(app_state)}, |
411 |
| - storage_reader=storage_reader, |
412 |
| - process_group=process_group, |
413 |
| - planner=planner, |
414 |
| - ) |
| 377 | + dcp.load( |
| 378 | + {"app_state": MultiStateful(app_state)}, |
| 379 | + checkpoint_id=checkpoint_id, |
| 380 | + storage_reader=storage_reader, |
| 381 | + planner=planner, |
| 382 | + process_group=process_group, |
| 383 | + ) |
| 384 | + |
415 | 385 | rank_zero_info(
|
416 | 386 | f"Restored snapshot for checkpoint_id: {checkpoint_id}", logger=logger
|
417 | 387 | )
|
|
0 commit comments