Skip to content

Commit 747b1a7

Browse files
authored
[Core][Distributed] fix _is_full_nvlink detection (#4233)
1 parent 95e5b08 commit 747b1a7

File tree

1 file changed

+30
-18
lines changed

1 file changed

+30
-18
lines changed

vllm/distributed/device_communicators/custom_all_reduce.py

Lines changed: 30 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1+
import os
12
from contextlib import contextmanager
2-
from typing import Optional
3+
from typing import List, Optional
34

45
import torch
56
import torch.distributed as dist
@@ -53,14 +54,20 @@ def init_custom_ar() -> None:
5354
return False
5455
# test nvlink first, this will filter out most of the cases
5556
# where custom allreduce is not supported
56-
full_nvlink = _is_full_nvlink(rank, world_size)
57+
if "CUDA_VISIBLE_DEVICES" in os.environ:
58+
device_ids = list(
59+
map(int, os.environ["CUDA_VISIBLE_DEVICES"].split(",")))
60+
else:
61+
device_ids = list(range(num_dev))
62+
# this checks hardware and driver support for NVLink
63+
full_nvlink = _is_full_nvlink(device_ids)
5764
if world_size > 2 and not full_nvlink:
5865
logger.warn(
5966
"Custom allreduce is disabled because it's not supported on more"
6067
" than two PCIe-only GPUs. To silence this warning, specify"
6168
" disable_custom_all_reduce=True explicitly.")
6269
return
63-
# test P2P capability
70+
# test P2P capability, this checks software/cudaruntime support
6471
# this is expensive to compute at the first time
6572
# then we cache the result
6673
if not _can_p2p(rank, world_size):
@@ -138,23 +145,28 @@ def _nvml():
138145
pynvml.nvmlShutdown()
139146

140147

141-
# query if the set of gpus are fully connected by nvlink (1 hop)
142148
@_nvml()
143-
def _is_full_nvlink(rank, world_size):
144-
handle = pynvml.nvmlDeviceGetHandleByIndex(rank)
145-
for i in range(world_size):
146-
if i != rank:
147-
try:
148-
peer_handle = pynvml.nvmlDeviceGetHandleByIndex(i)
149-
p2p_status = pynvml.nvmlDeviceGetP2PStatus(
150-
handle, peer_handle, pynvml.NVML_P2P_CAPS_INDEX_NVLINK)
151-
if p2p_status != pynvml.NVML_P2P_STATUS_OK:
149+
def _is_full_nvlink(device_ids: List[int]) -> bool:
150+
"""
151+
query if the set of gpus are fully connected by nvlink (1 hop)
152+
Note that `pynvml` is not affected by `CUDA_VISIBLE_DEVICES`,
153+
so it works on real physical device ids.
154+
"""
155+
handles = [pynvml.nvmlDeviceGetHandleByIndex(i) for i in device_ids]
156+
for i, handle in enumerate(handles):
157+
for j, peer_handle in enumerate(handles):
158+
if i < j:
159+
try:
160+
p2p_status = pynvml.nvmlDeviceGetP2PStatus(
161+
handle, peer_handle, pynvml.NVML_P2P_CAPS_INDEX_NVLINK)
162+
if p2p_status != pynvml.NVML_P2P_STATUS_OK:
163+
return False
164+
except pynvml.NVMLError as error:
165+
logger.error(
166+
"NVLink detection failed. This is normal if your"
167+
" machine has no NVLink equipped.",
168+
exc_info=error)
152169
return False
153-
except pynvml.NVMLError as error:
154-
logger.info(
155-
f"NVLink detection failed with message \"{str(error)}\". "
156-
"This is normal if your machine has no NVLink equipped")
157-
return False
158170
return True
159171

160172

0 commit comments

Comments
 (0)