Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 0 additions & 109 deletions tests/lora/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

from vllm.lora.utils import (get_adapter_absolute_path,
parse_fine_tuned_lora_name, replace_submodule)
from vllm.utils import LRUCache


def test_parse_fine_tuned_lora_name_valid():
Expand Down Expand Up @@ -85,114 +84,6 @@ def test_replace_submodule():
assert dict(model.named_modules())["seq1.dense2"] == dense2


class TestLRUCache(LRUCache):

def _on_remove(self, key, value):
if not hasattr(self, "_remove_counter"):
self._remove_counter = 0
self._remove_counter += 1


def test_lru_cache():
cache = TestLRUCache(3)

cache.put(1, 1)
assert len(cache) == 1

cache.put(1, 1)
assert len(cache) == 1

cache.put(2, 2)
assert len(cache) == 2

cache.put(3, 3)
assert len(cache) == 3
assert set(cache.cache) == {1, 2, 3}

cache.put(4, 4)
assert len(cache) == 3
assert set(cache.cache) == {2, 3, 4}
assert cache._remove_counter == 1
assert cache.get(2) == 2

cache.put(5, 5)
assert set(cache.cache) == {2, 4, 5}
assert cache._remove_counter == 2

assert cache.pop(5) == 5
assert len(cache) == 2
assert set(cache.cache) == {2, 4}
assert cache._remove_counter == 3

cache.pop(10)
assert len(cache) == 2
assert set(cache.cache) == {2, 4}
assert cache._remove_counter == 3

cache.get(10)
assert len(cache) == 2
assert set(cache.cache) == {2, 4}
assert cache._remove_counter == 3

cache.put(6, 6)
assert len(cache) == 3
assert set(cache.cache) == {2, 4, 6}
assert 2 in cache
assert 4 in cache
assert 6 in cache

cache.remove_oldest()
assert len(cache) == 2
assert set(cache.cache) == {2, 6}
assert cache._remove_counter == 4

cache.clear()
assert len(cache) == 0
assert cache._remove_counter == 6

cache._remove_counter = 0

cache[1] = 1
assert len(cache) == 1

cache[1] = 1
assert len(cache) == 1

cache[2] = 2
assert len(cache) == 2

cache[3] = 3
assert len(cache) == 3
assert set(cache.cache) == {1, 2, 3}

cache[4] = 4
assert len(cache) == 3
assert set(cache.cache) == {2, 3, 4}
assert cache._remove_counter == 1
assert cache[2] == 2

cache[5] = 5
assert set(cache.cache) == {2, 4, 5}
assert cache._remove_counter == 2

del cache[5]
assert len(cache) == 2
assert set(cache.cache) == {2, 4}
assert cache._remove_counter == 3

cache.pop(10)
assert len(cache) == 2
assert set(cache.cache) == {2, 4}
assert cache._remove_counter == 3

cache[6] = 6
assert len(cache) == 3
assert set(cache.cache) == {2, 4, 6}
assert 2 in cache
assert 4 in cache
assert 6 in cache


# Unit tests for get_adapter_absolute_path
@patch('os.path.isabs')
def test_get_adapter_absolute_path_absolute(mock_isabs):
Expand Down
133 changes: 128 additions & 5 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@
from vllm_test_utils.monitor import monitor

from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
from vllm.utils import (FlexibleArgumentParser, MemorySnapshot,
PlaceholderModule, StoreBoolean, bind_kv_cache,
deprecate_kwargs, get_open_port, memory_profiling,
merge_async_iterators, sha256, supports_kw,
swap_dict_values)
from vllm.utils import (CacheInfo, FlexibleArgumentParser, LRUCache,
MemorySnapshot, PlaceholderModule, StoreBoolean,
bind_kv_cache, deprecate_kwargs, get_open_port,
memory_profiling, merge_async_iterators, sha256,
supports_kw, swap_dict_values)

from .utils import create_new_process_for_each_test, error_on_warning

Expand Down Expand Up @@ -417,6 +417,129 @@ def test_bind_kv_cache_pp():
assert ctx['layers.0.self_attn'].kv_cache[1] is kv_cache[1][0]


class TestLRUCache(LRUCache):

def _on_remove(self, key, value):
if not hasattr(self, "_remove_counter"):
self._remove_counter = 0
self._remove_counter += 1


def test_lru_cache():
cache = TestLRUCache(3)
assert cache.stat() == CacheInfo(hits=0, total=0)
assert cache.stat(delta=True) == CacheInfo(hits=0, total=0)

cache.put(1, 1)
assert len(cache) == 1

cache.put(1, 1)
assert len(cache) == 1

cache.put(2, 2)
assert len(cache) == 2

cache.put(3, 3)
assert len(cache) == 3
assert set(cache.cache) == {1, 2, 3}

cache.put(4, 4)
assert len(cache) == 3
assert set(cache.cache) == {2, 3, 4}
assert cache._remove_counter == 1

assert cache.get(2) == 2
assert cache.stat() == CacheInfo(hits=1, total=1)
assert cache.stat(delta=True) == CacheInfo(hits=1, total=1)

assert cache[2] == 2
assert cache.stat() == CacheInfo(hits=2, total=2)
assert cache.stat(delta=True) == CacheInfo(hits=1, total=1)

cache.put(5, 5)
assert set(cache.cache) == {2, 4, 5}
assert cache._remove_counter == 2

assert cache.pop(5) == 5
assert len(cache) == 2
assert set(cache.cache) == {2, 4}
assert cache._remove_counter == 3

assert cache.get(-1) is None
assert cache.stat() == CacheInfo(hits=2, total=3)
assert cache.stat(delta=True) == CacheInfo(hits=0, total=1)

cache.pop(10)
assert len(cache) == 2
assert set(cache.cache) == {2, 4}
assert cache._remove_counter == 3

cache.get(10)
assert len(cache) == 2
assert set(cache.cache) == {2, 4}
assert cache._remove_counter == 3

cache.put(6, 6)
assert len(cache) == 3
assert set(cache.cache) == {2, 4, 6}
assert 2 in cache
assert 4 in cache
assert 6 in cache

cache.remove_oldest()
assert len(cache) == 2
assert set(cache.cache) == {2, 6}
assert cache._remove_counter == 4

cache.clear()
assert len(cache) == 0
assert cache._remove_counter == 6
assert cache.stat() == CacheInfo(hits=0, total=0)
assert cache.stat(delta=True) == CacheInfo(hits=0, total=0)

cache._remove_counter = 0

cache[1] = 1
assert len(cache) == 1

cache[1] = 1
assert len(cache) == 1

cache[2] = 2
assert len(cache) == 2

cache[3] = 3
assert len(cache) == 3
assert set(cache.cache) == {1, 2, 3}

cache[4] = 4
assert len(cache) == 3
assert set(cache.cache) == {2, 3, 4}
assert cache._remove_counter == 1
assert cache[2] == 2

cache[5] = 5
assert set(cache.cache) == {2, 4, 5}
assert cache._remove_counter == 2

del cache[5]
assert len(cache) == 2
assert set(cache.cache) == {2, 4}
assert cache._remove_counter == 3

cache.pop(10)
assert len(cache) == 2
assert set(cache.cache) == {2, 4}
assert cache._remove_counter == 3

cache[6] = 6
assert len(cache) == 3
assert set(cache.cache) == {2, 4, 6}
assert 2 in cache
assert 4 in cache
assert 6 in cache


def test_placeholder_module_error_handling():
placeholder = PlaceholderModule("placeholder_1234")

Expand Down
69 changes: 58 additions & 11 deletions vllm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,22 +236,39 @@ def hit_ratio(self) -> float:

return self.hits / self.total

def __sub__(self, other: CacheInfo):
return CacheInfo(
hits=self.hits - other.hits,
total=self.total - other.total,
)


class LRUCache(cachetools.LRUCache[_K, _V], Generic[_K, _V]):

def __init__(self,
capacity: float,
getsizeof: Optional[Callable[[_V], float]] = None):
super().__init__(capacity, getsizeof)

self.pinned_items = set[_K]()
self.capacity = capacity

self._hits = 0
self._total = 0
self._last_info = CacheInfo(hits=0, total=0)

def __getitem__(self, key: _K, *, update_info: bool = True) -> _V:
value = super().__getitem__(key)

if update_info:
self._hits += 1
self._total += 1

return value

def __delitem__(self, key: _K) -> None:
run_on_remove = key in self
value = self.__getitem__(key)
value = self.__getitem__(key,
update_info=False) # type: ignore[call-arg]
super().__delitem__(key)
if key in self.pinned_items:
# Todo: add warning to inform that del pinned item
Expand All @@ -271,8 +288,32 @@ def order(self) -> Mapping[_K, None]:
"""Return the internal order dictionary (read-only)."""
return MappingProxyType(self._LRUCache__order) # type: ignore

def stat(self) -> CacheInfo:
return CacheInfo(hits=self._hits, total=self._total)
@property
def capacity(self) -> float:
return self.maxsize

@property
def usage(self) -> float:
if self.maxsize == 0:
return 0

return self.currsize / self.maxsize

def stat(self, *, delta: bool = False) -> CacheInfo:
"""
Gets the cumulative number of hits and queries against this cache.

If :code:`delta=True`, instead gets these statistics
since the last call that also passed :code:`delta=True`.
"""
info = CacheInfo(hits=self._hits, total=self._total)

if delta:
info_delta = info - self._last_info
self._last_info = info
info = info_delta

return info

def touch(self, key: _K) -> None:
self._LRUCache__update(key) # type: ignore
Expand All @@ -292,7 +333,8 @@ def get(self,
_T]] = None) -> Optional[Union[_V, _T]]:
value: Optional[Union[_V, _T]]
if key in self:
value = self.__getitem__(key)
value = self.__getitem__(
key, update_info=False) # type: ignore[call-arg]

self._hits += 1
else:
Expand All @@ -317,8 +359,9 @@ def pop(self,
if key not in self:
return default

value = self[key]
del self[key]
value = self.__getitem__(key,
update_info=False) # type: ignore[call-arg]
self.__delitem__(key)
return value

def put(self, key: _K, value: _V) -> None:
Expand Down Expand Up @@ -353,10 +396,6 @@ def _remove_old_if_needed(self) -> None:
while self.currsize > self.capacity:
self.remove_oldest()

def clear(self) -> None:
while len(self) > 0:
self.remove_oldest(remove_pinned=True)

def popitem(self, remove_pinned: bool = False):
"""Remove and return the `(key, value)` pair least recently used."""
if not remove_pinned:
Expand All @@ -372,6 +411,14 @@ def popitem(self, remove_pinned: bool = False):
value = self.pop(cast(_K, lru_key))
return (lru_key, value)

def clear(self) -> None:
while len(self) > 0:
self.remove_oldest(remove_pinned=True)

self._hits = 0
self._total = 0
self._last_info = CacheInfo(hits=0, total=0)


class PyObjectCache:
"""Used to cache python objects to avoid object allocations
Expand Down
Loading