@@ -196,10 +196,11 @@ class GroupCoordinator:
196
196
# 3 | 1 | 3 | 1 | 3
197
197
local_rank : int # local rank used to assign devices
198
198
rank_in_group : int # rank inside the group
199
- cpu_group : ProcessGroup # group for CPU communication
200
- device_group : ProcessGroup # group for device communication
199
+ cpu_group : Optional [ ProcessGroup ] # group for CPU communication
200
+ device_group : Optional [ ProcessGroup ] # group for device communication
201
201
use_device_communicator : bool # whether to use device communicator
202
- device_communicator : DeviceCommunicatorBase # device communicator
202
+ device_communicator : Optional [
203
+ DeviceCommunicatorBase ] # device communicator
203
204
mq_broadcaster : Optional [Any ] # shared memory broadcaster
204
205
205
206
def __init__ (
@@ -250,7 +251,7 @@ def __init__(
250
251
251
252
self .use_device_communicator = use_device_communicator
252
253
253
- self .device_communicator : DeviceCommunicatorBase = None # type: ignore
254
+ self .device_communicator = None
254
255
if use_device_communicator and self .world_size > 1 :
255
256
device_comm_cls = resolve_obj_by_qualname (
256
257
current_platform .get_device_communicator_cls ())
@@ -364,6 +365,8 @@ def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
364
365
return self ._all_reduce_out_place (input_ )
365
366
366
367
def _all_reduce_out_place (self , input_ : torch .Tensor ) -> torch .Tensor :
368
+ if self .device_communicator is None :
369
+ raise ValueError ("No device communicator found" )
367
370
return self .device_communicator .all_reduce (input_ )
368
371
369
372
def all_gather (self , input_ : torch .Tensor , dim : int = - 1 ) -> torch .Tensor :
@@ -384,12 +387,16 @@ def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
384
387
385
388
def _all_gather_out_place (self , input_ : torch .Tensor ,
386
389
dim : int ) -> torch .Tensor :
390
+ if self .device_communicator is None :
391
+ raise ValueError ("No device communicator found" )
387
392
return self .device_communicator .all_gather (input_ , dim )
388
393
389
394
def all_gatherv (self ,
390
395
input_ : Union [torch .Tensor , list [torch .Tensor ]],
391
396
dim : int = 0 ,
392
397
sizes : Optional [list [int ]] = None ):
398
+ if self .device_communicator is None :
399
+ raise ValueError ("No device communicator found" )
393
400
return self .device_communicator .all_gatherv (input_ , dim , sizes )
394
401
395
402
def reduce_scatter (self ,
@@ -414,10 +421,14 @@ def reduce_scatterv(self,
414
421
input_ : torch .Tensor ,
415
422
dim : int = - 1 ,
416
423
sizes : Optional [list [int ]] = None ) -> torch .Tensor :
424
+ if self .device_communicator is None :
425
+ raise ValueError ("No device communicator found" )
417
426
return self .device_communicator .reduce_scatterv (input_ , dim , sizes )
418
427
419
428
def _reduce_scatter_out_place (self , input_ : torch .Tensor ,
420
429
dim : int ) -> torch .Tensor :
430
+ if self .device_communicator is None :
431
+ raise ValueError ("No device communicator found" )
421
432
return self .device_communicator .reduce_scatter (input_ , dim )
422
433
423
434
def gather (self ,
@@ -433,6 +444,8 @@ def gather(self,
433
444
# Bypass the function if we are using only 1 GPU.
434
445
if world_size == 1 :
435
446
return input_
447
+ if self .device_communicator is None :
448
+ raise ValueError ("No device communicator found" )
436
449
return self .device_communicator .gather (input_ , dst , dim )
437
450
438
451
def broadcast (self , input_ : torch .Tensor , src : int = 0 ):
@@ -667,6 +680,8 @@ def send_tensor_dict(
667
680
assert dst < self .world_size , f"Invalid dst rank ({ dst } )"
668
681
669
682
if self .use_cpu_custom_send_recv :
683
+ if self .device_communicator is None :
684
+ raise ValueError ("No device communicator found" )
670
685
self .device_communicator .send_tensor_dict ( # type: ignore
671
686
tensor_dict , dst )
672
687
return None
@@ -727,6 +742,8 @@ def recv_tensor_dict(
727
742
assert src < self .world_size , f"Invalid src rank ({ src } )"
728
743
729
744
if self .use_cpu_custom_send_recv :
745
+ if self .device_communicator is None :
746
+ raise ValueError ("No device communicator found" )
730
747
return self .device_communicator .recv_tensor_dict ( # type: ignore
731
748
src )
732
749
@@ -784,6 +801,8 @@ def barrier(self):
784
801
def send (self , tensor : torch .Tensor , dst : Optional [int ] = None ) -> None :
785
802
"""Sends a tensor to the destination rank in a blocking way"""
786
803
"""NOTE: `dst` is the local rank of the destination rank."""
804
+ if self .device_communicator is None :
805
+ raise ValueError ("No device communicator found" )
787
806
self .device_communicator .send (tensor , dst )
788
807
789
808
def recv (self ,
@@ -792,6 +811,8 @@ def recv(self,
792
811
src : Optional [int ] = None ) -> torch .Tensor :
793
812
"""Receives a tensor from the source rank."""
794
813
"""NOTE: `src` is the local rank of the source rank."""
814
+ if self .device_communicator is None :
815
+ raise ValueError ("No device communicator found" )
795
816
return self .device_communicator .recv (size , dtype , src )
796
817
797
818
def destroy (self ):
0 commit comments