Skip to content

Commit 74878ac

Browse files
kwen2501pytorchmergebot
authored andcommitted
[PGNCCL] Make sure we do not use split for P2P comm creation (pytorch#139013)
Resolve comment pytorch#138527 (comment) There was a split-vs-P2P bug: When P2P comm creation invokes `getNCCLComm`, it may see a `split_from` options which is meant for the previous PG creation. Then the P2P comm creation may use `ncclCommSplit` and hang, because not all ranks join this call. The bug slips previously/today because there is no CI test with the following recipe: eager init + new group + P2P in that new group. Pull Request resolved: pytorch#139013 Approved by: https://github.com/shuqiangzhang
1 parent fb2c750 commit 74878ac

File tree

2 files changed

+28
-1
lines changed

2 files changed

+28
-1
lines changed

test/distributed/test_c10d_nccl.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -982,6 +982,28 @@ def test_non_blocking_p2p(self):
982982
self.assertEqual(send_tensor, recv_tensor)
983983
dist.destroy_process_group()
984984

985+
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
986+
@parametrize("eager_init", [True, False])
987+
def test_subgroup_p2p(self, eager_init: bool):
988+
store = c10d.FileStore(self.file_name, self.world_size)
989+
device = torch.device(f"cuda:{self.rank % torch.cuda.device_count()}")
990+
c10d.init_process_group(
991+
"nccl",
992+
world_size=self.world_size,
993+
rank=self.rank,
994+
store=store,
995+
device_id=device if eager_init else None,
996+
)
997+
send_tensor = torch.ones(10, 10, device=device)
998+
group = dist.new_group()
999+
if self.rank == 0:
1000+
dist.send(send_tensor, 1, group=group)
1001+
if self.rank == 1:
1002+
recv_tensor = torch.rand(10, 10, device=device)
1003+
dist.recv(recv_tensor, 0, group=group)
1004+
self.assertEqual(send_tensor, recv_tensor)
1005+
dist.destroy_process_group()
1006+
9851007
@requires_nccl()
9861008
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
9871009
def test_get_uid(self):

torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2401,7 +2401,12 @@ std::shared_ptr<NCCLComm> ProcessGroupNCCL::getNCCLComm(
24012401
#endif
24022402

24032403
#ifdef NCCL_HAS_COMM_SPLIT
2404-
if (options_->split_from) {
2404+
// Use split to create a new communicator only if:
2405+
// 1. The parent comm is known; AND
2406+
// 2. The new comm is not for a point-to-point operation.
2407+
// ncclCommSplit() is a collective call, so it does not work for P2P
2408+
// operations.
2409+
if (options_->split_from && !singleP2POp) {
24052410
// Find a valid, healthy communicator to split from if possible.
24062411
std::lock_guard<std::mutex> lock(options_->split_from->mutex_);
24072412
auto& other_comms = options_->split_from->devNCCLCommMap_;

0 commit comments

Comments
 (0)