Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 40 additions & 41 deletions safetytooling/apis/inference/cache_manager.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging
import os
import sys
from collections import deque
from collections import OrderedDict, deque
from itertools import chain
from pathlib import Path
from typing import List, Tuple, Union
Expand Down Expand Up @@ -137,57 +137,55 @@ def save_embeddings(self, params: EmbeddingParams, response: EmbeddingResponseBa


class FileBasedCacheManager(BaseCacheManager):
"""Original file-based cache manager implementation."""
"""File-based cache with an LRU-evicted in-memory layer."""

def __init__(self, cache_dir: Path, num_bins: int = 20, max_mem_usage_mb: float = 5_000):
super().__init__(cache_dir, num_bins)
self.in_memory_cache = {}
self.sizes = {} # Track the size of each cache file in memory
self.total_usage_mb = 0
self.in_memory_cache: OrderedDict[Path, dict] = OrderedDict()
self.sizes: dict[Path, float] = {}
self.total_usage_mb = 0.0
self.max_mem_usage_mb = max_mem_usage_mb

def remove_entry(self, cache_file: Path):
self.in_memory_cache.pop(cache_file)
def _evict_lru(self):
"""Evict the least-recently-used bin from memory."""
lru_key = next(iter(self.in_memory_cache))
del self.in_memory_cache[lru_key]
self.total_usage_mb -= self.sizes.pop(lru_key)
LOGGER.info(f"Evicted LRU entry {lru_key} from mem cache. Total usage is now {self.total_usage_mb:.1f} MB.")

if self.max_mem_usage_mb is not None:
size = self.sizes.pop(cache_file)
self.total_usage_mb -= size
LOGGER.info(f"Removed entry from mem cache. Freed {size} MB.")
def add_entry(self, cache_file: Path, contents: dict) -> bool:
"""Add or replace a bin in the in-memory cache, evicting LRU entries if needed.
def add_entry(self, cache_file: Path, contents: dict):
self.in_memory_cache[cache_file] = contents
Returns False if the single entry is larger than the entire cache limit.
"""
size_mb = total_size(contents)

if self.max_mem_usage_mb is not None:
size = total_size(contents)
if self.total_usage_mb + size > self.max_mem_usage_mb:
space_available = self.free_space_for(size)
if not space_available:
return False
self.sizes[cache_file] = size
self.total_usage_mb += size

def free_space_for(self, needed_space_mb: float):
if self.max_mem_usage_mb is None:
return True

if needed_space_mb > self.max_mem_usage_mb:
LOGGER.warning(
f"Needed space {needed_space_mb} MB is greater than max mem usage {self.max_mem_usage_mb} MB. "
"This is not possible."
)
if self.max_mem_usage_mb is not None and size_mb > self.max_mem_usage_mb:
LOGGER.warning(f"Entry {cache_file} ({size_mb:.1f} MB) exceeds cache limit ({self.max_mem_usage_mb} MB).")
return False
LOGGER.info(f"Evicting entry from mem cache to free up {needed_space_mb} MB")
while self.total_usage_mb > self.max_mem_usage_mb - needed_space_mb:
# Find the entry with the smallest size
try:
smallest_entry = min(self.sizes.items(), key=lambda x: x[1])
except ValueError:
LOGGER.warning("No entries in mem cache to evict")
return True
self.remove_entry(smallest_entry[0])
LOGGER.info(f"Evicted entry from mem cache. Total usage is now {self.total_usage_mb} MB.")

# Remove old version first. This prevents self-eviction (the eviction
# loop below can only see OTHER bins) and prevents double-counting.
if cache_file in self.sizes:
del self.in_memory_cache[cache_file]
self.total_usage_mb -= self.sizes.pop(cache_file)

# Evict least-recently-used bins until there's room
if self.max_mem_usage_mb is not None:
while self.in_memory_cache and self.total_usage_mb + size_mb > self.max_mem_usage_mb:
self._evict_lru()

# Insert at the back (most-recently-used position)
self.in_memory_cache[cache_file] = contents
self.sizes[cache_file] = size_mb
self.total_usage_mb += size_mb
return True

def touch(self, cache_file: Path):
"""Mark a bin as recently used (moves it to the back of the LRU queue)."""
if cache_file in self.in_memory_cache:
self.in_memory_cache.move_to_end(cache_file)

def get_cache_file(self, prompt: Prompt, params: LLMParams) -> tuple[Path, str]:
# Use the SHA-1 hash of the prompt for the dictionary key
prompt_hash = prompt.model_hash() # Assuming this gives a SHA-1 hash as a hex string
Expand All @@ -211,6 +209,7 @@ def maybe_load_cache(self, prompt: Prompt, params: LLMParams):
self.add_entry(cache_file, contents)
else:
contents = self.in_memory_cache[cache_file]
self.touch(cache_file)

data = contents.get(prompt_hash, None)
return None if data is None else LLMCache.model_validate_json(data)
Expand Down
167 changes: 167 additions & 0 deletions tests/test_cache_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
"""Unit tests for FileBasedCacheManager LRU memory eviction.
These tests exercise the in-memory cache eviction without making any API calls.
"""

import tempfile
from pathlib import Path

from safetytooling.apis.inference.cache_manager import FileBasedCacheManager, total_size


def _make_data(size_chars: int, num_keys: int = 1) -> dict:
"""Create a dict whose in-memory size is roughly proportional to size_chars."""
return {f"k{i}": "x" * (size_chars // num_keys) for i in range(num_keys)}


class TestAddEntryReAdd:
"""Regression tests for re-adding bins (the self-eviction bug)."""

def test_re_add_larger_bin_does_not_crash(self):
"""Re-adding a bin that triggers eviction must not self-evict."""
with tempfile.TemporaryDirectory() as tmpdir:
small = _make_data(200_000)
large = _make_data(400_000)

limit = total_size(small) + total_size(large) + 0.01
cm = FileBasedCacheManager(Path(tmpdir), max_mem_usage_mb=limit)

bin_a = Path(tmpdir) / "binA.json"
bin_b = Path(tmpdir) / "binB.json"
bin_c = Path(tmpdir) / "binC.json"

cm.add_entry(bin_a, small)
cm.add_entry(bin_b, large)

# Re-add bin_a with larger contents (simulates disk reload after
# save_cache grew the bin on disk).
cm.add_entry(bin_a, large)
cm.add_entry(bin_c, large) # must not crash

def test_re_add_keeps_dicts_consistent(self):
"""After re-adding a bin, sizes and in_memory_cache must have the same keys."""
with tempfile.TemporaryDirectory() as tmpdir:
cm = FileBasedCacheManager(Path(tmpdir), max_mem_usage_mb=100)

bin_a = Path(tmpdir) / "binA.json"
small = _make_data(100_000)
big = _make_data(500_000)

cm.add_entry(bin_a, small)
cm.add_entry(bin_a, big)

assert set(cm.sizes.keys()) == set(cm.in_memory_cache.keys())
assert bin_a in cm.in_memory_cache
assert cm.in_memory_cache[bin_a] is big

def test_re_add_does_not_double_count(self):
"""Re-adding a bin must not inflate total_usage_mb."""
with tempfile.TemporaryDirectory() as tmpdir:
cm = FileBasedCacheManager(Path(tmpdir), max_mem_usage_mb=100)

bin_a = Path(tmpdir) / "binA.json"
data = _make_data(200_000)
expected_size = total_size(data)

cm.add_entry(bin_a, data)
assert abs(cm.total_usage_mb - expected_size) < 0.001

cm.add_entry(bin_a, data)
assert abs(cm.total_usage_mb - expected_size) < 0.001

def test_oversized_entry_not_leaked(self):
"""If an entry exceeds the entire cache limit, it must not leak."""
with tempfile.TemporaryDirectory() as tmpdir:
cm = FileBasedCacheManager(Path(tmpdir), max_mem_usage_mb=0.001)

bin_a = Path(tmpdir) / "binA.json"
huge = _make_data(1_000_000)

result = cm.add_entry(bin_a, huge)

assert result is False
assert bin_a not in cm.in_memory_cache
assert bin_a not in cm.sizes
assert cm.total_usage_mb == 0


class TestLRUEvictionOrder:
"""Tests that eviction follows LRU order, not smallest-first."""

def test_evicts_oldest_not_smallest(self):
"""When memory is full, the least-recently-used bin is evicted."""
with tempfile.TemporaryDirectory() as tmpdir:
small = _make_data(100_000)
large = _make_data(300_000)

# Room for small + large, but not small + large + small
limit = total_size(small) + total_size(large) + 0.01
cm = FileBasedCacheManager(Path(tmpdir), max_mem_usage_mb=limit)

bin_a = Path(tmpdir) / "binA.json"
bin_b = Path(tmpdir) / "binB.json"
bin_c = Path(tmpdir) / "binC.json"

cm.add_entry(bin_a, small) # oldest
cm.add_entry(bin_b, large) # newer

# Adding bin_c must evict bin_a (oldest), NOT bin_a (smallest).
# Under the old smallest-first policy, bin_a would also have been
# evicted — but for the wrong reason. We verify the LRU property
# by checking that bin_b (larger but newer) survives.
cm.add_entry(bin_c, small)

assert bin_a not in cm.in_memory_cache # evicted (oldest)
assert bin_b in cm.in_memory_cache # kept (newer)
assert bin_c in cm.in_memory_cache # just added

def test_touch_prevents_eviction(self):
"""Accessing a bin via touch() moves it to the back of the LRU queue."""
with tempfile.TemporaryDirectory() as tmpdir:
data = _make_data(200_000)

# Room for exactly 2 bins
limit = total_size(data) * 2 + 0.01
cm = FileBasedCacheManager(Path(tmpdir), max_mem_usage_mb=limit)

bin_a = Path(tmpdir) / "binA.json"
bin_b = Path(tmpdir) / "binB.json"
bin_c = Path(tmpdir) / "binC.json"

cm.add_entry(bin_a, data) # oldest
cm.add_entry(bin_b, data) # newer

# Touch bin_a — it's now the most-recently-used
cm.touch(bin_a)

# Adding bin_c should evict bin_b (now the LRU), not bin_a
cm.add_entry(bin_c, data)

assert bin_a in cm.in_memory_cache # survived (was touched)
assert bin_b not in cm.in_memory_cache # evicted (LRU after touch)
assert bin_c in cm.in_memory_cache

def test_add_entry_moves_to_mru(self):
"""Re-adding a bin moves it to the most-recently-used position."""
with tempfile.TemporaryDirectory() as tmpdir:
data = _make_data(200_000)

limit = total_size(data) * 2 + 0.01
cm = FileBasedCacheManager(Path(tmpdir), max_mem_usage_mb=limit)

bin_a = Path(tmpdir) / "binA.json"
bin_b = Path(tmpdir) / "binB.json"
bin_c = Path(tmpdir) / "binC.json"

cm.add_entry(bin_a, data) # oldest
cm.add_entry(bin_b, data) # newer

# Re-add bin_a (simulates reload from disk) — now it's MRU
cm.add_entry(bin_a, data)

# Adding bin_c should evict bin_b (now LRU), not bin_a
cm.add_entry(bin_c, data)

assert bin_a in cm.in_memory_cache
assert bin_b not in cm.in_memory_cache
assert bin_c in cm.in_memory_cache