diff --git a/tests/entrypoints/llm/test_mm_cache_stats.py b/tests/entrypoints/llm/test_mm_cache_stats.py new file mode 100644 index 000000000000..ed5bab9e9e68 --- /dev/null +++ b/tests/entrypoints/llm/test_mm_cache_stats.py @@ -0,0 +1,76 @@ +# SPDX-License-Identifier: Apache-2.0 + +import pytest + +from vllm import LLM +from vllm.engine.llm_engine import LLMEngine as V0LLMEngine +from vllm.entrypoints.chat_utils import ChatCompletionMessageParam +from vllm.v1.engine.llm_engine import LLMEngine as V1LLMEngine + +from ..openai.test_vision import TEST_IMAGE_URLS + + +def _make_messages(image_url: str) -> list[ChatCompletionMessageParam]: + return [{ + "role": + "user", + "content": [ + { + "type": "image_url", + "image_url": { + "url": image_url + }, + }, + ], + }] + + +@pytest.mark.parametrize("image_urls", + [[TEST_IMAGE_URLS[0], TEST_IMAGE_URLS[1]]]) +@pytest.mark.parametrize("use_v1", [True, False]) +def test_mm_cache_stats( + image_urls: list[str], + use_v1: bool, + monkeypatch: pytest.MonkeyPatch, +): + with monkeypatch.context() as m: + m.setenv("VLLM_USE_V1", "1" if use_v1 else "0") + + llm = LLM( + model="HuggingFaceTB/SmolVLM-256M-Instruct", + max_model_len=4096, + max_num_seqs=5, + enforce_eager=True, + limit_mm_per_prompt={"image": 2}, + ) + engine = llm.llm_engine + if isinstance(engine, V0LLMEngine): + mm_registry = engine.input_preprocessor.mm_registry + elif isinstance(engine, V1LLMEngine): + mm_registry = engine.processor.mm_registry + + # In case the previous test failed, we still need to reset the cache + # (which is shared across tests) + engine.reset_mm_cache() + mm_registry.make_processor_cache_stats() + + llm.chat(_make_messages(image_urls[0])) + + cache_stats = mm_registry.make_processor_cache_stats() + assert cache_stats.size_items == 1 + + llm.chat(_make_messages(image_urls[1])) + + cache_stats = mm_registry.make_processor_cache_stats() + assert cache_stats.size_items == 2 + + llm.chat(_make_messages(image_urls[0])) + + cache_stats = mm_registry.make_processor_cache_stats() + assert cache_stats.size_items == 2 + + engine.reset_mm_cache() + + cache_stats = mm_registry.make_processor_cache_stats() + assert cache_stats.size_items == 0 + assert cache_stats.reset is True diff --git a/tests/entrypoints/openai/test_metrics.py b/tests/entrypoints/openai/test_metrics.py index 42f7b098f917..228ad235b09f 100644 --- a/tests/entrypoints/openai/test_metrics.py +++ b/tests/entrypoints/openai/test_metrics.py @@ -17,10 +17,18 @@ from ...utils import RemoteOpenAIServer -MODEL_NAME = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" +MODELS = { + "text": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", + "multimodal": "HuggingFaceTB/SmolVLM-256M-Instruct", +} PREV_MINOR_VERSION = version._prev_minor_version() +@pytest.fixture(scope="module", params=list(MODELS.keys())) +def model_key(request): + yield request.param + + @pytest.fixture(scope="module", params=[True, False]) def use_v1(request): # Module-scoped variant of run_with_both_engines @@ -60,11 +68,13 @@ def default_server_args(): "--disable-frontend-multiprocessing", f"--show-hidden-metrics-for-version={PREV_MINOR_VERSION}", ]) -def server(use_v1, default_server_args, request): +def server(model_key, use_v1, default_server_args, request): if request.param: default_server_args.append(request.param) + + model_name = MODELS[model_key] env_dict = dict(VLLM_USE_V1='1' if use_v1 else '0') - with RemoteOpenAIServer(MODEL_NAME, default_server_args, + with RemoteOpenAIServer(model_name, default_server_args, env_dict=env_dict) as remote_server: yield remote_server @@ -76,55 +86,65 @@ async def client(server): _PROMPT = "Hello my name is Robert and I love magic" -tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) -_TOKENIZED_PROMPT = tokenizer(_PROMPT)["input_ids"] +_IMAGE_URL = "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg" _NUM_REQUESTS = 10 -_NUM_PROMPT_TOKENS_PER_REQUEST = len(_TOKENIZED_PROMPT) _NUM_GENERATION_TOKENS_PER_REQUEST = 10 -# {metric_family: [(suffix, expected_value)]} -EXPECTED_VALUES = { - "vllm:time_to_first_token_seconds": [("_count", _NUM_REQUESTS)], - "vllm:time_per_output_token_seconds": - [("_count", _NUM_REQUESTS * (_NUM_GENERATION_TOKENS_PER_REQUEST - 1))], - "vllm:e2e_request_latency_seconds": [("_count", _NUM_REQUESTS)], - "vllm:request_queue_time_seconds": [("_count", _NUM_REQUESTS)], - "vllm:request_inference_time_seconds": [("_count", _NUM_REQUESTS)], - "vllm:request_prefill_time_seconds": [("_count", _NUM_REQUESTS)], - "vllm:request_decode_time_seconds": [("_count", _NUM_REQUESTS)], - "vllm:request_prompt_tokens": - [("_sum", _NUM_REQUESTS * _NUM_PROMPT_TOKENS_PER_REQUEST), - ("_count", _NUM_REQUESTS)], - "vllm:request_generation_tokens": - [("_sum", _NUM_REQUESTS * _NUM_GENERATION_TOKENS_PER_REQUEST), - ("_count", _NUM_REQUESTS)], - "vllm:request_params_n": [("_count", _NUM_REQUESTS)], - "vllm:request_params_max_tokens": [ - ("_sum", _NUM_REQUESTS * _NUM_GENERATION_TOKENS_PER_REQUEST), - ("_count", _NUM_REQUESTS) - ], - "vllm:iteration_tokens_total": - [("_sum", _NUM_REQUESTS * - (_NUM_PROMPT_TOKENS_PER_REQUEST + _NUM_GENERATION_TOKENS_PER_REQUEST)), - ("_count", _NUM_REQUESTS * _NUM_GENERATION_TOKENS_PER_REQUEST)], - "vllm:prompt_tokens": [("_total", - _NUM_REQUESTS * _NUM_PROMPT_TOKENS_PER_REQUEST)], - "vllm:generation_tokens": [ - ("_total", _NUM_REQUESTS * _NUM_PROMPT_TOKENS_PER_REQUEST) - ], - "vllm:request_success": [("_total", _NUM_REQUESTS)], -} + +def _get_expected_values(prompt_ids: list[int]): + num_prompt_tokens = len(prompt_ids) + + # {metric_family: [(suffix, expected_value)]} + return { + "vllm:time_to_first_token_seconds": [("_count", _NUM_REQUESTS)], + "vllm:time_per_output_token_seconds": + [("_count", _NUM_REQUESTS * (_NUM_GENERATION_TOKENS_PER_REQUEST - 1))], + "vllm:e2e_request_latency_seconds": [("_count", _NUM_REQUESTS)], + "vllm:request_queue_time_seconds": [("_count", _NUM_REQUESTS)], + "vllm:request_inference_time_seconds": [("_count", _NUM_REQUESTS)], + "vllm:request_prefill_time_seconds": [("_count", _NUM_REQUESTS)], + "vllm:request_decode_time_seconds": [("_count", _NUM_REQUESTS)], + "vllm:request_prompt_tokens": [("_sum", + _NUM_REQUESTS * num_prompt_tokens), + ("_count", _NUM_REQUESTS)], + "vllm:request_generation_tokens": + [("_sum", _NUM_REQUESTS * _NUM_GENERATION_TOKENS_PER_REQUEST), + ("_count", _NUM_REQUESTS)], + "vllm:request_params_n": [("_count", _NUM_REQUESTS)], + "vllm:request_params_max_tokens": + [("_sum", _NUM_REQUESTS * _NUM_GENERATION_TOKENS_PER_REQUEST), + ("_count", _NUM_REQUESTS)], + "vllm:iteration_tokens_total": + [("_sum", _NUM_REQUESTS * + (num_prompt_tokens + _NUM_GENERATION_TOKENS_PER_REQUEST)), + ("_count", _NUM_REQUESTS * _NUM_GENERATION_TOKENS_PER_REQUEST)], + "vllm:prompt_tokens": [("_total", _NUM_REQUESTS * num_prompt_tokens)], + "vllm:generation_tokens": [("_total", + _NUM_REQUESTS * num_prompt_tokens)], + "vllm:request_success": [("_total", _NUM_REQUESTS)], + } @pytest.mark.asyncio -async def test_metrics_counts(server: RemoteOpenAIServer, - client: openai.AsyncClient, use_v1: bool): +async def test_metrics_counts( + server: RemoteOpenAIServer, + client: openai.AsyncClient, + model_key: str, + use_v1: bool, +): + if model_key == "multimodal": + pytest.skip("Unnecessary test") + + model_name = MODELS[model_key] + tokenizer = AutoTokenizer.from_pretrained(model_name) + prompt_ids = tokenizer.encode(_PROMPT) + for _ in range(_NUM_REQUESTS): # sending a request triggers the metrics to be logged. await client.completions.create( - model=MODEL_NAME, - prompt=_TOKENIZED_PROMPT, + model=model_name, + prompt=prompt_ids, max_tokens=_NUM_GENERATION_TOKENS_PER_REQUEST) response = requests.get(server.url_for("metrics")) @@ -132,7 +152,8 @@ async def test_metrics_counts(server: RemoteOpenAIServer, assert response.status_code == HTTPStatus.OK # Loop over all expected metric_families - for metric_family, suffix_values_list in EXPECTED_VALUES.items(): + expected_values = _get_expected_values(prompt_ids) + for metric_family, suffix_values_list in expected_values.items(): if ((use_v1 and metric_family not in EXPECTED_METRICS_V1) or (not server.show_hidden_metrics and metric_family in HIDDEN_DEPRECATED_METRICS)): @@ -274,6 +295,14 @@ async def test_metrics_counts(server: RemoteOpenAIServer, "vllm:request_decode_time_seconds_count", ] +EXPECTED_METRICS_MM = [ + "vllm:mm_cache_usage", + "vllm:mm_cache_size_G", + "vllm:mm_cache_size_items", + "vllm:mm_cache_queries", + "vllm:mm_cache_hits", +] + HIDDEN_DEPRECATED_METRICS = [ "vllm:num_requests_swapped", "vllm:cpu_cache_usage_perc", @@ -281,18 +310,52 @@ async def test_metrics_counts(server: RemoteOpenAIServer, @pytest.mark.asyncio -async def test_metrics_exist(server: RemoteOpenAIServer, - client: openai.AsyncClient, use_v1: bool): +async def test_metrics_exist( + server: RemoteOpenAIServer, + client: openai.AsyncClient, + model_key: str, + use_v1: bool, +): # sending a request triggers the metrics to be logged. - await client.completions.create(model=MODEL_NAME, - prompt="Hello, my name is", - max_tokens=5, - temperature=0.0) + model_name = MODELS[model_key] + + if model_key == "text": + await client.completions.create(model=model_name, + prompt="Hello, my name is", + max_tokens=5, + temperature=0.0) + else: + messages = [{ + "role": + "user", + "content": [ + { + "type": "image_url", + "image_url": { + "url": _IMAGE_URL + } + }, + { + "type": "text", + "text": "What's in this image?" + }, + ], + }] + + await client.chat.completions.create(model=model_name, + messages=messages, + max_tokens=5, + temperature=0.0) response = requests.get(server.url_for("metrics")) assert response.status_code == HTTPStatus.OK - for metric in (EXPECTED_METRICS_V1 if use_v1 else EXPECTED_METRICS): + expected_metrics = EXPECTED_METRICS_V1 if use_v1 else EXPECTED_METRICS + if model_key == "multimodal": + # NOTE: Don't use in-place assignment + expected_metrics = expected_metrics + EXPECTED_METRICS_MM + + for metric in expected_metrics: if (not server.show_hidden_metrics and metric not in HIDDEN_DEPRECATED_METRICS): assert metric in response.text diff --git a/tests/metrics/test_metrics.py b/tests/metrics/test_metrics.py index e71c87ff3fc8..ae6af0a55034 100644 --- a/tests/metrics/test_metrics.py +++ b/tests/metrics/test_metrics.py @@ -288,7 +288,6 @@ def test_metric_spec_decode( @pytest.mark.parametrize("max_tokens", [10]) @pytest.mark.parametrize("log_interval", [1, 3, 5, 7]) def test_metric_spec_decode_interval( - vllm_runner, example_prompts, model: str, dtype: str, diff --git a/tests/v1/core/test_kv_cache_utils.py b/tests/v1/core/test_kv_cache_utils.py index 1cdc80dd3546..3ee6e98f3c98 100644 --- a/tests/v1/core/test_kv_cache_utils.py +++ b/tests/v1/core/test_kv_cache_utils.py @@ -12,7 +12,6 @@ # yapf: disable from vllm.v1.core.kv_cache_utils import (NONE_HASH, BlockHashType, FreeKVCacheBlockQueue, KVCacheBlock, - PrefixCachingMetrics, estimate_max_model_len, generate_block_hash_extra_keys, hash_block_tokens, @@ -20,7 +19,7 @@ unify_kv_cache_configs) from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, KVCacheGroupSpec, KVCacheTensor) -from vllm.v1.metrics.stats import PrefixCacheStats +from vllm.v1.metrics.stats import CachingMetrics, PrefixCacheStats from vllm.v1.request import Request # yapf: enable @@ -351,7 +350,7 @@ def test_metrics(): def stats(requests, queries, hits): return PrefixCacheStats(requests=requests, queries=queries, hits=hits) - metrics = PrefixCachingMetrics(max_recent_requests=5) + metrics = CachingMetrics(max_recent_requests=5) assert metrics.hit_rate == 0.0 metrics.observe(stats(1, 20, 9)) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 2a27afe9757e..de0b0a39a43f 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -1662,6 +1662,15 @@ def _get_stats(self, gpu_prefix_cache_hit_rate = self.scheduler[ 0].get_prefix_cache_hit_rate(Device.GPU) + # Multi-modal cache stats + mm_registry = self.input_preprocessor.mm_registry + processor_cache_stats = mm_registry.make_processor_cache_stats() + mm_cache_usage = processor_cache_stats.usage + mm_cache_size_G = processor_cache_stats.size_G + mm_cache_size_items = processor_cache_stats.size_items + mm_cache_queries = processor_cache_stats.queries + mm_cache_hits = processor_cache_stats.hits + # Iteration stats num_prompt_tokens_iter = 0 num_generation_tokens_iter = 0 @@ -1848,6 +1857,12 @@ def _get_stats(self, # Prefix Cache Hit Rate cpu_prefix_cache_hit_rate=cpu_prefix_cache_hit_rate, gpu_prefix_cache_hit_rate=gpu_prefix_cache_hit_rate, + # Multi-modal cache stats + mm_cache_usage=mm_cache_usage, + mm_cache_size_G=mm_cache_size_G, + mm_cache_size_items=mm_cache_size_items, + mm_cache_queries=mm_cache_queries, + mm_cache_hits=mm_cache_hits, # Iteration stats num_prompt_tokens_iter=num_prompt_tokens_iter, diff --git a/vllm/engine/metrics.py b/vllm/engine/metrics.py index 033551d07c39..3654013e1673 100644 --- a/vllm/engine/metrics.py +++ b/vllm/engine/metrics.py @@ -127,6 +127,33 @@ def __init__(self, labelnames: List[str], vllm_config: VllmConfig): labelnames=labelnames, multiprocess_mode="sum") + # Multi-modal cache stats + self.gauge_mm_cache_usage = self._gauge_cls( + name="vllm:mm_cache_usage", + documentation="Multi-modal cache usage. " + "1 means 100 percent usage.", + labelnames=labelnames) + + self.gauge_mm_cache_size_G = self._gauge_cls( + name="vllm:mm_cache_size_G", + documentation="Multi-modal cache size (in GiB).", + labelnames=labelnames) + + self.gauge_mm_cache_size_items = self._gauge_cls( + name="vllm:mm_cache_size_items", + documentation="Multi-modal cache size (in number of items).", + labelnames=labelnames) + + self.counter_mm_cache_queries = self._counter_cls( + name="vllm:mm_cache_queries", + documentation="Multi-modal cache queries.", + labelnames=labelnames) + + self.counter_mm_cache_hits = self._counter_cls( + name="vllm:mm_cache_hits", + documentation="Multi-modal cache hits.", + labelnames=labelnames) + # Iteration stats self.counter_num_preemption = self._counter_cls( name="vllm:num_preemptions_total", @@ -497,6 +524,13 @@ def log(self, stats: Stats) -> None: stats.gpu_cache_usage_sys * 100, stats.cpu_cache_usage_sys * 100, ) + if (stats.mm_cache_usage >= 0): + log_fn( + "MM cache usage: %.2f%% (%d items = %.2f GiB)", + stats.mm_cache_usage * 100, + stats.mm_cache_size_items, + stats.mm_cache_size_G, + ) if (stats.cpu_prefix_cache_hit_rate >= 0 or stats.gpu_prefix_cache_hit_rate >= 0): log_fn( @@ -504,6 +538,7 @@ def log(self, stats: Stats) -> None: stats.gpu_prefix_cache_hit_rate * 100, stats.cpu_prefix_cache_hit_rate * 100, ) + if self.spec_decode_metrics is not None: log_fn( self._format_spec_decode_metrics_str( @@ -594,6 +629,18 @@ def _log_prometheus(self, stats: Stats) -> None: stats.cpu_prefix_cache_hit_rate) self._log_gauge(self.metrics.gauge_gpu_prefix_cache_hit_rate, stats.gpu_prefix_cache_hit_rate) + + self._log_gauge(self.metrics.gauge_mm_cache_usage, + stats.mm_cache_usage) + self._log_gauge(self.metrics.gauge_mm_cache_size_G, + stats.mm_cache_size_G) + self._log_gauge(self.metrics.gauge_mm_cache_size_items, + stats.mm_cache_size_items) + self._log_counter(self.metrics.counter_mm_cache_queries, + stats.mm_cache_queries) + self._log_counter(self.metrics.counter_mm_cache_hits, + stats.mm_cache_hits) + # Including max-lora in metric, in future this property of lora # config maybe extended to be dynamic. lora_info = { diff --git a/vllm/engine/metrics_types.py b/vllm/engine/metrics_types.py index 9e6d5ef29bed..f67d0372c3d7 100644 --- a/vllm/engine/metrics_types.py +++ b/vllm/engine/metrics_types.py @@ -37,6 +37,12 @@ class Stats: # Prefix caching block hit rate cpu_prefix_cache_hit_rate: float gpu_prefix_cache_hit_rate: float + # Multi-modal cache stats + mm_cache_usage: float + mm_cache_size_G: float + mm_cache_size_items: int + mm_cache_queries: int + mm_cache_hits: int # Iteration stats (should have _iter suffix) num_prompt_tokens_iter: int diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index cebddcc8e6aa..21ab1f3cea21 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -1259,6 +1259,9 @@ def start_profile(self) -> None: def stop_profile(self) -> None: self.llm_engine.stop_profile() + def reset_mm_cache(self) -> bool: + return self.llm_engine.reset_mm_cache() + def reset_prefix_cache(self, device: Optional[Device] = None) -> bool: return self.llm_engine.reset_prefix_cache(device) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index a954a9ff90bc..3c25bb2d55cc 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -728,6 +728,16 @@ async def show_server_info(raw_request: Request): server_info = {"vllm_config": str(raw_request.app.state.vllm_config)} return JSONResponse(content=server_info) + @router.post("/reset_mm_cache") + async def reset_mm_cache(raw_request: Request): + """ + Reset the multi-modal cache. Note that we currently do not check if the + multi-modal cache is successfully reset in the API server. + """ + logger.info("Resetting multi-modal cache...") + await engine_client(raw_request).reset_mm_cache() + return Response(status_code=200) + @router.post("/reset_prefix_cache") async def reset_prefix_cache(raw_request: Request): """ diff --git a/vllm/envs.py b/vllm/envs.py index 0c742bf05623..0d66f40af333 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -626,7 +626,7 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]: # If set, vllm will run in development mode, which will enable # some additional endpoints for developing and debugging, - # e.g. `/reset_prefix_cache` + # e.g. `/reset_mm_cache`, `/reset_prefix_cache` "VLLM_SERVER_DEV_MODE": lambda: bool(int(os.getenv("VLLM_SERVER_DEV_MODE", "0"))), diff --git a/vllm/multimodal/processing.py b/vllm/multimodal/processing.py index 92f9e70b5234..fe52621c98e2 100644 --- a/vllm/multimodal/processing.py +++ b/vllm/multimodal/processing.py @@ -22,6 +22,7 @@ from vllm.transformers_utils.tokenizer import (AnyTokenizer, decode_tokens, encode_tokens) from vllm.utils import GiB_bytes, LRUCache, flatten_2d_lists, full_groupby +from vllm.v1.metrics.stats import MultiModalCacheStats from .hasher import MultiModalHasher from .inputs import (MultiModalDataDict, MultiModalEncDecInputs, @@ -933,32 +934,35 @@ def __init__( self, capacity_gb: float, *, - debug_cache_hit_ratio_steps: Optional[int] = None, + debug: bool = False, ) -> None: super().__init__() - self.debug_cache_hit_ratio_steps = debug_cache_hit_ratio_steps - self.debug_cache_hits = 0 - self.debug_cache_total = 0 - self._cache = self.get_lru_cache( capacity_gb, MultiModalKwargsItem, - debug=bool(debug_cache_hit_ratio_steps), + debug=debug, ) + self._stats = MultiModalCacheStats() + + def make_stats(self) -> MultiModalCacheStats: + """Get (and reset) the multi-modal cache stats. + + Returns: + The current multi-modal caching stats. + """ + mm_cache = self._cache - def _maybe_log_cache_stats(self) -> None: - steps = self.debug_cache_hit_ratio_steps - if not steps: - return + info_delta = mm_cache.stat(delta=True) + self._stats.hits = info_delta.hits + self._stats.queries = info_delta.total + self._stats.usage = mm_cache.usage + self._stats.size_G = mm_cache.currsize / GiB_bytes + self._stats.size_items = len(mm_cache) - total = self.debug_cache_total - if total > 0 and total % steps == 0: - logger.debug("ProcessingCache: hit_ratio = %.2f", - self.debug_cache_hits / total) - logger.debug("ProcessingCache: size = %.2f / %.2f GiB", - self._cache.currsize / GiB_bytes, - self._cache.maxsize / GiB_bytes) + stats = self._stats + self._stats = MultiModalCacheStats() + return stats def get( self, @@ -976,18 +980,10 @@ def get( - The original data item passed to the HF processor - The configuration options of the HF processor """ - self._maybe_log_cache_stats() - cache_key = MultiModalHasher.hash_kwargs(model_id=model_id, **{modality: input_item}, **input_kwargs) - if self.debug_cache_hit_ratio_steps: - if cache_key in self._cache: - self.debug_cache_hits += 1 - - self.debug_cache_total += 1 - return self._cache.get(cache_key) def get_item( @@ -1028,6 +1024,7 @@ def put_item(self, item: ProcessingCacheItem) -> None: def reset(self) -> bool: self._cache.clear() + self._stats.reset = True return True diff --git a/vllm/multimodal/registry.py b/vllm/multimodal/registry.py index 67d0d7fc1183..72b73c0e52f0 100644 --- a/vllm/multimodal/registry.py +++ b/vllm/multimodal/registry.py @@ -12,6 +12,7 @@ from vllm.transformers_utils.tokenizer import (AnyTokenizer, cached_tokenizer_from_config) from vllm.utils import ClassRegistry +from vllm.v1.metrics.stats import MultiModalCacheStats from .processing import (BaseMultiModalProcessor, BaseProcessingInfo, ProcessingCache) @@ -88,6 +89,14 @@ def __init__(self) -> None: self._processing_cache = ProcessingCache(VLLM_MM_INPUT_CACHE_GIB) + def make_processor_cache_stats(self) -> MultiModalCacheStats: + """Get (and reset) the multi-modal cache stats. + + Returns: + The current multi-modal caching stats. + """ + return self._processing_cache.make_stats() + def reset_processor_cache(self) -> bool: """Reset the multi-modal processing cache.""" self._processing_cache.reset() diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index 27c515835087..1a3b252a2dab 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 """KV-Cache Utilities.""" import os -from collections import deque from collections.abc import Sequence from dataclasses import dataclass from typing import Any, Callable, NamedTuple, Optional @@ -12,7 +11,6 @@ from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, KVCacheGroupSpec, KVCacheSpec, KVCacheTensor, SlidingWindowSpec) -from vllm.v1.metrics.stats import PrefixCacheStats from vllm.v1.request import Request logger = init_logger(__name__) @@ -46,68 +44,6 @@ class BlockHashType(NamedTuple): 'PYTHONHASHSEED') is None else sha256(os.getenv('PYTHONHASHSEED')) -class PrefixCachingMetrics: - """Metrics for prefix caching with a hit rate of the max recent N requests. - - Args: - max_recent_requests: The number of the max recent requests to aggregate. - Defaults to 1000. - """ - - def __init__(self, max_recent_requests: int = 1000): - self.max_recent_requests = max_recent_requests - # The current aggregated values. - self.aggregated_requests = 0 - self.aggregated_query_total = 0 - self.aggregated_query_hit = 0 - # A deque of (requests, queries, hits) for the most recent requests. - self.query_queue: deque[tuple[int, int, int]] = deque() - - def observe(self, stats: PrefixCacheStats): - """Observe the prefix caching for a set of requests. - - This function is called with information gathered when new requests - are being scheduled and are looking for computed blocks. - - When there are more than `interval` requests, the oldest set of - requests are removed from the metrics. - - Args: - stats: The prefix cache stats. - """ - # reset_prefix_cache was invoked before the current update. - # Reset the metrics before aggregating the current stats. - if stats.reset: - self.reset() - - # Update the metrics. - self.query_queue.append((stats.requests, stats.queries, stats.hits)) - self.aggregated_requests += stats.requests - self.aggregated_query_total += stats.queries - self.aggregated_query_hit += stats.hits - - # Remove the oldest stats if the number of requests exceeds. - if self.aggregated_requests > self.max_recent_requests: - old_requests, old_queries, old_hits = self.query_queue.popleft() - self.aggregated_requests -= old_requests - self.aggregated_query_total -= old_queries - self.aggregated_query_hit -= old_hits - - def reset(self): - """Reset the metrics.""" - self.aggregated_requests = 0 - self.aggregated_query_total = 0 - self.aggregated_query_hit = 0 - self.query_queue.clear() - - @property - def hit_rate(self) -> float: - """Calculate the hit rate for the past N requests.""" - if self.aggregated_query_total == 0: - return 0.0 - return self.aggregated_query_hit / self.aggregated_query_total - - @dataclass class KVCacheBlock: """KV-cache block metadata.""" diff --git a/vllm/v1/engine/__init__.py b/vllm/v1/engine/__init__.py index 122a5a72cc36..7bbf15341555 100644 --- a/vllm/v1/engine/__init__.py +++ b/vllm/v1/engine/__init__.py @@ -11,7 +11,7 @@ from vllm.multimodal import MultiModalKwargs from vllm.multimodal.inputs import PlaceholderRange from vllm.sampling_params import SamplingParams -from vllm.v1.metrics.stats import SchedulerStats +from vllm.v1.metrics.stats import MultiModalCacheStats, SchedulerStats from vllm.v1.outputs import LogprobsLists, LogprobsTensors # These are possible values of RequestOutput.finish_reason, @@ -138,6 +138,7 @@ class EngineCoreOutputs( # [num_reqs] outputs: list[EngineCoreOutput] = [] scheduler_stats: Optional[SchedulerStats] = None + mm_cache_stats: Optional[MultiModalCacheStats] = None timestamp: float = 0.0 utility_output: Optional[UtilityOutput] = None diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index 0d646d8dd575..fc97588bcb5a 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -34,7 +34,9 @@ from vllm.v1.executor.abstract import Executor from vllm.v1.metrics.loggers import (StatLoggerBase, StatLoggerFactory, setup_default_loggers) -from vllm.v1.metrics.stats import IterationStats, SchedulerStats +from vllm.v1.metrics.stats import (IterationStats, MultiModalCacheStats, + MultiModalCacheStatsCollection, + SchedulerStats) logger = init_logger(__name__) @@ -350,6 +352,7 @@ def _run_output_handler(self): # Ensure that the task doesn't have a circular ref back to the AsyncLLM # object, or else it won't be garbage collected and cleaned up properly. engine_core = self.engine_core + input_processor = self.processor output_processor = self.output_processor log_stats = self.log_stats stat_loggers = self.stat_loggers if log_stats else None @@ -395,8 +398,10 @@ async def output_handler(): if stat_loggers: assert outputs.scheduler_stats is not None AsyncLLM._record_stats( - stat_loggers[outputs.engine_index], + input_processor=input_processor, + stat_loggers=stat_loggers[outputs.engine_index], scheduler_stats=outputs.scheduler_stats, + p1_mm_cache_stats=outputs.mm_cache_stats, iteration_stats=iteration_stats, ) except Exception as e: @@ -416,14 +421,24 @@ async def abort(self, request_id: str) -> None: @staticmethod def _record_stats( + input_processor: Processor, stat_loggers: list[StatLoggerBase], scheduler_stats: SchedulerStats, + p1_mm_cache_stats: Optional[MultiModalCacheStats], iteration_stats: Optional[IterationStats], ): """static so that it can be used from the output_handler task without a circular ref to AsyncLLM.""" + mm_cache_stats = MultiModalCacheStatsCollection( + p0_processor=input_processor.mm_registry. + make_processor_cache_stats(), + p0_mirror=input_processor.mm_input_cache_client.make_stats(), + p1_mirror=p1_mm_cache_stats, + ) if p1_mm_cache_stats else None + for stat_logger in stat_loggers: stat_logger.record(scheduler_stats=scheduler_stats, + mm_cache_stats=mm_cache_stats, iteration_stats=iteration_stats) def encode( diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index edc79ae20b9f..21cefac44e87 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -105,6 +105,7 @@ def __init__(self, ) # Setup MM Input Mapper. + self.is_multimodal_model = vllm_config.model_config.is_multimodal_model self.mm_input_cache_server = MirroredProcessingCache( vllm_config.model_config) @@ -214,20 +215,28 @@ def execute_model(self, scheduler_output: SchedulerOutput): def step(self) -> EngineCoreOutputs: """Schedule, execute, and make output.""" - # Check for any requests remaining in the scheduler - unfinished, # or finished and not yet removed from the batch. - if not self.scheduler.has_requests(): - return EngineCoreOutputs( + if self.scheduler.has_requests(): + scheduler_output = self.scheduler.schedule() + model_output = self.execute_model(scheduler_output) + core_outputs = self.scheduler.update_from_output( + scheduler_output, model_output) # type: ignore + else: + core_outputs = EngineCoreOutputs( outputs=[], scheduler_stats=self.scheduler.make_stats(), ) - scheduler_output = self.scheduler.schedule() - model_output = self.execute_model(scheduler_output) - engine_core_outputs = self.scheduler.update_from_output( - scheduler_output, model_output) # type: ignore - return engine_core_outputs + if self.is_multimodal_model: + mm_cache_stats = self.mm_input_cache_server.make_stats() + if (scheduler_stats := core_outputs.scheduler_stats): + mm_cache_stats.requests = ( + scheduler_stats.prefix_cache_stats.requests) + + core_outputs.mm_cache_stats = mm_cache_stats + + return core_outputs def step_with_batch_queue(self) -> Optional[EngineCoreOutputs]: """Schedule and execute batches with the batch queue. diff --git a/vllm/v1/engine/mm_input_cache.py b/vllm/v1/engine/mm_input_cache.py index fcb90bebdb62..33c75191fd28 100644 --- a/vllm/v1/engine/mm_input_cache.py +++ b/vllm/v1/engine/mm_input_cache.py @@ -1,11 +1,15 @@ # SPDX-License-Identifier: Apache-2.0 from collections.abc import Sequence -from typing import Optional +from typing import TYPE_CHECKING, Optional from vllm.envs import VLLM_MM_INPUT_CACHE_GIB from vllm.multimodal import MultiModalKwargs from vllm.multimodal.processing import ProcessingCache -from vllm.utils import is_list_of +from vllm.utils import GiB_bytes, is_list_of +from vllm.v1.metrics.stats import MultiModalCacheStats + +if TYPE_CHECKING: + from vllm.config import ModelConfig # The idea of multimodal preprocessing caching is based on having a client and # a server, where the client executes in the frontend process (=P0) and the @@ -32,7 +36,7 @@ class MirroredProcessingCache: - def __init__(self, model_config): + def __init__(self, model_config: "ModelConfig") -> None: mm_config = model_config.multimodal_config disable_mm_preprocessor_cache = mm_config is not None and \ not mm_config.disable_mm_preprocessor_cache @@ -40,6 +44,27 @@ def __init__(self, model_config): self.mm_cache = ProcessingCache.get_lru_cache(VLLM_MM_INPUT_CACHE_GIB, MultiModalKwargs) + self._stats = MultiModalCacheStats() + + def make_stats(self) -> MultiModalCacheStats: + """Get (and reset) the multi-modal cache stats. + + Returns: + The current multi-modal caching stats. + """ + mm_cache = self.mm_cache + + info_delta = mm_cache.stat(delta=True) + self._stats.hits = info_delta.hits + self._stats.queries = info_delta.total + self._stats.usage = mm_cache.usage + self._stats.size_G = mm_cache.currsize / GiB_bytes + self._stats.size_items = len(mm_cache) + + stats = self._stats + self._stats = MultiModalCacheStats() + return stats + def get_and_update_p0( self, mm_inputs: Sequence[MultiModalKwargs], @@ -86,5 +111,6 @@ def get_and_update_p1( def reset(self) -> bool: self.mm_cache.clear() + self._stats.reset = True return True diff --git a/vllm/v1/metrics/loggers.py b/vllm/v1/metrics/loggers.py index 6ee40850beb1..9e460275acdb 100644 --- a/vllm/v1/metrics/loggers.py +++ b/vllm/v1/metrics/loggers.py @@ -10,9 +10,10 @@ from vllm.config import SupportsMetricsInfo, VllmConfig from vllm.logger import init_logger -from vllm.v1.core.kv_cache_utils import PrefixCachingMetrics from vllm.v1.engine import FinishReason -from vllm.v1.metrics.stats import IterationStats, SchedulerStats +from vllm.v1.metrics.stats import (CachingMetrics, IterationStats, + MultiModalCacheStatsCollection, + SchedulerStats) from vllm.v1.spec_decode.metrics import SpecDecodingLogging, SpecDecodingProm logger = init_logger(__name__) @@ -35,8 +36,12 @@ def __init__(self, vllm_config: VllmConfig, engine_index: int = 0): ... @abstractmethod - def record(self, scheduler_stats: SchedulerStats, - iteration_stats: Optional[IterationStats]): + def record( + self, + scheduler_stats: SchedulerStats, + mm_cache_stats: Optional[MultiModalCacheStatsCollection], + iteration_stats: Optional[IterationStats], + ): ... @abstractmethod @@ -54,9 +59,16 @@ def __init__(self, vllm_config: VllmConfig, engine_index: int = 0): self.vllm_config = vllm_config self._reset(time.monotonic()) self.last_scheduler_stats = SchedulerStats() - # Prefix cache metrics. This cannot be reset. + + # Caching metrics. This cannot be reset. # TODO: Make the interval configurable. - self.prefix_caching_metrics = PrefixCachingMetrics() + self.last_mm_cache_stats: Optional[ + MultiModalCacheStatsCollection] = None + self.p0_processor_mm_caching_metrics = CachingMetrics() + self.p0_mirror_mm_caching_metrics = CachingMetrics() + self.p1_mirror_mm_caching_metrics = CachingMetrics() + self.prefix_caching_metrics = CachingMetrics() + self.spec_decoding_logging = SpecDecodingLogging() self.last_prompt_throughput: float = 0.0 self.last_generation_throughput: float = 0.0 @@ -78,9 +90,19 @@ def _get_throughput(self, tracked_stats: list[int], now: float) -> float: # Compute summary metrics for tracked stats return float(np.sum(tracked_stats) / (now - self.last_log_time)) - def record(self, scheduler_stats: SchedulerStats, - iteration_stats: Optional[IterationStats]): + def record( + self, + scheduler_stats: SchedulerStats, + mm_cache_stats: Optional[MultiModalCacheStatsCollection], + iteration_stats: Optional[IterationStats], + ): """Log Stats to standard output.""" + if mm_cache_stats: + self.p0_processor_mm_caching_metrics.observe( + mm_cache_stats.p0_processor) + self.p0_mirror_mm_caching_metrics.observe(mm_cache_stats.p0_mirror) + self.p1_mirror_mm_caching_metrics.observe(mm_cache_stats.p1_mirror) + self.last_mm_cache_stats = mm_cache_stats if iteration_stats: self._track_iteration_stats(iteration_stats) @@ -128,6 +150,29 @@ def log(self): scheduler_stats.gpu_cache_usage * 100, self.prefix_caching_metrics.hit_rate * 100, ) + + if self.last_mm_cache_stats: + log_fn( + "P0 Processor MM cache usage: %.2f%% (%d items = %.2f GiB), " + "hit rate: %.2f%%; " + "P0 Mirrored MM cache usage: %.2f%% (%d items = %.2f GiB), " + "hit rate: %.2f%%; " + "P1 Mirrored MM cache usage: %.2f%% (%d items = %.2f GiB), " + "hit rate: %.2f%%", + self.last_mm_cache_stats.p0_processor.usage * 100, + self.last_mm_cache_stats.p0_processor.size_items, + self.last_mm_cache_stats.p0_processor.size_G, + self.p0_processor_mm_caching_metrics.hit_rate * 100, + self.last_mm_cache_stats.p0_mirror.usage * 100, + self.last_mm_cache_stats.p0_mirror.size_items, + self.last_mm_cache_stats.p0_mirror.size_G, + self.p0_mirror_mm_caching_metrics.hit_rate * 100, + self.last_mm_cache_stats.p1_mirror.usage * 100, + self.last_mm_cache_stats.p1_mirror.size_items, + self.last_mm_cache_stats.p1_mirror.size_G, + self.p1_mirror_mm_caching_metrics.hit_rate * 100, + ) + self.spec_decoding_logging.log(log_fn=log_fn) def log_engine_initialized(self): @@ -192,6 +237,57 @@ def __init__(self, vllm_config: VllmConfig, engine_index: int = 0): "GPU prefix cache hits, in terms of number of cached tokens.", labelnames=labelnames).labels(*labelvalues) + # + # Multi-modal cache + # + mm_cache_keys = ("p0_processor", "p0_mirror", "p1_mirror") + + gauge_mm_cache_usage = prometheus_client.Gauge( + name="vllm:mm_cache_usage", + documentation="Multi-modal cache usage. " + "1 means 100 percent usage.", + labelnames=labelnames + ["cache"]) + self.gauge_mm_cache_usage = { + k: gauge_mm_cache_usage.labels(*(labelvalues + [k])) + for k in mm_cache_keys + } + + gauge_mm_cache_size_G = prometheus_client.Gauge( + name="vllm:mm_cache_size_G", + documentation="Multi-modal cache size (in GiB).", + labelnames=labelnames + ["cache"]) + self.gauge_mm_cache_size_G = { + k: gauge_mm_cache_size_G.labels(*(labelvalues + [k])) + for k in mm_cache_keys + } + + gauge_mm_cache_size_items = prometheus_client.Gauge( + name="vllm:mm_cache_size_items", + documentation="Multi-modal cache size (in number of items).", + labelnames=labelnames + ["cache"]) + self.gauge_mm_cache_size_items = { + k: gauge_mm_cache_size_items.labels(*(labelvalues + [k])) + for k in mm_cache_keys + } + + gauge_mm_cache_size_queries = prometheus_client.Gauge( + name="vllm:mm_cache_queries", + documentation="Multi-modal cache queries.", + labelnames=labelnames + ["cache"]) + self.gauge_mm_cache_queries = { + k: gauge_mm_cache_size_queries.labels(*(labelvalues + [k])) + for k in mm_cache_keys + } + + gauge_mm_cache_hits = prometheus_client.Gauge( + name="vllm:mm_cache_hits", + documentation="Multi-modal cache hits.", + labelnames=labelnames + ["cache"]) + self.gauge_mm_cache_hits = { + k: gauge_mm_cache_hits.labels(*(labelvalues + [k])) + for k in mm_cache_keys + } + # # Counters # @@ -371,8 +467,12 @@ def log_metrics_info(self, type: str, config_obj: SupportsMetricsInfo): labelnames=metrics_info.keys()).labels(**metrics_info) info_gauge.set(1) - def record(self, scheduler_stats: SchedulerStats, - iteration_stats: Optional[IterationStats]): + def record( + self, + scheduler_stats: SchedulerStats, + mm_cache_stats: Optional[MultiModalCacheStatsCollection], + iteration_stats: Optional[IterationStats], + ): """Log to prometheus.""" self.gauge_scheduler_running.set(scheduler_stats.num_running_reqs) self.gauge_scheduler_waiting.set(scheduler_stats.num_waiting_reqs) @@ -388,6 +488,18 @@ def record(self, scheduler_stats: SchedulerStats, self.spec_decoding_prom.observe( scheduler_stats.spec_decoding_stats) + if mm_cache_stats is not None: + for key, gauge in self.gauge_mm_cache_usage.items(): + gauge.set(mm_cache_stats[key].usage) + for key, gauge in self.gauge_mm_cache_size_G.items(): + gauge.set(mm_cache_stats[key].size_G) + for key, gauge in self.gauge_mm_cache_size_items.items(): + gauge.set(mm_cache_stats[key].size_items) + for key, counter in self.gauge_mm_cache_queries.items(): + counter.set(mm_cache_stats[key].queries) + for key, counter in self.gauge_mm_cache_hits.items(): + counter.set(mm_cache_stats[key].hits) + if iteration_stats is None: return diff --git a/vllm/v1/metrics/stats.py b/vllm/v1/metrics/stats.py index 8fe1630616a4..9d8569cf2cc7 100644 --- a/vllm/v1/metrics/stats.py +++ b/vllm/v1/metrics/stats.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 import time +from collections import deque from dataclasses import dataclass, field from typing import TYPE_CHECKING, Optional @@ -12,17 +13,91 @@ @dataclass -class PrefixCacheStats: - """Stores prefix cache hit statistics.""" - # Whether reset_prefix_cache was invoked. +class CacheStats: + """Stores cache hit statistics.""" reset: bool = False - # The number of requests in this update. + """Whether the cache was reset.""" + requests: int = 0 - # The number of queries in these requests. Note that "queries" here - # means the number of tokens that were queried from the cache. + """The number of requests in this update.""" + queries: int = 0 - # The number of hits in these requests. + """The number of queries in these requests.""" + hits: int = 0 + """The number of hits in these requests.""" + + +class CachingMetrics: + """Metrics for caching with a hit rate of the most recent N requests. + + Args: + interval: The number of the most recent requests to aggregate. + Defaults to 1000. + """ + + def __init__(self, max_recent_requests: int = 1000): + self.max_recent_requests = max_recent_requests + # The current aggregated values. + self.aggregated_requests = 0 + self.aggregated_query_total = 0 + self.aggregated_query_hit = 0 + # A deque of (requests, queries, hits) for the most recent requests. + self.query_queue: deque[tuple[int, int, int]] = deque() + + def observe(self, stats: CacheStats): + """Observe the caching for a set of requests. + + This function is called with information gathered when new requests + are being scheduled and are looking for computed blocks. + + When there are more than `interval` requests, the oldest set of + requests are removed from the metrics. + + Args: + stats: The cache stats. + """ + # The cache was reset before the current update. + # Reset the metrics before aggregating the current stats. + if stats.reset: + self.reset() + + # Update the metrics. + self.query_queue.append((stats.requests, stats.queries, stats.hits)) + self.aggregated_requests += stats.requests + self.aggregated_query_total += stats.queries + self.aggregated_query_hit += stats.hits + + # Remove the oldest stats if the number of requests exceeds. + if self.aggregated_requests > self.max_recent_requests: + old_requests, old_queries, old_hits = self.query_queue.popleft() + self.aggregated_requests -= old_requests + self.aggregated_query_total -= old_queries + self.aggregated_query_hit -= old_hits + + def reset(self): + """Reset the metrics.""" + self.aggregated_requests = 0 + self.aggregated_query_total = 0 + self.aggregated_query_hit = 0 + self.query_queue.clear() + + @property + def hit_rate(self) -> float: + """Calculate the hit rate for the past N requests.""" + if self.aggregated_query_total == 0: + return 0.0 + return self.aggregated_query_hit / self.aggregated_query_total + + +@dataclass +class PrefixCacheStats(CacheStats): + """ + Stores prefix cache hit statistics. + + - `reset`: Whether `reset_prefix_cache` was invoked. + - `queries`: Refers to the number of tokens that were queried. + """ @dataclass @@ -40,6 +115,43 @@ class SchedulerStats: spec_decoding_stats: Optional[SpecDecodingStats] = None +@dataclass +class MultiModalCacheStats(CacheStats): + """ + Stores multi-modal cache hit statistics. + + - `reset`: Whether `reset_mm_cache` was invoked. + - `queries`: Refers to the number of multi-modal data items + that were queried. + """ + usage: float = 0.0 + """The current memory usage of the cache relative to its capacity.""" + + size_G: float = 0.0 + """The current size of the cache (in GiB).""" + + size_items: int = 0 + """The current size of the cache (in number of items).""" + + +@dataclass +class MultiModalCacheStatsCollection: + p0_processor: MultiModalCacheStats + """Stats for Process 0 processor cache.""" + + p0_mirror: MultiModalCacheStats + """Stats for Process 0 mirrored cache.""" + + p1_mirror: MultiModalCacheStats + """Stats for Process 1 mirrored cache.""" + + def __getitem__(self, key: str) -> MultiModalCacheStats: + try: + return getattr(self, key) + except AttributeError as e: + raise KeyError(key) from e + + @dataclass class LoRAStats: waiting_requests: set[str] = field(default_factory=set) diff --git a/vllm/v1/spec_decode/metrics.py b/vllm/v1/spec_decode/metrics.py index f71a59908ef3..a5f0b370a2b5 100644 --- a/vllm/v1/spec_decode/metrics.py +++ b/vllm/v1/spec_decode/metrics.py @@ -1,14 +1,16 @@ # SPDX-License-Identifier: Apache-2.0 from dataclasses import dataclass, field -from typing import Optional +from typing import TYPE_CHECKING, Optional import numpy as np import prometheus_client -from vllm.config import SpeculativeConfig from vllm.logger import init_logger +if TYPE_CHECKING: + from vllm.config import SpeculativeConfig + logger = init_logger(__name__) @@ -120,7 +122,7 @@ class SpecDecodingProm: vllm:spec_decode_num_drafts[$interval] """ - def __init__(self, speculative_config: Optional[SpeculativeConfig], + def __init__(self, speculative_config: Optional["SpeculativeConfig"], labelnames: list[str], labelvalues: list[str]): self.spec_decoding_enabled = speculative_config is not None if not self.spec_decoding_enabled: