Skip to content

Commit 25149cd

Browse files
kwen2501pytorchmergebot
authored andcommitted
[c10d] Add more tests to prevent extra context (pytorch#154174)
Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom): Loop a bunch of sync ops and see if any of them creates extra context. Requires nvml to check number of processes resident on a device. Pull Request resolved: pytorch#154174 Approved by: https://github.com/atalman
1 parent ba5d45d commit 25149cd

File tree

1 file changed

+50
-0
lines changed

1 file changed

+50
-0
lines changed

test/distributed/test_c10d_nccl.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -684,6 +684,56 @@ def test_extra_cuda_context(self):
684684
except ModuleNotFoundError:
685685
self._helper_test_extra_cuda_context_by_memory()
686686

687+
@requires_nccl()
688+
@skip_if_lt_x_gpu(2)
689+
def test_extra_cuda_context_sync_ops(self):
690+
# Loop a bunch of sync ops and see if any of them creates extra context.
691+
# Requires nvml to check number of processes resident on a device.
692+
try:
693+
import pynvml
694+
695+
pynvml.nvmlInit()
696+
except Exception:
697+
self.skipTest("pynvml not available")
698+
699+
# Check if non-0 ranks would create extra CUDA context on device 0
700+
store = c10d.FileStore(self.file_name, self.world_size)
701+
device = torch.device(f"cuda:{self.rank:d}")
702+
c10d.init_process_group(
703+
backend="nccl",
704+
store=store,
705+
rank=self.rank,
706+
world_size=self.world_size,
707+
device_id=device,
708+
)
709+
710+
x = torch.empty((1,), device=device)
711+
y = torch.empty((self.world_size,), device=device)
712+
713+
c10d.all_reduce(x)
714+
c10d.reduce(x, dst=0)
715+
c10d.broadcast(x, src=0)
716+
c10d.all_gather_into_tensor(y, x)
717+
c10d.reduce_scatter_tensor(x, y)
718+
c10d.barrier()
719+
720+
# Wait a bit for remote processes to touch my device
721+
if self.rank == 0:
722+
time.sleep(5)
723+
724+
handle = pynvml.nvmlDeviceGetHandleByIndex(self.rank)
725+
processes = pynvml.nvmlDeviceGetComputeRunningProcesses(handle)
726+
nprocs = len(processes)
727+
728+
# Don't exit till rank 0 is done with the nvml detection
729+
c10d.barrier()
730+
c10d.destroy_process_group()
731+
self.assertLessEqual(
732+
nprocs,
733+
1,
734+
f"Found {nprocs} processes creating contexts on {device}, expecting 1 at most",
735+
)
736+
687737
@requires_nccl()
688738
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
689739
def test_destruct_before_terminate_pg(self):

0 commit comments

Comments
 (0)