Skip to content

Commit c4bd03c

Browse files
authored
[Core][Distributed] add same-node detection (#5369)
1 parent dcbf428 commit c4bd03c

File tree

4 files changed

+87
-1
lines changed

4 files changed

+87
-1
lines changed

.buildkite/test-pipeline.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ steps:
3737
working_dir: "/vllm-workspace/tests"
3838
num_gpus: 2
3939
commands:
40+
- VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py
4041
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py
4142
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py
4243
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_chunked_prefill_distributed.py
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
import os
2+
3+
import torch
4+
5+
from vllm.distributed.parallel_state import is_in_the_same_node
6+
7+
torch.distributed.init_process_group(backend="gloo")
8+
test_result = is_in_the_same_node(torch.distributed.group.WORLD)
9+
10+
expected = os.environ.get("VLLM_TEST_SAME_HOST", "1") == "1"
11+
assert test_result == expected, f"Expected {expected}, got {test_result}"

vllm/distributed/device_communicators/custom_all_reduce.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from vllm.distributed.device_communicators.custom_all_reduce_utils import (
1111
gpu_p2p_access_check)
1212
from vllm.distributed.parallel_state import (
13-
get_local_rank, get_tensor_model_parallel_cpu_group)
13+
get_local_rank, get_tensor_model_parallel_cpu_group, is_in_the_same_node)
1414
from vllm.logger import init_logger
1515

1616
try:
@@ -113,6 +113,13 @@ def __init__(self,
113113
assert dist.get_backend(group) != dist.Backend.NCCL, (
114114
"CustomAllreduce should be attached to a non-NCCL group.")
115115

116+
if not is_in_the_same_node(group):
117+
# No need to initialize custom allreduce for multi-node case.
118+
logger.warning(
119+
"Custom allreduce is disabled because this process group"
120+
" spans across nodes.")
121+
return
122+
116123
rank = dist.get_rank(group=self.group)
117124
world_size = dist.get_world_size(group=self.group)
118125
if world_size == 1:

vllm/distributed/parallel_state.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py
44
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
55
"""Tensor and pipeline parallel groups."""
6+
import contextlib
7+
from multiprocessing import resource_tracker, shared_memory
68
from typing import List, Optional
79

810
import torch
@@ -376,3 +378,68 @@ def destroy_model_parallel():
376378
_PP_DEVICE_GROUP = None
377379
global _PP_GLOBAL_RANKS
378380
_PP_GLOBAL_RANKS = None
381+
382+
383+
def is_in_the_same_node(pg: ProcessGroup):
384+
"""
385+
This is a collective operation that checks if all processes in the group
386+
are in the same node. It tests if all processes are attached to the same
387+
memory system (shared access to shared memory).
388+
"""
389+
assert torch.distributed.get_backend(
390+
pg) != torch.distributed.Backend.NCCL, (
391+
"is_in_the_same_node should be tested with a non-NCCL group.")
392+
# local rank inside the group
393+
rank = torch.distributed.get_rank(group=pg)
394+
world_size = torch.distributed.get_world_size(group=pg)
395+
396+
# local tensor in each process to store the result
397+
is_in_the_same_node = torch.tensor([0] * world_size, dtype=torch.int32)
398+
399+
# global ranks of the processes in the group
400+
ranks = torch.distributed.get_process_group_ranks(pg)
401+
402+
magic_message = b"magic_message"
403+
shm = None
404+
405+
try:
406+
with contextlib.suppress(OSError):
407+
if rank == 0:
408+
# create a shared memory segment
409+
shm = shared_memory.SharedMemory(create=True, size=128)
410+
shm.buf[:len(magic_message)] = magic_message
411+
torch.distributed.broadcast_object_list([shm.name],
412+
src=ranks[0],
413+
group=pg)
414+
is_in_the_same_node[0] = 1
415+
else:
416+
# try to open the shared memory segment
417+
recv = [None]
418+
torch.distributed.broadcast_object_list(recv,
419+
src=ranks[0],
420+
group=pg)
421+
name = recv[0]
422+
shm = shared_memory.SharedMemory(name=name)
423+
if shm.buf[:len(magic_message)] == magic_message:
424+
is_in_the_same_node[rank] = 1
425+
except Exception as e:
426+
logger.error("Error ignored in is_in_the_same_node: %s", e)
427+
finally:
428+
if shm:
429+
shm.close()
430+
431+
torch.distributed.barrier(group=pg)
432+
433+
# clean up the shared memory segment
434+
with contextlib.suppress(OSError):
435+
if rank == 0:
436+
if shm:
437+
shm.unlink()
438+
else:
439+
if shm:
440+
# fix to https://stackoverflow.com/q/62748654/9191338
441+
resource_tracker.unregister(
442+
shm._name, "shared_memory") # type: ignore[attr-defined]
443+
torch.distributed.all_reduce(is_in_the_same_node, group=pg)
444+
445+
return is_in_the_same_node.sum().item() == world_size

0 commit comments

Comments
 (0)