@@ -30,8 +30,8 @@ class NaiveAll2AllManager(All2AllManagerBase):
30
30
debugging.
31
31
"""
32
32
33
- def __init__ (self , cpu_group ):
34
- super ().__init__ (cpu_group )
33
+ def __init__ (self , cpu_group , tcp_store_group = None ):
34
+ super ().__init__ (cpu_group , tcp_store_group )
35
35
36
36
def naive_multicast (
37
37
self ,
@@ -101,8 +101,8 @@ class AgRsAll2AllManager(All2AllManagerBase):
101
101
all-gather (dispatch) and reduce-scatter (combine).
102
102
"""
103
103
104
- def __init__ (self , cpu_group ):
105
- super ().__init__ (cpu_group )
104
+ def __init__ (self , cpu_group , tcp_store_group = None ):
105
+ super ().__init__ (cpu_group , tcp_store_group )
106
106
107
107
def dispatch (
108
108
self ,
@@ -145,13 +145,16 @@ class PPLXAll2AllManager(All2AllManagerBase):
145
145
All2All communication based on PPLX kernels.
146
146
"""
147
147
148
- def __init__ (self , cpu_group ):
148
+ def __init__ (self , cpu_group , tcp_store_group = None ):
149
149
assert has_pplx (), (
150
150
"pplx_kernels not found. Please follow https://github.com/vllm-project/vllm/blob/main/tools/ep_kernels/README.md"
151
151
" to install pplx_kernels."
152
152
)
153
- super ().__init__ (cpu_group )
153
+ super ().__init__ (cpu_group , tcp_store_group )
154
+ self .nvshmem_initialized = False
155
+ self .handle_cache = Cache ()
154
156
157
+ def get_handle (self , kwargs ):
155
158
if self .internode :
156
159
# inter-node communication needs nvshmem,
157
160
# intra-node communication uses p2p mapping directly
@@ -171,17 +174,18 @@ def __init__(self, cpu_group):
171
174
if self .rank == 0
172
175
else nvshmem_alloc_empty_unique_id ()
173
176
)
174
- dist .broadcast (
175
- uid ,
176
- src = dist .get_process_group_ranks (self .cpu_group )[0 ],
177
- group = self .cpu_group ,
178
- )
177
+ if self .tcp_store_group is not None :
178
+ uid = self .tcp_store_group .broadcast_obj (uid , src = 0 )
179
+ else :
180
+ dist .broadcast (
181
+ uid ,
182
+ src = dist .get_process_group_ranks (self .cpu_group )[0 ],
183
+ group = self .cpu_group ,
184
+ )
179
185
logger .debug ("PPLX NVSHMEM UID = %s" , uid )
180
186
nvshmem_init (uid , self .rank , self .world_size )
181
-
182
- self .handle_cache = Cache ()
183
-
184
- def get_handle (self , kwargs ):
187
+ self .nvshmem_initialized = True
188
+
185
189
import pplx_kernels as pplx
186
190
187
191
return self .handle_cache .get_or_create (
@@ -219,12 +223,12 @@ class DeepEPAll2AllManagerBase(All2AllManagerBase):
219
223
All2All communication based on DeepEP High-Throughput kernels.
220
224
"""
221
225
222
- def __init__ (self , cpu_group ):
226
+ def __init__ (self , cpu_group , tcp_store_group = None ):
223
227
assert has_deep_ep (), (
224
228
"DeepEP kernels not found. Please follow https://github.com/vllm-project/vllm/blob/main/tools/ep_kernels/README.md"
225
229
" to install DeepEP kernels."
226
230
) # noqa
227
- super ().__init__ (cpu_group )
231
+ super ().__init__ (cpu_group , tcp_store_group )
228
232
self .handle_cache = Cache ()
229
233
230
234
# This is the DeepEP default. Stick to it till we can establish
@@ -256,8 +260,8 @@ class DeepEPHTAll2AllManager(DeepEPAll2AllManagerBase):
256
260
All2All communication based on DeepEP High-Throughput kernels.
257
261
"""
258
262
259
- def __init__ (self , cpu_group ):
260
- super ().__init__ (cpu_group )
263
+ def __init__ (self , cpu_group , tcp_store_group = None ):
264
+ super ().__init__ (cpu_group , tcp_store_group )
261
265
262
266
def _make_all2all_kwargs (self ) -> dict [Any , Any ]:
263
267
# Defaults for internode and intranode are taken from DeepEP tests.
@@ -313,8 +317,8 @@ class DeepEPLLAll2AllManager(DeepEPAll2AllManagerBase):
313
317
All2All communication based on DeepEP Low-Latency kernels.
314
318
"""
315
319
316
- def __init__ (self , cpu_group ):
317
- super ().__init__ (cpu_group )
320
+ def __init__ (self , cpu_group , tcp_store_group = None ):
321
+ super ().__init__ (cpu_group , tcp_store_group )
318
322
319
323
def _make_all2all_kwargs (
320
324
self ,
0 commit comments