@@ -897,29 +897,22 @@ def initialize_model_parallel(
897
897
get_world_group ().device_group )
898
898
899
899
data_parallel_size = 1
900
- has_external_dp = False
901
900
from vllm .config import get_current_vllm_config
902
901
config = get_current_vllm_config ()
903
902
if config is not None :
904
- if config .parallel_config .world_size != world_size :
905
- # detect external data parallelism.
906
- # dp in vllm means all dp instances need to run together.
907
- # if the world size does not match, it means this dp is external,
908
- # and the dp instances can run independently, e.g. in rlhf workflow
909
- # from https://github.com/volcengine/verl .
910
- # in that case, we treat the rest dimensions as if they are
911
- # data parallel, and create a dummy dp group that is not used.
912
- data_parallel_size = world_size // (pipeline_model_parallel_size *
913
- tensor_model_parallel_size )
914
- has_external_dp = True
915
- else :
916
- data_parallel_size = config .parallel_config .data_parallel_size
917
-
918
- # the layout order is: DP x PP x TP
903
+ data_parallel_size = config .parallel_config .data_parallel_size
904
+
905
+ # the layout order is: ExternalDP x DP x PP x TP
906
+ # ExternalDP is the data parallel group that is not part of the model,
907
+ # every dp rank can generate independently (in verl integration).
908
+ # DP is the data parallel group that is part of the model,
909
+ # all the ranks in the same DP group should generate simultaneously,
910
+ # i.e. the `generate` call in the same DP group should be called together,
911
+ # otherwise it will cause deadlock.
919
912
# to get group_ranks for each dimension, transpose that dimension to the
920
913
# last dimension, then reshape to 2D, then unbind the last dimension
921
914
all_ranks = torch .arange (world_size ).reshape (
922
- data_parallel_size , pipeline_model_parallel_size ,
915
+ - 1 , data_parallel_size , pipeline_model_parallel_size ,
923
916
tensor_model_parallel_size ) # noqa
924
917
925
918
# Build the tensor model-parallel groups.
@@ -939,7 +932,7 @@ def initialize_model_parallel(
939
932
global _PP
940
933
assert _PP is None , (
941
934
"pipeline model parallel group is already initialized" )
942
- group_ranks = all_ranks .transpose (1 , 2 ).reshape (
935
+ group_ranks = all_ranks .transpose (2 , 3 ).reshape (
943
936
- 1 , pipeline_model_parallel_size ).unbind (0 )
944
937
group_ranks = [x .tolist () for x in group_ranks ]
945
938
_PP = init_model_parallel_group (group_ranks ,
@@ -949,16 +942,10 @@ def initialize_model_parallel(
949
942
950
943
global _DP
951
944
assert _DP is None , ("data parallel group is already initialized" )
952
- group_ranks = all_ranks .transpose (0 ,
953
- 2 ).reshape (- 1 ,
945
+ group_ranks = all_ranks .transpose (1 ,
946
+ 3 ).reshape (- 1 ,
954
947
data_parallel_size ).unbind (0 )
955
948
group_ranks = [x .tolist () for x in group_ranks ]
956
- if has_external_dp :
957
- # create a dummy dp group that is not used actually,
958
- # since this dp is external.
959
- # a dummy dp group means every rank is a group itself.
960
- # this way, no communication is needed, no memory is wasted.
961
- group_ranks = [[x ] for x in range (world_size )]
962
949
_DP = init_model_parallel_group (group_ranks ,
963
950
get_world_group ().local_rank ,
964
951
backend ,
0 commit comments