28
28
get_dummy_fit_state ,
29
29
get_dummy_train_state ,
30
30
)
31
- from torchtnt .framework .callbacks .base_checkpointer import BaseCheckpointer
31
+ from torchtnt .framework .callbacks .base_checkpointer import (
32
+ BaseCheckpointer as BaseCheckpointer ,
33
+ )
32
34
from torchtnt .framework .callbacks .checkpointer_types import (
33
35
BestCheckpointConfig ,
34
36
RestoreOptions ,
41
43
from torchtnt .framework .unit import AppStateMixin , TrainUnit , TTrainData
42
44
from torchtnt .utils .distributed import get_global_rank , spawn_multi_process
43
45
from torchtnt .utils .env import init_from_env
44
- from torchtnt .utils .test_utils import skip_if_not_distributed , skip_if_not_gpu
46
+ from torchtnt .utils .test_utils import skip_if_not_distributed
45
47
46
48
47
49
class BaseCheckpointSaver (BaseCheckpointer ):
@@ -411,24 +413,20 @@ def test_invalid_args(self) -> None:
411
413
BaseCheckpointSaver (temp_dir , save_every_n_epochs = 0 )
412
414
413
415
@skip_if_not_distributed
414
- @skip_if_not_gpu
415
416
def test_process_group_plumbing (self ) -> None :
416
- """
417
- Creates a new process group and verifies GLOO group is created accordingly
418
- """
419
417
spawn_multi_process (
420
418
2 ,
421
- "nccl " ,
422
- self ._test_process_group_plumbing ,
419
+ "gloo " ,
420
+ self ._test_process_group_plumbing_gloo ,
423
421
)
424
422
spawn_multi_process (
425
423
2 ,
426
- "gloo" ,
427
- self ._test_process_group_plumbing ,
424
+ "gloo" , # inner test mocks nccl backend
425
+ self ._test_process_group_plumbing_nccl ,
428
426
)
429
427
430
428
@staticmethod
431
- def _test_process_group_plumbing () -> None :
429
+ def _test_process_group_plumbing_gloo () -> None :
432
430
checkpoint_cb = BaseCheckpointSaver (
433
431
"foo" ,
434
432
process_group = None ,
@@ -441,6 +439,23 @@ def _test_process_group_plumbing() -> None:
441
439
# verify no new process group was created
442
440
tc .assertEqual (checkpoint_cb ._process_group , dist .group .WORLD )
443
441
442
+ @staticmethod
443
+ @patch ("torch.cuda.nccl.version" , return_value = (1 , 0 , 0 ))
444
+ def _test_process_group_plumbing_nccl (_ : MagicMock ) -> None :
445
+ with patch ("torch.distributed.get_backend" , return_value = dist .Backend .NCCL ):
446
+ checkpoint_cb = BaseCheckpointSaver (
447
+ "foo" ,
448
+ process_group = None ,
449
+ )
450
+
451
+ tc = unittest .TestCase ()
452
+ tc .assertIsNotNone (checkpoint_cb ._process_group )
453
+ tc .assertEqual (
454
+ checkpoint_cb ._process_group ._get_backend_name (), dist .Backend .GLOO
455
+ )
456
+ # check that a new process group was created
457
+ tc .assertNotEqual (checkpoint_cb ._process_group , dist .group .WORLD )
458
+
444
459
@patch (
445
460
"torchtnt.framework.callbacks.base_checkpointer.get_checkpoint_dirpaths" ,
446
461
return_value = ["epoch_1_step_10" , "epoch_2_step_20" ],
0 commit comments