Skip to content

Commit f48b6a0

Browse files
luccafongLucia (Lu) Fang
andauthored
[Misc]allow disable pynccl (#25421)
Signed-off-by: Lu Fang <[email protected]> Co-authored-by: Lucia (Lu) Fang <[email protected]>
1 parent 2a69ab4 commit f48b6a0

File tree

3 files changed

+12
-1
lines changed

3 files changed

+12
-1
lines changed

vllm/distributed/device_communicators/cuda_communicator.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,10 @@ def all_reduce(self, input_):
147147
assert out is not None
148148
return out
149149
pynccl_comm = self.pynccl_comm
150+
if pynccl_comm is None or pynccl_comm.disabled:
151+
out = input_.clone()
152+
torch.distributed.all_reduce(out, group=self.device_group)
153+
return out
150154
assert pynccl_comm is not None
151155
out = pynccl_comm.all_reduce(input_)
152156
if out is None:

vllm/distributed/device_communicators/pynccl.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import torch.distributed as dist
99
from torch.distributed import ProcessGroup, ReduceOp
1010

11+
import vllm.envs as envs
1112
from vllm.distributed.device_communicators.pynccl_wrapper import (
1213
NCCLLibrary, buffer_type, cudaStream_t, ncclComm_t, ncclDataTypeEnum,
1314
ncclRedOpTypeEnum, ncclUniqueId)
@@ -83,7 +84,7 @@ def __init__(
8384
self.group = group
8485

8586
# if world_size == 1, no need to create communicator
86-
if self.world_size == 1:
87+
if self.world_size == 1 or envs.VLLM_DISABLE_PYNCCL:
8788
self.available = False
8889
self.disabled = True
8990
return

vllm/envs.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@
9898
VLLM_SKIP_P2P_CHECK: bool = False
9999
VLLM_DISABLED_KERNELS: list[str] = []
100100
VLLM_DISABLE_NCCL_FOR_DP_SYNCHRONIZATION: bool = False
101+
VLLM_DISABLE_PYNCCL: bool = False
101102
VLLM_USE_V1: bool = True
102103
VLLM_ROCM_USE_AITER: bool = False
103104
VLLM_ROCM_USE_AITER_PAGED_ATTN: bool = False
@@ -897,6 +898,11 @@ def get_vllm_port() -> Optional[int]:
897898
(os.getenv("VLLM_DISABLE_NCCL_FOR_DP_SYNCHRONIZATION", "False").lower() in
898899
("true", "1")),
899900

901+
# Disable pynccl (using torch.distributed instead)
902+
"VLLM_DISABLE_PYNCCL":
903+
lambda:
904+
(os.getenv("VLLM_DISABLE_PYNCCL", "False").lower() in ("true", "1")),
905+
900906
# If set, use the V1 code path.
901907
"VLLM_USE_V1":
902908
lambda: bool(int(os.getenv("VLLM_USE_V1", "1"))),

0 commit comments

Comments
 (0)