Skip to content

Commit 3d44643

Browse files
[Bugfix] Fix size calculation of processing cache (#15114)
Signed-off-by: DarkLight1337 <[email protected]>
1 parent 1fe0fd1 commit 3d44643

File tree

2 files changed

+92
-16
lines changed

2 files changed

+92
-16
lines changed

tests/multimodal/test_processing.py

Lines changed: 46 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,20 @@
77

88
import numpy as np
99
import pytest
10+
import torch
1011
from transformers import ProcessorMixin
1112

1213
from vllm.config import ModelConfig
1314
from vllm.multimodal import MULTIMODAL_REGISTRY
15+
from vllm.multimodal.inputs import (MultiModalFieldElem, MultiModalKwargs,
16+
MultiModalKwargsItem,
17+
MultiModalSharedField)
1418
# yapf conflicts with isort for this block
1519
# yapf: disable
1620
from vllm.multimodal.processing import (PlaceholderFeaturesInfo,
17-
PromptIndexTargets, PromptInsertion,
18-
PromptReplacement, apply_text_matches,
21+
ProcessingCache, PromptIndexTargets,
22+
PromptInsertion, PromptReplacement,
23+
apply_text_matches,
1924
apply_token_matches,
2025
find_mm_placeholders,
2126
find_text_matches, find_token_matches,
@@ -890,6 +895,45 @@ def test_find_mm_placeholders(
890895
assert result == expected
891896

892897

898+
def _dummy_elem(modality: str, key: str, size: int):
899+
return MultiModalFieldElem(
900+
modality=modality,
901+
key=key,
902+
data=torch.empty((size, ), dtype=torch.int8),
903+
field=MultiModalSharedField(1),
904+
)
905+
906+
907+
def _dummy_item(modality: str, size_by_key: dict[str, int]):
908+
return MultiModalKwargsItem.from_elems([
909+
_dummy_elem(modality, key, size) for key, size in size_by_key.items()
910+
])
911+
912+
913+
def _dummy_kw(size_by_key_modality: dict[str, dict[str, int]]):
914+
return MultiModalKwargs.from_items([
915+
_dummy_item(modality, size_by_key)
916+
for modality, size_by_key in size_by_key_modality.items()
917+
])
918+
919+
920+
# yapf: disable
921+
@pytest.mark.parametrize(
922+
("item", "expected_size"),
923+
[
924+
(_dummy_item("a", {"a1": 100}), 100),
925+
(_dummy_item("a", {"a1": 100, "a2": 110}), 210),
926+
(_dummy_kw({"a": {"a1": 100, "a2": 110}, "b": {"b1": 120, "b2": 130}}), 460), # noqa: E501
927+
],
928+
)
929+
# yapf: enable
930+
def test_cache_item_size(item, expected_size):
931+
cache = ProcessingCache.get_lru_cache(2048, type(item))
932+
cache[""] = item
933+
934+
assert cache.currsize == expected_size
935+
936+
893937
@pytest.mark.parametrize("model_id", ["llava-hf/llava-v1.6-mistral-7b-hf"])
894938
@pytest.mark.parametrize(
895939
("limit", "num_supported", "is_valid"),

vllm/multimodal/processing.py

Lines changed: 46 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from .hasher import MultiModalHasher
2727
from .inputs import (MultiModalDataDict, MultiModalEncDecInputs,
2828
MultiModalFieldConfig, MultiModalInputs, MultiModalKwargs,
29-
MultiModalKwargsItem, PlaceholderRange)
29+
MultiModalKwargsItem, NestedTensors, PlaceholderRange)
3030
from .parse import (DictEmbeddingItems, EmbeddingItems, MultiModalDataItems,
3131
MultiModalDataParser)
3232

@@ -853,33 +853,62 @@ class ProcessingCache:
853853

854854
@staticmethod
855855
def get_lru_cache(
856-
capacity_gb: int,
856+
capacity_gb: float,
857857
value_type: type[_V],
858+
*,
859+
debug: bool = False,
858860
) -> LRUCache[str, _V]:
859861

860-
def get_size(leaf: object) -> int:
862+
def get_leaf_size(leaf: object) -> int:
863+
# MultiModalKwargs is not a subclass of dict
864+
if isinstance(leaf, MultiModalKwargs):
865+
return get_item_size(leaf.data)
866+
867+
# MultiModalKwargsItem is not a subclass of dict
868+
if isinstance(leaf, MultiModalKwargsItem):
869+
leaf_data = {k: v.data for k, v in leaf.items()}
870+
return get_item_size(leaf_data)
871+
872+
# sys.getsizeof doesn't work for tensors
861873
if isinstance(leaf, torch.Tensor):
862-
return leaf.nbytes # sys.getsizeof doesn't work for tensors
874+
return leaf.nbytes
863875

864876
return sys.getsizeof(leaf)
865877

866-
return LRUCache[str, _V](
867-
GiB_bytes * capacity_gb,
868-
getsizeof=lambda x: json_reduce_leaves(
878+
def get_item_size(
879+
value: Union[MultiModalKwargs, MultiModalKwargsItem,
880+
Mapping[str, NestedTensors]]
881+
) -> int:
882+
size = json_reduce_leaves(
869883
lambda a, b: a + b,
870-
json_map_leaves(get_size, x),
871-
),
872-
)
884+
json_map_leaves(get_leaf_size, value),
885+
)
886+
887+
if debug:
888+
logger.debug("Calculated size of %s to be %.2f GiB",
889+
type(value), size / GiB_bytes)
873890

874-
def __init__(self, capacity_gb: int) -> None:
891+
return size
892+
893+
return LRUCache(GiB_bytes * capacity_gb, getsizeof=get_item_size)
894+
895+
def __init__(
896+
self,
897+
capacity_gb: float,
898+
*,
899+
debug_cache_hit_ratio_steps: Optional[int] = None,
900+
) -> None:
875901
super().__init__()
876902

877-
# DEBUG: Set to None to disable
878-
self.debug_cache_hit_ratio_steps: Optional[int] = None
903+
self.debug_cache_hit_ratio_steps = debug_cache_hit_ratio_steps
879904
self.debug_cache_hits = 0
880905
self.debug_cache_total = 0
881906

882-
self._cache = self.get_lru_cache(capacity_gb, MultiModalKwargsItem)
907+
self._cache = self.get_lru_cache(
908+
capacity_gb,
909+
MultiModalKwargsItem,
910+
debug=bool(debug_cache_hit_ratio_steps),
911+
)
883912

884913
def _maybe_log_cache_stats(self) -> None:
885914
steps = self.debug_cache_hit_ratio_steps
@@ -890,6 +919,9 @@ def _maybe_log_cache_stats(self) -> None:
890919
if total > 0 and total % steps == 0:
891920
logger.debug("ProcessingCache: hit_ratio = %.2f",
892921
self.debug_cache_hits / total)
922+
logger.debug("ProcessingCache: size = %.2f / %.2f GiB",
923+
self._cache.currsize / GiB_bytes,
924+
self._cache.maxsize / GiB_bytes)
893925

894926
def get(
895927
self,

0 commit comments

Comments
 (0)