diff --git a/tests/entrypoints/openai/test_metrics.py b/tests/entrypoints/openai/test_metrics.py index 711505c74bca..dbcec9d31fc9 100644 --- a/tests/entrypoints/openai/test_metrics.py +++ b/tests/entrypoints/openai/test_metrics.py @@ -18,25 +18,15 @@ from ...utils import RemoteOpenAIServer -MODEL_NAME = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" +MODELS = { + "text": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", + "multimodal": "HuggingFaceTB/SmolVLM-256M-Instruct", +} PREV_MINOR_VERSION = version._prev_minor_version() -@pytest.fixture(scope="module", params=[True]) -def use_v1(request): - # Module-scoped variant of run_with_both_engines - # - # Use this fixture to run a test with both v0 and v1, and - # also to conditionalize the test logic e.g. - # - # def test_metrics_exist(use_v1, server, client): - # ... - # expected = EXPECTED_V1_METRICS if use_v1 else EXPECTED_METRICS - # for metric in expected: - # assert metric in response.text - # - # @skip_v1 wouldn't work here because this is a module-level - # fixture - per-function decorators would have no effect +@pytest.fixture(scope="module", params=list(MODELS.keys())) +def model_key(request): yield request.param @@ -63,13 +53,12 @@ def default_server_args(): f"--show-hidden-metrics-for-version={PREV_MINOR_VERSION}", ], ) -def server(use_v1, default_server_args, request): +def server(model_key, default_server_args, request): if request.param: default_server_args.append(request.param) - env_dict = dict(VLLM_USE_V1="1" if use_v1 else "0") - with RemoteOpenAIServer( - MODEL_NAME, default_server_args, env_dict=env_dict - ) as remote_server: + + model_name = MODELS[model_key] + with RemoteOpenAIServer(model_name, default_server_args) as remote_server: yield remote_server @@ -80,63 +69,70 @@ async def client(server): _PROMPT = "Hello my name is Robert and I love magic" -tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) -_TOKENIZED_PROMPT = tokenizer(_PROMPT)["input_ids"] - -_NUM_REQUESTS = 10 -_NUM_PROMPT_TOKENS_PER_REQUEST = len(_TOKENIZED_PROMPT) -_NUM_GENERATION_TOKENS_PER_REQUEST = 10 - -# {metric_family: [(suffix, expected_value)]} -EXPECTED_VALUES = { - "vllm:time_to_first_token_seconds": [("_count", _NUM_REQUESTS)], - "vllm:time_per_output_token_seconds": [ - ("_count", _NUM_REQUESTS * (_NUM_GENERATION_TOKENS_PER_REQUEST - 1)) - ], - "vllm:e2e_request_latency_seconds": [("_count", _NUM_REQUESTS)], - "vllm:request_queue_time_seconds": [("_count", _NUM_REQUESTS)], - "vllm:request_inference_time_seconds": [("_count", _NUM_REQUESTS)], - "vllm:request_prefill_time_seconds": [("_count", _NUM_REQUESTS)], - "vllm:request_decode_time_seconds": [("_count", _NUM_REQUESTS)], - "vllm:request_prompt_tokens": [ - ("_sum", _NUM_REQUESTS * _NUM_PROMPT_TOKENS_PER_REQUEST), - ("_count", _NUM_REQUESTS), - ], - "vllm:request_generation_tokens": [ - ("_sum", _NUM_REQUESTS * _NUM_GENERATION_TOKENS_PER_REQUEST), - ("_count", _NUM_REQUESTS), - ], - "vllm:request_params_n": [("_count", _NUM_REQUESTS)], - "vllm:request_params_max_tokens": [ - ("_sum", _NUM_REQUESTS * _NUM_GENERATION_TOKENS_PER_REQUEST), - ("_count", _NUM_REQUESTS), - ], - "vllm:iteration_tokens_total": [ - ( - "_sum", - _NUM_REQUESTS - * (_NUM_PROMPT_TOKENS_PER_REQUEST + _NUM_GENERATION_TOKENS_PER_REQUEST), - ), - ("_count", _NUM_REQUESTS * _NUM_GENERATION_TOKENS_PER_REQUEST), - ], - "vllm:prompt_tokens": [("_total", _NUM_REQUESTS * _NUM_PROMPT_TOKENS_PER_REQUEST)], - "vllm:generation_tokens": [ - ("_total", _NUM_REQUESTS * _NUM_PROMPT_TOKENS_PER_REQUEST) - ], - "vllm:request_success": [("_total", _NUM_REQUESTS)], -} +_IMAGE_URL = "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg" + + +def _get_expected_values(num_requests: int, prompt_ids: list[int], max_tokens: int): + num_prompt_tokens = len(prompt_ids) + + # {metric_family: [(suffix, expected_value)]} + return { + "vllm:time_to_first_token_seconds": [("_count", num_requests)], + "vllm:time_per_output_token_seconds": [ + ("_count", num_requests * (max_tokens - 1)) + ], + "vllm:e2e_request_latency_seconds": [("_count", num_requests)], + "vllm:request_queue_time_seconds": [("_count", num_requests)], + "vllm:request_inference_time_seconds": [("_count", num_requests)], + "vllm:request_prefill_time_seconds": [("_count", num_requests)], + "vllm:request_decode_time_seconds": [("_count", num_requests)], + "vllm:request_prompt_tokens": [ + ("_sum", num_requests * num_prompt_tokens), + ("_count", num_requests), + ], + "vllm:request_generation_tokens": [ + ("_sum", num_requests * max_tokens), + ("_count", num_requests), + ], + "vllm:request_params_n": [("_count", num_requests)], + "vllm:request_params_max_tokens": [ + ("_sum", num_requests * max_tokens), + ("_count", num_requests), + ], + "vllm:iteration_tokens_total": [ + ( + "_sum", + num_requests * (num_prompt_tokens + max_tokens), + ), + ("_count", num_requests * max_tokens), + ], + "vllm:prompt_tokens": [("_total", num_requests * num_prompt_tokens)], + "vllm:generation_tokens": [("_total", num_requests * max_tokens)], + "vllm:request_success": [("_total", num_requests)], + } @pytest.mark.asyncio async def test_metrics_counts( - server: RemoteOpenAIServer, client: openai.AsyncClient, use_v1: bool + server: RemoteOpenAIServer, + client: openai.AsyncClient, + model_key: str, ): - for _ in range(_NUM_REQUESTS): + if model_key == "multimodal": + pytest.skip("Unnecessary test") + + model_name = MODELS[model_key] + tokenizer = AutoTokenizer.from_pretrained(model_name) + prompt_ids = tokenizer.encode(_PROMPT) + num_requests = 10 + max_tokens = 10 + + for _ in range(num_requests): # sending a request triggers the metrics to be logged. await client.completions.create( - model=MODEL_NAME, - prompt=_TOKENIZED_PROMPT, - max_tokens=_NUM_GENERATION_TOKENS_PER_REQUEST, + model=model_name, + prompt=prompt_ids, + max_tokens=max_tokens, ) response = requests.get(server.url_for("metrics")) @@ -144,8 +140,9 @@ async def test_metrics_counts( assert response.status_code == HTTPStatus.OK # Loop over all expected metric_families - for metric_family, suffix_values_list in EXPECTED_VALUES.items(): - if (use_v1 and metric_family not in EXPECTED_METRICS_V1) or ( + expected_values = _get_expected_values(num_requests, prompt_ids, max_tokens) + for metric_family, suffix_values_list in expected_values.items(): + if metric_family not in EXPECTED_METRICS_V1 or ( not server.show_hidden_metrics and metric_family in HIDDEN_DEPRECATED_METRICS ): @@ -183,62 +180,6 @@ async def test_metrics_counts( assert found_metric, f"Did not find {metric_family} in prom endpoint" -EXPECTED_METRICS = [ - "vllm:num_requests_running", - "vllm:num_requests_waiting", - "vllm:gpu_cache_usage_perc", - "vllm:time_to_first_token_seconds_sum", - "vllm:time_to_first_token_seconds_bucket", - "vllm:time_to_first_token_seconds_count", - "vllm:time_per_output_token_seconds_sum", - "vllm:time_per_output_token_seconds_bucket", - "vllm:time_per_output_token_seconds_count", - "vllm:e2e_request_latency_seconds_sum", - "vllm:e2e_request_latency_seconds_bucket", - "vllm:e2e_request_latency_seconds_count", - "vllm:request_queue_time_seconds_sum", - "vllm:request_queue_time_seconds_bucket", - "vllm:request_queue_time_seconds_count", - "vllm:request_inference_time_seconds_sum", - "vllm:request_inference_time_seconds_bucket", - "vllm:request_inference_time_seconds_count", - "vllm:request_prefill_time_seconds_sum", - "vllm:request_prefill_time_seconds_bucket", - "vllm:request_prefill_time_seconds_count", - "vllm:request_decode_time_seconds_sum", - "vllm:request_decode_time_seconds_bucket", - "vllm:request_decode_time_seconds_count", - "vllm:request_prompt_tokens_sum", - "vllm:request_prompt_tokens_bucket", - "vllm:request_prompt_tokens_count", - "vllm:request_generation_tokens_sum", - "vllm:request_generation_tokens_bucket", - "vllm:request_generation_tokens_count", - "vllm:request_params_n_sum", - "vllm:request_params_n_bucket", - "vllm:request_params_n_count", - "vllm:request_params_max_tokens_sum", - "vllm:request_params_max_tokens_bucket", - "vllm:request_params_max_tokens_count", - "vllm:iteration_tokens_total", - "vllm:num_preemptions_total", - "vllm:prompt_tokens_total", - "vllm:generation_tokens_total", - "vllm:request_success_total", - "vllm:cache_config_info", - # labels in cache_config_info - "block_size", - "cache_dtype", - "cpu_offload_gb", - "enable_prefix_caching", - "gpu_memory_utilization", - "num_cpu_blocks", - "num_gpu_blocks", - "num_gpu_blocks_override", - "sliding_window", - "swap_space_bytes", -] - EXPECTED_METRICS_V1 = [ "vllm:num_requests_running", "vllm:num_requests_waiting", @@ -292,6 +233,11 @@ async def test_metrics_counts( "vllm:request_decode_time_seconds_count", ] +EXPECTED_METRICS_MM = [ + "vllm:mm_cache_queries", + "vllm:mm_cache_hits", +] + HIDDEN_DEPRECATED_METRICS: list[str] = [ "vllm:gpu_cache_usage_perc", "vllm:gpu_prefix_cache_queries", @@ -304,17 +250,45 @@ async def test_metrics_counts( @pytest.mark.asyncio async def test_metrics_exist( - server: RemoteOpenAIServer, client: openai.AsyncClient, use_v1: bool + server: RemoteOpenAIServer, + client: openai.AsyncClient, + model_key: str, ): + model_name = MODELS[model_key] + # sending a request triggers the metrics to be logged. - await client.completions.create( - model=MODEL_NAME, prompt="Hello, my name is", max_tokens=5, temperature=0.0 - ) + if model_key == "text": + await client.completions.create( + model=model_name, + prompt="Hello, my name is", + max_tokens=5, + temperature=0.0, + ) + else: + await client.chat.completions.create( + model=model_name, + messages=[ + { + "role": "user", + "content": [ + {"type": "image_url", "image_url": {"url": _IMAGE_URL}}, + {"type": "text", "text": "What's in this image?"}, + ], + } + ], + max_tokens=5, + temperature=0.0, + ) response = requests.get(server.url_for("metrics")) assert response.status_code == HTTPStatus.OK - for metric in EXPECTED_METRICS_V1 if use_v1 else EXPECTED_METRICS: + expected_metrics = EXPECTED_METRICS_V1 + if model_key == "multimodal": + # NOTE: Don't use in-place assignment + expected_metrics = expected_metrics + EXPECTED_METRICS_MM + + for metric in expected_metrics: if metric in HIDDEN_DEPRECATED_METRICS and not server.show_hidden_metrics: continue assert metric in response.text @@ -322,10 +296,16 @@ async def test_metrics_exist( @pytest.mark.asyncio async def test_abort_metrics_reset( - server: RemoteOpenAIServer, client: openai.AsyncClient, use_v1: bool + server: RemoteOpenAIServer, + client: openai.AsyncClient, + model_key: str, ): + model_name = MODELS[model_key] + tokenizer = AutoTokenizer.from_pretrained(model_name) + prompt_ids = tokenizer.encode(_PROMPT) + running_requests, waiting_requests, kv_cache_usage = _get_running_metrics_from_api( - server, use_v1 + server, ) # Expect no running requests or kvcache usage @@ -338,8 +318,8 @@ async def test_abort_metrics_reset( for _ in range(3): task = asyncio.create_task( client.completions.create( - model=MODEL_NAME, - prompt=_TOKENIZED_PROMPT, + model=model_name, + prompt=prompt_ids, max_tokens=100, # Long generation to give time to abort temperature=0.0, ) @@ -351,7 +331,7 @@ async def test_abort_metrics_reset( # Check that we have running requests running_requests, waiting_requests, kv_cache_usage = _get_running_metrics_from_api( - server, use_v1 + server, ) # Expect running requests and kvcache usage @@ -371,7 +351,7 @@ async def test_abort_metrics_reset( # Verify running and waiting requests counts and KV cache usage are zero running_requests_after, waiting_requests_after, kv_cache_usage_after = ( - _get_running_metrics_from_api(server, use_v1) + _get_running_metrics_from_api(server) ) assert running_requests_after == 0, ( @@ -385,7 +365,7 @@ async def test_abort_metrics_reset( ) -def _get_running_metrics_from_api(server: RemoteOpenAIServer, use_v1: bool): +def _get_running_metrics_from_api(server: RemoteOpenAIServer): """Return (running_count, waiting_count, kv_cache_usage)""" response = requests.get(server.url_for("metrics")) @@ -394,9 +374,7 @@ def _get_running_metrics_from_api(server: RemoteOpenAIServer, use_v1: bool): # Verify running and waiting requests counts and KV cache usage are zero running_requests, waiting_requests, kv_cache_usage = None, None, None - kv_cache_usage_metric = ( - "vllm:kv_cache_usage_perc" if use_v1 else "vllm:gpu_cache_usage_perc" - ) + kv_cache_usage_metric = "vllm:kv_cache_usage_perc" for family in text_string_to_metric_families(response.text): if family.name == "vllm:num_requests_running": @@ -422,7 +400,7 @@ def _get_running_metrics_from_api(server: RemoteOpenAIServer, use_v1: bool): return running_requests, waiting_requests, kv_cache_usage -def test_metrics_exist_run_batch(use_v1: bool): +def test_metrics_exist_run_batch(): input_batch = """{"custom_id": "request-0", "method": "POST", "url": "/v1/embeddings", "body": {"model": "intfloat/multilingual-e5-small", "input": "You are a helpful assistant."}}""" # noqa: E501 base_url = "0.0.0.0" @@ -452,7 +430,6 @@ def test_metrics_exist_run_batch(use_v1: bool): "--port", port, ], - env={"VLLM_USE_V1": "1"}, ) def is_server_up(url): diff --git a/tests/v1/core/test_kv_cache_utils.py b/tests/v1/core/test_kv_cache_utils.py index aed00a60aeb4..50e16d261930 100644 --- a/tests/v1/core/test_kv_cache_utils.py +++ b/tests/v1/core/test_kv_cache_utils.py @@ -20,7 +20,6 @@ BlockHash, FreeKVCacheBlockQueue, KVCacheBlock, - PrefixCachingMetrics, estimate_max_model_len, generate_block_hash_extra_keys, generate_scheduler_kv_cache_config, @@ -42,7 +41,7 @@ SlidingWindowSpec, UniformTypeKVCacheSpecs, ) -from vllm.v1.metrics.stats import PrefixCacheStats +from vllm.v1.metrics.stats import CachingMetrics, PrefixCacheStats from vllm.v1.request import Request pytestmark = pytest.mark.cpu_test @@ -536,7 +535,7 @@ def test_metrics(): """ Test the prefix caching metrics. """ - metrics = PrefixCachingMetrics(max_recent_requests=5) + metrics = CachingMetrics(max_recent_requests=5) assert metrics.hit_rate == 0.0 metrics.observe(_stats(1, 20, 9)) @@ -568,7 +567,7 @@ def test_metrics_empty_stats(): """ Test the prefix caching metrics with empty stats. """ - metrics = PrefixCachingMetrics(max_recent_requests=5) + metrics = CachingMetrics(max_recent_requests=5) metrics.observe(_stats(0, 0, 0)) metrics.observe(_stats(1, 20, 9)) metrics.observe(_stats(0, 0, 0)) diff --git a/vllm/multimodal/cache.py b/vllm/multimodal/cache.py index 15aa91a04092..9831d9ded6ee 100644 --- a/vllm/multimodal/cache.py +++ b/vllm/multimodal/cache.py @@ -17,8 +17,9 @@ ) from vllm.envs import VLLM_OBJECT_STORAGE_SHM_BUFFER_NAME from vllm.logger import init_logger -from vllm.utils import GiB_bytes, LRUCache, MiB_bytes +from vllm.utils import CacheInfo, GiB_bytes, LRUCache, MiB_bytes from vllm.utils.jsontree import json_count_leaves, json_map_leaves, json_reduce_leaves +from vllm.v1.metrics.stats import MultiModalCacheStats from .inputs import ( MultiModalBatchedField, @@ -301,6 +302,16 @@ def is_cached(self, mm_hashes: list[str]) -> list[bool]: """ return [self.is_cached_item(mm_hash) for mm_hash in mm_hashes] + @abstractmethod + def make_stats(self) -> MultiModalCacheStats: + """ + Get (and reset) the multi-modal cache stats. + + Returns: + The current multi-modal caching stats. + """ + raise NotImplementedError + class MultiModalProcessorOnlyCache(BaseMultiModalProcessorCache): """ @@ -323,6 +334,8 @@ def __init__(self, model_config: "ModelConfig") -> None: MultiModalProcessorCacheItem, ) + self._stats = MultiModalCacheStats() + @override def is_cached_item(self, mm_hash: str) -> bool: return mm_hash in self._cache @@ -345,6 +358,19 @@ def get_and_update_item( @override def clear_cache(self) -> None: self._cache.clear() + self._stats.reset = True + + @override + def make_stats(self) -> MultiModalCacheStats: + cache = self._cache + + info_delta = cache.stat(delta=True) + self._stats.hits = info_delta.hits + self._stats.queries = info_delta.total + + stats = self._stats + self._stats = MultiModalCacheStats() + return stats class MultiModalProcessorSenderCache(BaseMultiModalProcessorCache): @@ -373,6 +399,8 @@ def __init__(self, model_config: "ModelConfig") -> None: MultiModalProcessorCacheItemMetadata, ) + self._stats = MultiModalCacheStats() + @override def is_cached_item(self, mm_hash: str) -> bool: return mm_hash in self._cache @@ -395,6 +423,19 @@ def get_and_update_item( @override def clear_cache(self) -> None: self._cache.clear() + self._stats.reset = True + + @override + def make_stats(self) -> MultiModalCacheStats: + cache = self._cache + + info_delta = cache.stat(delta=True) + self._stats.hits = info_delta.hits + self._stats.queries = info_delta.total + + stats = self._stats + self._stats = MultiModalCacheStats() + return stats class ShmObjectStoreSenderCache(BaseMultiModalProcessorCache): @@ -429,6 +470,21 @@ def __init__(self, vllm_config: "VllmConfig") -> None: # cache (prompt_updates, modality) for P0 only self._p0_cache: dict[str, tuple[Sequence[ResolvedPromptUpdate], str]] = {} + self._hits = 0 + self._total = 0 + self._last_info = CacheInfo(hits=0, total=0) + self._stats = MultiModalCacheStats() + + def _stat(self, *, delta: bool = False) -> CacheInfo: + info = CacheInfo(hits=self._hits, total=self._total) + + if delta: + info_delta = info - self._last_info + self._last_info = info + info = info_delta + + return info + @override def is_cached_item(self, mm_hash: str) -> bool: return self._shm_cache.is_cached(mm_hash) @@ -440,12 +496,17 @@ def get_and_update_item( mm_hash: str, ) -> MultiModalProcessorCacheOutItem: if self._shm_cache.is_cached(mm_hash): + self._hits += 1 + self._total += 1 + address, monotonic_id = self._shm_cache.get_cached(mm_hash) prompt_updates, modality = self._p0_cache[mm_hash] return self.address_as_item(address, monotonic_id, modality), prompt_updates assert mm_item is not None, f"Expected a cached item for {mm_hash=}" + self._total += 1 + try: address, monotonic_id = self._shm_cache.put(mm_hash, mm_item[0]) # Try to remove dangling items if p0 cache is too large. @@ -468,6 +529,21 @@ def clear_cache(self) -> None: self._shm_cache.clear() self._p0_cache.clear() + self._hits = 0 + self._total = 0 + self._last_info = CacheInfo(hits=0, total=0) + self._stats.reset = True + + @override + def make_stats(self) -> MultiModalCacheStats: + info_delta = self._stat(delta=True) + self._stats.hits = info_delta.hits + self._stats.queries = info_delta.total + + stats = self._stats + self._stats = MultiModalCacheStats() + return stats + def remove_dangling_items(self) -> None: """Remove items that are no longer in the shared memory cache.""" cached_hashes = self._shm_cache.key_index.keys() diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index 4683ad62981f..b1a5fd5454c5 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -4,7 +4,7 @@ import copy import os -from collections import defaultdict, deque +from collections import defaultdict from collections.abc import Iterable, Sequence from dataclasses import dataclass from typing import Any, Callable, NewType, Optional, Union @@ -23,7 +23,6 @@ SlidingWindowSpec, UniformTypeKVCacheSpecs, ) -from vllm.v1.metrics.stats import PrefixCacheStats from vllm.v1.request import Request # BlockHash represents the hash of a single KV-cache block used for @@ -101,78 +100,6 @@ def init_none_hash(hash_fn: Callable[[Any], bytes]): NONE_HASH = BlockHash(hash_fn(hash_seed)) -class PrefixCachingMetrics: - """Metrics for prefix caching with a hit rate of the max recent N requests. - - Args: - max_recent_requests: The number of the max recent requests to aggregate. - Defaults to 1000. - """ - - def __init__(self, max_recent_requests: int = 1000): - self.max_recent_requests = max_recent_requests - # The current aggregated values. - self.aggregated_requests = 0 - self.aggregated_query_total = 0 - self.aggregated_query_hit = 0 - # A deque of (requests, queries, hits) for the most recent requests. - self.query_queue: deque[tuple[int, int, int]] = deque() - - def observe(self, stats: PrefixCacheStats): - """Observe the prefix caching for a set of requests. - - This function is called with information gathered when new requests - are being scheduled and are looking for computed blocks. - - When there are more than `max_recent_requests` requests, the oldest set - of requests are removed from the metrics. - - Args: - stats: The prefix cache stats. - """ - # reset_prefix_cache was invoked before the current update. - # Reset the metrics before aggregating the current stats. - if stats.reset: - self.reset() - - # DO NOT appending empty stats to avoid helpful info get kicked out - # due to sliding window. - if stats.requests == 0: - return - - # Update the metrics. - self.query_queue.append((stats.requests, stats.queries, stats.hits)) - self.aggregated_requests += stats.requests - self.aggregated_query_total += stats.queries - self.aggregated_query_hit += stats.hits - - # Remove the oldest stats until number of requests does not exceed - # the limit. - # NOTE: We preserve the latest added stats regardless. - while ( - len(self.query_queue) > 1 - and self.aggregated_requests > self.max_recent_requests - ): - old_requests, old_queries, old_hits = self.query_queue.popleft() - self.aggregated_requests -= old_requests - self.aggregated_query_total -= old_queries - self.aggregated_query_hit -= old_hits - - def reset(self): - """Reset the metrics.""" - self.aggregated_requests = 0 - self.aggregated_query_total = 0 - self.aggregated_query_hit = 0 - self.query_queue.clear() - - @property - def hit_rate(self) -> float: - """Calculate the hit rate for the past N requests.""" - if self.aggregated_query_total == 0: - return 0.0 - return self.aggregated_query_hit / self.aggregated_query_total - - @dataclass class KVCacheBlock: """KV-cache block metadata.""" diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index ca668bc217e1..826eca6dc09d 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -511,10 +511,14 @@ async def output_handler(): # TODO(rob): make into a coroutine and launch it in # background thread once Prometheus overhead is non-trivial. if logger_manager: + mm_cache = self.processor.input_preprocessor.mm_processor_cache + mm_cache_stats = mm_cache.make_stats() if mm_cache else None + logger_manager.record( engine_idx=outputs.engine_index, scheduler_stats=outputs.scheduler_stats, iteration_stats=iteration_stats, + mm_cache_stats=mm_cache_stats, ) except Exception as e: logger.exception("AsyncLLM output_handler failed.") diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index 9da25c0662a8..e78b939b5021 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -307,9 +307,14 @@ def step(self) -> Union[list[RequestOutput], list[PoolingRequestOutput]]: # 4) Record stats if self.logger_manager is not None: assert outputs.scheduler_stats is not None + + mm_cache = self.processor.input_preprocessor.mm_processor_cache + mm_cache_stats = mm_cache.make_stats() if mm_cache else None + self.logger_manager.record( scheduler_stats=outputs.scheduler_stats, iteration_stats=iteration_stats, + mm_cache_stats=mm_cache_stats, ) self.do_log_stats_with_interval() diff --git a/vllm/v1/metrics/loggers.py b/vllm/v1/metrics/loggers.py index 541af7af1725..cacabc0a0fc2 100644 --- a/vllm/v1/metrics/loggers.py +++ b/vllm/v1/metrics/loggers.py @@ -11,10 +11,14 @@ from vllm.config import SupportsMetricsInfo, VllmConfig from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorLogging from vllm.logger import init_logger -from vllm.v1.core.kv_cache_utils import PrefixCachingMetrics from vllm.v1.engine import FinishReason from vllm.v1.metrics.prometheus import unregister_vllm_metrics -from vllm.v1.metrics.stats import IterationStats, SchedulerStats +from vllm.v1.metrics.stats import ( + CachingMetrics, + IterationStats, + MultiModalCacheStats, + SchedulerStats, +) from vllm.v1.spec_decode.metrics import SpecDecodingLogging, SpecDecodingProm logger = init_logger(__name__) @@ -38,6 +42,7 @@ def record( self, scheduler_stats: Optional[SchedulerStats], iteration_stats: Optional[IterationStats], + mm_cache_stats: Optional[MultiModalCacheStats] = None, engine_idx: int = 0, ): ... @@ -53,10 +58,15 @@ def __init__(self, vllm_config: VllmConfig, engine_index: int = 0): self.engine_index = engine_index self.vllm_config = vllm_config self._reset(time.monotonic()) + self.last_scheduler_stats = SchedulerStats() - # Prefix cache metrics. This cannot be reset. + self.last_mm_cache_stats = MultiModalCacheStats() + + # Caching metrics. This cannot be reset. # TODO: Make the interval configurable. - self.prefix_caching_metrics = PrefixCachingMetrics() + self.prefix_caching_metrics = CachingMetrics() + self.mm_caching_metrics = CachingMetrics() + self.spec_decoding_logging = SpecDecodingLogging() kv_tranfer_config = self.vllm_config.kv_transfer_config self.kv_connector_logging = KVConnectorLogging(kv_tranfer_config) @@ -86,6 +96,7 @@ def record( self, scheduler_stats: Optional[SchedulerStats], iteration_stats: Optional[IterationStats], + mm_cache_stats: Optional[MultiModalCacheStats] = None, engine_idx: int = 0, ): """Log Stats to standard output.""" @@ -101,6 +112,11 @@ def record( self.kv_connector_logging.observe(kv_connector_stats) self.last_scheduler_stats = scheduler_stats + if mm_cache_stats: + self.mm_caching_metrics.observe(mm_cache_stats) + + self.last_mm_cache_stats = mm_cache_stats + def log(self): now = time.monotonic() prompt_throughput = self._get_throughput(self.num_prompt_tokens, now) @@ -125,21 +141,32 @@ def log(self): self.last_prompt_throughput = prompt_throughput # Format and print output. - log_fn( - "Engine %03d: " - "Avg prompt throughput: %.1f tokens/s, " - "Avg generation throughput: %.1f tokens/s, " - "Running: %d reqs, Waiting: %d reqs, " - "GPU KV cache usage: %.1f%%, " + log_parts = [ + "Avg prompt throughput: %.1f tokens/s", + "Avg generation throughput: %.1f tokens/s", + "Running: %d reqs", + "Waiting: %d reqs", + "GPU KV cache usage: %.1f%%", "Prefix cache hit rate: %.1f%%", - self.engine_index, + ] + log_args = [ prompt_throughput, generation_throughput, scheduler_stats.num_running_reqs, scheduler_stats.num_waiting_reqs, scheduler_stats.kv_cache_usage * 100, self.prefix_caching_metrics.hit_rate * 100, + ] + if self.last_mm_cache_stats: + log_parts.append("MM cache hit rate: %.1f%%") + log_args.append(self.mm_caching_metrics.hit_rate * 100) + + log_fn( + "Engine %03d: " + ", ".join(log_parts), + self.engine_index, + *log_args, ) + self.spec_decoding_logging.log(log_fn=log_fn) self.kv_connector_logging.log(log_fn=log_fn) @@ -288,6 +315,32 @@ def __init__( counter_prefix_cache_hits, engine_indexes, model_name ) + # + # Multi-modal cache + # + + counter_mm_cache_queries = self._counter_cls( + name="vllm:mm_cache_queries", + documentation=( + "Multi-modal cache queries, in terms of number of queried items." + ), + labelnames=labelnames, + ) + self.counter_mm_cache_queries = make_per_engine( + counter_mm_cache_queries, engine_indexes, model_name + ) + + counter_mm_cache_hits = self._counter_cls( + name="vllm:mm_cache_hits", + documentation=( + "Multi-modal cache hits, in terms of number of cached items." + ), + labelnames=labelnames, + ) + self.counter_mm_cache_hits = make_per_engine( + counter_mm_cache_hits, engine_indexes, model_name + ) + # # Counters # @@ -657,6 +710,7 @@ def record( self, scheduler_stats: Optional[SchedulerStats], iteration_stats: Optional[IterationStats], + mm_cache_stats: Optional[MultiModalCacheStats] = None, engine_idx: int = 0, ): """Log to prometheus.""" @@ -694,6 +748,10 @@ def record( scheduler_stats.spec_decoding_stats, engine_idx ) + if mm_cache_stats is not None: + self.counter_mm_cache_queries[engine_idx].inc(mm_cache_stats.queries) + self.counter_mm_cache_hits[engine_idx].inc(mm_cache_stats.hits) + if iteration_stats is None: return @@ -871,6 +929,7 @@ def record( self, scheduler_stats: Optional[SchedulerStats], iteration_stats: Optional[IterationStats], + mm_cache_stats: Optional[MultiModalCacheStats] = None, engine_idx: Optional[int] = None, ): if engine_idx is None: @@ -878,9 +937,19 @@ def record( per_engine_loggers = self.per_engine_logger_dict[engine_idx] for logger in per_engine_loggers: - logger.record(scheduler_stats, iteration_stats, engine_idx) + logger.record( + scheduler_stats, + iteration_stats, + mm_cache_stats=mm_cache_stats, + engine_idx=engine_idx, + ) - self.prometheus_logger.record(scheduler_stats, iteration_stats, engine_idx) + self.prometheus_logger.record( + scheduler_stats, + iteration_stats, + mm_cache_stats=mm_cache_stats, + engine_idx=engine_idx, + ) def log(self): for per_engine_loggers in self.per_engine_logger_dict.values(): diff --git a/vllm/v1/metrics/stats.py b/vllm/v1/metrics/stats.py index 5564718d5165..f0922288db32 100644 --- a/vllm/v1/metrics/stats.py +++ b/vllm/v1/metrics/stats.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import time +from collections import deque from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any, Optional @@ -13,24 +14,122 @@ @dataclass -class PrefixCacheStats: - """Stores prefix cache hit statistics.""" +class BaseCacheStats: + """Stores cache hit statistics.""" - # Whether reset_prefix_cache was invoked. reset: bool = False - # The number of new requests in this update. + """Whether the cache was reset.""" + requests: int = 0 - # The number of queries in these requests. Note that "queries" here - # means the number of tokens that were queried from the cache. + """The number of requests in this update.""" + queries: int = 0 - # The number of hits in these requests. + """The number of queries in these requests.""" + hits: int = 0 - # The number of previously preempted requests in this update. + """The number of hits in these requests.""" + + +class CachingMetrics: + """Metrics for caching with a hit rate of the most recent N requests. + Args: + interval: The number of the most recent requests to aggregate. + Defaults to 1000. + """ + + def __init__(self, max_recent_requests: int = 1000) -> None: + super().__init__() + + self.max_recent_requests = max_recent_requests + # The current aggregated values. + self.aggregated_requests = 0 + self.aggregated_query_total = 0 + self.aggregated_query_hit = 0 + + # A deque of (requests, queries, hits) for the most recent requests. + self.query_queue = deque[tuple[int, int, int]]() + + def observe(self, stats: BaseCacheStats): + """Observe the prefix caching for a set of requests. + + This function is called with information gathered when new requests + are being scheduled and are looking for computed blocks. + + When there are more than `max_recent_requests` requests, the oldest set + of requests are removed from the metrics. + + Args: + stats: The prefix cache stats. + """ + # reset_prefix_cache was invoked before the current update. + # Reset the metrics before aggregating the current stats. + if stats.reset: + self.reset() + + # DO NOT appending empty stats to avoid helpful info get kicked out + # due to sliding window. + if stats.requests == 0: + return + + # Update the metrics. + self.query_queue.append((stats.requests, stats.queries, stats.hits)) + self.aggregated_requests += stats.requests + self.aggregated_query_total += stats.queries + self.aggregated_query_hit += stats.hits + + # Remove the oldest stats until number of requests does not exceed + # the limit. + # NOTE: We preserve the latest added stats regardless. + while ( + len(self.query_queue) > 1 + and self.aggregated_requests > self.max_recent_requests + ): + old_requests, old_queries, old_hits = self.query_queue.popleft() + self.aggregated_requests -= old_requests + self.aggregated_query_total -= old_queries + self.aggregated_query_hit -= old_hits + + def reset(self): + """Reset the metrics.""" + self.aggregated_requests = 0 + self.aggregated_query_total = 0 + self.aggregated_query_hit = 0 + self.query_queue.clear() + + @property + def hit_rate(self) -> float: + """Calculate the hit rate for the past N requests.""" + if self.aggregated_query_total == 0: + return 0.0 + return self.aggregated_query_hit / self.aggregated_query_total + + +@dataclass +class PrefixCacheStats(BaseCacheStats): + """ + Stores prefix cache hit statistics. + - `reset`: Whether `reset_prefix_cache` was invoked. + - `queries`: Refers to the number of tokens that were queried. + """ + preempted_requests: int = 0 - # The `queries` number for preempted requests. + """The number of previously preempted requests in this update.""" + preempted_queries: int = 0 - # The `hits` number for preempted requests. + """The `queries` number for preempted requests.""" + preempted_hits: int = 0 + """The `hits` number for preempted requests.""" + + +@dataclass +class MultiModalCacheStats(BaseCacheStats): + """ + Stores multi-modal cache hit statistics. + - `reset`: Whether `reset_mm_cache` was invoked. + - `queries`: Refers to the number of multi-modal data items + that were queried. + """ @dataclass