Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 48 additions & 2 deletions vllm/config/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,9 @@ class ParallelConfig:
disable_custom_all_reduce: bool = False
"""Disable the custom all-reduce kernel and fall back to NCCL."""

enable_elastic_ep: bool = False
"""Enable elastic expert parallelism with stateless NCCL groups for DP/EP."""

enable_dbo: bool = False
"""Enable microbatching for the model executor."""

Expand Down Expand Up @@ -188,6 +191,21 @@ class is dynamically inherited by the worker class. This is used to inject
Set to be private as it's not intended to be configured by users.
"""

_stateless_world_group_port_list: list[int] = field(default_factory=list)
"""List of open ports for stateless world group when enable_elastic_ep is True.
Set to be private as it's not intended to be configured by users.
"""

_stateless_dp_group_port_list: list[int] = field(default_factory=list)
"""List of open ports for stateless DP groups when enable_elastic_ep is True.
Set to be private as it's not intended to be configured by users.
"""

_stateless_ep_group_port_list: list[int] = field(default_factory=list)
"""List of open ports for stateless EP groups when enable_elastic_ep is True.
Set to be private as it's not intended to be configured by users.
"""

decode_context_parallel_size: int = 1
"""Number of decode context parallel groups, because the world size does
not change by dcp, it simply reuse the GPUs of TP group, and tp_size
Expand Down Expand Up @@ -235,7 +253,16 @@ def get_next_dp_init_port(self) -> int:

return answer

def stateless_init_dp_group(self) -> ProcessGroup:
def get_next_stateless_world_group_port(self) -> list[int]:
return self._stateless_world_group_port_list.pop(0)

def get_next_stateless_dp_group_port(self) -> list[int]:
return self._stateless_dp_group_port_list.pop(0)

def get_next_stateless_ep_group_port(self) -> list[int]:
return self._stateless_ep_group_port_list.pop(0)
Comment on lines +256 to +263
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

These methods use pop(0) to retrieve a port from a list without checking if the list is empty. If the port lists (_stateless_world_group_port_list, _stateless_dp_group_port_list, _stateless_ep_group_port_list) are exhausted for any reason, this will raise an IndexError and crash the process. While the logic in __post_init__ seems to pre-allocate the necessary ports, this design is fragile. A more robust implementation would be to check if the list is empty before popping and raise a more informative error message.


def stateless_init_dp_group(self, return_store: bool = False) -> ProcessGroup:
# NOTE: In high-concurrency scenarios multiple processes
# can pick the same (currently free) port through a race
# condition when calling `get_open_port()`. When the first
Expand All @@ -258,7 +285,8 @@ def stateless_init_dp_group(self) -> ProcessGroup:
self.get_next_dp_init_port(),
self.data_parallel_rank,
self.data_parallel_size,
backend="gloo")
backend="gloo",
return_store=return_store)
except DistNetworkError as e:
# We only want to retry when the root cause is EADDRINUSE.
if "EADDRINUSE" in str(e):
Expand Down Expand Up @@ -351,6 +379,24 @@ def __post_init__(self) -> None:
self.world_size = self.pipeline_parallel_size * \
self.tensor_parallel_size

# Initialize stateless group ports for elastic EP
if self.enable_elastic_ep:
num_world_groups = 1
num_dp_groups = max(1, self.world_size_across_dp // self.data_parallel_size)
num_ep_groups = max(1, self.world_size_across_dp // (self.data_parallel_size * self.tensor_parallel_size))

total_ports_needed = (num_world_groups + num_dp_groups + num_ep_groups) * 3

if not self._stateless_world_group_port_list:
all_ports = get_open_ports_list(total_ports_needed + 5)
self._data_parallel_master_port_list = all_ports[-5:]
all_ports = all_ports[:-5]
self._stateless_world_group_port_list = [all_ports[i:i+3] for i in range(0, num_world_groups * 3, 3)]
start_idx = num_world_groups * 3
self._stateless_dp_group_port_list = [all_ports[i:i+3] for i in range(start_idx, start_idx + num_dp_groups * 3, 3)]
start_idx += num_dp_groups * 3
self._stateless_ep_group_port_list = [all_ports[i:i+3] for i in range(start_idx, start_idx + num_ep_groups * 3, 3)]

if self.data_parallel_size_local > self.data_parallel_size:
raise ValueError(
f"data_parallel_size_local ({self.data_parallel_size_local}) "
Expand Down
53 changes: 30 additions & 23 deletions vllm/distributed/device_communicators/all2all.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ class NaiveAll2AllManager(All2AllManagerBase):
debugging.
"""

def __init__(self, cpu_group):
super().__init__(cpu_group)
def __init__(self, cpu_group, tcp_store_group=None):
super().__init__(cpu_group, tcp_store_group)

def naive_multicast(self, x: torch.Tensor,
cu_tokens_across_dp_cpu: torch.Tensor):
Expand Down Expand Up @@ -76,8 +76,8 @@ class AgRsAll2AllManager(All2AllManagerBase):
all-gather (dispatch) and reduce-scatter (combine).
"""

def __init__(self, cpu_group):
super().__init__(cpu_group)
def __init__(self, cpu_group, tcp_store_group=None):
super().__init__(cpu_group, tcp_store_group)

def dispatch(self, hidden_states: torch.Tensor,
router_logits: torch.Tensor):
Expand Down Expand Up @@ -113,14 +113,16 @@ class PPLXAll2AllManager(All2AllManagerBase):
All2All communication based on PPLX kernels.
"""

def __init__(self, cpu_group):
def __init__(self, cpu_group, tcp_store_group=None):
assert has_pplx(
), "pplx_kernels not found. Please follow https://github.com/vllm-project/vllm/blob/main/tools/ep_kernels/README.md to install pplx_kernels." # noqa
super().__init__(cpu_group)
super().__init__(cpu_group, tcp_store_group)

if self.internode:
# inter-node communication needs nvshmem,
# intra-node communication uses p2p mapping directly
self.nvshmem_initialized = False
self.handle_cache = Cache()

def get_handle(self, kwargs):
if self.internode and not self.nvshmem_initialized:
from pplx_kernels.nvshmem import (nvshmem_alloc_empty_unique_id,
nvshmem_get_unique_id,
nvshmem_init)
Expand All @@ -129,15 +131,18 @@ def __init__(self, cpu_group):
"rank=%d, world size=%d", self.rank, self.world_size)
uid = nvshmem_get_unique_id(
) if self.rank == 0 else nvshmem_alloc_empty_unique_id()
dist.broadcast(uid,
src=dist.get_process_group_ranks(self.cpu_group)[0],
group=self.cpu_group)

if self.tcp_store_group is not None:
uid = self.tcp_store_group.broadcast_obj(uid, src=0)
else:
dist.broadcast(uid,
src=dist.get_process_group_ranks(self.cpu_group)[0],
group=self.cpu_group)

logger.debug("PPLX NVSHMEM UID = %s", uid)
nvshmem_init(uid, self.rank, self.world_size)
self.nvshmem_initialized = True

self.handle_cache = Cache()

def get_handle(self, kwargs):
import pplx_kernels as pplx
return self.handle_cache.get_or_create(
kwargs, pplx.AllToAll.internode
Expand Down Expand Up @@ -166,10 +171,10 @@ class DeepEPAll2AllManagerBase(All2AllManagerBase):
All2All communication based on DeepEP High-Throughput kernels.
"""

def __init__(self, cpu_group):
def __init__(self, cpu_group, tcp_store_group=None):
assert has_deep_ep(
), "DeepEP kernels not found. Please follow https://github.com/vllm-project/vllm/blob/main/tools/ep_kernels/README.md to install DeepEP kernels." # noqa
super().__init__(cpu_group)
super().__init__(cpu_group, tcp_store_group)
self.handle_cache = Cache()

# This is the DeepEP default. Stick to it till we can establish
Expand All @@ -195,8 +200,8 @@ class DeepEPHTAll2AllManager(DeepEPAll2AllManagerBase):
All2All communication based on DeepEP High-Throughput kernels.
"""

def __init__(self, cpu_group):
super().__init__(cpu_group)
def __init__(self, cpu_group, tcp_store_group=None):
super().__init__(cpu_group, tcp_store_group)

def _make_all2all_kwargs(self) -> dict[Any, Any]:
# Defaults for internode and intranode are taken from DeepEP tests.
Expand Down Expand Up @@ -243,8 +248,8 @@ class DeepEPLLAll2AllManager(DeepEPAll2AllManagerBase):
All2All communication based on DeepEP Low-Latency kernels.
"""

def __init__(self, cpu_group):
super().__init__(cpu_group)
def __init__(self, cpu_group, tcp_store_group=None):
super().__init__(cpu_group, tcp_store_group)

def _make_all2all_kwargs(
self,
Expand All @@ -265,7 +270,8 @@ def _make_all2all_kwargs(
import deep_ep

# Defaults for internode and intranode are taken from DeepEP tests.
num_nvl_bytes = 1024 * 1024 * 1024
# num_nvl_bytes = 1024 * 1024 * 1024
num_nvl_bytes = 0
num_qps_per_rank = num_local_experts
num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint(
num_max_dispatch_tokens_per_rank=max_num_tokens_per_dp_rank,
Expand All @@ -278,7 +284,8 @@ def _make_all2all_kwargs(
num_nvl_bytes=num_nvl_bytes,
num_rdma_bytes=num_rdma_bytes,
low_latency_mode=True,
num_qps_per_rank=num_qps_per_rank)
num_qps_per_rank=num_qps_per_rank,
allow_mnnvl=True)

def get_handle(self, kwargs):
"""
Expand Down
44 changes: 32 additions & 12 deletions vllm/distributed/device_communicators/base_device_communicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,9 @@ def get_or_create(self, kwargs, func):

class All2AllManagerBase:

def __init__(self, cpu_group):
def __init__(self, cpu_group, tcp_store_group=None):
self.cpu_group = cpu_group
self.tcp_store_group = tcp_store_group

# compute some common properties
from vllm.distributed.parallel_state import (get_dp_group,
Expand All @@ -44,12 +45,15 @@ def __init__(self, cpu_group):
# when we create this object
self.dp_rank = self.dp_group.rank_in_group
self.dp_world_size = self.dp_group.world_size
self.rank = dist.get_rank(cpu_group)
self.world_size = dist.get_world_size(cpu_group)
self.rank = cpu_group.rank()
self.world_size = cpu_group.size()

# all2all communication often has separate implementations for
# intra-node and inter-node communication
self.internode = not all(in_the_same_node_as(cpu_group, source_rank=0))
if tcp_store_group is None:
self.internode = not all(in_the_same_node_as(cpu_group, source_rank=0))
else:
self.internode = not all(in_the_same_node_as(tcp_store_group, source_rank=0))

def get_handle(self, kwargs):
# get a handle for the all2all communication,
Expand Down Expand Up @@ -83,18 +87,34 @@ def __init__(self,
cpu_group: ProcessGroup,
device: Optional[torch.device] = None,
device_group: Optional[ProcessGroup] = None,
unique_name: str = ""):
unique_name: str = "",
global_ranks: Optional[list[int]] = None,
global_world_size: Optional[int] = None):
self.device = device or torch.device("cpu")
self.cpu_group = cpu_group
self.device_group = device_group
self.unique_name = unique_name
self.rank = dist.get_rank(cpu_group)
self.world_size = dist.get_world_size(cpu_group)
self.ranks = dist.get_process_group_ranks(cpu_group)
self.global_rank = dist.get_rank()
self.global_world_size = dist.get_world_size()
self.rank_in_group = dist.get_group_rank(self.cpu_group,
self.global_rank)

# Check if this is a stateless process group
from torch.distributed.distributed_c10d import _world
is_stateless = _world.pg_map.get(cpu_group, None) is None
Comment on lines +98 to +100
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The check _world.pg_map.get(cpu_group, None) is None relies on an internal, undocumented implementation detail of torch.distributed to determine if a process group is stateless. This is a brittle approach that could break with future PyTorch updates. It would be more robust to use an explicit mechanism to identify stateless groups, such as a custom process group class that carries this information, or passing a flag during initialization.


if is_stateless:
# For stateless groups, we can't use torch.distributed methods
self.rank = cpu_group.rank()
self.world_size = cpu_group.size()
self.ranks = global_ranks
self.global_rank = self.ranks[self.rank]
self.global_world_size = global_world_size
self.rank_in_group = self.rank
else:
self.rank = dist.get_rank(cpu_group)
self.world_size = dist.get_world_size(cpu_group)
self.ranks = dist.get_process_group_ranks(cpu_group)
self.global_rank = dist.get_rank()
self.global_world_size = dist.get_world_size()
self.rank_in_group = dist.get_group_rank(self.cpu_group,
self.global_rank)

use_ep = False
from vllm.config import get_current_vllm_config
Expand Down
42 changes: 33 additions & 9 deletions vllm/distributed/device_communicators/cuda_communicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from vllm.platforms import current_platform

from .base_device_communicator import DeviceCommunicatorBase
from ..utils import StatelessProcessGroup

logger = init_logger(__name__)

Expand All @@ -21,8 +22,12 @@ def __init__(self,
cpu_group: ProcessGroup,
device: Optional[torch.device] = None,
device_group: Optional[ProcessGroup] = None,
unique_name: str = ""):
super().__init__(cpu_group, device, device_group, unique_name)
unique_name: str = "",
global_ranks: Optional[list[int]] = None,
global_world_size: Optional[int] = None,
tcp_store_group: Optional[StatelessProcessGroup] = None):
super().__init__(cpu_group, device, device_group, unique_name,
global_ranks, global_world_size)
if "tp" not in unique_name:
# only tp uses custom allreduce
use_custom_allreduce = False
Expand All @@ -32,7 +37,7 @@ def __init__(self,
use_custom_allreduce = _ENABLE_CUSTOM_ALL_REDUCE

# ep does not use pynccl
use_pynccl = "ep" not in unique_name
use_pynccl = ("ep" not in unique_name) or (tcp_store_group is not None)

self.use_pynccl = use_pynccl
self.use_custom_allreduce = use_custom_allreduce
Expand All @@ -50,7 +55,7 @@ def __init__(self,
self.pynccl_comm: Optional[PyNcclCommunicator] = None
if use_pynccl and self.world_size > 1:
self.pynccl_comm = PyNcclCommunicator(
group=self.cpu_group,
group=self.cpu_group if tcp_store_group is None else tcp_store_group,
device=self.device,
)

Expand Down Expand Up @@ -85,23 +90,23 @@ def __init__(self,
all2all_backend = envs.VLLM_ALL2ALL_BACKEND
if all2all_backend == "naive":
from .all2all import NaiveAll2AllManager
self.all2all_manager = NaiveAll2AllManager(self.cpu_group)
self.all2all_manager = NaiveAll2AllManager(self.cpu_group, tcp_store_group=tcp_store_group)
logger.info("Using naive all2all manager.")
elif all2all_backend == "allgather_reducescatter":
from .all2all import AgRsAll2AllManager
self.all2all_manager = AgRsAll2AllManager(self.cpu_group)
self.all2all_manager = AgRsAll2AllManager(self.cpu_group, tcp_store_group=tcp_store_group)
logger.info("Using AllGather-ReduceScatter all2all manager.")
elif all2all_backend == "pplx":
from .all2all import PPLXAll2AllManager
self.all2all_manager = PPLXAll2AllManager(self.cpu_group)
self.all2all_manager = PPLXAll2AllManager(self.cpu_group, tcp_store_group=tcp_store_group)
logger.info("Using PPLX all2all manager.")
elif all2all_backend == "deepep_high_throughput":
from .all2all import DeepEPHTAll2AllManager
self.all2all_manager = DeepEPHTAll2AllManager(self.cpu_group)
self.all2all_manager = DeepEPHTAll2AllManager(self.cpu_group, tcp_store_group=tcp_store_group)
logger.info("Using DeepEP High-Throughput all2all manager.")
elif all2all_backend == "deepep_low_latency":
from .all2all import DeepEPLLAll2AllManager
self.all2all_manager = DeepEPLLAll2AllManager(self.cpu_group)
self.all2all_manager = DeepEPLLAll2AllManager(self.cpu_group, tcp_store_group=tcp_store_group)
logger.info("Using DeepEP Low-Latency all2all manager.")
else:
raise ValueError(f"Unknown all2all backend: {all2all_backend}")
Expand Down Expand Up @@ -229,6 +234,18 @@ def recv(self,
torch.distributed.recv(tensor, self.ranks[src], self.device_group)
return tensor

def broadcast(self, tensor: torch.Tensor, src: int = 0) -> torch.Tensor:
"""Broadcast a tensor from source rank to all ranks."""
if self.world_size == 1:
return tensor

pynccl_comm = self.pynccl_comm
if pynccl_comm is not None and not pynccl_comm.disabled:
pynccl_comm.broadcast(tensor, src)
return tensor
else:
raise ValueError("No PyNCCL communicator found")

def destroy(self):
if self.pynccl_comm is not None:
self.pynccl_comm = None
Expand Down Expand Up @@ -296,3 +313,10 @@ def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
assert self.all2all_manager is not None
hidden_states = self.all2all_manager.combine(hidden_states)
return hidden_states

def batch_isend_irecv(self, p2p_ops: list):
pynccl_comm = self.pynccl_comm
if pynccl_comm is not None and not pynccl_comm.disabled:
pynccl_comm.batch_isend_irecv(p2p_ops)
else:
raise ValueError("No PyNCCL communicator found")
Loading