@@ -23,8 +23,8 @@ class NaiveAll2AllManager(All2AllManagerBase):
23
23
debugging.
24
24
"""
25
25
26
- def __init__ (self , cpu_group ):
27
- super ().__init__ (cpu_group )
26
+ def __init__ (self , cpu_group , tcp_store_group = None ):
27
+ super ().__init__ (cpu_group , tcp_store_group )
28
28
29
29
def naive_multicast (self , x : torch .Tensor ,
30
30
cu_tokens_across_dp_cpu : torch .Tensor ):
@@ -76,8 +76,8 @@ class AgRsAll2AllManager(All2AllManagerBase):
76
76
all-gather (dispatch) and reduce-scatter (combine).
77
77
"""
78
78
79
- def __init__ (self , cpu_group ):
80
- super ().__init__ (cpu_group )
79
+ def __init__ (self , cpu_group , tcp_store_group = None ):
80
+ super ().__init__ (cpu_group , tcp_store_group )
81
81
82
82
def dispatch (self , hidden_states : torch .Tensor ,
83
83
router_logits : torch .Tensor ):
@@ -113,14 +113,16 @@ class PPLXAll2AllManager(All2AllManagerBase):
113
113
All2All communication based on PPLX kernels.
114
114
"""
115
115
116
- def __init__ (self , cpu_group ):
116
+ def __init__ (self , cpu_group , tcp_store_group = None ):
117
117
assert has_pplx (
118
118
), "pplx_kernels not found. Please follow https://github.com/vllm-project/vllm/blob/main/tools/ep_kernels/README.md to install pplx_kernels." # noqa
119
- super ().__init__ (cpu_group )
119
+ super ().__init__ (cpu_group , tcp_store_group )
120
120
121
- if self .internode :
122
- # inter-node communication needs nvshmem,
123
- # intra-node communication uses p2p mapping directly
121
+ self .nvshmem_initialized = False
122
+ self .handle_cache = Cache ()
123
+
124
+ def get_handle (self , kwargs ):
125
+ if self .internode and not self .nvshmem_initialized :
124
126
from pplx_kernels .nvshmem import (nvshmem_alloc_empty_unique_id ,
125
127
nvshmem_get_unique_id ,
126
128
nvshmem_init )
@@ -129,15 +131,18 @@ def __init__(self, cpu_group):
129
131
"rank=%d, world size=%d" , self .rank , self .world_size )
130
132
uid = nvshmem_get_unique_id (
131
133
) if self .rank == 0 else nvshmem_alloc_empty_unique_id ()
132
- dist .broadcast (uid ,
133
- src = dist .get_process_group_ranks (self .cpu_group )[0 ],
134
- group = self .cpu_group )
134
+
135
+ if self .tcp_store_group is not None :
136
+ uid = self .tcp_store_group .broadcast_obj (uid , src = 0 )
137
+ else :
138
+ dist .broadcast (uid ,
139
+ src = dist .get_process_group_ranks (self .cpu_group )[0 ],
140
+ group = self .cpu_group )
141
+
135
142
logger .debug ("PPLX NVSHMEM UID = %s" , uid )
136
143
nvshmem_init (uid , self .rank , self .world_size )
144
+ self .nvshmem_initialized = True
137
145
138
- self .handle_cache = Cache ()
139
-
140
- def get_handle (self , kwargs ):
141
146
import pplx_kernels as pplx
142
147
return self .handle_cache .get_or_create (
143
148
kwargs , pplx .AllToAll .internode
@@ -166,10 +171,10 @@ class DeepEPAll2AllManagerBase(All2AllManagerBase):
166
171
All2All communication based on DeepEP High-Throughput kernels.
167
172
"""
168
173
169
- def __init__ (self , cpu_group ):
174
+ def __init__ (self , cpu_group , tcp_store_group = None ):
170
175
assert has_deep_ep (
171
176
), "DeepEP kernels not found. Please follow https://github.com/vllm-project/vllm/blob/main/tools/ep_kernels/README.md to install DeepEP kernels." # noqa
172
- super ().__init__ (cpu_group )
177
+ super ().__init__ (cpu_group , tcp_store_group )
173
178
self .handle_cache = Cache ()
174
179
175
180
# This is the DeepEP default. Stick to it till we can establish
@@ -195,8 +200,8 @@ class DeepEPHTAll2AllManager(DeepEPAll2AllManagerBase):
195
200
All2All communication based on DeepEP High-Throughput kernels.
196
201
"""
197
202
198
- def __init__ (self , cpu_group ):
199
- super ().__init__ (cpu_group )
203
+ def __init__ (self , cpu_group , tcp_store_group = None ):
204
+ super ().__init__ (cpu_group , tcp_store_group )
200
205
201
206
def _make_all2all_kwargs (self ) -> dict [Any , Any ]:
202
207
# Defaults for internode and intranode are taken from DeepEP tests.
@@ -243,8 +248,8 @@ class DeepEPLLAll2AllManager(DeepEPAll2AllManagerBase):
243
248
All2All communication based on DeepEP Low-Latency kernels.
244
249
"""
245
250
246
- def __init__ (self , cpu_group ):
247
- super ().__init__ (cpu_group )
251
+ def __init__ (self , cpu_group , tcp_store_group = None ):
252
+ super ().__init__ (cpu_group , tcp_store_group )
248
253
249
254
def _make_all2all_kwargs (
250
255
self ,
@@ -265,7 +270,8 @@ def _make_all2all_kwargs(
265
270
import deep_ep
266
271
267
272
# Defaults for internode and intranode are taken from DeepEP tests.
268
- num_nvl_bytes = 1024 * 1024 * 1024
273
+ # num_nvl_bytes = 1024 * 1024 * 1024
274
+ num_nvl_bytes = 0
269
275
num_qps_per_rank = num_local_experts
270
276
num_rdma_bytes = deep_ep .Buffer .get_low_latency_rdma_size_hint (
271
277
num_max_dispatch_tokens_per_rank = max_num_tokens_per_dp_rank ,
@@ -278,7 +284,8 @@ def _make_all2all_kwargs(
278
284
num_nvl_bytes = num_nvl_bytes ,
279
285
num_rdma_bytes = num_rdma_bytes ,
280
286
low_latency_mode = True ,
281
- num_qps_per_rank = num_qps_per_rank )
287
+ num_qps_per_rank = num_qps_per_rank ,
288
+ allow_mnnvl = True )
282
289
283
290
def get_handle (self , kwargs ):
284
291
"""
0 commit comments