Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions experimental/bench.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
#!/bin/bash

# MODEL_NAME="deepseek-ai/DeepSeek-V3.1"
MODEL_NAME="Qwen/Qwen3-30B-A3B-Thinking-2507-FP8"
# MODEL_NAME="Qwen/Qwen3-235B-A22B-Thinking-2507-FP8"
HOST="localhost"
PORT=8006

vllm bench serve \
--model $MODEL_NAME \
--host $HOST \
--port $PORT \
--dataset-name random \
--random-input-len 128 \
--random-output-len 128 \
--num-prompts 512
5 changes: 5 additions & 0 deletions experimental/scale.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
#!/bin/bash
HOST="localhost"
PORT=8006

python examples/online_serving/elastic_ep/scale.py --host $HOST --port $PORT --new-dp-size 4
49 changes: 49 additions & 0 deletions experimental/serve.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
#!/bin/bash

# MODEL_NAME="deepseek-ai/DeepSeek-V3.1"
MODEL_NAME="Qwen/Qwen3-30B-A3B-Thinking-2507-FP8"
# MODEL_NAME="Qwen/Qwen3-235B-A22B-Thinking-2507-FP8"
HOST="0.0.0.0"
PORT=8006

DATA_PARALLEL_SIZE=2
DATA_PARALLEL_SIZE_LOCAL=2
LEADER_ADDRESS="192.168.5.45"
# LEADER_ADDRESS="172.18.0.3"

NUM_REDUNDANT_EXPERTS=16
EPLB_WINDOW_SIZE=1000
EPLB_STEP_INTERVAL=3000
MAX_MODEL_LEN=16384
GPU_MEMORY_UTILIZATION=0.9

export DG_JIT_NVCC_COMPILER=/usr/local/cuda-12.8/bin/nvcc
export CUDA_HOME='/usr/local/cuda-12.8'

export VLLM_USE_V1=1
export VLLM_ALL2ALL_BACKEND="pplx"
# export VLLM_ALL2ALL_BACKEND="deepep_low_latency"
export VLLM_USE_DEEP_GEMM=1
# export VLLM_ATTENTION_BACKEND="TRITON_MLA"

# Launch the vLLM server
vllm serve $MODEL_NAME --trust-remote-code \
--disable-log-requests \
--host $HOST \
--port $PORT \
--tensor-parallel-size 1 \
--gpu-memory-utilization $GPU_MEMORY_UTILIZATION \
--max-model-len $MAX_MODEL_LEN \
--no-enable-prefix-caching \
--enable-expert-parallel \
--enable-elastic-ep \
--enable-eplb \
--eplb-config.num_redundant_experts $NUM_REDUNDANT_EXPERTS \
--eplb-config.window_size $EPLB_WINDOW_SIZE \
--eplb-config.step_interval $EPLB_STEP_INTERVAL \
--data-parallel-backend ray \
--data-parallel-size $DATA_PARALLEL_SIZE \
--data-parallel-size-local $DATA_PARALLEL_SIZE_LOCAL \
--data-parallel-address $LEADER_ADDRESS \
--data-parallel-rpc-port 9876 \
--data-parallel-start-rank 0
50 changes: 48 additions & 2 deletions vllm/config/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,9 @@ class ParallelConfig:
disable_custom_all_reduce: bool = False
"""Disable the custom all-reduce kernel and fall back to NCCL."""

enable_elastic_ep: bool = False
"""Enable elastic expert parallelism with stateless NCCL groups for DP/EP."""

enable_dbo: bool = False
"""Enable microbatching for the model executor."""

Expand Down Expand Up @@ -188,6 +191,21 @@ class is dynamically inherited by the worker class. This is used to inject
Set to be private as it's not intended to be configured by users.
"""

_stateless_world_group_port_list: list[int] = field(default_factory=list)
"""List of open ports for stateless world group when enable_elastic_ep is True.
Set to be private as it's not intended to be configured by users.
"""

_stateless_dp_group_port_list: list[int] = field(default_factory=list)
"""List of open ports for stateless DP groups when enable_elastic_ep is True.
Set to be private as it's not intended to be configured by users.
"""

_stateless_ep_group_port_list: list[int] = field(default_factory=list)
"""List of open ports for stateless EP groups when enable_elastic_ep is True.
Set to be private as it's not intended to be configured by users.
"""

decode_context_parallel_size: int = 1
"""Number of decode context parallel groups, because the world size does
not change by dcp, it simply reuse the GPUs of TP group, and tp_size
Expand Down Expand Up @@ -235,7 +253,16 @@ def get_next_dp_init_port(self) -> int:

return answer

def stateless_init_dp_group(self) -> ProcessGroup:
def get_next_stateless_world_group_port(self) -> list[int]:
return self._stateless_world_group_port_list.pop(0)

def get_next_stateless_dp_group_port(self) -> list[int]:
return self._stateless_dp_group_port_list.pop(0)

def get_next_stateless_ep_group_port(self) -> list[int]:
return self._stateless_ep_group_port_list.pop(0)
Comment on lines +256 to +263
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

These methods use pop(0) to retrieve a port from a list without checking if the list is empty. If the port lists (_stateless_world_group_port_list, _stateless_dp_group_port_list, _stateless_ep_group_port_list) are exhausted for any reason, this will raise an IndexError and crash the process. While the logic in __post_init__ seems to pre-allocate the necessary ports, this design is fragile. A more robust implementation would be to check if the list is empty before popping and raise a more informative error message.


def stateless_init_dp_group(self, return_store: bool = False) -> ProcessGroup:
# NOTE: In high-concurrency scenarios multiple processes
# can pick the same (currently free) port through a race
# condition when calling `get_open_port()`. When the first
Expand All @@ -258,7 +285,8 @@ def stateless_init_dp_group(self) -> ProcessGroup:
self.get_next_dp_init_port(),
self.data_parallel_rank,
self.data_parallel_size,
backend="gloo")
backend="gloo",
return_store=return_store)
except DistNetworkError as e:
# We only want to retry when the root cause is EADDRINUSE.
if "EADDRINUSE" in str(e):
Expand Down Expand Up @@ -351,6 +379,24 @@ def __post_init__(self) -> None:
self.world_size = self.pipeline_parallel_size * \
self.tensor_parallel_size

# Initialize stateless group ports for elastic EP
if self.enable_elastic_ep:
num_world_groups = 1
num_dp_groups = max(1, self.world_size_across_dp // self.data_parallel_size)
num_ep_groups = max(1, self.world_size_across_dp // (self.data_parallel_size * self.tensor_parallel_size))

total_ports_needed = (num_world_groups + num_dp_groups + num_ep_groups) * 3

if not self._stateless_world_group_port_list:
all_ports = get_open_ports_list(total_ports_needed + 5)
self._data_parallel_master_port_list = all_ports[-5:]
all_ports = all_ports[:-5]
self._stateless_world_group_port_list = [all_ports[i:i+3] for i in range(0, num_world_groups * 3, 3)]
start_idx = num_world_groups * 3
self._stateless_dp_group_port_list = [all_ports[i:i+3] for i in range(start_idx, start_idx + num_dp_groups * 3, 3)]
start_idx += num_dp_groups * 3
self._stateless_ep_group_port_list = [all_ports[i:i+3] for i in range(start_idx, start_idx + num_ep_groups * 3, 3)]

if self.data_parallel_size_local > self.data_parallel_size:
raise ValueError(
f"data_parallel_size_local ({self.data_parallel_size_local}) "
Expand Down
53 changes: 30 additions & 23 deletions vllm/distributed/device_communicators/all2all.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ class NaiveAll2AllManager(All2AllManagerBase):
debugging.
"""

def __init__(self, cpu_group):
super().__init__(cpu_group)
def __init__(self, cpu_group, tcp_store_group=None):
super().__init__(cpu_group, tcp_store_group)

def naive_multicast(self, x: torch.Tensor,
cu_tokens_across_dp_cpu: torch.Tensor):
Expand Down Expand Up @@ -76,8 +76,8 @@ class AgRsAll2AllManager(All2AllManagerBase):
all-gather (dispatch) and reduce-scatter (combine).
"""

def __init__(self, cpu_group):
super().__init__(cpu_group)
def __init__(self, cpu_group, tcp_store_group=None):
super().__init__(cpu_group, tcp_store_group)

def dispatch(self, hidden_states: torch.Tensor,
router_logits: torch.Tensor):
Expand Down Expand Up @@ -113,14 +113,16 @@ class PPLXAll2AllManager(All2AllManagerBase):
All2All communication based on PPLX kernels.
"""

def __init__(self, cpu_group):
def __init__(self, cpu_group, tcp_store_group=None):
assert has_pplx(
), "pplx_kernels not found. Please follow https://github.com/vllm-project/vllm/blob/main/tools/ep_kernels/README.md to install pplx_kernels." # noqa
super().__init__(cpu_group)
super().__init__(cpu_group, tcp_store_group)

if self.internode:
# inter-node communication needs nvshmem,
# intra-node communication uses p2p mapping directly
self.nvshmem_initialized = False
self.handle_cache = Cache()

def get_handle(self, kwargs):
if self.internode and not self.nvshmem_initialized:
from pplx_kernels.nvshmem import (nvshmem_alloc_empty_unique_id,
nvshmem_get_unique_id,
nvshmem_init)
Expand All @@ -129,15 +131,18 @@ def __init__(self, cpu_group):
"rank=%d, world size=%d", self.rank, self.world_size)
uid = nvshmem_get_unique_id(
) if self.rank == 0 else nvshmem_alloc_empty_unique_id()
dist.broadcast(uid,
src=dist.get_process_group_ranks(self.cpu_group)[0],
group=self.cpu_group)

if self.tcp_store_group is not None:
uid = self.tcp_store_group.broadcast_obj(uid, src=0)
else:
dist.broadcast(uid,
src=dist.get_process_group_ranks(self.cpu_group)[0],
group=self.cpu_group)

logger.debug("PPLX NVSHMEM UID = %s", uid)
nvshmem_init(uid, self.rank, self.world_size)
self.nvshmem_initialized = True

self.handle_cache = Cache()

def get_handle(self, kwargs):
import pplx_kernels as pplx
return self.handle_cache.get_or_create(
kwargs, pplx.AllToAll.internode
Expand Down Expand Up @@ -166,10 +171,10 @@ class DeepEPAll2AllManagerBase(All2AllManagerBase):
All2All communication based on DeepEP High-Throughput kernels.
"""

def __init__(self, cpu_group):
def __init__(self, cpu_group, tcp_store_group=None):
assert has_deep_ep(
), "DeepEP kernels not found. Please follow https://github.com/vllm-project/vllm/blob/main/tools/ep_kernels/README.md to install DeepEP kernels." # noqa
super().__init__(cpu_group)
super().__init__(cpu_group, tcp_store_group)
self.handle_cache = Cache()

# This is the DeepEP default. Stick to it till we can establish
Expand All @@ -195,8 +200,8 @@ class DeepEPHTAll2AllManager(DeepEPAll2AllManagerBase):
All2All communication based on DeepEP High-Throughput kernels.
"""

def __init__(self, cpu_group):
super().__init__(cpu_group)
def __init__(self, cpu_group, tcp_store_group=None):
super().__init__(cpu_group, tcp_store_group)

def _make_all2all_kwargs(self) -> dict[Any, Any]:
# Defaults for internode and intranode are taken from DeepEP tests.
Expand Down Expand Up @@ -243,8 +248,8 @@ class DeepEPLLAll2AllManager(DeepEPAll2AllManagerBase):
All2All communication based on DeepEP Low-Latency kernels.
"""

def __init__(self, cpu_group):
super().__init__(cpu_group)
def __init__(self, cpu_group, tcp_store_group=None):
super().__init__(cpu_group, tcp_store_group)

def _make_all2all_kwargs(
self,
Expand All @@ -265,7 +270,8 @@ def _make_all2all_kwargs(
import deep_ep

# Defaults for internode and intranode are taken from DeepEP tests.
num_nvl_bytes = 1024 * 1024 * 1024
# num_nvl_bytes = 1024 * 1024 * 1024
num_nvl_bytes = 0
num_qps_per_rank = num_local_experts
num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint(
num_max_dispatch_tokens_per_rank=max_num_tokens_per_dp_rank,
Expand All @@ -278,7 +284,8 @@ def _make_all2all_kwargs(
num_nvl_bytes=num_nvl_bytes,
num_rdma_bytes=num_rdma_bytes,
low_latency_mode=True,
num_qps_per_rank=num_qps_per_rank)
num_qps_per_rank=num_qps_per_rank,
allow_mnnvl=True)

def get_handle(self, kwargs):
"""
Expand Down
44 changes: 32 additions & 12 deletions vllm/distributed/device_communicators/base_device_communicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,9 @@ def get_or_create(self, kwargs, func):

class All2AllManagerBase:

def __init__(self, cpu_group):
def __init__(self, cpu_group, tcp_store_group=None):
self.cpu_group = cpu_group
self.tcp_store_group = tcp_store_group

# compute some common properties
from vllm.distributed.parallel_state import (get_dp_group,
Expand All @@ -44,12 +45,15 @@ def __init__(self, cpu_group):
# when we create this object
self.dp_rank = self.dp_group.rank_in_group
self.dp_world_size = self.dp_group.world_size
self.rank = dist.get_rank(cpu_group)
self.world_size = dist.get_world_size(cpu_group)
self.rank = cpu_group.rank()
self.world_size = cpu_group.size()

# all2all communication often has separate implementations for
# intra-node and inter-node communication
self.internode = not all(in_the_same_node_as(cpu_group, source_rank=0))
if tcp_store_group is None:
self.internode = not all(in_the_same_node_as(cpu_group, source_rank=0))
else:
self.internode = not all(in_the_same_node_as(tcp_store_group, source_rank=0))

def get_handle(self, kwargs):
# get a handle for the all2all communication,
Expand Down Expand Up @@ -83,18 +87,34 @@ def __init__(self,
cpu_group: ProcessGroup,
device: Optional[torch.device] = None,
device_group: Optional[ProcessGroup] = None,
unique_name: str = ""):
unique_name: str = "",
global_ranks: Optional[list[int]] = None,
global_world_size: Optional[int] = None):
self.device = device or torch.device("cpu")
self.cpu_group = cpu_group
self.device_group = device_group
self.unique_name = unique_name
self.rank = dist.get_rank(cpu_group)
self.world_size = dist.get_world_size(cpu_group)
self.ranks = dist.get_process_group_ranks(cpu_group)
self.global_rank = dist.get_rank()
self.global_world_size = dist.get_world_size()
self.rank_in_group = dist.get_group_rank(self.cpu_group,
self.global_rank)

# Check if this is a stateless process group
from torch.distributed.distributed_c10d import _world
is_stateless = _world.pg_map.get(cpu_group, None) is None
Comment on lines +98 to +100
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The check _world.pg_map.get(cpu_group, None) is None relies on an internal, undocumented implementation detail of torch.distributed to determine if a process group is stateless. This is a brittle approach that could break with future PyTorch updates. It would be more robust to use an explicit mechanism to identify stateless groups, such as a custom process group class that carries this information, or passing a flag during initialization.


if is_stateless:
# For stateless groups, we can't use torch.distributed methods
self.rank = cpu_group.rank()
self.world_size = cpu_group.size()
self.ranks = global_ranks
self.global_rank = self.ranks[self.rank]
self.global_world_size = global_world_size
self.rank_in_group = self.rank
else:
self.rank = dist.get_rank(cpu_group)
self.world_size = dist.get_world_size(cpu_group)
self.ranks = dist.get_process_group_ranks(cpu_group)
self.global_rank = dist.get_rank()
self.global_world_size = dist.get_world_size()
self.rank_in_group = dist.get_group_rank(self.cpu_group,
self.global_rank)

use_ep = False
from vllm.config import get_current_vllm_config
Expand Down
Loading