|
3 | 3 | # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py
|
4 | 4 | # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
5 | 5 | """Tensor and pipeline parallel groups."""
|
| 6 | +import contextlib |
| 7 | +from multiprocessing import resource_tracker, shared_memory |
6 | 8 | from typing import List, Optional
|
7 | 9 |
|
8 | 10 | import torch
|
@@ -376,3 +378,68 @@ def destroy_model_parallel():
|
376 | 378 | _PP_DEVICE_GROUP = None
|
377 | 379 | global _PP_GLOBAL_RANKS
|
378 | 380 | _PP_GLOBAL_RANKS = None
|
| 381 | + |
| 382 | + |
| 383 | +def is_in_the_same_node(pg: ProcessGroup): |
| 384 | + """ |
| 385 | + This is a collective operation that checks if all processes in the group |
| 386 | + are in the same node. It tests if all processes are attached to the same |
| 387 | + memory system (shared access to shared memory). |
| 388 | + """ |
| 389 | + assert torch.distributed.get_backend( |
| 390 | + pg) != torch.distributed.Backend.NCCL, ( |
| 391 | + "is_in_the_same_node should be tested with a non-NCCL group.") |
| 392 | + # local rank inside the group |
| 393 | + rank = torch.distributed.get_rank(group=pg) |
| 394 | + world_size = torch.distributed.get_world_size(group=pg) |
| 395 | + |
| 396 | + # local tensor in each process to store the result |
| 397 | + is_in_the_same_node = torch.tensor([0] * world_size, dtype=torch.int32) |
| 398 | + |
| 399 | + # global ranks of the processes in the group |
| 400 | + ranks = torch.distributed.get_process_group_ranks(pg) |
| 401 | + |
| 402 | + magic_message = b"magic_message" |
| 403 | + shm = None |
| 404 | + |
| 405 | + try: |
| 406 | + with contextlib.suppress(OSError): |
| 407 | + if rank == 0: |
| 408 | + # create a shared memory segment |
| 409 | + shm = shared_memory.SharedMemory(create=True, size=128) |
| 410 | + shm.buf[:len(magic_message)] = magic_message |
| 411 | + torch.distributed.broadcast_object_list([shm.name], |
| 412 | + src=ranks[0], |
| 413 | + group=pg) |
| 414 | + is_in_the_same_node[0] = 1 |
| 415 | + else: |
| 416 | + # try to open the shared memory segment |
| 417 | + recv = [None] |
| 418 | + torch.distributed.broadcast_object_list(recv, |
| 419 | + src=ranks[0], |
| 420 | + group=pg) |
| 421 | + name = recv[0] |
| 422 | + shm = shared_memory.SharedMemory(name=name) |
| 423 | + if shm.buf[:len(magic_message)] == magic_message: |
| 424 | + is_in_the_same_node[rank] = 1 |
| 425 | + except Exception as e: |
| 426 | + logger.error("Error ignored in is_in_the_same_node: %s", e) |
| 427 | + finally: |
| 428 | + if shm: |
| 429 | + shm.close() |
| 430 | + |
| 431 | + torch.distributed.barrier(group=pg) |
| 432 | + |
| 433 | + # clean up the shared memory segment |
| 434 | + with contextlib.suppress(OSError): |
| 435 | + if rank == 0: |
| 436 | + if shm: |
| 437 | + shm.unlink() |
| 438 | + else: |
| 439 | + if shm: |
| 440 | + # fix to https://stackoverflow.com/q/62748654/9191338 |
| 441 | + resource_tracker.unregister( |
| 442 | + shm._name, "shared_memory") # type: ignore[attr-defined] |
| 443 | + torch.distributed.all_reduce(is_in_the_same_node, group=pg) |
| 444 | + |
| 445 | + return is_in_the_same_node.sum().item() == world_size |
0 commit comments