Skip to content

Commit 29d1ffc

Browse files
robertgshaw2-redhatRobert Shaw
andauthored
[DP] Fix Prometheus Logging (#21257)
Signed-off-by: Robert Shaw <[email protected]> Co-authored-by: Robert Shaw <[email protected]>
1 parent 304dce7 commit 29d1ffc

File tree

6 files changed

+378
-258
lines changed

6 files changed

+378
-258
lines changed

tests/v1/engine/test_async_llm.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -336,9 +336,10 @@ async def test_customize_loggers(monkeypatch):
336336

337337
await engine.do_log_stats()
338338

339-
assert len(engine.stat_loggers) == 1
340-
assert len(engine.stat_loggers[0]) == 1
341-
engine.stat_loggers[0][0].log.assert_called_once()
339+
stat_loggers = engine.logger_manager.per_engine_logger_dict
340+
assert len(stat_loggers) == 1
341+
assert len(stat_loggers[0]) == 1
342+
stat_loggers[0][0].log.assert_called_once()
342343

343344

344345
@pytest.mark.asyncio(scope="module")

tests/v1/test_async_llm_dp.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,8 +90,10 @@ class SimpleStatsLogger(StatLoggerBase):
9090
def __init__(self, vllm_config: VllmConfig, engine_index: int = 0):
9191
stats_loggers[engine_index] = self
9292

93-
def record(self, scheduler_stats: Optional[SchedulerStats],
94-
iteration_stats: Optional[IterationStats]):
93+
def record(self,
94+
scheduler_stats: Optional[SchedulerStats],
95+
iteration_stats: Optional[IterationStats],
96+
engine_idx: int = 0):
9597
if iteration_stats:
9698
self.finished_req_count += len(
9799
iteration_stats.finished_requests)

vllm/v1/engine/async_llm.py

Lines changed: 26 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,9 @@
3636
from vllm.v1.engine.parallel_sampling import ParentRequest
3737
from vllm.v1.engine.processor import Processor
3838
from vllm.v1.executor.abstract import Executor
39-
from vllm.v1.metrics.loggers import (StatLoggerBase, StatLoggerFactory,
40-
setup_default_loggers)
39+
from vllm.v1.metrics.loggers import StatLoggerFactory, StatLoggerManager
4140
from vllm.v1.metrics.prometheus import shutdown_prometheus
42-
from vllm.v1.metrics.stats import IterationStats, SchedulerStats
41+
from vllm.v1.metrics.stats import IterationStats
4342

4443
logger = init_logger(__name__)
4544

@@ -95,14 +94,6 @@ def __init__(
9594
self.log_requests = log_requests
9695
self.log_stats = log_stats
9796

98-
# Set up stat loggers; independent set for each DP rank.
99-
self.stat_loggers: list[list[StatLoggerBase]] = setup_default_loggers(
100-
vllm_config=vllm_config,
101-
log_stats=self.log_stats,
102-
engine_num=vllm_config.parallel_config.data_parallel_size,
103-
custom_stat_loggers=stat_loggers,
104-
)
105-
10697
# Tokenizer (+ ensure liveness if running in another process).
10798
self.tokenizer = init_tokenizer_from_configs(
10899
model_config=vllm_config.model_config,
@@ -121,17 +112,24 @@ def __init__(
121112
log_stats=self.log_stats)
122113

123114
# EngineCore (starts the engine in background process).
124-
125115
self.engine_core = EngineCoreClient.make_async_mp_client(
126116
vllm_config=vllm_config,
127117
executor_class=executor_class,
128118
log_stats=self.log_stats,
129119
client_addresses=client_addresses,
130120
client_index=client_index,
131121
)
132-
if self.stat_loggers:
133-
for stat_logger in self.stat_loggers[0]:
134-
stat_logger.log_engine_initialized()
122+
123+
# Loggers.
124+
self.logger_manager: Optional[StatLoggerManager] = None
125+
if self.log_stats:
126+
self.logger_manager = StatLoggerManager(
127+
vllm_config=vllm_config,
128+
engine_idxs=self.engine_core.engine_ranks,
129+
custom_stat_loggers=stat_loggers,
130+
)
131+
self.logger_manager.log_engine_initialized()
132+
135133
self.output_handler: Optional[asyncio.Task] = None
136134
try:
137135
# Start output handler eagerly if we are in the asyncio eventloop.
@@ -370,7 +368,7 @@ def _run_output_handler(self):
370368
engine_core = self.engine_core
371369
output_processor = self.output_processor
372370
log_stats = self.log_stats
373-
stat_loggers = self.stat_loggers if log_stats else None
371+
logger_manager = self.logger_manager
374372

375373
async def output_handler():
376374
try:
@@ -410,9 +408,9 @@ async def output_handler():
410408
# 4) Logging.
411409
# TODO(rob): make into a coroutine and launch it in
412410
# background thread once Prometheus overhead is non-trivial.
413-
if stat_loggers:
414-
AsyncLLM._record_stats(
415-
stat_loggers[outputs.engine_index],
411+
if logger_manager:
412+
logger_manager.record(
413+
engine_idx=outputs.engine_index,
416414
scheduler_stats=outputs.scheduler_stats,
417415
iteration_stats=iteration_stats,
418416
)
@@ -431,18 +429,6 @@ async def abort(self, request_id: str) -> None:
431429
if self.log_requests:
432430
logger.info("Aborted request %s.", request_id)
433431

434-
@staticmethod
435-
def _record_stats(
436-
stat_loggers: list[StatLoggerBase],
437-
scheduler_stats: Optional[SchedulerStats],
438-
iteration_stats: Optional[IterationStats],
439-
):
440-
"""static so that it can be used from the output_handler task
441-
without a circular ref to AsyncLLM."""
442-
for stat_logger in stat_loggers:
443-
stat_logger.record(scheduler_stats=scheduler_stats,
444-
iteration_stats=iteration_stats)
445-
446432
async def encode(
447433
self,
448434
prompt: PromptType,
@@ -547,9 +533,8 @@ async def do_log_stats(
547533
scheduler_outputs=None,
548534
model_output=None,
549535
) -> None:
550-
for loggers in self.stat_loggers:
551-
for stat_logger in loggers:
552-
stat_logger.log()
536+
if self.logger_manager:
537+
self.logger_manager.log()
553538

554539
async def check_health(self) -> None:
555540
logger.debug("Called check_health.")
@@ -653,18 +638,16 @@ async def scale_elastic_ep(self,
653638
new_data_parallel_size
654639

655640
# recreate stat loggers
656-
if new_data_parallel_size > old_data_parallel_size:
657-
stat_loggers: list[list[StatLoggerBase]] = setup_default_loggers(
641+
if new_data_parallel_size > old_data_parallel_size and self.log_stats:
642+
# TODO(rob): fix this after talking with Ray team.
643+
# This resets all the prometheus metrics since we
644+
# unregister during initialization. Need to understand
645+
# the intended behavior here better.
646+
self.logger_manager = StatLoggerManager(
658647
vllm_config=self.vllm_config,
659-
log_stats=self.log_stats,
660-
engine_num=new_data_parallel_size,
648+
engine_idxs=list(range(new_data_parallel_size)),
661649
custom_stat_loggers=None,
662650
)
663-
num_new_engines = len(stat_loggers) - len(self.stat_loggers)
664-
self.stat_loggers.extend(stat_loggers[-num_new_engines:])
665-
else:
666-
for _ in range(old_data_parallel_size - new_data_parallel_size):
667-
self.stat_loggers.pop()
668651

669652
@property
670653
def is_running(self) -> bool:

vllm/v1/engine/core_client.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -432,14 +432,15 @@ def __init__(
432432
external_dp_lb = parallel_config.data_parallel_external_lb
433433

434434
offline_mode = parallel_config.data_parallel_rank_local is not None
435-
engine_ranks = [dp_rank] if (offline_mode
436-
or external_dp_lb) else range(dp_size)
435+
self.engine_ranks = ([dp_rank] if
436+
(offline_mode or external_dp_lb) else list(
437+
range(dp_size)))
437438
assert parallel_config.data_parallel_size_local <= len(
438-
engine_ranks)
439+
self.engine_ranks)
439440

440441
# ZMQ identity of each engine that this client will talk to.
441442
self.core_engines: list[EngineIdentity] = [
442-
index.to_bytes(2, "little") for index in engine_ranks
443+
index.to_bytes(2, "little") for index in self.engine_ranks
443444
]
444445

445446
# Wait for ready messages from each engine on the input socket.

0 commit comments

Comments
 (0)