|
15 | 15 | import torch.distributed as dist
|
16 | 16 | from torch.distributed import checkpoint as dcp
|
17 | 17 |
|
18 |
| -from torch.distributed.checkpoint._fsspec_filesystem import FsspecReader, FsspecWriter |
19 | 18 | from torchtnt.framework.callbacks._checkpoint_utils import (
|
20 | 19 | _prepare_app_state_for_checkpoint,
|
21 | 20 | _prepare_app_state_for_restore,
|
|
44 | 43 |
|
45 | 44 | logger: logging.Logger = logging.getLogger(__name__)
|
46 | 45 |
|
| 46 | +_LATEST_DCP_AVAIL: bool = True |
| 47 | +try: |
| 48 | + from torch.distributed.checkpoint._fsspec_filesystem import ( |
| 49 | + FsspecReader as Reader, |
| 50 | + FsspecWriter as Writer, |
| 51 | + ) |
| 52 | +except ModuleNotFoundError: |
| 53 | + logger.warn( |
| 54 | + "To use FsspecReader / FsspecWriter, please install latest pytorch version" |
| 55 | + ) |
| 56 | + _LATEST_DCP_AVAIL = False |
| 57 | + from torch.distributed.checkpoint import ( |
| 58 | + FileSystemReader as Reader, |
| 59 | + FileSystemWriter as Writer, |
| 60 | + ) |
| 61 | + |
47 | 62 |
|
48 | 63 | class DistributedCheckpointSaver(BaseCheckpointer):
|
49 | 64 | """
|
@@ -166,17 +181,24 @@ def _async_save(self, checkpoint_id: str, app_state: Dict[str, Stateful]) -> boo
|
166 | 181 | self._prev_snapshot = dcp.async_save(
|
167 | 182 | state_dict={"app_state": MultiStateful(app_state)},
|
168 | 183 | process_group=self._process_group,
|
169 |
| - storage_writer=FsspecWriter(checkpoint_id, **self.default_writer_options), |
| 184 | + storage_writer=Writer(checkpoint_id, **self.default_writer_options), |
170 | 185 | )
|
171 | 186 |
|
172 | 187 | return True
|
173 | 188 |
|
174 | 189 | def _save(self, checkpoint_id: str, app_state: Dict[str, Stateful]) -> bool:
|
175 |
| - dcp.save( |
176 |
| - state_dict={"app_state": MultiStateful(app_state)}, |
177 |
| - process_group=self._process_group, |
178 |
| - storage_writer=FsspecWriter(checkpoint_id, **self.default_writer_options), |
179 |
| - ) |
| 190 | + try: |
| 191 | + dcp.save( |
| 192 | + state_dict={"app_state": MultiStateful(app_state)}, |
| 193 | + process_group=self._process_group, |
| 194 | + storage_writer=Writer(checkpoint_id, **self.default_writer_options), |
| 195 | + ) |
| 196 | + except AttributeError: |
| 197 | + dcp.save_state_dict( |
| 198 | + state_dict={"app_state": MultiStateful(app_state)}, |
| 199 | + process_group=self._process_group, |
| 200 | + storage_writer=Writer(checkpoint_id, **self.default_writer_options), |
| 201 | + ) |
180 | 202 |
|
181 | 203 | return True
|
182 | 204 |
|
@@ -217,7 +239,7 @@ def restore(
|
217 | 239 | "Ignoring `knob_options` which was passed to DistributedCheckpointSaver.restore, but is not supported."
|
218 | 240 | )
|
219 | 241 |
|
220 |
| - storage_reader = FsspecReader(path) |
| 242 | + storage_reader = Reader(path) |
221 | 243 |
|
222 | 244 | restore_options = restore_options or RestoreOptions()
|
223 | 245 | app_state = _prepare_app_state_for_restore(unit, restore_options)
|
@@ -250,11 +272,18 @@ def restore(
|
250 | 272 | if isinstance(optimizer, torch.optim.Optimizer):
|
251 | 273 | init_optim_state(optimizer)
|
252 | 274 |
|
253 |
| - dcp.load( |
254 |
| - {"app_state": MultiStateful(app_state)}, |
255 |
| - storage_reader=storage_reader, |
256 |
| - process_group=process_group, |
257 |
| - ) |
| 275 | + try: |
| 276 | + dcp.load( |
| 277 | + {"app_state": MultiStateful(app_state)}, |
| 278 | + storage_reader=storage_reader, |
| 279 | + process_group=process_group, |
| 280 | + ) |
| 281 | + except AttributeError: |
| 282 | + dcp.load_state_dict( |
| 283 | + {"app_state": MultiStateful(app_state)}, |
| 284 | + storage_reader=storage_reader, |
| 285 | + process_group=process_group, |
| 286 | + ) |
258 | 287 | rank_zero_info(f"Restored snapshot from path: {path}", logger=logger)
|
259 | 288 |
|
260 | 289 | def _does_checkpoint_exist(
|
|
0 commit comments