Skip to content

Commit 35f9d92

Browse files
diego-urgellfacebook-github-bot
authored andcommitted
Fix test_checkpoint_utils multiprocess test (#791)
Summary: Pull Request resolved: #791 Reviewed By: JKSenthil Differential Revision: D56260189 fbshipit-source-id: 32d5035a7c654308695e8e6f5126e4e0da75f610
1 parent 5534617 commit 35f9d92

File tree

1 file changed

+7
-6
lines changed

1 file changed

+7
-6
lines changed

tests/framework/callbacks/test_checkpoint_utils.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
import torch
1515
import torch.distributed as dist
1616
from torch import nn
17-
from torch.distributed import launcher
1817
from torchsnapshot import Snapshot
1918
from torchsnapshot.snapshot import SNAPSHOT_METADATA_FNAME
2019
from torchtnt.framework._test_utils import DummyTrainUnit, get_dummy_train_state
@@ -34,7 +33,7 @@
3433
from torchtnt.utils.distributed import get_global_rank, PGWrapper, spawn_multi_process
3534
from torchtnt.utils.env import init_from_env
3635
from torchtnt.utils.fsspec import get_filesystem
37-
from torchtnt.utils.test_utils import get_pet_launch_config, skip_if_not_distributed
36+
from torchtnt.utils.test_utils import skip_if_not_distributed
3837

3938
METADATA_FNAME: str = ".metadata"
4039

@@ -88,10 +87,11 @@ def test_latest_checkpoint_path(self) -> None:
8887

8988
@skip_if_not_distributed
9089
def test_latest_checkpoint_path_distributed(self) -> None:
91-
config = get_pet_launch_config(2)
92-
launcher.elastic_launch(
93-
config, entrypoint=self._latest_checkpoint_path_distributed
94-
)()
90+
spawn_multi_process(
91+
2,
92+
"gloo",
93+
self._latest_checkpoint_path_distributed,
94+
)
9595

9696
@staticmethod
9797
def _latest_checkpoint_path_distributed() -> None:
@@ -130,6 +130,7 @@ def _latest_checkpoint_path_distributed() -> None:
130130
path_container = [path_2] if is_rank0 else [None]
131131
pg.broadcast_object_list(path_container, 0)
132132
expected_path = path_container[0]
133+
tc.assertIsNotNone(expected_path)
133134
tc.assertEqual(
134135
get_latest_checkpoint_path(temp_dir, METADATA_FNAME), expected_path
135136
)

0 commit comments

Comments
 (0)