7
7
8
8
# pyre-strict
9
9
10
+ import unittest
11
+
12
+ from torchtnt .framework .callbacks .dcp_saver import _LATEST_DCP_AVAIL
13
+
14
+ if not _LATEST_DCP_AVAIL :
15
+ raise unittest .SkipTest ("Latest Pytorch is required to run DCP tests" )
16
+
10
17
import math
11
18
import os
12
19
import shutil
13
20
import tempfile
14
- import unittest
15
- from typing import Any , Dict , Iterator , List
21
+ from typing import Any , Dict , Iterator , List , Optional
16
22
from unittest import mock
17
23
from unittest .mock import MagicMock , patch
18
24
23
29
DefaultLoadPlanner ,
24
30
DefaultSavePlanner ,
25
31
)
26
- from torch .distributed .checkpoint .metadata import Metadata , STATE_DICT_TYPE
32
+ from torch .distributed .checkpoint .metadata import Metadata , STATE_DICT_TYPE , StorageMeta
27
33
from torch .utils .data import DataLoader
28
34
from torchsnapshot .test_utils import assert_state_dict_eq , check_state_dict_eq
29
35
from torchtnt .framework ._test_utils import (
32
38
generate_random_dataloader ,
33
39
)
34
40
from torchtnt .framework .callbacks .checkpointer_types import KnobOptions , RestoreOptions
35
- from torchtnt .framework .callbacks .dcp_saver import (
36
- _LATEST_DCP_AVAIL ,
37
- DistributedCheckpointSaver ,
38
- )
41
+ from torchtnt .framework .callbacks .dcp_saver import DistributedCheckpointSaver
39
42
from torchtnt .framework .train import train
40
43
from torchtnt .utils .distributed import get_global_rank , spawn_multi_process
41
44
from torchtnt .utils .env import seed
42
45
from torchtnt .utils .test_utils import skip_if_not_distributed
43
46
44
- if not _LATEST_DCP_AVAIL :
45
- raise unittest .SkipTest ("Latest Pytorch is required to run DCP tests" )
46
-
47
47
48
48
class DistributedCheckpointSaverTest (unittest .TestCase ):
49
49
def test_save_restore (self ) -> None :
@@ -410,8 +410,13 @@ class DummySavePlanner(DefaultSavePlanner):
410
410
def __init__ (self ) -> None :
411
411
super ().__init__ ()
412
412
413
- def set_up_planner (self , state_dict : STATE_DICT_TYPE , is_coordinator : bool ) -> None :
414
- super ().set_up_planner (state_dict , is_coordinator )
413
+ def set_up_planner (
414
+ self ,
415
+ state_dict : STATE_DICT_TYPE ,
416
+ storage_meta : Optional [StorageMeta ],
417
+ is_coordinator : bool ,
418
+ ) -> None :
419
+ super ().set_up_planner (state_dict , storage_meta , is_coordinator )
415
420
416
421
417
422
class DummyLoadPlanner (DefaultLoadPlanner ):
@@ -421,7 +426,7 @@ def __init__(self) -> None:
421
426
def set_up_planner (
422
427
self ,
423
428
state_dict : STATE_DICT_TYPE ,
424
- metadata : Metadata ,
429
+ metadata : Optional [ Metadata ] ,
425
430
is_coordinator : bool ,
426
431
) -> None :
427
432
super ().set_up_planner (state_dict , metadata , is_coordinator )
0 commit comments