Skip to content

Commit 91f50a6

Browse files
authored
[Core][Distributed] use cpu/gloo to initialize pynccl (#4248)
1 parent 79a268c commit 91f50a6

File tree

5 files changed

+93
-71
lines changed

5 files changed

+93
-71
lines changed

tests/distributed/test_pynccl.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
from vllm.distributed.device_communicators.pynccl import (NCCLCommunicator,
77
ncclGetUniqueId)
8+
from vllm.distributed.parallel_state import init_distributed_environment
89
from vllm.utils import update_environment_variables
910

1011

@@ -26,19 +27,23 @@ def distributed_run(fn, world_size):
2627
for p in processes:
2728
p.join()
2829

30+
for p in processes:
31+
assert p.exitcode == 0
32+
2933

30-
def update_env(fn):
34+
def worker_fn_wrapper(fn):
3135
# `multiprocessing.Process` cannot accept environment variables directly
3236
# so we need to pass the environment variables as arguments
3337
# and update the environment variables in the function
34-
def wrapper(env):
38+
def wrapped_fn(env):
3539
update_environment_variables(env)
40+
init_distributed_environment()
3641
fn()
3742

38-
return wrapper
43+
return wrapped_fn
3944

4045

41-
@update_env
46+
@worker_fn_wrapper
4247
def worker_fn():
4348
comm = NCCLCommunicator()
4449
tensor = torch.ones(16, 1024, 1024, dtype=torch.float32).cuda(comm.rank)
@@ -53,7 +58,7 @@ def test_pynccl():
5358
distributed_run(worker_fn, 2)
5459

5560

56-
@update_env
61+
@worker_fn_wrapper
5762
def worker_fn_with_cudagraph():
5863
with torch.no_grad():
5964
graph = torch.cuda.CUDAGraph()

vllm/distributed/device_communicators/pynccl.py

Lines changed: 71 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,15 @@
2020
# variable in the code.
2121

2222
import ctypes
23-
import datetime
2423
import platform
24+
from typing import Optional, Union
2525

2626
# ===================== import region =====================
2727
import torch
2828
import torch.distributed as dist
29-
from torch.distributed import ReduceOp
29+
from torch.distributed import ProcessGroup, ReduceOp
3030

31+
from vllm.distributed.parallel_state import get_cpu_world_group, get_local_rank
3132
from vllm.logger import init_logger
3233
from vllm.utils import find_nccl_library, nccl_integrity_check
3334

@@ -59,6 +60,18 @@
5960

6061
ncclResult_t = ctypes.c_int
6162

63+
_c_ncclGetErrorString = nccl.ncclGetErrorString
64+
_c_ncclGetErrorString.restype = ctypes.c_char_p
65+
_c_ncclGetErrorString.argtypes = [ncclResult_t]
66+
67+
68+
def NCCL_CHECK(result: ncclResult_t) -> None:
69+
if result != 0:
70+
error_str = _c_ncclGetErrorString(result)
71+
error_str = error_str.decode("utf-8")
72+
raise RuntimeError(f"NCCL error: {error_str}")
73+
74+
6275
# equivalent to c declaration:
6376
# ncclResult_t ncclGetVersion(int *version);
6477
_c_ncclGetVersion = nccl.ncclGetVersion
@@ -68,8 +81,7 @@
6881

6982
def ncclGetVersion() -> str:
7083
version = ctypes.c_int()
71-
result = _c_ncclGetVersion(ctypes.byref(version))
72-
assert result == 0
84+
NCCL_CHECK(_c_ncclGetVersion(ctypes.byref(version)))
7385
# something like 21903 --> "2.19.3"
7486
version_str = str(version.value)
7587
major = version_str[0].lstrip("0")
@@ -91,8 +103,7 @@ class NcclUniqueId(ctypes.Structure):
91103

92104
def ncclGetUniqueId() -> NcclUniqueId:
93105
unique_id = NcclUniqueId()
94-
result = _c_ncclGetUniqueId(ctypes.byref(unique_id))
95-
assert result == 0
106+
NCCL_CHECK(_c_ncclGetUniqueId(ctypes.byref(unique_id)))
96107
return unique_id
97108

98109

@@ -199,66 +210,75 @@ class NCCLCommunicator:
199210

200211
def __init__(
201212
self,
202-
backend=None,
203-
init_method=None,
204-
timeout=datetime.timedelta(seconds=10),
205-
world_size: int = -1,
206-
rank: int = -1,
207-
store=None,
208-
group_name: str = "",
209-
pg_options=None,
210-
local_rank: int = -1,
213+
group: Optional[ProcessGroup] = None,
214+
device: Optional[Union[int, str, torch.device]] = None,
211215
):
212-
if not dist.is_initialized():
213-
backend = backend or "nccl"
214-
assert backend == 'nccl', (
215-
"only use nccl backend for starting the NCCL communicator")
216-
dist.init_process_group(backend=backend,
217-
init_method=init_method,
218-
timeout=timeout,
219-
world_size=world_size,
220-
rank=rank,
221-
store=store,
222-
group_name=group_name,
223-
pg_options=pg_options)
224-
self.rank = dist.get_rank()
225-
self.world_size = dist.get_world_size()
226-
if local_rank == -1:
227-
local_rank = self.rank
228-
self.local_rank = local_rank
229-
# don't use these args, as they can be -1
230-
# use `self.rank`, `self.local_rank` and `self.world_size` instead
231-
del world_size, rank, local_rank
232-
torch.cuda.set_device(self.local_rank)
216+
"""
217+
Args:
218+
group: the process group to work on. If None, it will use the
219+
default process group.
220+
device: the device to bind the NCCLCommunicator to. If None,
221+
it will be bind to f"cuda:{local_rank}".
222+
It is the caller's responsibility to make sure each communicator
223+
is bind to a unique device.
224+
"""
225+
assert dist.is_initialized()
226+
group = get_cpu_world_group() if group is None else group
227+
assert dist.get_backend(group) != dist.Backend.NCCL, (
228+
"NCCLCommunicator should be attached to a non-NCCL group.")
229+
self.group = group
230+
self.rank = dist.get_rank(group)
231+
self.world_size = dist.get_world_size(group)
233232
if self.rank == 0:
234233
self.unique_id = ncclGetUniqueId()
235234
else:
236235
self.unique_id = NcclUniqueId()
237-
tensor = torch.ByteTensor(list(self.unique_id.internal)).cuda(
238-
self.local_rank)
239-
dist.broadcast(tensor, src=0)
240-
byte_list = tensor.cpu().tolist()
236+
tensor = torch.ByteTensor(list(self.unique_id.internal))
237+
dist.broadcast(tensor, src=0, group=group)
238+
byte_list = tensor.tolist()
241239
for i, byte in enumerate(byte_list):
242240
self.unique_id.internal[i] = byte
243241
self.comm = ctypes.c_void_p()
244-
result = _c_ncclCommInitRank(ctypes.byref(self.comm), self.world_size,
245-
self.unique_id, self.rank)
246-
assert result == 0
247-
self.stream = torch.cuda.Stream(device=f"cuda:{self.local_rank}")
242+
if device is None:
243+
local_rank = get_local_rank()
244+
device = torch.device(f"cuda:{local_rank}")
245+
elif isinstance(device, int):
246+
device = torch.device(f"cuda:{device}")
247+
elif isinstance(device, str):
248+
device = torch.device(device)
249+
# now `device` is a `torch.device` object
250+
assert isinstance(device, torch.device)
251+
self.device = device
252+
# nccl communicator and stream will use this device
253+
current_device = torch.cuda.current_device()
254+
try:
255+
torch.cuda.set_device(device)
256+
NCCL_CHECK(
257+
_c_ncclCommInitRank(ctypes.byref(self.comm), self.world_size,
258+
self.unique_id, self.rank))
259+
self.stream = torch.cuda.Stream()
260+
finally:
261+
torch.cuda.set_device(current_device)
248262

249263
def all_reduce(self,
250264
tensor: torch.Tensor,
251265
op: ReduceOp = ReduceOp.SUM,
252266
stream=None):
267+
# nccl communicator created on a specific device
268+
# will only work on tensors on the same device
269+
# otherwise it will cause "illegal memory access"
270+
assert tensor.device == self.device, (
271+
f"this nccl communicator is created to work on {self.device}, "
272+
f"but the input tensor is on {tensor.device}")
253273
if stream is None:
254274
stream = self.stream
255-
result = _c_ncclAllReduce(ctypes.c_void_p(tensor.data_ptr()),
256-
ctypes.c_void_p(tensor.data_ptr()),
257-
tensor.numel(),
258-
ncclDataTypeEnum.from_torch(tensor.dtype),
259-
ncclRedOpTypeEnum.from_torch(op), self.comm,
260-
ctypes.c_void_p(stream.cuda_stream))
261-
assert result == 0
275+
NCCL_CHECK(
276+
_c_ncclAllReduce(ctypes.c_void_p(tensor.data_ptr()),
277+
ctypes.c_void_p(tensor.data_ptr()),
278+
tensor.numel(),
279+
ncclDataTypeEnum.from_torch(tensor.dtype),
280+
ncclRedOpTypeEnum.from_torch(op), self.comm,
281+
ctypes.c_void_p(stream.cuda_stream)))
262282

263283
def __del__(self):
264284
# `dist` module might have been already destroyed

vllm/distributed/device_communicators/pynccl_utils.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from typing import Optional
33

44
import torch
5-
from torch.distributed import ReduceOp
5+
from torch.distributed import ProcessGroup, ReduceOp
66

77
from vllm.logger import init_logger
88

@@ -37,17 +37,11 @@ def set_pynccl_stream(stream: torch.cuda.Stream):
3737
pass
3838

3939

40-
def init_process_group(world_size: int,
41-
rank: int,
42-
init_method: str,
43-
local_rank: int = -1) -> None:
40+
def init_process_group(group: Optional[ProcessGroup] = None) -> None:
4441
assert not is_initialized()
4542
global comm
4643
logger.info(f"vLLM is using nccl=={ncclGetVersion()}")
47-
comm = NCCLCommunicator(init_method=init_method,
48-
world_size=world_size,
49-
local_rank=local_rank,
50-
rank=rank)
44+
comm = NCCLCommunicator(group=group)
5145

5246

5347
def all_reduce(input_: torch.Tensor, op=ReduceOp.SUM) -> None:

vllm/distributed/parallel_state.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
55
"""Tensor and pipeline parallel groups."""
66
import contextlib
7+
import os
78
from typing import Optional
89

910
import torch
@@ -73,6 +74,11 @@ def init_distributed_environment(
7374
ranks = list(range(torch.distributed.get_world_size()))
7475
_CPU_WORLD_GROUP = torch.distributed.new_group(ranks=ranks,
7576
backend="gloo")
77+
# set the local rank
78+
# local_rank is not available in torch ProcessGroup,
79+
# see https://github.com/pytorch/pytorch/issues/122816
80+
if local_rank == -1 and distributed_init_method == "env://":
81+
local_rank = int(os.environ['LOCAL_RANK'])
7682
global _LOCAL_RANK
7783
_LOCAL_RANK = local_rank
7884

vllm/worker/worker.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -298,12 +298,9 @@ def init_worker_distributed_environment(
298298
elif parallel_config.world_size > 1:
299299
# NOTE(woosuk): We don't initialize pynccl process group when world size
300300
# is 1.
301-
pynccl_utils.init_process_group(
302-
world_size=parallel_config.world_size,
303-
local_rank=local_rank,
304-
rank=rank,
305-
init_method=distributed_init_method,
306-
)
301+
# NOTE(kaichao): By default, pynccl will use information inside
302+
# `parallel_state` for initialization.
303+
pynccl_utils.init_process_group()
307304

308305
ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
309306
parallel_config.pipeline_parallel_size)

0 commit comments

Comments
 (0)