Skip to content

Commit 8456d0c

Browse files
JKSenthilfacebook-github-bot
authored andcommitted
fix dcp pyre errors (#832)
Summary: Pull Request resolved: #832 Fixes pyre error in dcp tests Also moves latest pytorch check to the top, otherwise import error is raised on StorageMeta in stable unit tests Reviewed By: diego-urgell Differential Revision: D57342263 fbshipit-source-id: 79d66a5637313eb399d995158baa0ae5c1821d27
1 parent 5e1db7f commit 8456d0c

File tree

1 file changed

+18
-13
lines changed

1 file changed

+18
-13
lines changed

tests/framework/callbacks/test_dcp_saver.py

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,18 @@
77

88
# pyre-strict
99

10+
import unittest
11+
12+
from torchtnt.framework.callbacks.dcp_saver import _LATEST_DCP_AVAIL
13+
14+
if not _LATEST_DCP_AVAIL:
15+
raise unittest.SkipTest("Latest Pytorch is required to run DCP tests")
16+
1017
import math
1118
import os
1219
import shutil
1320
import tempfile
14-
import unittest
15-
from typing import Any, Dict, Iterator, List
21+
from typing import Any, Dict, Iterator, List, Optional
1622
from unittest import mock
1723
from unittest.mock import MagicMock, patch
1824

@@ -23,7 +29,7 @@
2329
DefaultLoadPlanner,
2430
DefaultSavePlanner,
2531
)
26-
from torch.distributed.checkpoint.metadata import Metadata, STATE_DICT_TYPE
32+
from torch.distributed.checkpoint.metadata import Metadata, STATE_DICT_TYPE, StorageMeta
2733
from torch.utils.data import DataLoader
2834
from torchsnapshot.test_utils import assert_state_dict_eq, check_state_dict_eq
2935
from torchtnt.framework._test_utils import (
@@ -32,18 +38,12 @@
3238
generate_random_dataloader,
3339
)
3440
from torchtnt.framework.callbacks.checkpointer_types import KnobOptions, RestoreOptions
35-
from torchtnt.framework.callbacks.dcp_saver import (
36-
_LATEST_DCP_AVAIL,
37-
DistributedCheckpointSaver,
38-
)
41+
from torchtnt.framework.callbacks.dcp_saver import DistributedCheckpointSaver
3942
from torchtnt.framework.train import train
4043
from torchtnt.utils.distributed import get_global_rank, spawn_multi_process
4144
from torchtnt.utils.env import seed
4245
from torchtnt.utils.test_utils import skip_if_not_distributed
4346

44-
if not _LATEST_DCP_AVAIL:
45-
raise unittest.SkipTest("Latest Pytorch is required to run DCP tests")
46-
4747

4848
class DistributedCheckpointSaverTest(unittest.TestCase):
4949
def test_save_restore(self) -> None:
@@ -410,8 +410,13 @@ class DummySavePlanner(DefaultSavePlanner):
410410
def __init__(self) -> None:
411411
super().__init__()
412412

413-
def set_up_planner(self, state_dict: STATE_DICT_TYPE, is_coordinator: bool) -> None:
414-
super().set_up_planner(state_dict, is_coordinator)
413+
def set_up_planner(
414+
self,
415+
state_dict: STATE_DICT_TYPE,
416+
storage_meta: Optional[StorageMeta],
417+
is_coordinator: bool,
418+
) -> None:
419+
super().set_up_planner(state_dict, storage_meta, is_coordinator)
415420

416421

417422
class DummyLoadPlanner(DefaultLoadPlanner):
@@ -421,7 +426,7 @@ def __init__(self) -> None:
421426
def set_up_planner(
422427
self,
423428
state_dict: STATE_DICT_TYPE,
424-
metadata: Metadata,
429+
metadata: Optional[Metadata],
425430
is_coordinator: bool,
426431
) -> None:
427432
super().set_up_planner(state_dict, metadata, is_coordinator)

0 commit comments

Comments
 (0)