Skip to content

Commit 1f51ac5

Browse files
committed
eep phase2 init
stateless group elastic EP: support CUDA graph + peer weights transfer update state filter small fix bench script small fix fix intra-node to inter-node scaling remove unused code
1 parent 06a4133 commit 1f51ac5

27 files changed

+2123
-551
lines changed

experimental/bench.sh

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
#!/bin/bash
2+
3+
# MODEL_NAME="deepseek-ai/DeepSeek-V3.1"
4+
MODEL_NAME="Qwen/Qwen3-30B-A3B-Thinking-2507-FP8"
5+
# MODEL_NAME="Qwen/Qwen3-235B-A22B-Thinking-2507-FP8"
6+
HOST="localhost"
7+
PORT=8006
8+
9+
vllm bench serve \
10+
--model $MODEL_NAME \
11+
--host $HOST \
12+
--port $PORT \
13+
--dataset-name random \
14+
--random-input-len 128 \
15+
--random-output-len 128 \
16+
--num-prompts 512

experimental/scale.sh

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
#!/bin/bash
2+
HOST="localhost"
3+
PORT=8006
4+
5+
python examples/online_serving/elastic_ep/scale.py --host $HOST --port $PORT --new-dp-size 4

experimental/serve.sh

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
#!/bin/bash
2+
3+
# MODEL_NAME="deepseek-ai/DeepSeek-V3.1"
4+
MODEL_NAME="Qwen/Qwen3-30B-A3B-Thinking-2507-FP8"
5+
# MODEL_NAME="Qwen/Qwen3-235B-A22B-Thinking-2507-FP8"
6+
HOST="0.0.0.0"
7+
PORT=8006
8+
9+
DATA_PARALLEL_SIZE=2
10+
DATA_PARALLEL_SIZE_LOCAL=2
11+
LEADER_ADDRESS="192.168.5.45"
12+
# LEADER_ADDRESS="172.18.0.3"
13+
14+
NUM_REDUNDANT_EXPERTS=16
15+
EPLB_WINDOW_SIZE=1000
16+
EPLB_STEP_INTERVAL=3000
17+
MAX_MODEL_LEN=16384
18+
GPU_MEMORY_UTILIZATION=0.9
19+
20+
export DG_JIT_NVCC_COMPILER=/usr/local/cuda-12.8/bin/nvcc
21+
export CUDA_HOME='/usr/local/cuda-12.8'
22+
23+
export VLLM_USE_V1=1
24+
export VLLM_ALL2ALL_BACKEND="pplx"
25+
# export VLLM_ALL2ALL_BACKEND="deepep_low_latency"
26+
export VLLM_USE_DEEP_GEMM=1
27+
# export VLLM_ATTENTION_BACKEND="TRITON_MLA"
28+
29+
# Launch the vLLM server
30+
vllm serve $MODEL_NAME --trust-remote-code \
31+
--disable-log-requests \
32+
--host $HOST \
33+
--port $PORT \
34+
--tensor-parallel-size 1 \
35+
--gpu-memory-utilization $GPU_MEMORY_UTILIZATION \
36+
--max-model-len $MAX_MODEL_LEN \
37+
--no-enable-prefix-caching \
38+
--enable-expert-parallel \
39+
--enable-elastic-ep \
40+
--enable-eplb \
41+
--eplb-config.num_redundant_experts $NUM_REDUNDANT_EXPERTS \
42+
--eplb-config.window_size $EPLB_WINDOW_SIZE \
43+
--eplb-config.step_interval $EPLB_STEP_INTERVAL \
44+
--data-parallel-backend ray \
45+
--data-parallel-size $DATA_PARALLEL_SIZE \
46+
--data-parallel-size-local $DATA_PARALLEL_SIZE_LOCAL \
47+
--data-parallel-address $LEADER_ADDRESS \
48+
--data-parallel-rpc-port 9876 \
49+
--data-parallel-start-rank 0

vllm/config/parallel.py

Lines changed: 48 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,9 @@ class ParallelConfig:
137137
disable_custom_all_reduce: bool = False
138138
"""Disable the custom all-reduce kernel and fall back to NCCL."""
139139

140+
enable_elastic_ep: bool = False
141+
"""Enable elastic expert parallelism with stateless NCCL groups for DP/EP."""
142+
140143
enable_dbo: bool = False
141144
"""Enable microbatching for the model executor."""
142145

@@ -188,6 +191,21 @@ class is dynamically inherited by the worker class. This is used to inject
188191
Set to be private as it's not intended to be configured by users.
189192
"""
190193

194+
_stateless_world_group_port_list: list[int] = field(default_factory=list)
195+
"""List of open ports for stateless world group when enable_elastic_ep is True.
196+
Set to be private as it's not intended to be configured by users.
197+
"""
198+
199+
_stateless_dp_group_port_list: list[int] = field(default_factory=list)
200+
"""List of open ports for stateless DP groups when enable_elastic_ep is True.
201+
Set to be private as it's not intended to be configured by users.
202+
"""
203+
204+
_stateless_ep_group_port_list: list[int] = field(default_factory=list)
205+
"""List of open ports for stateless EP groups when enable_elastic_ep is True.
206+
Set to be private as it's not intended to be configured by users.
207+
"""
208+
191209
decode_context_parallel_size: int = 1
192210
"""Number of decode context parallel groups, because the world size does
193211
not change by dcp, it simply reuse the GPUs of TP group, and tp_size
@@ -235,7 +253,16 @@ def get_next_dp_init_port(self) -> int:
235253

236254
return answer
237255

238-
def stateless_init_dp_group(self) -> ProcessGroup:
256+
def get_next_stateless_world_group_port(self) -> list[int]:
257+
return self._stateless_world_group_port_list.pop(0)
258+
259+
def get_next_stateless_dp_group_port(self) -> list[int]:
260+
return self._stateless_dp_group_port_list.pop(0)
261+
262+
def get_next_stateless_ep_group_port(self) -> list[int]:
263+
return self._stateless_ep_group_port_list.pop(0)
264+
265+
def stateless_init_dp_group(self, return_store: bool = False) -> ProcessGroup:
239266
# NOTE: In high-concurrency scenarios multiple processes
240267
# can pick the same (currently free) port through a race
241268
# condition when calling `get_open_port()`. When the first
@@ -258,7 +285,8 @@ def stateless_init_dp_group(self) -> ProcessGroup:
258285
self.get_next_dp_init_port(),
259286
self.data_parallel_rank,
260287
self.data_parallel_size,
261-
backend="gloo")
288+
backend="gloo",
289+
return_store=return_store)
262290
except DistNetworkError as e:
263291
# We only want to retry when the root cause is EADDRINUSE.
264292
if "EADDRINUSE" in str(e):
@@ -351,6 +379,24 @@ def __post_init__(self) -> None:
351379
self.world_size = self.pipeline_parallel_size * \
352380
self.tensor_parallel_size
353381

382+
# Initialize stateless group ports for elastic EP
383+
if self.enable_elastic_ep:
384+
num_world_groups = 1
385+
num_dp_groups = max(1, self.world_size_across_dp // self.data_parallel_size)
386+
num_ep_groups = max(1, self.world_size_across_dp // (self.data_parallel_size * self.tensor_parallel_size))
387+
388+
total_ports_needed = (num_world_groups + num_dp_groups + num_ep_groups) * 3
389+
390+
if not self._stateless_world_group_port_list:
391+
all_ports = get_open_ports_list(total_ports_needed + 5)
392+
self._data_parallel_master_port_list = all_ports[-5:]
393+
all_ports = all_ports[:-5]
394+
self._stateless_world_group_port_list = [all_ports[i:i+3] for i in range(0, num_world_groups * 3, 3)]
395+
start_idx = num_world_groups * 3
396+
self._stateless_dp_group_port_list = [all_ports[i:i+3] for i in range(start_idx, start_idx + num_dp_groups * 3, 3)]
397+
start_idx += num_dp_groups * 3
398+
self._stateless_ep_group_port_list = [all_ports[i:i+3] for i in range(start_idx, start_idx + num_ep_groups * 3, 3)]
399+
354400
if self.data_parallel_size_local > self.data_parallel_size:
355401
raise ValueError(
356402
f"data_parallel_size_local ({self.data_parallel_size_local}) "

vllm/distributed/device_communicators/all2all.py

Lines changed: 30 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@ class NaiveAll2AllManager(All2AllManagerBase):
2323
debugging.
2424
"""
2525

26-
def __init__(self, cpu_group):
27-
super().__init__(cpu_group)
26+
def __init__(self, cpu_group, tcp_store_group=None):
27+
super().__init__(cpu_group, tcp_store_group)
2828

2929
def naive_multicast(self, x: torch.Tensor,
3030
cu_tokens_across_dp_cpu: torch.Tensor):
@@ -76,8 +76,8 @@ class AgRsAll2AllManager(All2AllManagerBase):
7676
all-gather (dispatch) and reduce-scatter (combine).
7777
"""
7878

79-
def __init__(self, cpu_group):
80-
super().__init__(cpu_group)
79+
def __init__(self, cpu_group, tcp_store_group=None):
80+
super().__init__(cpu_group, tcp_store_group)
8181

8282
def dispatch(self, hidden_states: torch.Tensor,
8383
router_logits: torch.Tensor):
@@ -113,14 +113,16 @@ class PPLXAll2AllManager(All2AllManagerBase):
113113
All2All communication based on PPLX kernels.
114114
"""
115115

116-
def __init__(self, cpu_group):
116+
def __init__(self, cpu_group, tcp_store_group=None):
117117
assert has_pplx(
118118
), "pplx_kernels not found. Please follow https://github.com/vllm-project/vllm/blob/main/tools/ep_kernels/README.md to install pplx_kernels." # noqa
119-
super().__init__(cpu_group)
119+
super().__init__(cpu_group, tcp_store_group)
120120

121-
if self.internode:
122-
# inter-node communication needs nvshmem,
123-
# intra-node communication uses p2p mapping directly
121+
self.nvshmem_initialized = False
122+
self.handle_cache = Cache()
123+
124+
def get_handle(self, kwargs):
125+
if self.internode and not self.nvshmem_initialized:
124126
from pplx_kernels.nvshmem import (nvshmem_alloc_empty_unique_id,
125127
nvshmem_get_unique_id,
126128
nvshmem_init)
@@ -129,15 +131,18 @@ def __init__(self, cpu_group):
129131
"rank=%d, world size=%d", self.rank, self.world_size)
130132
uid = nvshmem_get_unique_id(
131133
) if self.rank == 0 else nvshmem_alloc_empty_unique_id()
132-
dist.broadcast(uid,
133-
src=dist.get_process_group_ranks(self.cpu_group)[0],
134-
group=self.cpu_group)
134+
135+
if self.tcp_store_group is not None:
136+
uid = self.tcp_store_group.broadcast_obj(uid, src=0)
137+
else:
138+
dist.broadcast(uid,
139+
src=dist.get_process_group_ranks(self.cpu_group)[0],
140+
group=self.cpu_group)
141+
135142
logger.debug("PPLX NVSHMEM UID = %s", uid)
136143
nvshmem_init(uid, self.rank, self.world_size)
144+
self.nvshmem_initialized = True
137145

138-
self.handle_cache = Cache()
139-
140-
def get_handle(self, kwargs):
141146
import pplx_kernels as pplx
142147
return self.handle_cache.get_or_create(
143148
kwargs, pplx.AllToAll.internode
@@ -166,10 +171,10 @@ class DeepEPAll2AllManagerBase(All2AllManagerBase):
166171
All2All communication based on DeepEP High-Throughput kernels.
167172
"""
168173

169-
def __init__(self, cpu_group):
174+
def __init__(self, cpu_group, tcp_store_group=None):
170175
assert has_deep_ep(
171176
), "DeepEP kernels not found. Please follow https://github.com/vllm-project/vllm/blob/main/tools/ep_kernels/README.md to install DeepEP kernels." # noqa
172-
super().__init__(cpu_group)
177+
super().__init__(cpu_group, tcp_store_group)
173178
self.handle_cache = Cache()
174179

175180
# This is the DeepEP default. Stick to it till we can establish
@@ -195,8 +200,8 @@ class DeepEPHTAll2AllManager(DeepEPAll2AllManagerBase):
195200
All2All communication based on DeepEP High-Throughput kernels.
196201
"""
197202

198-
def __init__(self, cpu_group):
199-
super().__init__(cpu_group)
203+
def __init__(self, cpu_group, tcp_store_group=None):
204+
super().__init__(cpu_group, tcp_store_group)
200205

201206
def _make_all2all_kwargs(self) -> dict[Any, Any]:
202207
# Defaults for internode and intranode are taken from DeepEP tests.
@@ -243,8 +248,8 @@ class DeepEPLLAll2AllManager(DeepEPAll2AllManagerBase):
243248
All2All communication based on DeepEP Low-Latency kernels.
244249
"""
245250

246-
def __init__(self, cpu_group):
247-
super().__init__(cpu_group)
251+
def __init__(self, cpu_group, tcp_store_group=None):
252+
super().__init__(cpu_group, tcp_store_group)
248253

249254
def _make_all2all_kwargs(
250255
self,
@@ -265,7 +270,8 @@ def _make_all2all_kwargs(
265270
import deep_ep
266271

267272
# Defaults for internode and intranode are taken from DeepEP tests.
268-
num_nvl_bytes = 1024 * 1024 * 1024
273+
# num_nvl_bytes = 1024 * 1024 * 1024
274+
num_nvl_bytes = 0
269275
num_qps_per_rank = num_local_experts
270276
num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint(
271277
num_max_dispatch_tokens_per_rank=max_num_tokens_per_dp_rank,
@@ -278,7 +284,8 @@ def _make_all2all_kwargs(
278284
num_nvl_bytes=num_nvl_bytes,
279285
num_rdma_bytes=num_rdma_bytes,
280286
low_latency_mode=True,
281-
num_qps_per_rank=num_qps_per_rank)
287+
num_qps_per_rank=num_qps_per_rank,
288+
allow_mnnvl=True)
282289

283290
def get_handle(self, kwargs):
284291
"""

vllm/distributed/device_communicators/base_device_communicator.py

Lines changed: 32 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,9 @@ def get_or_create(self, kwargs, func):
2929

3030
class All2AllManagerBase:
3131

32-
def __init__(self, cpu_group):
32+
def __init__(self, cpu_group, tcp_store_group=None):
3333
self.cpu_group = cpu_group
34+
self.tcp_store_group = tcp_store_group
3435

3536
# compute some common properties
3637
from vllm.distributed.parallel_state import (get_dp_group,
@@ -44,12 +45,15 @@ def __init__(self, cpu_group):
4445
# when we create this object
4546
self.dp_rank = self.dp_group.rank_in_group
4647
self.dp_world_size = self.dp_group.world_size
47-
self.rank = dist.get_rank(cpu_group)
48-
self.world_size = dist.get_world_size(cpu_group)
48+
self.rank = cpu_group.rank()
49+
self.world_size = cpu_group.size()
4950

5051
# all2all communication often has separate implementations for
5152
# intra-node and inter-node communication
52-
self.internode = not all(in_the_same_node_as(cpu_group, source_rank=0))
53+
if tcp_store_group is None:
54+
self.internode = not all(in_the_same_node_as(cpu_group, source_rank=0))
55+
else:
56+
self.internode = not all(in_the_same_node_as(tcp_store_group, source_rank=0))
5357

5458
def get_handle(self, kwargs):
5559
# get a handle for the all2all communication,
@@ -83,18 +87,34 @@ def __init__(self,
8387
cpu_group: ProcessGroup,
8488
device: Optional[torch.device] = None,
8589
device_group: Optional[ProcessGroup] = None,
86-
unique_name: str = ""):
90+
unique_name: str = "",
91+
global_ranks: Optional[list[int]] = None,
92+
global_world_size: Optional[int] = None):
8793
self.device = device or torch.device("cpu")
8894
self.cpu_group = cpu_group
8995
self.device_group = device_group
9096
self.unique_name = unique_name
91-
self.rank = dist.get_rank(cpu_group)
92-
self.world_size = dist.get_world_size(cpu_group)
93-
self.ranks = dist.get_process_group_ranks(cpu_group)
94-
self.global_rank = dist.get_rank()
95-
self.global_world_size = dist.get_world_size()
96-
self.rank_in_group = dist.get_group_rank(self.cpu_group,
97-
self.global_rank)
97+
98+
# Check if this is a stateless process group
99+
from torch.distributed.distributed_c10d import _world
100+
is_stateless = _world.pg_map.get(cpu_group, None) is None
101+
102+
if is_stateless:
103+
# For stateless groups, we can't use torch.distributed methods
104+
self.rank = cpu_group.rank()
105+
self.world_size = cpu_group.size()
106+
self.ranks = global_ranks
107+
self.global_rank = self.ranks[self.rank]
108+
self.global_world_size = global_world_size
109+
self.rank_in_group = self.rank
110+
else:
111+
self.rank = dist.get_rank(cpu_group)
112+
self.world_size = dist.get_world_size(cpu_group)
113+
self.ranks = dist.get_process_group_ranks(cpu_group)
114+
self.global_rank = dist.get_rank()
115+
self.global_world_size = dist.get_world_size()
116+
self.rank_in_group = dist.get_group_rank(self.cpu_group,
117+
self.global_rank)
98118

99119
use_ep = False
100120
from vllm.config import get_current_vllm_config

0 commit comments

Comments
 (0)