@@ -197,19 +197,18 @@ class GroupCoordinator:
197
197
# 3 | 1 | 3 | 1 | 3
198
198
local_rank : int # local rank used to assign devices
199
199
rank_in_group : int # rank inside the group
200
- cpu_group : Optional [ProcessGroup ] # group for CPU communication
201
- device_group : Optional [ProcessGroup ] # group for device communication
202
- use_device_communicator : bool # whether to use device communicator
203
- device_communicator : Optional [
204
- DeviceCommunicatorBase ] # device communicator
200
+ cpu_group : ProcessGroup # group for CPU communication
201
+ device_group : ProcessGroup # group for device communication
202
+ # device communicator (if use_device_communicator=True)
203
+ device_communicator : Optional [DeviceCommunicatorBase ]
205
204
mq_broadcaster : Optional [Any ] # shared memory broadcaster
206
205
207
206
def __init__ (
208
207
self ,
209
208
group_ranks : list [list [int ]],
210
209
local_rank : int ,
211
210
torch_distributed_backend : Union [str , Backend ],
212
- use_device_communicator : bool ,
211
+ use_device_communicator : bool , # whether to use device communicator
213
212
use_message_queue_broadcaster : bool = False ,
214
213
group_name : Optional [str ] = None ,
215
214
):
@@ -219,8 +218,9 @@ def __init__(
219
218
220
219
self .rank = torch .distributed .get_rank ()
221
220
self .local_rank = local_rank
222
- self .device_group = None
223
- self .cpu_group = None
221
+
222
+ self_device_group = None
223
+ self_cpu_group = None
224
224
225
225
for ranks in group_ranks :
226
226
device_group = torch .distributed .new_group (
@@ -232,11 +232,14 @@ def __init__(
232
232
self .ranks = ranks
233
233
self .world_size = len (ranks )
234
234
self .rank_in_group = ranks .index (self .rank )
235
- self .device_group = device_group
236
- self .cpu_group = cpu_group
235
+ self_device_group = device_group
236
+ self_cpu_group = cpu_group
237
+
238
+ assert self_cpu_group is not None
239
+ assert self_device_group is not None
237
240
238
- assert self .cpu_group is not None
239
- assert self .device_group is not None
241
+ self .cpu_group = self_cpu_group
242
+ self .device_group = self_device_group
240
243
241
244
from vllm .platforms import current_platform
242
245
@@ -251,7 +254,6 @@ def __init__(
251
254
self .device = torch .device ("cpu" )
252
255
253
256
self .use_device_communicator = use_device_communicator
254
-
255
257
self .device_communicator = None
256
258
if use_device_communicator and self .world_size > 1 :
257
259
device_comm_cls = resolve_obj_by_qualname (
@@ -817,12 +819,12 @@ def recv(self,
817
819
return self .device_communicator .recv (size , dtype , src )
818
820
819
821
def destroy (self ):
820
- if self . device_group is not None :
822
+ if hasattr ( self , "device_group" ) :
821
823
torch .distributed .destroy_process_group (self .device_group )
822
- self .device_group = None
823
- if self . cpu_group is not None :
824
+ del self .device_group
825
+ if hasattr ( self , "cpu_group" ) :
824
826
torch .distributed .destroy_process_group (self .cpu_group )
825
- self .cpu_group = None
827
+ del self .cpu_group
826
828
if self .device_communicator is not None :
827
829
self .device_communicator .destroy ()
828
830
if self .mq_broadcaster is not None :
0 commit comments