Skip to content

Commit 8d524ce

Browse files
authored
[BugFix] Improve internal DP load balancing (#21617)
Signed-off-by: Nick Hill <[email protected]>
1 parent 9f9c38c commit 8d524ce

File tree

7 files changed

+122
-59
lines changed

7 files changed

+122
-59
lines changed

vllm/entrypoints/openai/api_server.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,8 @@ async def build_async_engine_client_from_engine_args(
199199

200200
from vllm.v1.engine.async_llm import AsyncLLM
201201
async_llm: Optional[AsyncLLM] = None
202+
client_count = client_config.pop(
203+
"client_count") if client_config else 1
202204
client_index = client_config.pop(
203205
"client_index") if client_config else 0
204206
try:
@@ -208,6 +210,7 @@ async def build_async_engine_client_from_engine_args(
208210
enable_log_requests=engine_args.enable_log_requests,
209211
disable_log_stats=engine_args.disable_log_stats,
210212
client_addresses=client_config,
213+
client_count=client_count,
211214
client_index=client_index)
212215

213216
# Don't keep the dummy data in memory

vllm/v1/engine/async_llm.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ def __init__(
5757
start_engine_loop: bool = True,
5858
stat_loggers: Optional[list[StatLoggerFactory]] = None,
5959
client_addresses: Optional[dict[str, str]] = None,
60+
client_count: int = 1,
6061
client_index: int = 0,
6162
) -> None:
6263
"""
@@ -120,6 +121,7 @@ def __init__(
120121
executor_class=executor_class,
121122
log_stats=self.log_stats,
122123
client_addresses=client_addresses,
124+
client_count=client_count,
123125
client_index=client_index,
124126
)
125127

@@ -156,6 +158,7 @@ def from_vllm_config(
156158
enable_log_requests: bool = False,
157159
disable_log_stats: bool = False,
158160
client_addresses: Optional[dict[str, str]] = None,
161+
client_count: int = 1,
159162
client_index: int = 0,
160163
disable_log_requests: bool = True, # Deprecated, will be removed
161164
) -> "AsyncLLM":
@@ -176,6 +179,7 @@ def from_vllm_config(
176179
log_stats=not disable_log_stats,
177180
usage_context=usage_context,
178181
client_addresses=client_addresses,
182+
client_count=client_count,
179183
client_index=client_index,
180184
)
181185

vllm/v1/engine/coordinator.py

Lines changed: 73 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
import copy
34
import multiprocessing
45
import time
56
import weakref
@@ -65,18 +66,14 @@ def __init__(self, parallel_config: ParallelConfig):
6566

6667
# Assume coordinator is colocated with front-end procs when not in
6768
# either external or hybrid DP LB mode.
69+
local_only = not (external_lb or hybrid_lb)
6870
front_publish_address = get_engine_client_zmq_addr(
69-
local_only=not external_lb and not hybrid_lb, host=host)
71+
local_only=local_only, host=host)
7072

7173
local_only_eng = dp_size == parallel_config.data_parallel_size_local
7274
back_publish_address = get_engine_client_zmq_addr(local_only_eng, host)
7375
back_output_address = get_engine_client_zmq_addr(local_only_eng, host)
7476

75-
# When in external LB mode, load stats aren't published, only changes
76-
# to request wave / running state, so we don't need to rate-limit the
77-
# updates to the front-end proc(s).
78-
min_stats_update_interval_ms = 0 if external_lb else 100
79-
8077
context = get_mp_context()
8178
self.proc: multiprocessing.Process = context.Process(
8279
target=DPCoordinatorProc.run_coordinator,
@@ -86,7 +83,6 @@ def __init__(self, parallel_config: ParallelConfig):
8683
"front_publish_address": front_publish_address,
8784
"back_output_address": back_output_address,
8885
"back_publish_address": back_publish_address,
89-
"min_stats_update_interval_ms": min_stats_update_interval_ms,
9086
},
9187
daemon=True)
9288
self.proc.start()
@@ -125,10 +121,6 @@ def __init__(self,
125121

126122
self.stats_update_interval_ms = min_stats_update_interval_ms
127123

128-
self.current_wave = 0
129-
self.engines_running = False
130-
self.stats_changed = False
131-
132124
@staticmethod
133125
def run_coordinator(
134126
engine_count: int,
@@ -155,6 +147,16 @@ def process_input_socket(self, front_publish_address: str,
155147

156148
decoder = MsgpackDecoder(EngineCoreOutputs)
157149

150+
# For tracking request wave progression.
151+
current_wave = 0
152+
engines_running = False
153+
154+
# For tracking request counts for internal load-balancing.
155+
stats_changed = False
156+
last_stats_step = -1
157+
last_stats_wave = -1
158+
last_step_counts: Optional[list[list[int]]] = None
159+
158160
with make_zmq_socket(
159161
path=front_publish_address, # IPC
160162
ctx=self.ctx,
@@ -191,21 +193,33 @@ def process_input_socket(self, front_publish_address: str,
191193
while True:
192194
elapsed = int(time.time() * 1000) - last_publish_time
193195
# Send at stats_update_interval_ms interval if the stats have
194-
# changed, or otherwise every 4 seconds.
196+
# changed, or otherwise every 5 seconds.
195197
wait_for = (self.stats_update_interval_ms
196-
if self.stats_changed else 4000)
197-
events = poller.poll(timeout=max(0, wait_for - elapsed))
198+
if stats_changed else 5000)
199+
200+
# Wait at least 50ms to ensure we've received all stats for
201+
# the current step.
202+
min_timeout = 50 if last_step_counts is None else 0
203+
204+
events = poller.poll(timeout=max(min_timeout, wait_for -
205+
elapsed))
198206
if not events:
199207
# Poller timeout - publish current stats to front-ends.
200-
engine_req_counts_list = self._get_engine_counts()
201-
to_publish = (engine_req_counts_list, self.current_wave,
202-
self.engines_running)
208+
if last_step_counts is not None:
209+
engine_req_counts_list = last_step_counts
210+
last_step_counts = None
211+
else:
212+
engine_req_counts_list = self._get_engine_counts()
213+
stats_changed = False
214+
215+
to_publish = (engine_req_counts_list, current_wave,
216+
engines_running)
203217
publish_front.send(msgspec.msgpack.encode(to_publish))
204218
last_publish_time = int(time.time() * 1000)
205-
self.stats_changed = False
206219
continue
207220

208221
events = dict(events)
222+
wave_state_changed = False
209223

210224
if publish_front in events:
211225
buffer = publish_front.recv()
@@ -232,7 +246,7 @@ def process_input_socket(self, front_publish_address: str,
232246
# current_wave
233247
# we note that 0 is the wave number for the new
234248
# engine
235-
self.engines_running = False
249+
engines_running = False
236250
logger.info(
237251
"DPCoordinator scaled up from %s to %s "
238252
"engines", current_count, new_engine_count)
@@ -248,15 +262,15 @@ def process_input_socket(self, front_publish_address: str,
248262
# engines are paused, so that we can wake the other
249263
# engines.
250264
engine_to_exclude, wave = decoded
251-
if not self.engines_running:
252-
if wave < self.current_wave:
265+
if not engines_running:
266+
if wave < current_wave:
253267
# If the wave number is stale, ensure the message
254268
# is handled by all the engines.
255269
engine_to_exclude = None
256270

257-
self.engines_running = True
258-
self.stats_changed = True
259-
self._send_start_wave(publish_back, self.current_wave,
271+
engines_running = True
272+
wave_state_changed = True
273+
self._send_start_wave(publish_back, current_wave,
260274
engine_to_exclude)
261275

262276
if output_back in events:
@@ -274,36 +288,56 @@ def process_input_socket(self, front_publish_address: str,
274288
# 1. Updated request load stats - update our local
275289
# state with these.
276290
stats = self.engines[eng_index].request_counts
291+
stats_step = scheduler_stats.step_counter
292+
stats_wave = scheduler_stats.current_wave
293+
if (stats_wave > last_stats_wave
294+
or stats_wave == last_stats_wave
295+
and stats_step > last_stats_step):
296+
if stats_changed:
297+
last_step_counts = self._get_engine_counts(
298+
do_copy=True)
299+
last_stats_step = stats_step
300+
last_stats_wave = stats_wave
301+
elif stats_wave != last_stats_wave or (
302+
stats_step != last_stats_step):
303+
logger.warning(
304+
"Received stats for out-of-order "
305+
"step (%d, %d) from engine %d (expected "
306+
"> (%d, %d))", stats_wave, stats_step,
307+
eng_index, last_stats_wave, last_stats_step)
277308
stats[0] = scheduler_stats.num_waiting_reqs
278309
stats[1] = scheduler_stats.num_running_reqs
279-
self.stats_changed = True
310+
stats_changed = True
280311

281312
if (wave := outputs.wave_complete) is not None:
282313
# 2. Notification from rank 0 engine that we've
283314
# moved into the global paused state
284315
# (engines_running==False).
285-
if self.current_wave <= wave:
316+
if current_wave <= wave:
286317
new_wave = wave + 1
287318
logger.debug("Moving DP wave from %d to %d.",
288-
self.current_wave, new_wave)
289-
self.current_wave = new_wave
290-
self.engines_running = False
291-
self.stats_changed = True
319+
current_wave, new_wave)
320+
current_wave = new_wave
321+
engines_running = False
322+
wave_state_changed = True
292323
elif (wave := outputs.start_wave) is not None and (
293-
wave > self.current_wave or
294-
(wave == self.current_wave
295-
and not self.engines_running)):
324+
wave > current_wave or
325+
(wave == current_wave and not engines_running)):
296326
# 3. The engine received request for a non-current wave
297327
# so we must ensure that other engines progress to the
298328
# next wave (race condition handling).
299329
logger.debug(
300330
"Starting wave %d after notification of "
301331
"stale wave request from engine.", wave)
302-
self.current_wave = wave
303-
self.engines_running = True
304-
self.stats_changed = True
332+
current_wave = wave
333+
engines_running = True
334+
wave_state_changed = True
305335
self._send_start_wave(publish_back, wave, eng_index)
306336

337+
if wave_state_changed:
338+
message = (None, current_wave, engines_running)
339+
publish_front.send(msgspec.msgpack.encode(message))
340+
307341
@staticmethod
308342
def _send_start_wave(socket: zmq.Socket, wave: int,
309343
exclude_engine_index: Optional[int]):
@@ -316,6 +350,8 @@ def _send_start_wave(socket: zmq.Socket, wave: int,
316350
socket.send_multipart(
317351
(EngineCoreRequestType.START_DP_WAVE.value, wave_encoded))
318352

319-
def _get_engine_counts(self) -> list[list[int]]:
353+
def _get_engine_counts(self, do_copy=False) -> list[list[int]]:
320354
"""Return list of [waiting, running] count lists for each engine."""
355+
if do_copy:
356+
return [copy.copy(e.request_counts) for e in self.engines]
321357
return [e.request_counts for e in self.engines]

vllm/v1/engine/core.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -928,7 +928,7 @@ def __init__(
928928
):
929929
# Counts forward-passes of the model so that we can synchronize
930930
# finished with DP peers every N steps.
931-
self.counter = 0
931+
self.step_counter = 0
932932
self.current_wave = 0
933933
self.last_counts = (0, 0)
934934

@@ -999,7 +999,9 @@ def _maybe_publish_request_counts(self):
999999
counts = self.scheduler.get_request_counts()
10001000
if counts != self.last_counts:
10011001
self.last_counts = counts
1002-
stats = SchedulerStats(*counts)
1002+
stats = SchedulerStats(*counts,
1003+
step_counter=self.step_counter,
1004+
current_wave=self.current_wave)
10031005
self.output_queue.put_nowait(
10041006
(-1, EngineCoreOutputs(scheduler_stats=stats)))
10051007

@@ -1041,15 +1043,16 @@ def run_busy_loop(self):
10411043
self.output_queue.put_nowait(
10421044
(client_index,
10431045
EngineCoreOutputs(wave_complete=self.current_wave)))
1046+
# Increment wave count and reset step counter.
10441047
self.current_wave += 1
1048+
self.step_counter = 0
10451049

10461050
def _has_global_unfinished_reqs(self, local_unfinished: bool) -> bool:
10471051

10481052
# Optimization - only perform finish-sync all-reduce every 32 steps.
1049-
self.counter += 1
1050-
if self.counter != 32:
1053+
self.step_counter += 1
1054+
if self.step_counter % 32 != 0:
10511055
return True
1052-
self.counter = 0
10531056

10541057
return ParallelConfig.has_unfinished_dp(self.dp_group,
10551058
local_unfinished)

0 commit comments

Comments
 (0)