Skip to content

Commit 3c6325f

Browse files
authored
[core][distributed] custom allreduce when pp size > 1 (#6117)
1 parent 47f0954 commit 3c6325f

File tree

2 files changed

+17
-15
lines changed

2 files changed

+17
-15
lines changed

vllm/config.py

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -723,17 +723,11 @@ def _verify_args(self) -> None:
723723
if self.distributed_executor_backend == "ray":
724724
from vllm.executor import ray_utils
725725
ray_utils.assert_ray_available()
726-
if not self.disable_custom_all_reduce and self.world_size > 1:
727-
if is_hip():
728-
self.disable_custom_all_reduce = True
729-
logger.info(
730-
"Disabled the custom all-reduce kernel because it is not "
731-
"supported on AMD GPUs.")
732-
elif self.pipeline_parallel_size > 1:
733-
self.disable_custom_all_reduce = True
734-
logger.info(
735-
"Disabled the custom all-reduce kernel because it is not "
736-
"supported with pipeline parallelism.")
726+
if is_hip():
727+
self.disable_custom_all_reduce = True
728+
logger.info(
729+
"Disabled the custom all-reduce kernel because it is not "
730+
"supported on AMD GPUs.")
737731
if self.ray_workers_use_nsight and (
738732
not self.distributed_executor_backend == "ray"):
739733
raise ValueError("Unable to use nsight profiling unless workers "

vllm/distributed/parallel_state.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -719,14 +719,19 @@ def init_world_group(ranks: List[int], local_rank: int,
719719
)
720720

721721

722-
def init_model_parallel_group(group_ranks: List[List[int]], local_rank: int,
723-
backend: str) -> GroupCoordinator:
722+
def init_model_parallel_group(
723+
group_ranks: List[List[int]],
724+
local_rank: int,
725+
backend: str,
726+
use_custom_allreduce: Optional[bool] = None) -> GroupCoordinator:
727+
if use_custom_allreduce is None:
728+
use_custom_allreduce = _ENABLE_CUSTOM_ALL_REDUCE
724729
return GroupCoordinator(
725730
group_ranks=group_ranks,
726731
local_rank=local_rank,
727732
torch_distributed_backend=backend,
728733
use_pynccl=True,
729-
use_custom_allreduce=_ENABLE_CUSTOM_ALL_REDUCE,
734+
use_custom_allreduce=use_custom_allreduce,
730735
)
731736

732737

@@ -888,8 +893,11 @@ def initialize_model_parallel(
888893
for i in range(num_pipeline_model_parallel_groups):
889894
ranks = list(range(i, world_size, num_pipeline_model_parallel_groups))
890895
group_ranks.append(ranks)
896+
# pipeline parallel does not need custom allreduce
891897
_PP = init_model_parallel_group(group_ranks,
892-
get_world_group().local_rank, backend)
898+
get_world_group().local_rank,
899+
backend,
900+
use_custom_allreduce=False)
893901

894902

895903
def ensure_model_parallel_initialized(

0 commit comments

Comments
 (0)