|
14 | 14 | import torch |
15 | 15 | import torch.distributed as dist |
16 | 16 | from torch import nn |
17 | | -from torch.distributed import launcher |
18 | 17 | from torchsnapshot import Snapshot |
19 | 18 | from torchsnapshot.snapshot import SNAPSHOT_METADATA_FNAME |
20 | 19 | from torchtnt.framework._test_utils import DummyTrainUnit, get_dummy_train_state |
|
34 | 33 | from torchtnt.utils.distributed import get_global_rank, PGWrapper, spawn_multi_process |
35 | 34 | from torchtnt.utils.env import init_from_env |
36 | 35 | from torchtnt.utils.fsspec import get_filesystem |
37 | | -from torchtnt.utils.test_utils import get_pet_launch_config, skip_if_not_distributed |
| 36 | +from torchtnt.utils.test_utils import skip_if_not_distributed |
38 | 37 |
|
39 | 38 | METADATA_FNAME: str = ".metadata" |
40 | 39 |
|
@@ -88,10 +87,11 @@ def test_latest_checkpoint_path(self) -> None: |
88 | 87 |
|
89 | 88 | @skip_if_not_distributed |
90 | 89 | def test_latest_checkpoint_path_distributed(self) -> None: |
91 | | - config = get_pet_launch_config(2) |
92 | | - launcher.elastic_launch( |
93 | | - config, entrypoint=self._latest_checkpoint_path_distributed |
94 | | - )() |
| 90 | + spawn_multi_process( |
| 91 | + 2, |
| 92 | + "gloo", |
| 93 | + self._latest_checkpoint_path_distributed, |
| 94 | + ) |
95 | 95 |
|
96 | 96 | @staticmethod |
97 | 97 | def _latest_checkpoint_path_distributed() -> None: |
@@ -130,6 +130,7 @@ def _latest_checkpoint_path_distributed() -> None: |
130 | 130 | path_container = [path_2] if is_rank0 else [None] |
131 | 131 | pg.broadcast_object_list(path_container, 0) |
132 | 132 | expected_path = path_container[0] |
| 133 | + tc.assertIsNotNone(expected_path) |
133 | 134 | tc.assertEqual( |
134 | 135 | get_latest_checkpoint_path(temp_dir, METADATA_FNAME), expected_path |
135 | 136 | ) |
|
0 commit comments