|
18 | 18 |
|
19 | 19 | import torch
|
20 | 20 | from torch import nn
|
21 |
| -from torch.distributed.checkpoint import FileSystemWriter |
22 |
| -from torch.distributed.checkpoint.default_planner import DefaultSavePlanner |
23 |
| -from torch.distributed.checkpoint.metadata import STATE_DICT_TYPE |
| 21 | +from torch.distributed.checkpoint import FileSystemReader, FileSystemWriter |
| 22 | +from torch.distributed.checkpoint.default_planner import ( |
| 23 | + DefaultLoadPlanner, |
| 24 | + DefaultSavePlanner, |
| 25 | +) |
| 26 | +from torch.distributed.checkpoint.metadata import Metadata, STATE_DICT_TYPE |
24 | 27 | from torch.utils.data import DataLoader
|
25 | 28 | from torchsnapshot.test_utils import assert_state_dict_eq, check_state_dict_eq
|
26 | 29 | from torchtnt.framework._test_utils import (
|
@@ -348,6 +351,42 @@ def test_save_planner_storage_components(self, mock_dist_cp: MagicMock) -> None:
|
348 | 351 | self.assertIsInstance(planner, DummySavePlanner)
|
349 | 352 | self.assertIsInstance(storage_writer, DummyStorageWriter)
|
350 | 353 |
|
| 354 | + @patch("torchtnt.framework.callbacks.dcp_saver.dcp") |
| 355 | + def test_restore_default_planner_storage_components( |
| 356 | + self, mock_dist_cp: MagicMock |
| 357 | + ) -> None: |
| 358 | + from torch.distributed.checkpoint._fsspec_filesystem import FsspecReader |
| 359 | + |
| 360 | + my_unit = DummyTrainUnit(input_dim=2) |
| 361 | + restore_options = RestoreOptions(restore_optimizers=False) |
| 362 | + DistributedCheckpointSaver.restore( |
| 363 | + path="path/to/snapshot", |
| 364 | + unit=my_unit, |
| 365 | + restore_options=restore_options, |
| 366 | + ) |
| 367 | + planner = mock_dist_cp.load.call_args[1]["planner"] |
| 368 | + storage_reader = mock_dist_cp.load.call_args[1]["storage_reader"] |
| 369 | + |
| 370 | + self.assertIsInstance(planner, DefaultLoadPlanner) |
| 371 | + self.assertIsInstance(storage_reader, FsspecReader) |
| 372 | + |
| 373 | + @patch("torchtnt.framework.callbacks.dcp_saver.dcp") |
| 374 | + def test_restore_planner_storage_components(self, mock_dist_cp: MagicMock) -> None: |
| 375 | + my_unit = DummyTrainUnit(input_dim=2) |
| 376 | + restore_options = RestoreOptions(restore_optimizers=False) |
| 377 | + DistributedCheckpointSaver.restore( |
| 378 | + path="path/to/snapshot", |
| 379 | + unit=my_unit, |
| 380 | + restore_options=restore_options, |
| 381 | + planner=DummyLoadPlanner(), |
| 382 | + storage_reader=DummyStorageReader(path="path/to/snapshot"), |
| 383 | + ) |
| 384 | + planner = mock_dist_cp.load.call_args[1]["planner"] |
| 385 | + storage_reader = mock_dist_cp.load.call_args[1]["storage_reader"] |
| 386 | + |
| 387 | + self.assertIsInstance(planner, DummyLoadPlanner) |
| 388 | + self.assertIsInstance(storage_reader, DummyStorageReader) |
| 389 | + |
351 | 390 |
|
352 | 391 | class DummyStatefulDataLoader:
|
353 | 392 | def __init__(self, dataloader: DataLoader) -> None:
|
@@ -375,9 +414,30 @@ def set_up_planner(self, state_dict: STATE_DICT_TYPE, is_coordinator: bool) -> N
|
375 | 414 | super().set_up_planner(state_dict, is_coordinator)
|
376 | 415 |
|
377 | 416 |
|
| 417 | +class DummyLoadPlanner(DefaultLoadPlanner): |
| 418 | + def __init__(self) -> None: |
| 419 | + super().__init__() |
| 420 | + |
| 421 | + def set_up_planner( |
| 422 | + self, |
| 423 | + state_dict: STATE_DICT_TYPE, |
| 424 | + metadata: Metadata, |
| 425 | + is_coordinator: bool, |
| 426 | + ) -> None: |
| 427 | + super().set_up_planner(state_dict, metadata, is_coordinator) |
| 428 | + |
| 429 | + |
378 | 430 | class DummyStorageWriter(FileSystemWriter):
|
379 | 431 | def __init__(self, path: str) -> None:
|
380 | 432 | super().__init__(path)
|
381 | 433 |
|
382 | 434 | def set_up_storage_writer(self, is_coordinator: bool) -> None:
|
383 | 435 | pass
|
| 436 | + |
| 437 | + |
| 438 | +class DummyStorageReader(FileSystemReader): |
| 439 | + def __init__(self, path: str) -> None: |
| 440 | + super().__init__(path) |
| 441 | + |
| 442 | + def set_up_storage_writer(self, is_coordinator: bool) -> None: |
| 443 | + pass |
0 commit comments