Skip to content

Commit 9606d57

Browse files
authored
[distributed] fix dp group (#15355)
Signed-off-by: youkaichao <[email protected]>
1 parent cbcdf2c commit 9606d57

File tree

1 file changed

+13
-26
lines changed

1 file changed

+13
-26
lines changed

vllm/distributed/parallel_state.py

Lines changed: 13 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -897,29 +897,22 @@ def initialize_model_parallel(
897897
get_world_group().device_group)
898898

899899
data_parallel_size = 1
900-
has_external_dp = False
901900
from vllm.config import get_current_vllm_config
902901
config = get_current_vllm_config()
903902
if config is not None:
904-
if config.parallel_config.world_size != world_size:
905-
# detect external data parallelism.
906-
# dp in vllm means all dp instances need to run together.
907-
# if the world size does not match, it means this dp is external,
908-
# and the dp instances can run independently, e.g. in rlhf workflow
909-
# from https://github.com/volcengine/verl .
910-
# in that case, we treat the rest dimensions as if they are
911-
# data parallel, and create a dummy dp group that is not used.
912-
data_parallel_size = world_size // (pipeline_model_parallel_size *
913-
tensor_model_parallel_size)
914-
has_external_dp = True
915-
else:
916-
data_parallel_size = config.parallel_config.data_parallel_size
917-
918-
# the layout order is: DP x PP x TP
903+
data_parallel_size = config.parallel_config.data_parallel_size
904+
905+
# the layout order is: ExternalDP x DP x PP x TP
906+
# ExternalDP is the data parallel group that is not part of the model,
907+
# every dp rank can generate independently (in verl integration).
908+
# DP is the data parallel group that is part of the model,
909+
# all the ranks in the same DP group should generate simultaneously,
910+
# i.e. the `generate` call in the same DP group should be called together,
911+
# otherwise it will cause deadlock.
919912
# to get group_ranks for each dimension, transpose that dimension to the
920913
# last dimension, then reshape to 2D, then unbind the last dimension
921914
all_ranks = torch.arange(world_size).reshape(
922-
data_parallel_size, pipeline_model_parallel_size,
915+
-1, data_parallel_size, pipeline_model_parallel_size,
923916
tensor_model_parallel_size) # noqa
924917

925918
# Build the tensor model-parallel groups.
@@ -939,7 +932,7 @@ def initialize_model_parallel(
939932
global _PP
940933
assert _PP is None, (
941934
"pipeline model parallel group is already initialized")
942-
group_ranks = all_ranks.transpose(1, 2).reshape(
935+
group_ranks = all_ranks.transpose(2, 3).reshape(
943936
-1, pipeline_model_parallel_size).unbind(0)
944937
group_ranks = [x.tolist() for x in group_ranks]
945938
_PP = init_model_parallel_group(group_ranks,
@@ -949,16 +942,10 @@ def initialize_model_parallel(
949942

950943
global _DP
951944
assert _DP is None, ("data parallel group is already initialized")
952-
group_ranks = all_ranks.transpose(0,
953-
2).reshape(-1,
945+
group_ranks = all_ranks.transpose(1,
946+
3).reshape(-1,
954947
data_parallel_size).unbind(0)
955948
group_ranks = [x.tolist() for x in group_ranks]
956-
if has_external_dp:
957-
# create a dummy dp group that is not used actually,
958-
# since this dp is external.
959-
# a dummy dp group means every rank is a group itself.
960-
# this way, no communication is needed, no memory is wasted.
961-
group_ranks = [[x] for x in range(world_size)]
962949
_DP = init_model_parallel_group(group_ranks,
963950
get_world_group().local_rank,
964951
backend,

0 commit comments

Comments
 (0)