Skip to content

Commit f689e63

Browse files
committed
Update to V1 metrics
1 parent db4427f commit f689e63

File tree

3 files changed

+73
-37
lines changed

3 files changed

+73
-37
lines changed

ci/L0_backend_vllm/metrics_test/vllm_metrics_test.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -173,13 +173,11 @@ def test_vllm_metrics(self):
173173
# TODO: Revisit this test due to the removal of best_of
174174
def test_custom_sampling_params(self):
175175
# Adding sampling parameters for testing metrics.
176-
# Definitions can be found here https://docs.vllm.ai/en/latest/dev/sampling_params.html
177-
n, best_of = 2, 4
176+
# Definitions can be found here https://docs.vllm.ai/en/latest/api/vllm/sampling_params.html
177+
n, temperature = 2, 1
178178
custom_sampling_parameters = self.sampling_parameters.copy()
179-
# Changing "temperature" because "best_of" must be 1 when using greedy
180-
# sampling, i.e. "temperature": "0".
181179
custom_sampling_parameters.update(
182-
{"n": str(n), "best_of": str(best_of), "temperature": "1"}
180+
{"n": str(n), "temperature": str(temperature)}
183181
)
184182

185183
# Test vLLM metrics

src/model.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@
3535
from typing import Dict, List
3636

3737
import numpy as np
38-
import torch
3938
import triton_python_backend_utils as pb_utils
4039
from PIL import Image
4140
from vllm.engine.arg_utils import AsyncEngineArgs
@@ -45,7 +44,7 @@
4544
from vllm.lora.request import LoRARequest
4645
from vllm.utils import random_uuid
4746

48-
from utils.metrics import VllmStatLogger
47+
from utils.metrics import VllmStatLoggerFactory
4948
from utils.vllm_backend_utils import TritonSamplingParams
5049

5150
_VLLM_ENGINE_ARGS_FILENAME = "model.json"
@@ -184,12 +183,12 @@ def initialize(self, args):
184183
and not self._aync_engine_args.disable_log_stats
185184
)
186185

187-
# Starting the vLLM engine and its event thread running the AsyncIO event loop.
188-
self._init_engine()
189-
190186
# Setup vLLM metrics
191187
self._setup_metrics()
192188

189+
# Starting the vLLM engine and its event thread running the AsyncIO event loop.
190+
self._init_engine()
191+
193192
# Starting the response thread. It allows vLLM to keep making progress while
194193
# response sender(s) are sending responses to server frontend.
195194
self._response_queue = queue.Queue()
@@ -258,6 +257,7 @@ async def _run_llm_engine(self):
258257
async with build_async_engine_client_from_engine_args(
259258
engine_args=self._aync_engine_args,
260259
disable_frontend_multiprocessing=self._enable_metrics,
260+
stat_loggers=self._vllm_metrics,
261261
) as engine:
262262
# Capture the engine event loop and make it visible to other threads.
263263
self._event_loop = asyncio.get_running_loop()
@@ -348,7 +348,7 @@ def _setup_lora(self):
348348
)
349349

350350
def _setup_metrics(self):
351-
self._vllm_metrics = None
351+
self._vllm_metrics = []
352352
# TODO: Do not read metrics directly from the vLLM engine, read from prometheus
353353
# client to allow the use of ZMQ process when metrics are enabled. See
354354
# https://github.com/vllm-project/vllm/blob/v0.6.3.post1/vllm/entrypoints/openai/api_server.py#L222-L245
@@ -359,9 +359,8 @@ def _setup_metrics(self):
359359
"version": self.args["model_version"],
360360
}
361361
# Add vLLM custom metrics
362-
vllm_config = self._llm_engine.engine.vllm_config
363-
self._vllm_metrics = VllmStatLogger(labels, vllm_config, self.logger)
364-
self._llm_engine.add_logger("triton", self._vllm_metrics)
362+
factory = VllmStatLoggerFactory(labels, self.logger)
363+
self._vllm_metrics.append(factory)
365364
except pb_utils.TritonModelException as e:
366365
if "metrics not supported" in str(e):
367366
# Metrics are disabled at the server
@@ -785,8 +784,8 @@ def finalize(self):
785784
self._response_thread = None
786785

787786
# Shutdown the metrics thread.
788-
if self._vllm_metrics is not None:
789-
self._vllm_metrics.finalize()
787+
for stat_logger_factory in self._vllm_metrics:
788+
stat_logger_factory.finalize()
790789

791790
# When using parallel tensors, the stub process may not shutdown due to
792791
# unreleased references, so manually run the garbage collector once.

src/utils/metrics.py

Lines changed: 60 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,12 @@
2626

2727
import queue
2828
import threading
29-
from typing import Dict, List, Union
29+
from typing import Dict, List, Optional, Union
3030

3131
import triton_python_backend_utils as pb_utils
3232
from vllm.config import VllmConfig
33-
from vllm.engine.metrics import StatLoggerBase as VllmStatLoggerBase
34-
from vllm.engine.metrics import Stats as VllmStats
35-
from vllm.engine.metrics import SupportsMetricsInfo, build_1_2_5_buckets
33+
from vllm.v1.metrics.loggers import StatLoggerBase, build_1_2_5_buckets
34+
from vllm.v1.metrics.stats import IterationStats, SchedulerStats
3635

3736

3837
class TritonMetrics:
@@ -161,13 +160,35 @@ def __init__(self, labels: List[str], max_model_len: int):
161160
)
162161

163162

164-
class VllmStatLogger(VllmStatLoggerBase):
163+
# Create a partially initialized callable that adapts VllmStatLogger to StatLoggerFactory interface
164+
class VllmStatLoggerFactory:
165+
def __init__(self, labels, log_logger):
166+
self._labels = labels
167+
self._log_logger = log_logger
168+
self._instances_list = []
169+
170+
def __call__(self, vllm_config, engine_index):
171+
stat_logger = VllmStatLogger(
172+
self._labels, self._log_logger, vllm_config, engine_index
173+
)
174+
self._instances_list.append(stat_logger)
175+
return stat_logger
176+
177+
def finalize(self):
178+
for stat_logger in self._instances_list:
179+
if stat_logger is not None:
180+
stat_logger.finalize()
181+
182+
183+
class VllmStatLogger(StatLoggerBase):
165184
"""StatLogger is used as an adapter between vLLM stats collector and Triton metrics provider."""
166185

167-
def __init__(self, labels: Dict, vllm_config: VllmConfig, log_logger) -> None:
186+
def __init__(
187+
self, labels: Dict, log_logger, vllm_config: VllmConfig, engine_index: int
188+
) -> None:
168189
# Tracked stats over current local logging interval.
169190
# local_interval not used here. It's for vLLM logs to stdout.
170-
super().__init__(local_interval=0, vllm_config=vllm_config)
191+
super().__init__(vllm_config=vllm_config, engine_index=engine_index)
171192
self.metrics = TritonMetrics(
172193
labels=labels, max_model_len=vllm_config.model_config.max_model_len
173194
)
@@ -176,12 +197,9 @@ def __init__(self, labels: Dict, vllm_config: VllmConfig, log_logger) -> None:
176197
# Starting the metrics thread. It allows vLLM to keep making progress
177198
# while reporting metrics to triton metrics service.
178199
self._logger_queue = queue.Queue()
179-
self._logger_thread = threading.Thread(target=self.logger_loop)
200+
self._logger_thread = threading.Thread(target=self._logger_loop)
180201
self._logger_thread.start()
181202

182-
def info(self, type: str, obj: SupportsMetricsInfo) -> None:
183-
pass
184-
185203
def _log_counter(self, counter, data: Union[int, float]) -> None:
186204
"""Convenience function for logging to counter.
187205
@@ -208,7 +226,12 @@ def _log_histogram(self, histogram, data: Union[List[int], List[float]]) -> None
208226
for datum in data:
209227
self._logger_queue.put_nowait((histogram, "observe", datum))
210228

211-
def log(self, stats: VllmStats) -> None:
229+
def record(
230+
self,
231+
scheduler_stats: Optional[SchedulerStats],
232+
iteration_stats: Optional[IterationStats],
233+
engine_idx: int = 0,
234+
) -> None:
212235
"""Report stats to Triton metrics server.
213236
214237
Args:
@@ -217,38 +240,54 @@ def log(self, stats: VllmStats) -> None:
217240
Returns:
218241
None
219242
"""
243+
244+
# Parse finished request stats into lists
245+
e2e_latency: List[float] = []
246+
num_prompt_tokens: List[int] = []
247+
num_generation_tokens: List[int] = []
248+
for finished_req in iteration_stats.finished_requests:
249+
e2e_latency.append(finished_req.e2e_latency)
250+
num_prompt_tokens.append(finished_req.num_prompt_tokens)
251+
num_generation_tokens.append(finished_req.num_generation_tokens)
252+
220253
# The list of vLLM metrics reporting to Triton is also documented here.
221254
# https://github.com/triton-inference-server/vllm_backend/blob/main/README.md#triton-metrics
222255
counter_metrics = [
223-
(self.metrics.counter_prompt_tokens, stats.num_prompt_tokens_iter),
224-
(self.metrics.counter_generation_tokens, stats.num_generation_tokens_iter),
256+
(self.metrics.counter_prompt_tokens, iteration_stats.num_prompt_tokens),
257+
(
258+
self.metrics.counter_generation_tokens,
259+
iteration_stats.num_generation_tokens,
260+
),
225261
]
226262
histogram_metrics = [
227263
(
228264
self.metrics.histogram_time_to_first_token,
229-
stats.time_to_first_tokens_iter,
265+
iteration_stats.time_to_first_tokens_iter,
230266
),
231267
(
232268
self.metrics.histogram_time_per_output_token,
233-
stats.time_per_output_tokens_iter,
269+
iteration_stats.inter_token_latencies_iter,
234270
),
235-
(self.metrics.histogram_e2e_time_request, stats.time_e2e_requests),
271+
(self.metrics.histogram_e2e_time_request, e2e_latency),
236272
(
237273
self.metrics.histogram_num_prompt_tokens_request,
238-
stats.num_prompt_tokens_requests,
274+
num_prompt_tokens,
239275
),
240276
(
241277
self.metrics.histogram_num_generation_tokens_request,
242-
stats.num_generation_tokens_requests,
278+
num_generation_tokens,
243279
),
244-
(self.metrics.histogram_n_request, stats.n_requests),
280+
(self.metrics.histogram_n_request, iteration_stats.n_params_iter),
245281
]
246282
for metric, data in counter_metrics:
247283
self._log_counter(metric, data)
248284
for metric, data in histogram_metrics:
249285
self._log_histogram(metric, data)
250286

251-
def logger_loop(self):
287+
def log_engine_initialized(self) -> None:
288+
pass
289+
290+
def _logger_loop(self):
252291
while True:
253292
item = self._logger_queue.get()
254293
# To signal shutdown a None item will be added to the queue.

0 commit comments

Comments
 (0)