1
1
# SPDX-License-Identifier: Apache-2.0
2
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
+ import copy
3
4
import multiprocessing
4
5
import time
5
6
import weakref
@@ -65,18 +66,14 @@ def __init__(self, parallel_config: ParallelConfig):
65
66
66
67
# Assume coordinator is colocated with front-end procs when not in
67
68
# either external or hybrid DP LB mode.
69
+ local_only = not (external_lb or hybrid_lb )
68
70
front_publish_address = get_engine_client_zmq_addr (
69
- local_only = not external_lb and not hybrid_lb , host = host )
71
+ local_only = local_only , host = host )
70
72
71
73
local_only_eng = dp_size == parallel_config .data_parallel_size_local
72
74
back_publish_address = get_engine_client_zmq_addr (local_only_eng , host )
73
75
back_output_address = get_engine_client_zmq_addr (local_only_eng , host )
74
76
75
- # When in external LB mode, load stats aren't published, only changes
76
- # to request wave / running state, so we don't need to rate-limit the
77
- # updates to the front-end proc(s).
78
- min_stats_update_interval_ms = 0 if external_lb else 100
79
-
80
77
context = get_mp_context ()
81
78
self .proc : multiprocessing .Process = context .Process (
82
79
target = DPCoordinatorProc .run_coordinator ,
@@ -86,7 +83,6 @@ def __init__(self, parallel_config: ParallelConfig):
86
83
"front_publish_address" : front_publish_address ,
87
84
"back_output_address" : back_output_address ,
88
85
"back_publish_address" : back_publish_address ,
89
- "min_stats_update_interval_ms" : min_stats_update_interval_ms ,
90
86
},
91
87
daemon = True )
92
88
self .proc .start ()
@@ -125,10 +121,6 @@ def __init__(self,
125
121
126
122
self .stats_update_interval_ms = min_stats_update_interval_ms
127
123
128
- self .current_wave = 0
129
- self .engines_running = False
130
- self .stats_changed = False
131
-
132
124
@staticmethod
133
125
def run_coordinator (
134
126
engine_count : int ,
@@ -155,6 +147,16 @@ def process_input_socket(self, front_publish_address: str,
155
147
156
148
decoder = MsgpackDecoder (EngineCoreOutputs )
157
149
150
+ # For tracking request wave progression.
151
+ current_wave = 0
152
+ engines_running = False
153
+
154
+ # For tracking request counts for internal load-balancing.
155
+ stats_changed = False
156
+ last_stats_step = - 1
157
+ last_stats_wave = - 1
158
+ last_step_counts : Optional [list [list [int ]]] = None
159
+
158
160
with make_zmq_socket (
159
161
path = front_publish_address , # IPC
160
162
ctx = self .ctx ,
@@ -191,21 +193,33 @@ def process_input_socket(self, front_publish_address: str,
191
193
while True :
192
194
elapsed = int (time .time () * 1000 ) - last_publish_time
193
195
# Send at stats_update_interval_ms interval if the stats have
194
- # changed, or otherwise every 4 seconds.
196
+ # changed, or otherwise every 5 seconds.
195
197
wait_for = (self .stats_update_interval_ms
196
- if self .stats_changed else 4000 )
197
- events = poller .poll (timeout = max (0 , wait_for - elapsed ))
198
+ if stats_changed else 5000 )
199
+
200
+ # Wait at least 50ms to ensure we've received all stats for
201
+ # the current step.
202
+ min_timeout = 50 if last_step_counts is None else 0
203
+
204
+ events = poller .poll (timeout = max (min_timeout , wait_for -
205
+ elapsed ))
198
206
if not events :
199
207
# Poller timeout - publish current stats to front-ends.
200
- engine_req_counts_list = self ._get_engine_counts ()
201
- to_publish = (engine_req_counts_list , self .current_wave ,
202
- self .engines_running )
208
+ if last_step_counts is not None :
209
+ engine_req_counts_list = last_step_counts
210
+ last_step_counts = None
211
+ else :
212
+ engine_req_counts_list = self ._get_engine_counts ()
213
+ stats_changed = False
214
+
215
+ to_publish = (engine_req_counts_list , current_wave ,
216
+ engines_running )
203
217
publish_front .send (msgspec .msgpack .encode (to_publish ))
204
218
last_publish_time = int (time .time () * 1000 )
205
- self .stats_changed = False
206
219
continue
207
220
208
221
events = dict (events )
222
+ wave_state_changed = False
209
223
210
224
if publish_front in events :
211
225
buffer = publish_front .recv ()
@@ -232,7 +246,7 @@ def process_input_socket(self, front_publish_address: str,
232
246
# current_wave
233
247
# we note that 0 is the wave number for the new
234
248
# engine
235
- self . engines_running = False
249
+ engines_running = False
236
250
logger .info (
237
251
"DPCoordinator scaled up from %s to %s "
238
252
"engines" , current_count , new_engine_count )
@@ -248,15 +262,15 @@ def process_input_socket(self, front_publish_address: str,
248
262
# engines are paused, so that we can wake the other
249
263
# engines.
250
264
engine_to_exclude , wave = decoded
251
- if not self . engines_running :
252
- if wave < self . current_wave :
265
+ if not engines_running :
266
+ if wave < current_wave :
253
267
# If the wave number is stale, ensure the message
254
268
# is handled by all the engines.
255
269
engine_to_exclude = None
256
270
257
- self . engines_running = True
258
- self . stats_changed = True
259
- self ._send_start_wave (publish_back , self . current_wave ,
271
+ engines_running = True
272
+ wave_state_changed = True
273
+ self ._send_start_wave (publish_back , current_wave ,
260
274
engine_to_exclude )
261
275
262
276
if output_back in events :
@@ -274,36 +288,56 @@ def process_input_socket(self, front_publish_address: str,
274
288
# 1. Updated request load stats - update our local
275
289
# state with these.
276
290
stats = self .engines [eng_index ].request_counts
291
+ stats_step = scheduler_stats .step_counter
292
+ stats_wave = scheduler_stats .current_wave
293
+ if (stats_wave > last_stats_wave
294
+ or stats_wave == last_stats_wave
295
+ and stats_step > last_stats_step ):
296
+ if stats_changed :
297
+ last_step_counts = self ._get_engine_counts (
298
+ do_copy = True )
299
+ last_stats_step = stats_step
300
+ last_stats_wave = stats_wave
301
+ elif stats_wave != last_stats_wave or (
302
+ stats_step != last_stats_step ):
303
+ logger .warning (
304
+ "Received stats for out-of-order "
305
+ "step (%d, %d) from engine %d (expected "
306
+ "> (%d, %d))" , stats_wave , stats_step ,
307
+ eng_index , last_stats_wave , last_stats_step )
277
308
stats [0 ] = scheduler_stats .num_waiting_reqs
278
309
stats [1 ] = scheduler_stats .num_running_reqs
279
- self . stats_changed = True
310
+ stats_changed = True
280
311
281
312
if (wave := outputs .wave_complete ) is not None :
282
313
# 2. Notification from rank 0 engine that we've
283
314
# moved into the global paused state
284
315
# (engines_running==False).
285
- if self . current_wave <= wave :
316
+ if current_wave <= wave :
286
317
new_wave = wave + 1
287
318
logger .debug ("Moving DP wave from %d to %d." ,
288
- self . current_wave , new_wave )
289
- self . current_wave = new_wave
290
- self . engines_running = False
291
- self . stats_changed = True
319
+ current_wave , new_wave )
320
+ current_wave = new_wave
321
+ engines_running = False
322
+ wave_state_changed = True
292
323
elif (wave := outputs .start_wave ) is not None and (
293
- wave > self .current_wave or
294
- (wave == self .current_wave
295
- and not self .engines_running )):
324
+ wave > current_wave or
325
+ (wave == current_wave and not engines_running )):
296
326
# 3. The engine received request for a non-current wave
297
327
# so we must ensure that other engines progress to the
298
328
# next wave (race condition handling).
299
329
logger .debug (
300
330
"Starting wave %d after notification of "
301
331
"stale wave request from engine." , wave )
302
- self . current_wave = wave
303
- self . engines_running = True
304
- self . stats_changed = True
332
+ current_wave = wave
333
+ engines_running = True
334
+ wave_state_changed = True
305
335
self ._send_start_wave (publish_back , wave , eng_index )
306
336
337
+ if wave_state_changed :
338
+ message = (None , current_wave , engines_running )
339
+ publish_front .send (msgspec .msgpack .encode (message ))
340
+
307
341
@staticmethod
308
342
def _send_start_wave (socket : zmq .Socket , wave : int ,
309
343
exclude_engine_index : Optional [int ]):
@@ -316,6 +350,8 @@ def _send_start_wave(socket: zmq.Socket, wave: int,
316
350
socket .send_multipart (
317
351
(EngineCoreRequestType .START_DP_WAVE .value , wave_encoded ))
318
352
319
- def _get_engine_counts (self ) -> list [list [int ]]:
353
+ def _get_engine_counts (self , do_copy = False ) -> list [list [int ]]:
320
354
"""Return list of [waiting, running] count lists for each engine."""
355
+ if do_copy :
356
+ return [copy .copy (e .request_counts ) for e in self .engines ]
321
357
return [e .request_counts for e in self .engines ]
0 commit comments