@@ -719,14 +719,19 @@ def init_world_group(ranks: List[int], local_rank: int,
719
719
)
720
720
721
721
722
- def init_model_parallel_group (group_ranks : List [List [int ]], local_rank : int ,
723
- backend : str ) -> GroupCoordinator :
722
+ def init_model_parallel_group (
723
+ group_ranks : List [List [int ]],
724
+ local_rank : int ,
725
+ backend : str ,
726
+ use_custom_allreduce : Optional [bool ] = None ) -> GroupCoordinator :
727
+ if use_custom_allreduce is None :
728
+ use_custom_allreduce = _ENABLE_CUSTOM_ALL_REDUCE
724
729
return GroupCoordinator (
725
730
group_ranks = group_ranks ,
726
731
local_rank = local_rank ,
727
732
torch_distributed_backend = backend ,
728
733
use_pynccl = True ,
729
- use_custom_allreduce = _ENABLE_CUSTOM_ALL_REDUCE ,
734
+ use_custom_allreduce = use_custom_allreduce ,
730
735
)
731
736
732
737
@@ -888,8 +893,11 @@ def initialize_model_parallel(
888
893
for i in range (num_pipeline_model_parallel_groups ):
889
894
ranks = list (range (i , world_size , num_pipeline_model_parallel_groups ))
890
895
group_ranks .append (ranks )
896
+ # pipeline parallel does not need custom allreduce
891
897
_PP = init_model_parallel_group (group_ranks ,
892
- get_world_group ().local_rank , backend )
898
+ get_world_group ().local_rank ,
899
+ backend ,
900
+ use_custom_allreduce = False )
893
901
894
902
895
903
def ensure_model_parallel_initialized (
0 commit comments