Skip to content

Commit 8e93a51

Browse files
diego-urgellfacebook-github-bot
authored andcommitted
Make base_checkpointer NCCL test run in CPU (#761)
Summary: Pull Request resolved: #761 Reviewed By: galrotem Differential Revision: D55346092 fbshipit-source-id: da53c356ec50db2c90d7117b42e111061c4e6b7d
1 parent 7bbbb8c commit 8e93a51

File tree

2 files changed

+27
-12
lines changed

2 files changed

+27
-12
lines changed

tests/framework/callbacks/test_base_checkpointer.py

Lines changed: 26 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,9 @@
2828
get_dummy_fit_state,
2929
get_dummy_train_state,
3030
)
31-
from torchtnt.framework.callbacks.base_checkpointer import BaseCheckpointer
31+
from torchtnt.framework.callbacks.base_checkpointer import (
32+
BaseCheckpointer as BaseCheckpointer,
33+
)
3234
from torchtnt.framework.callbacks.checkpointer_types import (
3335
BestCheckpointConfig,
3436
RestoreOptions,
@@ -41,7 +43,7 @@
4143
from torchtnt.framework.unit import AppStateMixin, TrainUnit, TTrainData
4244
from torchtnt.utils.distributed import get_global_rank, spawn_multi_process
4345
from torchtnt.utils.env import init_from_env
44-
from torchtnt.utils.test_utils import skip_if_not_distributed, skip_if_not_gpu
46+
from torchtnt.utils.test_utils import skip_if_not_distributed
4547

4648

4749
class BaseCheckpointSaver(BaseCheckpointer):
@@ -411,24 +413,20 @@ def test_invalid_args(self) -> None:
411413
BaseCheckpointSaver(temp_dir, save_every_n_epochs=0)
412414

413415
@skip_if_not_distributed
414-
@skip_if_not_gpu
415416
def test_process_group_plumbing(self) -> None:
416-
"""
417-
Creates a new process group and verifies GLOO group is created accordingly
418-
"""
419417
spawn_multi_process(
420418
2,
421-
"nccl",
422-
self._test_process_group_plumbing,
419+
"gloo",
420+
self._test_process_group_plumbing_gloo,
423421
)
424422
spawn_multi_process(
425423
2,
426-
"gloo",
427-
self._test_process_group_plumbing,
424+
"gloo", # inner test mocks nccl backend
425+
self._test_process_group_plumbing_nccl,
428426
)
429427

430428
@staticmethod
431-
def _test_process_group_plumbing() -> None:
429+
def _test_process_group_plumbing_gloo() -> None:
432430
checkpoint_cb = BaseCheckpointSaver(
433431
"foo",
434432
process_group=None,
@@ -441,6 +439,23 @@ def _test_process_group_plumbing() -> None:
441439
# verify no new process group was created
442440
tc.assertEqual(checkpoint_cb._process_group, dist.group.WORLD)
443441

442+
@staticmethod
443+
@patch("torch.cuda.nccl.version", return_value=(1, 0, 0))
444+
def _test_process_group_plumbing_nccl(_: MagicMock) -> None:
445+
with patch("torch.distributed.get_backend", return_value=dist.Backend.NCCL):
446+
checkpoint_cb = BaseCheckpointSaver(
447+
"foo",
448+
process_group=None,
449+
)
450+
451+
tc = unittest.TestCase()
452+
tc.assertIsNotNone(checkpoint_cb._process_group)
453+
tc.assertEqual(
454+
checkpoint_cb._process_group._get_backend_name(), dist.Backend.GLOO
455+
)
456+
# check that a new process group was created
457+
tc.assertNotEqual(checkpoint_cb._process_group, dist.group.WORLD)
458+
444459
@patch(
445460
"torchtnt.framework.callbacks.base_checkpointer.get_checkpoint_dirpaths",
446461
return_value=["epoch_1_step_10", "epoch_2_step_20"],

torchtnt/framework/callbacks/base_checkpointer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -439,7 +439,7 @@ def restore_from_latest(
439439
**kwargs: Any,
440440
) -> bool:
441441
"""
442-
Given a parent directory where checkpoints are saved, restore the checkppoint state from the latest checkpoint in the directory.
442+
Given a parent directory where checkpoints are saved, restore the checkpoint state from the latest checkpoint in the directory.
443443
444444
There are additional flags offered should the user want to skip loading the train and eval progress.
445445
By default, the train and eval progress are restored, if applicable.

0 commit comments

Comments
 (0)