Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions ci/L0_backend_vllm/accuracy_test/accuracy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,6 @@ def test_guided_decoding(self):
sampling_params = SAMPLING_PARAMETERS
guided_decoding_params = {
"choice": ["Positive", "Negative"],
"backend": "outlines",
}
sampling_params["guided_decoding"] = json.dumps(guided_decoding_params)
for i in range(len(GUIDED_PROMPTS)):
Expand Down Expand Up @@ -245,7 +244,6 @@ def tearDown(self):
if FLAGS.generate_guided_baseline:
guided_decoding_params = {
"choice": ["Positive", "Negative"],
"backend": "outlines",
}
guided_generation = GuidedDecodingParams(**guided_decoding_params)
asyncio.run(
Expand Down
12 changes: 0 additions & 12 deletions ci/L0_backend_vllm/accuracy_test/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -48,17 +48,11 @@ RET=0
set +e
# Need to generate baseline first, since running 2 vLLM engines causes
# memory issues: https://github.com/vllm-project/vllm/issues/2248
export VLLM_USE_V1=0
export VLLM_WORKER_MULTIPROC_METHOD=spawn
python3 $CLIENT_PY --generate-baseline >> $VLLM_ENGINE_LOG 2>&1 & BASELINE_PID=$!
wait $BASELINE_PID

python3 $CLIENT_PY --generate-guided-baseline > $VLLM_ENGINE_LOG 2>&1 & BASELINE_PID=$!
wait $BASELINE_PID

unset VLLM_USE_V1
unset VLLM_WORKER_MULTIPROC_METHOD

set -e

run_server
Expand Down Expand Up @@ -88,12 +82,6 @@ set -e
kill $SERVER_PID
wait $SERVER_PID

# Check that warning about V1 Engine appears in log - this warning is expected
if ! grep -q "Engine in background thread is experimental on VLLM_USE_V1=1. Falling back to V0 Engine." $SERVER_LOG; then
echo -e "\n***\n*** ERROR: Expected warning about vLLM falling back to V0 Engine not found in logs.\n***"
RET=1
fi

rm -rf models/

if [ $RET -eq 1 ]; then
Expand Down
8 changes: 3 additions & 5 deletions ci/L0_backend_vllm/metrics_test/vllm_metrics_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,13 +173,11 @@ def test_vllm_metrics(self):
# TODO: Revisit this test due to the removal of best_of
def test_custom_sampling_params(self):
# Adding sampling parameters for testing metrics.
# Definitions can be found here https://docs.vllm.ai/en/latest/dev/sampling_params.html
n, best_of = 2, 4
# Definitions can be found here https://docs.vllm.ai/en/latest/api/vllm/sampling_params.html
n, temperature = 2, 1
custom_sampling_parameters = self.sampling_parameters.copy()
# Changing "temperature" because "best_of" must be 1 when using greedy
# sampling, i.e. "temperature": "0".
custom_sampling_parameters.update(
{"n": str(n), "best_of": str(best_of), "temperature": "1"}
{"n": str(n), "temperature": str(temperature)}
)

# Test vLLM metrics
Expand Down
19 changes: 10 additions & 9 deletions ci/L0_check_health_vllm/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -48,23 +48,24 @@ function enable_health_check {
}

VLLM_INSTALL_PATH="/usr/local/lib/python3.12/dist-packages/vllm"
VLLM_V1_ENGINE_PATH="$VLLM_INSTALL_PATH/v1/engine"

function mock_vllm_async_llm_engine {
# backup original file
mv $VLLM_INSTALL_PATH/engine/multiprocessing/client.py $VLLM_INSTALL_PATH/engine/multiprocessing/client.py.backup
cp $VLLM_INSTALL_PATH/engine/multiprocessing/client.py.backup $VLLM_INSTALL_PATH/engine/multiprocessing/client.py
mv $VLLM_V1_ENGINE_PATH/async_llm.py $VLLM_V1_ENGINE_PATH/async_llm.py.backup
cp $VLLM_V1_ENGINE_PATH/async_llm.py.backup $VLLM_V1_ENGINE_PATH/async_llm.py
# overwrite the original check_health method
echo -e "" >> $VLLM_INSTALL_PATH/engine/multiprocessing/client.py
echo -e " async def check_health(self, check_count=[0]):" >> $VLLM_INSTALL_PATH/engine/multiprocessing/client.py
echo -e " check_count[0] += 1" >> $VLLM_INSTALL_PATH/engine/multiprocessing/client.py
echo -e " if check_count[0] > 1:" >> $VLLM_INSTALL_PATH/engine/multiprocessing/client.py
echo -e " raise RuntimeError(\"Simulated vLLM check_health() failure\")" >> $VLLM_INSTALL_PATH/engine/multiprocessing/client.py
echo -e "" >> $VLLM_V1_ENGINE_PATH/async_llm.py
echo -e " async def check_health(self, check_count=[0]):" >> $VLLM_V1_ENGINE_PATH/async_llm.py
echo -e " check_count[0] += 1" >> $VLLM_V1_ENGINE_PATH/async_llm.py
echo -e " if check_count[0] > 1:" >> $VLLM_V1_ENGINE_PATH/async_llm.py
echo -e " raise RuntimeError(\"Simulated vLLM check_health() failure\")" >> $VLLM_V1_ENGINE_PATH/async_llm.py
}

function unmock_vllm_async_llm_engine {
# restore from backup
rm -f $VLLM_INSTALL_PATH/engine/multiprocessing/client.py
mv $VLLM_INSTALL_PATH/engine/multiprocessing/client.py.backup $VLLM_INSTALL_PATH/engine/multiprocessing/client.py
rm -f $VLLM_V1_ENGINE_PATH/async_llm.py
mv $VLLM_V1_ENGINE_PATH/async_llm.py.backup $VLLM_V1_ENGINE_PATH/async_llm.py
}

function test_check_health {
Expand Down
21 changes: 10 additions & 11 deletions src/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
from typing import Dict, List

import numpy as np
import torch
import triton_python_backend_utils as pb_utils
from PIL import Image
from vllm.engine.arg_utils import AsyncEngineArgs
Expand All @@ -45,7 +44,7 @@
from vllm.lora.request import LoRARequest
from vllm.utils import random_uuid

from utils.metrics import VllmStatLogger
from utils.metrics import VllmStatLoggerFactory
from utils.vllm_backend_utils import TritonSamplingParams

_VLLM_ENGINE_ARGS_FILENAME = "model.json"
Expand Down Expand Up @@ -184,12 +183,12 @@ def initialize(self, args):
and not self._aync_engine_args.disable_log_stats
)

# Starting the vLLM engine and its event thread running the AsyncIO event loop.
self._init_engine()

# Setup vLLM metrics
self._setup_metrics()

# Starting the vLLM engine and its event thread running the AsyncIO event loop.
self._init_engine()

# Starting the response thread. It allows vLLM to keep making progress while
# response sender(s) are sending responses to server frontend.
self._response_queue = queue.Queue()
Expand Down Expand Up @@ -258,6 +257,7 @@ async def _run_llm_engine(self):
async with build_async_engine_client_from_engine_args(
engine_args=self._aync_engine_args,
disable_frontend_multiprocessing=self._enable_metrics,
stat_loggers=self._vllm_metrics,
) as engine:
# Capture the engine event loop and make it visible to other threads.
self._event_loop = asyncio.get_running_loop()
Expand Down Expand Up @@ -348,7 +348,7 @@ def _setup_lora(self):
)

def _setup_metrics(self):
self._vllm_metrics = None
self._vllm_metrics = []
# TODO: Do not read metrics directly from the vLLM engine, read from prometheus
# client to allow the use of ZMQ process when metrics are enabled. See
# https://github.com/vllm-project/vllm/blob/v0.6.3.post1/vllm/entrypoints/openai/api_server.py#L222-L245
Expand All @@ -359,9 +359,8 @@ def _setup_metrics(self):
"version": self.args["model_version"],
}
# Add vLLM custom metrics
vllm_config = self._llm_engine.engine.vllm_config
self._vllm_metrics = VllmStatLogger(labels, vllm_config, self.logger)
self._llm_engine.add_logger("triton", self._vllm_metrics)
factory = VllmStatLoggerFactory(labels, self.logger)
self._vllm_metrics.append(factory)
except pb_utils.TritonModelException as e:
if "metrics not supported" in str(e):
# Metrics are disabled at the server
Expand Down Expand Up @@ -785,8 +784,8 @@ def finalize(self):
self._response_thread = None

# Shutdown the metrics thread.
if self._vllm_metrics is not None:
self._vllm_metrics.finalize()
for stat_logger_factory in self._vllm_metrics:
stat_logger_factory.finalize()

# When using parallel tensors, the stub process may not shutdown due to
# unreleased references, so manually run the garbage collector once.
Expand Down
81 changes: 60 additions & 21 deletions src/utils/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,12 @@

import queue
import threading
from typing import Dict, List, Union
from typing import Dict, List, Optional, Union

import triton_python_backend_utils as pb_utils
from vllm.config import VllmConfig
from vllm.engine.metrics import StatLoggerBase as VllmStatLoggerBase
from vllm.engine.metrics import Stats as VllmStats
from vllm.engine.metrics import SupportsMetricsInfo, build_1_2_5_buckets
from vllm.v1.metrics.loggers import StatLoggerBase, build_1_2_5_buckets
from vllm.v1.metrics.stats import IterationStats, SchedulerStats


class TritonMetrics:
Expand Down Expand Up @@ -161,13 +160,35 @@ def __init__(self, labels: List[str], max_model_len: int):
)


class VllmStatLogger(VllmStatLoggerBase):
# Create a partially initialized callable that adapts VllmStatLogger to StatLoggerFactory interface
class VllmStatLoggerFactory:
def __init__(self, labels, log_logger):
self._labels = labels
self._log_logger = log_logger
self._instances_list = []

def __call__(self, vllm_config, engine_index):
stat_logger = VllmStatLogger(
self._labels, self._log_logger, vllm_config, engine_index
)
self._instances_list.append(stat_logger)
return stat_logger

def finalize(self):
for stat_logger in self._instances_list:
if stat_logger is not None:
stat_logger.finalize()


class VllmStatLogger(StatLoggerBase):
"""StatLogger is used as an adapter between vLLM stats collector and Triton metrics provider."""

def __init__(self, labels: Dict, vllm_config: VllmConfig, log_logger) -> None:
def __init__(
self, labels: Dict, log_logger, vllm_config: VllmConfig, engine_index: int
) -> None:
# Tracked stats over current local logging interval.
# local_interval not used here. It's for vLLM logs to stdout.
super().__init__(local_interval=0, vllm_config=vllm_config)
super().__init__(vllm_config=vllm_config, engine_index=engine_index)
self.metrics = TritonMetrics(
labels=labels, max_model_len=vllm_config.model_config.max_model_len
)
Expand All @@ -176,12 +197,9 @@ def __init__(self, labels: Dict, vllm_config: VllmConfig, log_logger) -> None:
# Starting the metrics thread. It allows vLLM to keep making progress
# while reporting metrics to triton metrics service.
self._logger_queue = queue.Queue()
self._logger_thread = threading.Thread(target=self.logger_loop)
self._logger_thread = threading.Thread(target=self._logger_loop)
self._logger_thread.start()

def info(self, type: str, obj: SupportsMetricsInfo) -> None:
pass

def _log_counter(self, counter, data: Union[int, float]) -> None:
"""Convenience function for logging to counter.

Expand All @@ -208,7 +226,12 @@ def _log_histogram(self, histogram, data: Union[List[int], List[float]]) -> None
for datum in data:
self._logger_queue.put_nowait((histogram, "observe", datum))

def log(self, stats: VllmStats) -> None:
def record(
self,
scheduler_stats: Optional[SchedulerStats],
iteration_stats: Optional[IterationStats],
engine_idx: int = 0,
) -> None:
"""Report stats to Triton metrics server.

Args:
Expand All @@ -217,38 +240,54 @@ def log(self, stats: VllmStats) -> None:
Returns:
None
"""

# Parse finished request stats into lists
e2e_latency: List[float] = []
num_prompt_tokens: List[int] = []
num_generation_tokens: List[int] = []
for finished_req in iteration_stats.finished_requests:
e2e_latency.append(finished_req.e2e_latency)
num_prompt_tokens.append(finished_req.num_prompt_tokens)
num_generation_tokens.append(finished_req.num_generation_tokens)

# The list of vLLM metrics reporting to Triton is also documented here.
# https://github.com/triton-inference-server/vllm_backend/blob/main/README.md#triton-metrics
counter_metrics = [
(self.metrics.counter_prompt_tokens, stats.num_prompt_tokens_iter),
(self.metrics.counter_generation_tokens, stats.num_generation_tokens_iter),
(self.metrics.counter_prompt_tokens, iteration_stats.num_prompt_tokens),
(
self.metrics.counter_generation_tokens,
iteration_stats.num_generation_tokens,
),
]
histogram_metrics = [
(
self.metrics.histogram_time_to_first_token,
stats.time_to_first_tokens_iter,
iteration_stats.time_to_first_tokens_iter,
),
(
self.metrics.histogram_time_per_output_token,
stats.time_per_output_tokens_iter,
iteration_stats.inter_token_latencies_iter,
),
(self.metrics.histogram_e2e_time_request, stats.time_e2e_requests),
(self.metrics.histogram_e2e_time_request, e2e_latency),
(
self.metrics.histogram_num_prompt_tokens_request,
stats.num_prompt_tokens_requests,
num_prompt_tokens,
),
(
self.metrics.histogram_num_generation_tokens_request,
stats.num_generation_tokens_requests,
num_generation_tokens,
),
(self.metrics.histogram_n_request, stats.n_requests),
(self.metrics.histogram_n_request, iteration_stats.n_params_iter),
]
for metric, data in counter_metrics:
self._log_counter(metric, data)
for metric, data in histogram_metrics:
self._log_histogram(metric, data)

def logger_loop(self):
def log_engine_initialized(self) -> None:
pass

def _logger_loop(self):
while True:
item = self._logger_queue.get()
# To signal shutdown a None item will be added to the queue.
Expand Down
Loading