20
20
# variable in the code.
21
21
22
22
import ctypes
23
- import datetime
24
23
import platform
24
+ from typing import Optional , Union
25
25
26
26
# ===================== import region =====================
27
27
import torch
28
28
import torch .distributed as dist
29
- from torch .distributed import ReduceOp
29
+ from torch .distributed import ProcessGroup , ReduceOp
30
30
31
+ from vllm .distributed .parallel_state import get_cpu_world_group , get_local_rank
31
32
from vllm .logger import init_logger
32
33
from vllm .utils import find_nccl_library , nccl_integrity_check
33
34
59
60
60
61
ncclResult_t = ctypes .c_int
61
62
63
+ _c_ncclGetErrorString = nccl .ncclGetErrorString
64
+ _c_ncclGetErrorString .restype = ctypes .c_char_p
65
+ _c_ncclGetErrorString .argtypes = [ncclResult_t ]
66
+
67
+
68
+ def NCCL_CHECK (result : ncclResult_t ) -> None :
69
+ if result != 0 :
70
+ error_str = _c_ncclGetErrorString (result )
71
+ error_str = error_str .decode ("utf-8" )
72
+ raise RuntimeError (f"NCCL error: { error_str } " )
73
+
74
+
62
75
# equivalent to c declaration:
63
76
# ncclResult_t ncclGetVersion(int *version);
64
77
_c_ncclGetVersion = nccl .ncclGetVersion
68
81
69
82
def ncclGetVersion () -> str :
70
83
version = ctypes .c_int ()
71
- result = _c_ncclGetVersion (ctypes .byref (version ))
72
- assert result == 0
84
+ NCCL_CHECK (_c_ncclGetVersion (ctypes .byref (version )))
73
85
# something like 21903 --> "2.19.3"
74
86
version_str = str (version .value )
75
87
major = version_str [0 ].lstrip ("0" )
@@ -91,8 +103,7 @@ class NcclUniqueId(ctypes.Structure):
91
103
92
104
def ncclGetUniqueId () -> NcclUniqueId :
93
105
unique_id = NcclUniqueId ()
94
- result = _c_ncclGetUniqueId (ctypes .byref (unique_id ))
95
- assert result == 0
106
+ NCCL_CHECK (_c_ncclGetUniqueId (ctypes .byref (unique_id )))
96
107
return unique_id
97
108
98
109
@@ -199,66 +210,75 @@ class NCCLCommunicator:
199
210
200
211
def __init__ (
201
212
self ,
202
- backend = None ,
203
- init_method = None ,
204
- timeout = datetime .timedelta (seconds = 10 ),
205
- world_size : int = - 1 ,
206
- rank : int = - 1 ,
207
- store = None ,
208
- group_name : str = "" ,
209
- pg_options = None ,
210
- local_rank : int = - 1 ,
213
+ group : Optional [ProcessGroup ] = None ,
214
+ device : Optional [Union [int , str , torch .device ]] = None ,
211
215
):
212
- if not dist .is_initialized ():
213
- backend = backend or "nccl"
214
- assert backend == 'nccl' , (
215
- "only use nccl backend for starting the NCCL communicator" )
216
- dist .init_process_group (backend = backend ,
217
- init_method = init_method ,
218
- timeout = timeout ,
219
- world_size = world_size ,
220
- rank = rank ,
221
- store = store ,
222
- group_name = group_name ,
223
- pg_options = pg_options )
224
- self .rank = dist .get_rank ()
225
- self .world_size = dist .get_world_size ()
226
- if local_rank == - 1 :
227
- local_rank = self .rank
228
- self .local_rank = local_rank
229
- # don't use these args, as they can be -1
230
- # use `self.rank`, `self.local_rank` and `self.world_size` instead
231
- del world_size , rank , local_rank
232
- torch .cuda .set_device (self .local_rank )
216
+ """
217
+ Args:
218
+ group: the process group to work on. If None, it will use the
219
+ default process group.
220
+ device: the device to bind the NCCLCommunicator to. If None,
221
+ it will be bind to f"cuda:{local_rank}".
222
+ It is the caller's responsibility to make sure each communicator
223
+ is bind to a unique device.
224
+ """
225
+ assert dist .is_initialized ()
226
+ group = get_cpu_world_group () if group is None else group
227
+ assert dist .get_backend (group ) != dist .Backend .NCCL , (
228
+ "NCCLCommunicator should be attached to a non-NCCL group." )
229
+ self .group = group
230
+ self .rank = dist .get_rank (group )
231
+ self .world_size = dist .get_world_size (group )
233
232
if self .rank == 0 :
234
233
self .unique_id = ncclGetUniqueId ()
235
234
else :
236
235
self .unique_id = NcclUniqueId ()
237
- tensor = torch .ByteTensor (list (self .unique_id .internal )).cuda (
238
- self .local_rank )
239
- dist .broadcast (tensor , src = 0 )
240
- byte_list = tensor .cpu ().tolist ()
236
+ tensor = torch .ByteTensor (list (self .unique_id .internal ))
237
+ dist .broadcast (tensor , src = 0 , group = group )
238
+ byte_list = tensor .tolist ()
241
239
for i , byte in enumerate (byte_list ):
242
240
self .unique_id .internal [i ] = byte
243
241
self .comm = ctypes .c_void_p ()
244
- result = _c_ncclCommInitRank (ctypes .byref (self .comm ), self .world_size ,
245
- self .unique_id , self .rank )
246
- assert result == 0
247
- self .stream = torch .cuda .Stream (device = f"cuda:{ self .local_rank } " )
242
+ if device is None :
243
+ local_rank = get_local_rank ()
244
+ device = torch .device (f"cuda:{ local_rank } " )
245
+ elif isinstance (device , int ):
246
+ device = torch .device (f"cuda:{ device } " )
247
+ elif isinstance (device , str ):
248
+ device = torch .device (device )
249
+ # now `device` is a `torch.device` object
250
+ assert isinstance (device , torch .device )
251
+ self .device = device
252
+ # nccl communicator and stream will use this device
253
+ current_device = torch .cuda .current_device ()
254
+ try :
255
+ torch .cuda .set_device (device )
256
+ NCCL_CHECK (
257
+ _c_ncclCommInitRank (ctypes .byref (self .comm ), self .world_size ,
258
+ self .unique_id , self .rank ))
259
+ self .stream = torch .cuda .Stream ()
260
+ finally :
261
+ torch .cuda .set_device (current_device )
248
262
249
263
def all_reduce (self ,
250
264
tensor : torch .Tensor ,
251
265
op : ReduceOp = ReduceOp .SUM ,
252
266
stream = None ):
267
+ # nccl communicator created on a specific device
268
+ # will only work on tensors on the same device
269
+ # otherwise it will cause "illegal memory access"
270
+ assert tensor .device == self .device , (
271
+ f"this nccl communicator is created to work on { self .device } , "
272
+ f"but the input tensor is on { tensor .device } " )
253
273
if stream is None :
254
274
stream = self .stream
255
- result = _c_ncclAllReduce ( ctypes . c_void_p ( tensor . data_ptr ()),
256
- ctypes .c_void_p (tensor .data_ptr ()),
257
- tensor .numel ( ),
258
- ncclDataTypeEnum . from_torch ( tensor .dtype ),
259
- ncclRedOpTypeEnum .from_torch (op ), self . comm ,
260
- ctypes . c_void_p ( stream . cuda_stream ))
261
- assert result == 0
275
+ NCCL_CHECK (
276
+ _c_ncclAllReduce ( ctypes .c_void_p (tensor .data_ptr ()),
277
+ ctypes . c_void_p ( tensor .data_ptr () ),
278
+ tensor .numel ( ),
279
+ ncclDataTypeEnum .from_torch (tensor . dtype ) ,
280
+ ncclRedOpTypeEnum . from_torch ( op ), self . comm ,
281
+ ctypes . c_void_p ( stream . cuda_stream )))
262
282
263
283
def __del__ (self ):
264
284
# `dist` module might have been already destroyed
0 commit comments