@@ -684,6 +684,56 @@ def test_extra_cuda_context(self):
684684 except ModuleNotFoundError :
685685 self ._helper_test_extra_cuda_context_by_memory ()
686686
687+ @requires_nccl ()
688+ @skip_if_lt_x_gpu (2 )
689+ def test_extra_cuda_context_sync_ops (self ):
690+ # Loop a bunch of sync ops and see if any of them creates extra context.
691+ # Requires nvml to check number of processes resident on a device.
692+ try :
693+ import pynvml
694+
695+ pynvml .nvmlInit ()
696+ except Exception :
697+ self .skipTest ("pynvml not available" )
698+
699+ # Check if non-0 ranks would create extra CUDA context on device 0
700+ store = c10d .FileStore (self .file_name , self .world_size )
701+ device = torch .device (f"cuda:{ self .rank :d} " )
702+ c10d .init_process_group (
703+ backend = "nccl" ,
704+ store = store ,
705+ rank = self .rank ,
706+ world_size = self .world_size ,
707+ device_id = device ,
708+ )
709+
710+ x = torch .empty ((1 ,), device = device )
711+ y = torch .empty ((self .world_size ,), device = device )
712+
713+ c10d .all_reduce (x )
714+ c10d .reduce (x , dst = 0 )
715+ c10d .broadcast (x , src = 0 )
716+ c10d .all_gather_into_tensor (y , x )
717+ c10d .reduce_scatter_tensor (x , y )
718+ c10d .barrier ()
719+
720+ # Wait a bit for remote processes to touch my device
721+ if self .rank == 0 :
722+ time .sleep (5 )
723+
724+ handle = pynvml .nvmlDeviceGetHandleByIndex (self .rank )
725+ processes = pynvml .nvmlDeviceGetComputeRunningProcesses (handle )
726+ nprocs = len (processes )
727+
728+ # Don't exit till rank 0 is done with the nvml detection
729+ c10d .barrier ()
730+ c10d .destroy_process_group ()
731+ self .assertLessEqual (
732+ nprocs ,
733+ 1 ,
734+ f"Found { nprocs } processes creating contexts on { device } , expecting 1 at most" ,
735+ )
736+
687737 @requires_nccl ()
688738 @skip_but_pass_in_sandcastle_if (not TEST_MULTIGPU , "NCCL test requires 2+ GPUs" )
689739 def test_destruct_before_terminate_pg (self ):
0 commit comments