Skip to content

Commit 194bad1

Browse files
committed
eep phase2 init
stateless group elastic EP: support CUDA graph + peer weights transfer update state filter small fix bench script small fix fix intra-node to inter-node scaling remove unused code
1 parent 185d8ed commit 194bad1

27 files changed

+2043
-549
lines changed

experimental/bench.sh

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
#!/bin/bash
2+
3+
# MODEL_NAME="deepseek-ai/DeepSeek-V3.1"
4+
MODEL_NAME="Qwen/Qwen3-30B-A3B-Thinking-2507-FP8"
5+
# MODEL_NAME="Qwen/Qwen3-235B-A22B-Thinking-2507-FP8"
6+
HOST="localhost"
7+
PORT=8006
8+
9+
vllm bench serve \
10+
--model $MODEL_NAME \
11+
--host $HOST \
12+
--port $PORT \
13+
--dataset-name random \
14+
--random-input-len 128 \
15+
--random-output-len 128 \
16+
--num-prompts 512

experimental/scale.sh

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
#!/bin/bash
2+
HOST="localhost"
3+
PORT=8006
4+
5+
python examples/online_serving/elastic_ep/scale.py --host $HOST --port $PORT --new-dp-size 4

experimental/serve.sh

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
#!/bin/bash
2+
3+
# MODEL_NAME="deepseek-ai/DeepSeek-V3.1"
4+
MODEL_NAME="Qwen/Qwen3-30B-A3B-Thinking-2507-FP8"
5+
# MODEL_NAME="Qwen/Qwen3-235B-A22B-Thinking-2507-FP8"
6+
HOST="0.0.0.0"
7+
PORT=8006
8+
9+
DATA_PARALLEL_SIZE=2
10+
DATA_PARALLEL_SIZE_LOCAL=2
11+
LEADER_ADDRESS="192.168.5.45"
12+
# LEADER_ADDRESS="172.18.0.3"
13+
14+
NUM_REDUNDANT_EXPERTS=16
15+
EPLB_WINDOW_SIZE=1000
16+
EPLB_STEP_INTERVAL=3000
17+
MAX_MODEL_LEN=16384
18+
GPU_MEMORY_UTILIZATION=0.9
19+
20+
export DG_JIT_NVCC_COMPILER=/usr/local/cuda-12.8/bin/nvcc
21+
export CUDA_HOME='/usr/local/cuda-12.8'
22+
23+
export VLLM_USE_V1=1
24+
export VLLM_ALL2ALL_BACKEND="pplx"
25+
# export VLLM_ALL2ALL_BACKEND="deepep_low_latency"
26+
export VLLM_USE_DEEP_GEMM=1
27+
# export VLLM_ATTENTION_BACKEND="TRITON_MLA"
28+
29+
# Launch the vLLM server
30+
vllm serve $MODEL_NAME --trust-remote-code \
31+
--disable-log-requests \
32+
--host $HOST \
33+
--port $PORT \
34+
--tensor-parallel-size 1 \
35+
--gpu-memory-utilization $GPU_MEMORY_UTILIZATION \
36+
--max-model-len $MAX_MODEL_LEN \
37+
--no-enable-prefix-caching \
38+
--enable-expert-parallel \
39+
--enable-elastic-ep \
40+
--enable-eplb \
41+
--eplb-config.num_redundant_experts $NUM_REDUNDANT_EXPERTS \
42+
--eplb-config.window_size $EPLB_WINDOW_SIZE \
43+
--eplb-config.step_interval $EPLB_STEP_INTERVAL \
44+
--data-parallel-backend ray \
45+
--data-parallel-size $DATA_PARALLEL_SIZE \
46+
--data-parallel-size-local $DATA_PARALLEL_SIZE_LOCAL \
47+
--data-parallel-address $LEADER_ADDRESS \
48+
--data-parallel-rpc-port 9876 \
49+
--data-parallel-start-rank 0

vllm/config/parallel.py

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,9 @@ class ParallelConfig:
138138
disable_custom_all_reduce: bool = False
139139
"""Disable the custom all-reduce kernel and fall back to NCCL."""
140140

141+
enable_elastic_ep: bool = False
142+
"""Enable elastic expert parallelism with stateless NCCL groups for DP/EP."""
143+
141144
enable_dbo: bool = False
142145
"""Enable dual batch overlap for the model executor."""
143146

@@ -199,6 +202,21 @@ class is dynamically inherited by the worker class. This is used to inject
199202
Set to be private as it's not intended to be configured by users.
200203
"""
201204

205+
_stateless_world_group_port_list: list[int] = field(default_factory=list)
206+
"""List of open ports for stateless world group when enable_elastic_ep is True.
207+
Set to be private as it's not intended to be configured by users.
208+
"""
209+
210+
_stateless_dp_group_port_list: list[int] = field(default_factory=list)
211+
"""List of open ports for stateless DP groups when enable_elastic_ep is True.
212+
Set to be private as it's not intended to be configured by users.
213+
"""
214+
215+
_stateless_ep_group_port_list: list[int] = field(default_factory=list)
216+
"""List of open ports for stateless EP groups when enable_elastic_ep is True.
217+
Set to be private as it's not intended to be configured by users.
218+
"""
219+
202220
decode_context_parallel_size: int = 1
203221
"""Number of decode context parallel groups, because the world size does
204222
not change by dcp, it simply reuse the GPUs of TP group, and tp_size
@@ -246,7 +264,16 @@ def get_next_dp_init_port(self) -> int:
246264

247265
return answer
248266

249-
def stateless_init_dp_group(self) -> ProcessGroup:
267+
def get_next_stateless_world_group_port(self) -> list[int]:
268+
return self._stateless_world_group_port_list.pop(0)
269+
270+
def get_next_stateless_dp_group_port(self) -> list[int]:
271+
return self._stateless_dp_group_port_list.pop(0)
272+
273+
def get_next_stateless_ep_group_port(self) -> list[int]:
274+
return self._stateless_ep_group_port_list.pop(0)
275+
276+
def stateless_init_dp_group(self, return_store: bool = False) -> ProcessGroup:
250277
# NOTE: In high-concurrency scenarios multiple processes
251278
# can pick the same (currently free) port through a race
252279
# condition when calling `get_open_port()`. When the first
@@ -271,6 +298,7 @@ def stateless_init_dp_group(self) -> ProcessGroup:
271298
self.data_parallel_rank,
272299
self.data_parallel_size,
273300
backend="gloo",
301+
return_store=return_store
274302
)
275303
except DistNetworkError as e:
276304
# We only want to retry when the root cause is EADDRINUSE.
@@ -387,6 +415,24 @@ def __post_init__(self) -> None:
387415
logger.info("Using external launcher for distributed inference.")
388416
self.world_size *= self.data_parallel_size
389417

418+
# Initialize stateless group ports for elastic EP
419+
if self.enable_elastic_ep:
420+
num_world_groups = 1
421+
num_dp_groups = max(1, self.world_size_across_dp // self.data_parallel_size)
422+
num_ep_groups = max(1, self.world_size_across_dp // (self.data_parallel_size * self.tensor_parallel_size))
423+
424+
total_ports_needed = (num_world_groups + num_dp_groups + num_ep_groups) * 3
425+
426+
if not self._stateless_world_group_port_list:
427+
all_ports = get_open_ports_list(total_ports_needed + 5)
428+
self._data_parallel_master_port_list = all_ports[-5:]
429+
all_ports = all_ports[:-5]
430+
self._stateless_world_group_port_list = [all_ports[i:i+3] for i in range(0, num_world_groups * 3, 3)]
431+
start_idx = num_world_groups * 3
432+
self._stateless_dp_group_port_list = [all_ports[i:i+3] for i in range(start_idx, start_idx + num_dp_groups * 3, 3)]
433+
start_idx += num_dp_groups * 3
434+
self._stateless_ep_group_port_list = [all_ports[i:i+3] for i in range(start_idx, start_idx + num_ep_groups * 3, 3)]
435+
390436
if self.data_parallel_size_local > self.data_parallel_size:
391437
raise ValueError(
392438
f"data_parallel_size_local ({self.data_parallel_size_local}) "

vllm/distributed/device_communicators/all2all.py

Lines changed: 25 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,8 @@ class NaiveAll2AllManager(All2AllManagerBase):
3030
debugging.
3131
"""
3232

33-
def __init__(self, cpu_group):
34-
super().__init__(cpu_group)
33+
def __init__(self, cpu_group, tcp_store_group=None):
34+
super().__init__(cpu_group, tcp_store_group)
3535

3636
def naive_multicast(
3737
self,
@@ -101,8 +101,8 @@ class AgRsAll2AllManager(All2AllManagerBase):
101101
all-gather (dispatch) and reduce-scatter (combine).
102102
"""
103103

104-
def __init__(self, cpu_group):
105-
super().__init__(cpu_group)
104+
def __init__(self, cpu_group, tcp_store_group=None):
105+
super().__init__(cpu_group, tcp_store_group)
106106

107107
def dispatch(
108108
self,
@@ -145,13 +145,16 @@ class PPLXAll2AllManager(All2AllManagerBase):
145145
All2All communication based on PPLX kernels.
146146
"""
147147

148-
def __init__(self, cpu_group):
148+
def __init__(self, cpu_group, tcp_store_group=None):
149149
assert has_pplx(), (
150150
"pplx_kernels not found. Please follow https://github.com/vllm-project/vllm/blob/main/tools/ep_kernels/README.md"
151151
" to install pplx_kernels."
152152
)
153-
super().__init__(cpu_group)
153+
super().__init__(cpu_group, tcp_store_group)
154+
self.nvshmem_initialized = False
155+
self.handle_cache = Cache()
154156

157+
def get_handle(self, kwargs):
155158
if self.internode:
156159
# inter-node communication needs nvshmem,
157160
# intra-node communication uses p2p mapping directly
@@ -171,17 +174,18 @@ def __init__(self, cpu_group):
171174
if self.rank == 0
172175
else nvshmem_alloc_empty_unique_id()
173176
)
174-
dist.broadcast(
175-
uid,
176-
src=dist.get_process_group_ranks(self.cpu_group)[0],
177-
group=self.cpu_group,
178-
)
177+
if self.tcp_store_group is not None:
178+
uid = self.tcp_store_group.broadcast_obj(uid, src=0)
179+
else:
180+
dist.broadcast(
181+
uid,
182+
src=dist.get_process_group_ranks(self.cpu_group)[0],
183+
group=self.cpu_group,
184+
)
179185
logger.debug("PPLX NVSHMEM UID = %s", uid)
180186
nvshmem_init(uid, self.rank, self.world_size)
181-
182-
self.handle_cache = Cache()
183-
184-
def get_handle(self, kwargs):
187+
self.nvshmem_initialized = True
188+
185189
import pplx_kernels as pplx
186190

187191
return self.handle_cache.get_or_create(
@@ -219,12 +223,12 @@ class DeepEPAll2AllManagerBase(All2AllManagerBase):
219223
All2All communication based on DeepEP High-Throughput kernels.
220224
"""
221225

222-
def __init__(self, cpu_group):
226+
def __init__(self, cpu_group, tcp_store_group=None):
223227
assert has_deep_ep(), (
224228
"DeepEP kernels not found. Please follow https://github.com/vllm-project/vllm/blob/main/tools/ep_kernels/README.md"
225229
" to install DeepEP kernels."
226230
) # noqa
227-
super().__init__(cpu_group)
231+
super().__init__(cpu_group, tcp_store_group)
228232
self.handle_cache = Cache()
229233

230234
# This is the DeepEP default. Stick to it till we can establish
@@ -256,8 +260,8 @@ class DeepEPHTAll2AllManager(DeepEPAll2AllManagerBase):
256260
All2All communication based on DeepEP High-Throughput kernels.
257261
"""
258262

259-
def __init__(self, cpu_group):
260-
super().__init__(cpu_group)
263+
def __init__(self, cpu_group, tcp_store_group=None):
264+
super().__init__(cpu_group, tcp_store_group)
261265

262266
def _make_all2all_kwargs(self) -> dict[Any, Any]:
263267
# Defaults for internode and intranode are taken from DeepEP tests.
@@ -313,8 +317,8 @@ class DeepEPLLAll2AllManager(DeepEPAll2AllManagerBase):
313317
All2All communication based on DeepEP Low-Latency kernels.
314318
"""
315319

316-
def __init__(self, cpu_group):
317-
super().__init__(cpu_group)
320+
def __init__(self, cpu_group, tcp_store_group=None):
321+
super().__init__(cpu_group, tcp_store_group)
318322

319323
def _make_all2all_kwargs(
320324
self,

vllm/distributed/device_communicators/base_device_communicator.py

Lines changed: 30 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,9 @@ class All2AllManagerBase:
3030
rank: int
3131
world_size: int
3232

33-
def __init__(self, cpu_group):
33+
def __init__(self, cpu_group, tcp_store_group=None):
3434
self.cpu_group = cpu_group
35+
self.tcp_store_group = tcp_store_group
3536

3637
# compute some common properties
3738
from vllm.distributed.parallel_state import (
@@ -48,12 +49,15 @@ def __init__(self, cpu_group):
4849
# when we create this object
4950
self.dp_rank = self.dp_group.rank_in_group
5051
self.dp_world_size = self.dp_group.world_size
51-
self.rank = dist.get_rank(cpu_group)
52-
self.world_size = dist.get_world_size(cpu_group)
52+
self.rank = cpu_group.rank()
53+
self.world_size = cpu_group.size()
5354

5455
# all2all communication often has separate implementations for
5556
# intra-node and inter-node communication
56-
self.internode = not all(in_the_same_node_as(cpu_group, source_rank=0))
57+
if tcp_store_group is None:
58+
self.internode = not all(in_the_same_node_as(cpu_group, source_rank=0))
59+
else:
60+
self.internode = not all(in_the_same_node_as(tcp_store_group, source_rank=0))
5761

5862
def get_handle(self, kwargs):
5963
# get a handle for the all2all communication,
@@ -99,17 +103,33 @@ def __init__(
99103
device: Optional[torch.device] = None,
100104
device_group: Optional[ProcessGroup] = None,
101105
unique_name: str = "",
106+
global_ranks: Optional[list[int]] = None,
107+
global_world_size: Optional[int] = None
102108
):
103109
self.device = device or torch.device("cpu")
104110
self.cpu_group = cpu_group
105111
self.device_group = device_group
106112
self.unique_name = unique_name
107-
self.rank = dist.get_rank(cpu_group)
108-
self.world_size = dist.get_world_size(cpu_group)
109-
self.ranks = dist.get_process_group_ranks(cpu_group)
110-
self.global_rank = dist.get_rank()
111-
self.global_world_size = dist.get_world_size()
112-
self.rank_in_group = dist.get_group_rank(self.cpu_group, self.global_rank)
113+
114+
# Check if this is a stateless process group
115+
from torch.distributed.distributed_c10d import _world
116+
is_stateless = _world.pg_map.get(cpu_group, None) is None
117+
118+
if is_stateless:
119+
# For stateless groups, we can't use torch.distributed methods
120+
self.rank = cpu_group.rank()
121+
self.world_size = cpu_group.size()
122+
self.ranks = global_ranks
123+
self.global_rank = self.ranks[self.rank]
124+
self.global_world_size = global_world_size
125+
self.rank_in_group = self.rank
126+
else:
127+
self.rank = dist.get_rank(cpu_group)
128+
self.world_size = dist.get_world_size(cpu_group)
129+
self.ranks = dist.get_process_group_ranks(cpu_group)
130+
self.global_rank = dist.get_rank()
131+
self.global_world_size = dist.get_world_size()
132+
self.rank_in_group = dist.get_group_rank(self.cpu_group, self.global_rank)
113133

114134
use_ep = False
115135
from vllm.config import get_current_vllm_config

0 commit comments

Comments
 (0)