2020# variable in the code.
2121
2222import ctypes
23- import datetime
2423import platform
24+ from typing import Optional , Union
2525
2626# ===================== import region =====================
2727import torch
2828import torch .distributed as dist
29- from torch .distributed import ReduceOp
29+ from torch .distributed import ProcessGroup , ReduceOp
3030
31+ from vllm .distributed .parallel_state import get_cpu_world_group , get_local_rank
3132from vllm .logger import init_logger
3233from vllm .utils import find_nccl_library , nccl_integrity_check
3334
5960
6061ncclResult_t = ctypes .c_int
6162
63+ _c_ncclGetErrorString = nccl .ncclGetErrorString
64+ _c_ncclGetErrorString .restype = ctypes .c_char_p
65+ _c_ncclGetErrorString .argtypes = [ncclResult_t ]
66+
67+
68+ def NCCL_CHECK (result : ncclResult_t ) -> None :
69+ if result != 0 :
70+ error_str = _c_ncclGetErrorString (result )
71+ error_str = error_str .decode ("utf-8" )
72+ raise RuntimeError (f"NCCL error: { error_str } " )
73+
74+
6275# equivalent to c declaration:
6376# ncclResult_t ncclGetVersion(int *version);
6477_c_ncclGetVersion = nccl .ncclGetVersion
6881
6982def ncclGetVersion () -> str :
7083 version = ctypes .c_int ()
71- result = _c_ncclGetVersion (ctypes .byref (version ))
72- assert result == 0
84+ NCCL_CHECK (_c_ncclGetVersion (ctypes .byref (version )))
7385 # something like 21903 --> "2.19.3"
7486 version_str = str (version .value )
7587 major = version_str [0 ].lstrip ("0" )
@@ -91,8 +103,7 @@ class NcclUniqueId(ctypes.Structure):
91103
92104def ncclGetUniqueId () -> NcclUniqueId :
93105 unique_id = NcclUniqueId ()
94- result = _c_ncclGetUniqueId (ctypes .byref (unique_id ))
95- assert result == 0
106+ NCCL_CHECK (_c_ncclGetUniqueId (ctypes .byref (unique_id )))
96107 return unique_id
97108
98109
@@ -199,66 +210,75 @@ class NCCLCommunicator:
199210
200211 def __init__ (
201212 self ,
202- backend = None ,
203- init_method = None ,
204- timeout = datetime .timedelta (seconds = 10 ),
205- world_size : int = - 1 ,
206- rank : int = - 1 ,
207- store = None ,
208- group_name : str = "" ,
209- pg_options = None ,
210- local_rank : int = - 1 ,
213+ group : Optional [ProcessGroup ] = None ,
214+ device : Optional [Union [int , str , torch .device ]] = None ,
211215 ):
212- if not dist .is_initialized ():
213- backend = backend or "nccl"
214- assert backend == 'nccl' , (
215- "only use nccl backend for starting the NCCL communicator" )
216- dist .init_process_group (backend = backend ,
217- init_method = init_method ,
218- timeout = timeout ,
219- world_size = world_size ,
220- rank = rank ,
221- store = store ,
222- group_name = group_name ,
223- pg_options = pg_options )
224- self .rank = dist .get_rank ()
225- self .world_size = dist .get_world_size ()
226- if local_rank == - 1 :
227- local_rank = self .rank
228- self .local_rank = local_rank
229- # don't use these args, as they can be -1
230- # use `self.rank`, `self.local_rank` and `self.world_size` instead
231- del world_size , rank , local_rank
232- torch .cuda .set_device (self .local_rank )
216+ """
217+ Args:
218+ group: the process group to work on. If None, it will use the
219+ default process group.
220+ device: the device to bind the NCCLCommunicator to. If None,
221+ it will be bind to f"cuda:{local_rank}".
222+ It is the caller's responsibility to make sure each communicator
223+ is bind to a unique device.
224+ """
225+ assert dist .is_initialized ()
226+ group = get_cpu_world_group () if group is None else group
227+ assert dist .get_backend (group ) != dist .Backend .NCCL , (
228+ "NCCLCommunicator should be attached to a non-NCCL group." )
229+ self .group = group
230+ self .rank = dist .get_rank (group )
231+ self .world_size = dist .get_world_size (group )
233232 if self .rank == 0 :
234233 self .unique_id = ncclGetUniqueId ()
235234 else :
236235 self .unique_id = NcclUniqueId ()
237- tensor = torch .ByteTensor (list (self .unique_id .internal )).cuda (
238- self .local_rank )
239- dist .broadcast (tensor , src = 0 )
240- byte_list = tensor .cpu ().tolist ()
236+ tensor = torch .ByteTensor (list (self .unique_id .internal ))
237+ dist .broadcast (tensor , src = 0 , group = group )
238+ byte_list = tensor .tolist ()
241239 for i , byte in enumerate (byte_list ):
242240 self .unique_id .internal [i ] = byte
243241 self .comm = ctypes .c_void_p ()
244- result = _c_ncclCommInitRank (ctypes .byref (self .comm ), self .world_size ,
245- self .unique_id , self .rank )
246- assert result == 0
247- self .stream = torch .cuda .Stream (device = f"cuda:{ self .local_rank } " )
242+ if device is None :
243+ local_rank = get_local_rank ()
244+ device = torch .device (f"cuda:{ local_rank } " )
245+ elif isinstance (device , int ):
246+ device = torch .device (f"cuda:{ device } " )
247+ elif isinstance (device , str ):
248+ device = torch .device (device )
249+ # now `device` is a `torch.device` object
250+ assert isinstance (device , torch .device )
251+ self .device = device
252+ # nccl communicator and stream will use this device
253+ current_device = torch .cuda .current_device ()
254+ try :
255+ torch .cuda .set_device (device )
256+ NCCL_CHECK (
257+ _c_ncclCommInitRank (ctypes .byref (self .comm ), self .world_size ,
258+ self .unique_id , self .rank ))
259+ self .stream = torch .cuda .Stream ()
260+ finally :
261+ torch .cuda .set_device (current_device )
248262
249263 def all_reduce (self ,
250264 tensor : torch .Tensor ,
251265 op : ReduceOp = ReduceOp .SUM ,
252266 stream = None ):
267+ # nccl communicator created on a specific device
268+ # will only work on tensors on the same device
269+ # otherwise it will cause "illegal memory access"
270+ assert tensor .device == self .device , (
271+ f"this nccl communicator is created to work on { self .device } , "
272+ f"but the input tensor is on { tensor .device } " )
253273 if stream is None :
254274 stream = self .stream
255- result = _c_ncclAllReduce ( ctypes . c_void_p ( tensor . data_ptr ()),
256- ctypes .c_void_p (tensor .data_ptr ()),
257- tensor .numel ( ),
258- ncclDataTypeEnum . from_torch ( tensor .dtype ),
259- ncclRedOpTypeEnum .from_torch (op ), self . comm ,
260- ctypes . c_void_p ( stream . cuda_stream ))
261- assert result == 0
275+ NCCL_CHECK (
276+ _c_ncclAllReduce ( ctypes .c_void_p (tensor .data_ptr ()),
277+ ctypes . c_void_p ( tensor .data_ptr () ),
278+ tensor .numel ( ),
279+ ncclDataTypeEnum .from_torch (tensor . dtype ) ,
280+ ncclRedOpTypeEnum . from_torch ( op ), self . comm ,
281+ ctypes . c_void_p ( stream . cuda_stream )))
262282
263283 def __del__ (self ):
264284 # `dist` module might have been already destroyed
0 commit comments