diff --git a/changes/3357.feature.rst b/changes/3357.feature.rst new file mode 100644 index 0000000000..e9217d4841 --- /dev/null +++ b/changes/3357.feature.rst @@ -0,0 +1,3 @@ +Add LRUStoreCache for improved performance with remote stores + +The new ``LRUStoreCache`` provides a least-recently-used (LRU) caching layer that can be wrapped around any zarr store to significantly improve performance, especially for remote stores where network latency is a bottleneck. \ No newline at end of file diff --git a/docs/user-guide/lrustorecache.rst b/docs/user-guide/lrustorecache.rst new file mode 100644 index 0000000000..6e8f66d612 --- /dev/null +++ b/docs/user-guide/lrustorecache.rst @@ -0,0 +1,186 @@ +.. only:: doctest + + >>> import shutil + >>> shutil.rmtree('test.zarr', ignore_errors=True) + +.. _user-guide-lrustorecache: + +LRUStoreCache guide +=================== + +The :class:`zarr.storage.LRUStoreCache` provides a least-recently-used (LRU) cache layer +that can be wrapped around any Zarr store to improve performance for repeated data access. +This is particularly useful when working with remote stores (e.g., S3, HTTP) where network +latency can significantly impact data access speed. + +The LRUStoreCache implements a cache that stores frequently accessed data chunks in memory, +automatically evicting the least recently used items when the cache reaches its maximum size. + +.. note:: + The LRUStoreCache is a wrapper store that maintains compatibility with the full + :class:`zarr.abc.store.Store` API while adding transparent caching functionality. + +Basic Usage +----------- + +Creating an LRUStoreCache is straightforward - simply wrap any existing store with the cache: + + >>> import zarr + >>> import zarr.storage + >>> import numpy as np + >>> + >>> # Create a local store and wrap it with LRU cache + >>> local_store = zarr.storage.LocalStore('test.zarr') + >>> cache = zarr.storage.LRUStoreCache(local_store, max_size=1024 * 1024 * 256) # 256MB cache + >>> + >>> # Create an array using the cached store + >>> zarr_array = zarr.zeros((100, 100), chunks=(10, 10), dtype='f8', store=cache, mode='w') + >>> + >>> # Write some data to force chunk creation + >>> zarr_array[:] = np.random.random((100, 100)) + +The ``max_size`` parameter controls the maximum memory usage of the cache in bytes. Set it to +``None`` for unlimited cache size (use with caution). + +Performance Benefits +-------------------- + +The LRUStoreCache provides significant performance improvements for repeated data access: + + >>> import time + >>> + >>> # Benchmark reading with cache + >>> start = time.time() + >>> for _ in range(100): + ... _ = zarr_array[:] + >>> elapsed_cache = time.time() - start + >>> + >>> # Compare with direct store access (without cache) + >>> zarr_array_nocache = zarr.open('test.zarr', mode='r') + >>> start = time.time() + >>> for _ in range(100): + ... _ = zarr_array_nocache[:] + >>> elapsed_nocache = time.time() - start + >>> + >>> speedup = elapsed_nocache/elapsed_cache + +Cache effectiveness is particularly pronounced with repeated access to the same data chunks. + +Remote Store Caching +-------------------- + +The LRUStoreCache is most beneficial when used with remote stores where network latency +is a significant factor. Here's a conceptual example:: + + # Example with a remote store (requires gcsfs) + import gcsfs + + # Create a remote store (Google Cloud Storage example) + gcs = gcsfs.GCSFileSystem(token='anon') + remote_store = gcsfs.GCSMap( + root='your-bucket/data.zarr', + gcs=gcs, + check=False + ) + + # Wrap with LRU cache for better performance + cached_store = zarr.storage.LRUStoreCache(remote_store, max_size=2**28) + + # Open array through cached store + z = zarr.open(cached_store) + +The first access to any chunk will be slow (network retrieval), but subsequent accesses +to the same chunk will be served from the local cache, providing dramatic speedup. + +Cache Configuration +------------------- + +The LRUStoreCache can be configured with several parameters: + +**max_size**: Controls the maximum memory usage of the cache in bytes + + >>> # Create a base store for demonstration + >>> store = zarr.storage.LocalStore('config_example.zarr') + >>> + >>> # 256MB cache + >>> cache = zarr.storage.LRUStoreCache(store, max_size=2**28) + >>> + >>> # Unlimited cache size (use with caution) + >>> cache = zarr.storage.LRUStoreCache(store, max_size=None) + +**read_only**: Create a read-only cache + + >>> cache = zarr.storage.LRUStoreCache(store, max_size=2**28, read_only=True) + +Cache Statistics +---------------- + +The LRUStoreCache provides statistics to monitor cache performance: + + >>> # Access some data to generate cache activity + >>> data = zarr_array[0:50, 0:50] # First access - cache miss + >>> data = zarr_array[0:50, 0:50] # Second access - cache hit + >>> + >>> cache_hits = cache.hits + >>> cache_misses = cache.misses + >>> total_requests = cache.hits + cache.misses + >>> cache_hit_ratio = cache.hits / total_requests if total_requests > 0 else 0 + >>> # Typical hit ratio is > 50% with repeated access patterns + +Cache Management +---------------- + +The cache provides methods for manual cache management: + + >>> # Clear all cached values but keep keys cache + >>> cache.invalidate_values() + >>> + >>> # Clear keys cache + >>> cache.invalidate_keys() + >>> + >>> # Clear entire cache + >>> cache.invalidate() + +Best Practices +-------------- + +1. **Size the cache appropriately**: Set ``max_size`` based on available memory and expected data access patterns +2. **Use with remote stores**: The cache provides the most benefit when wrapping slow remote stores +3. **Monitor cache statistics**: Use hit/miss ratios to tune cache size and access patterns +4. **Consider data locality**: Access data in chunks sequentially rather than jumping around randomly to maximize cache reuse + +Examples from Real Usage +------------------------ + +Here's a complete example demonstrating cache effectiveness: + + >>> import zarr + >>> import zarr.storage + >>> import time + >>> import numpy as np + >>> + >>> # Create test data + >>> local_store = zarr.storage.LocalStore('benchmark.zarr') + >>> cache = zarr.storage.LRUStoreCache(local_store, max_size=2**28) + >>> zarr_array = zarr.zeros((100, 100), chunks=(10, 10), dtype='f8', store=cache, mode='w') + >>> zarr_array[:] = np.random.random((100, 100)) + >>> + >>> # Demonstrate cache effectiveness with repeated access + >>> # First access (cache miss): + >>> start = time.time() + >>> data = zarr_array[20:30, 20:30] + >>> first_access = time.time() - start + >>> + >>> # Second access (cache hit): + >>> start = time.time() + >>> data = zarr_array[20:30, 20:30] # Same data should be cached + >>> second_access = time.time() - start + >>> + >>> # Calculate cache performance metrics + >>> cache_speedup = first_access/second_access + +This example shows how the LRUStoreCache can significantly reduce access times for repeated +data reads, particularly important when working with remote data sources. + +.. _Zip Store Specification: https://github.com/zarr-developers/zarr-specs/pull/311 +.. _fsspec: https://filesystem-spec.readthedocs.io diff --git a/src/zarr/storage/__init__.py b/src/zarr/storage/__init__.py index 00df50214f..1c50a3acb9 100644 --- a/src/zarr/storage/__init__.py +++ b/src/zarr/storage/__init__.py @@ -4,6 +4,7 @@ from typing import Any from zarr.errors import ZarrDeprecationWarning +from zarr.storage._cache import LRUStoreCache from zarr.storage._common import StoreLike, StorePath from zarr.storage._fsspec import FsspecStore from zarr.storage._local import LocalStore @@ -16,6 +17,7 @@ __all__ = [ "FsspecStore", "GpuMemoryStore", + "LRUStoreCache", "LocalStore", "LoggingStore", "MemoryStore", diff --git a/src/zarr/storage/_cache.py b/src/zarr/storage/_cache.py new file mode 100644 index 0000000000..55ccfa8a68 --- /dev/null +++ b/src/zarr/storage/_cache.py @@ -0,0 +1,498 @@ +import warnings +from collections import OrderedDict +from collections.abc import AsyncIterator, Iterable +from threading import Lock +from typing import Any, Self + +from zarr.abc.store import ( + ByteRequest, + Store, +) +from zarr.core.buffer import Buffer, BufferPrototype +from zarr.core.buffer.core import default_buffer_prototype + + +class LRUStoreCache(Store): + """ + Storage class that implements a least-recently-used (LRU) cache layer over + some other store. + + Intended primarily for use with stores that can be slow to + access, e.g., remote stores that require network communication to store and + retrieve data. + + The cache stores the raw bytes returned by the underlying store, before any + decompression or array processing. This means that compressed data remains + compressed in the cache, and decompression happens each time the cached data + is accessed. This design choice keeps the cache lightweight while still + providing significant performance benefits for network-bound operations. + + This store supports both read and write operations. Write operations use a + write-through strategy where data is written to both the underlying store + and cached locally. The cache automatically invalidates entries when the + underlying data is modified to maintain consistency. + + Parameters + ---------- + store : Store + The store containing the actual data to be cached. + max_size : int + The maximum size that the cache may grow to, in number of bytes. + This parameter is required to prevent unbounded memory growth. + + Values smaller than your typical chunk size will result in most data + being silently excluded from the cache, reducing effectiveness. + + Examples + -------- + The example below wraps a LocalStore with an LRU cache for demonstration:: + + >>> import tempfile + >>> import zarr + >>> from zarr.storage import LocalStore + >>> + >>> # Create a temporary directory for the example + >>> temp_dir = tempfile.mkdtemp() + >>> store = LocalStore(temp_dir) + >>> + >>> # Create some test data first + >>> arr = zarr.create((1000, 1000), chunks=(100, 100), store=store, dtype='f4') + >>> arr[:] = 42.0 + >>> + >>> # Now wrap with cache for faster access + >>> cached_store = zarr.LRUStoreCache(store, max_size=1024 * 1024 * 256) # 256MB cache + >>> cached_arr = zarr.open(cached_store) + >>> + >>> # First access loads from disk and caches + >>> data1 = cached_arr[0:100, 0:100] # Cache miss + >>> + >>> # Second access uses cache (much faster for remote stores) + >>> data2 = cached_arr[0:100, 0:100] # Cache hit + + For remote stores where the performance benefit is more apparent:: + + >>> from zarr.storage import RemoteStore + >>> # remote_store = RemoteStore.from_url("https://example.com/data.zarr") + >>> # cached_remote = zarr.LRUStoreCache(remote_store, max_size=2**28) + + + """ + + @property + def supports_writes(self) -> bool: + """Whether the underlying store supports write operations.""" + return self._store.supports_writes + + @property + def supports_deletes(self) -> bool: + """Whether the underlying store supports delete operations.""" + return self._store.supports_deletes + + @property + def supports_partial_writes(self) -> bool: + """Whether the underlying store supports partial write operations.""" + return self._store.supports_partial_writes + + @property + def supports_listing(self) -> bool: + """Whether the underlying store supports listing operations.""" + return self._store.supports_listing + + def __init__(self, store: Store, *, max_size: int) -> None: + if max_size <= 0: + raise ValueError("max_size must be a positive integer (bytes)") + + # Always inherit read_only state from the underlying store + read_only = store.read_only + super().__init__(read_only=read_only) + + self._store = store + self._max_size = max_size + self._current_size = 0 + self._contains_cache: dict[Any, Any] = {} + self._listdir_cache: dict[str | None, list[str]] = {} + self._values_cache: OrderedDict[str, bytes] = OrderedDict() + self._mutex = Lock() + self.hits = self.misses = 0 + + @classmethod + async def open(cls, store: Store, *, max_size: int, read_only: bool = False) -> Self: + """ + Create and open a new LRU cache store. + + Parameters + ---------- + store : Store + The underlying store to wrap with caching. + max_size : int + The maximum size that the cache may grow to, in number of bytes. + + Returns + ------- + LRUStoreCache + The opened cache store instance. + """ + + cache = cls(store, max_size=max_size) + + if read_only: + cache._read_only = True + await cache._open() + return cache + + def with_read_only(self, read_only: bool = False) -> "LRUStoreCache": + """ + Return a new LRUStoreCache with a new read_only setting. + + Parameters + ---------- + read_only + If True, the store will be created in read-only mode. Defaults to False. + + Returns + ------- + LRUStoreCache + A new LRUStoreCache with the specified read_only setting. + """ + # Create a new underlying store with the new read_only setting + underlying_store = self._store.with_read_only(read_only) + return LRUStoreCache(underlying_store, max_size=self._max_size) + + def __getstate__( + self, + ) -> tuple[ + Store, + int, + int, + dict[Any, Any], + dict[str | None, list[str]], + OrderedDict[str, bytes], + int, + int, + bool, + bool, + ]: + return ( + self._store, + self._max_size, + self._current_size, + self._contains_cache, + self._listdir_cache, + self._values_cache, + self.hits, + self.misses, + self._read_only, + self._is_open, + ) + + def __setstate__( + self, + state: tuple[ + Store, + int, + int, + dict[Any, Any], + dict[str | None, list[str]], + OrderedDict[str, bytes], + int, + int, + bool, + bool, + ], + ) -> None: + ( + self._store, + self._max_size, + self._current_size, + self._contains_cache, + self._listdir_cache, + self._values_cache, + self.hits, + self.misses, + self._read_only, + self._is_open, + ) = state + self._mutex = Lock() + + def __len__(self) -> int: + return len(self._values_cache) + + async def clear(self) -> None: + """ + Remove all keys from the store and clear the cache. + + This operation clears both the underlying store and invalidates + all cached data to maintain consistency. + """ + + await self._store.clear() + self.invalidate() + + async def getsize( + self, + key: str, + prototype: BufferPrototype | None = None, + ) -> int: + """ + Get the size in bytes of the value stored at the given key. + + For remote stores, this method attempts to get and cache the value + since network latency typically dominates the cost of both getsize() + and get() operations, making it more efficient to retrieve the full + value during size queries. + """ + cache_key = key + + # Check cache first + with self._mutex: + if cache_key in self._values_cache: + cached_value = self._values_cache[cache_key] + # Move to end to mark as recently used + self._values_cache.move_to_end(cache_key) + self.hits += 1 + return len(cached_value) + + # Not in cache, delegate to underlying store + self.misses += 1 + + if prototype is None: + prototype = default_buffer_prototype() + + # Try to get the full value first (better for remote stores) + try: + value = await self._store.get(key, prototype) + if value is not None: + # Successfully got the value, cache it and return size + with self._mutex: + if cache_key not in self._values_cache: + self._cache_value(cache_key, value) + + # Return size based on the actual value we retrieved + return len(value) + except (KeyError, FileNotFoundError): + pass + except NotImplementedError: + pass + except (ConnectionError, TimeoutError, RuntimeError): + pass + except Exception: + # Re-raise unexpected exceptions rather than silently falling back + raise + + # Fallback to underlying store's getsize() method + return await self._store.getsize(key) + + def _pop_value(self) -> bytes: + # remove the first value from the cache, as this will be the least recently + # used value + _, v = self._values_cache.popitem(last=False) + return v + + def _accommodate_value(self, value_size: int) -> None: + # Remove items from the cache until there's enough room for a value + while self._current_size + value_size > self._max_size: + v = self._pop_value() + self._current_size -= len(v) + + def _cache_value(self, key: str, value: Buffer | bytes) -> None: + """Cache a value, handling both Buffer objects and bytes.""" + # Convert to bytes if needed + if isinstance(value, Buffer): + cache_value = value.to_bytes() + else: + # Already bytes + cache_value = value + + value_size = len(cache_value) + + # Check if value exceeds max size - if so, don't cache it + if value_size <= self._max_size: + self._accommodate_value(value_size) + cache_key = key + self._values_cache[cache_key] = cache_value + self._current_size += value_size + else: + # Emit warning when value is too large to cache + warnings.warn( + f"Value for key '{key}' ({value_size:,} bytes) exceeds cache max_size " + f"({self._max_size:,} bytes) and will not be cached. Consider increasing " + f"max_size if this data is frequently accessed.", + UserWarning, + stacklevel=3, + ) + + def invalidate(self) -> None: + """Completely clear the cache.""" + + with self._mutex: + self._values_cache.clear() + self._current_size = 0 + + def invalidate_values(self) -> None: + """Clear only the values cache, keeping other caches intact.""" + with self._mutex: + self._values_cache.clear() + self._current_size = 0 + + def __eq__(self, value: object) -> bool: + return type(self) is type(value) and self._store.__eq__(value._store) # type: ignore[attr-defined] + + def __str__(self) -> str: + return f"cache://{self._store}" + + def __repr__(self) -> str: + return f"LRUStoreCache({self._store!r}, max_size={self._max_size})" + + async def delete(self, key: str) -> None: + """ + Remove a key from the store. + + Parameters + ---------- + key : str + + Notes + ----- + If ``key`` is a directory within this store, the entire directory + at ``store.root / key`` is deleted. + """ + # Check if store is writable + self._check_writable() + + if self._store.supports_deletes: + await self._store.delete(key) + else: + raise NotImplementedError( + f"Store {type(self._store).__name__} does not support delete operations" + ) + + # Invalidate cache entries + self.invalidate_values() + + async def exists(self, key: str) -> bool: + """ + Check if a key exists in the store. + + This method first checks the cache for the key to avoid + unnecessary calls to the underlying store. + """ + cache_key = key + + # Check cache first + with self._mutex: + if cache_key in self._values_cache: + # Key exists in cache, so it exists in store + # Move to end to mark as recently used + self._values_cache.move_to_end(cache_key) + return True + + # Not in cache, delegate to underlying store + return await self._store.exists(key) + + async def get( + self, + key: str, + prototype: BufferPrototype | None = None, + byte_range: ByteRequest | None = None, + ) -> Buffer | None: + # Use the cache for get operations + cache_key = key + + if prototype is None: + prototype = default_buffer_prototype() + + # For byte_range requests, don't use cache for now (could be optimized later) + if byte_range is not None: + return await self._store.get(key, prototype, byte_range) + + if cache_key in self._values_cache: + # Try cache first + with self._mutex: + value = self._values_cache[cache_key] + self.hits += 1 + self._values_cache.move_to_end(cache_key) + return prototype.buffer.from_bytes(value) + else: + # Cache miss - get from store + result = await self._store.get(key, prototype, byte_range) + + # Cache the result if we got one + if result is not None: + with self._mutex: + self.misses += 1 + if cache_key not in self._values_cache: + self._cache_value(cache_key, result) + else: + # Still count as a miss even if result is None + with self._mutex: + self.misses += 1 + + return result + + async def get_partial_values( + self, + prototype: BufferPrototype, + key_ranges: Iterable[tuple[str, ByteRequest | None]], + ) -> list[Buffer | None]: + # Delegate to the underlying store + if self.supports_partial_writes: + return await self._store.get_partial_values(prototype, key_ranges) + else: + # Fallback - get each value individually + results = [] + for key, byte_range in key_ranges: + result = await self.get(key, prototype, byte_range) + results.append(result) + return results + + async def list(self) -> AsyncIterator[str]: + # Delegate to the underlying store + if self.supports_listing: + async for key in self._store.list(): + yield key + + async def list_dir(self, prefix: str) -> AsyncIterator[str]: + # Delegate to the underlying store + if self.supports_listing: + async for key in self._store.list_dir(prefix): + yield key + + async def list_prefix(self, prefix: str) -> AsyncIterator[str]: + # Delegate to the underlying store + if self.supports_listing: + async for key in self._store.list_prefix(prefix): + yield key + + async def set(self, key: str, value: Buffer, byte_range: tuple[int, int] | None = None) -> None: + # docstring inherited + # Check if store is writable + self._check_writable() + + # Write to underlying store first + await self._store.set(key, value) + + # Update cache with the new value + cache_key = key + with self._mutex: + if cache_key in self._values_cache: + old_value = self._values_cache[cache_key] + self._current_size -= len(old_value) + del self._values_cache[cache_key] + + # Cache the new value + self._cache_value(cache_key, value) + + async def set_partial_values( + self, key_start_values: Iterable[tuple[str, int, bytes | bytearray | memoryview]] + ) -> None: + # Check if store is writable + self._check_writable() + + # Delegate to the underlying store + if self.supports_partial_writes: + await self._store.set_partial_values(key_start_values) + else: + # Fallback - this is complex to implement properly, so just invalidate cache + for _key, _start, _value in key_start_values: + # For now, just invalidate the cache for these keys + self.invalidate_values() diff --git a/tests/test_store/test_cache.py b/tests/test_store/test_cache.py new file mode 100644 index 0000000000..e28b5acfe8 --- /dev/null +++ b/tests/test_store/test_cache.py @@ -0,0 +1,416 @@ +from __future__ import annotations + +import contextlib +import warnings +from collections import Counter +from typing import Any + +import pytest + +from zarr.abc.store import RangeByteRequest +from zarr.core.buffer import cpu +from zarr.core.buffer.cpu import Buffer, buffer_prototype +from zarr.storage import LRUStoreCache, MemoryStore +from zarr.testing.store import StoreTests + + +class CounterStore(MemoryStore): # type: ignore[misc] + """ + A thin wrapper of MemoryStore to count different method calls for testing. + """ + + def __init__(self) -> None: + super().__init__() + self.counter: Counter[tuple[str, Any] | str] = Counter() + + # Add Store-like attributes that LRUStoreCache expects + self.supports_writes = True + self.supports_deletes = True + self.supports_partial_writes = False + self.supports_listing = True + + async def clear(self) -> None: + self.counter["clear"] += 1 + # docstring inherited + self._store_dict.clear() + + async def set(self, key: str, value: Any) -> None: + """Store-like set method for async interface.""" + self.counter["set", key] += 1 + # Convert Buffer to bytes if needed + if hasattr(value, "to_bytes"): + self._store_dict[key] = value.to_bytes() + else: + self._store_dict[key] = value + + async def get(self, key: str, prototype: Any = None, byte_range: Any = None) -> Any: + """Store-like get method for async interface.""" + self.counter["get", key] += 1 + try: + data = self._store_dict[key] + # Return as Buffer if prototype provided + if prototype is not None and hasattr(prototype, "buffer"): + from zarr.core.buffer.cpu import Buffer + + return Buffer.from_bytes(data) + return data # noqa: TRY300 + except KeyError: + return None + + async def delete(self, key: str) -> None: + """Store-like delete method for async interface.""" + self.counter["delete", key] += 1 + with contextlib.suppress(KeyError): + del self._store_dict[key] + + async def exists(self, key: str) -> bool: + """Store-like exists method for async interface.""" + self.counter["exists", key] += 1 + return key in self._store_dict + + async def getsize(self, key: str) -> int: + """Store-like getsize method for async interface.""" + self.counter["getsize", key] += 1 + try: + return len(self._store_dict[key]) + except KeyError: # noqa: TRY203 + raise + + +class TestLRUStoreCache(StoreTests[LRUStoreCache, Buffer]): # type: ignore[misc] + store_cls = LRUStoreCache + buffer_cls = cpu.Buffer + CountingClass = CounterStore + LRUStoreClass = LRUStoreCache + root = "" + + async def get(self, store: LRUStoreCache, key: str) -> Buffer: + """Get method required by StoreTests.""" + return await store.get(key, prototype=cpu.buffer_prototype) + + async def set(self, store: LRUStoreCache, key: str, value: Buffer) -> None: + """Set method required by StoreTests.""" + await store.set(key, value) + + @pytest.fixture + def store_kwargs(self) -> dict[str, Any]: + """Provide default kwargs for store creation.""" + return {"store": MemoryStore(), "max_size": 2**27} + + @pytest.fixture + async def store(self, store_kwargs: dict[str, Any]) -> LRUStoreCache: + """Override store fixture to use constructor instead of open.""" + return self.store_cls(**store_kwargs) + + @pytest.fixture + def open_kwargs(self) -> dict[str, Any]: + """Provide default kwargs for store.open().""" + return {"store": MemoryStore(), "max_size": 2**27} + + def create_store(self, **kwargs: Any) -> LRUStoreCache: + return self.LRUStoreClass(MemoryStore(), max_size=2**27) + + def create_store_from_mapping(self, mapping: dict[str, Any], **kwargs: Any) -> LRUStoreCache: + # Handle creation from existing mapping + # Create a MemoryStore from the mapping + underlying_store = MemoryStore() + if mapping: + # Convert mapping to store data + for k, v in mapping.items(): + underlying_store._store_dict[k] = v + return self.LRUStoreClass(underlying_store, max_size=2**27) + + async def test_cache_values_no_max_size(self) -> None: + # setup store + store = self.CountingClass() + foo_key = self.root + "foo" + bar_key = self.root + "bar" + await store.set(foo_key, b"xxx") + await store.set(bar_key, b"yyy") + assert 0 == store.counter["get", foo_key] + assert 1 == store.counter["set", foo_key] + assert 0 == store.counter["get", bar_key] + assert 1 == store.counter["set", bar_key] + + # setup cache + cache = self.LRUStoreClass(store, max_size=1024 * 1024) + assert 0 == cache.hits + assert 0 == cache.misses + + # test first get(), cache miss + result = await cache.get(foo_key) + assert result is not None + assert result.to_bytes() == b"xxx" + assert 1 == store.counter["get", foo_key] + assert 1 == store.counter["set", foo_key] + assert 0 == cache.hits + assert 1 == cache.misses + + # test second get(), cache hit + result = await cache.get(foo_key) + assert result is not None + assert result.to_bytes() == b"xxx" + assert 1 == store.counter["get", foo_key] # No additional get call due to cache + assert 1 == store.counter["set", foo_key] + assert 1 == cache.hits + assert 1 == cache.misses + + # test set(), get() + from zarr.core.buffer.cpu import Buffer + + await cache.set(foo_key, Buffer.from_bytes(b"zzz")) + assert 1 == store.counter["get", foo_key] + assert 2 == store.counter["set", foo_key] + # should be a cache hit + result = await cache.get(foo_key) + assert result is not None + assert result.to_bytes() == b"zzz" + assert 1 == store.counter["get", foo_key] # No additional get call due to cache + assert 2 == store.counter["set", foo_key] + assert 2 == cache.hits + assert 1 == cache.misses + + # manually invalidate all cached values + cache.invalidate_values() + result = await cache.get(foo_key) + assert result is not None + assert result.to_bytes() == b"zzz" + assert 2 == store.counter["get", foo_key] # Cache invalidated, so new get call + assert 2 == store.counter["set", foo_key] + cache.invalidate() + result = await cache.get(foo_key) + assert result is not None + assert result.to_bytes() == b"zzz" + assert 3 == store.counter["get", foo_key] # Cache invalidated again, so another get call + assert 2 == store.counter["set", foo_key] + + # test delete() + await cache.delete(foo_key) + result = await cache.get(foo_key) + assert result is None + # Verify the key is actually deleted from underlying store + result = await store.get(foo_key) + assert result is None + + # verify other keys untouched + assert 0 == store.counter["get", bar_key] + assert 1 == store.counter["set", bar_key] + + async def test_cache_values_with_max_size(self) -> None: + # setup store + store = self.CountingClass() + foo_key = self.root + "foo" + bar_key = self.root + "bar" + await store.set(foo_key, b"xxx") + await store.set(bar_key, b"yyy") + assert 0 == store.counter["get", foo_key] + assert 0 == store.counter["get", bar_key] + # setup cache - can only hold one item + cache = self.LRUStoreClass(store, max_size=5) + assert 0 == cache.hits + assert 0 == cache.misses + + # test first 'foo' get(), cache miss + result = await cache.get(foo_key) + assert result is not None + assert result.to_bytes() == b"xxx" + assert 1 == store.counter["get", foo_key] + assert 0 == cache.hits + assert 1 == cache.misses + + # test second 'foo' get(), cache hit + result = await cache.get(foo_key) + assert result is not None + assert result.to_bytes() == b"xxx" + assert 1 == store.counter["get", foo_key] # No additional get call due to cache + assert 1 == cache.hits + assert 1 == cache.misses + + # test first 'bar' get(), cache miss + result = await cache.get(bar_key) + assert result is not None + assert result.to_bytes() == b"yyy" + assert 1 == store.counter["get", bar_key] + assert 1 == cache.hits + assert 2 == cache.misses + + # test second 'bar' get(), cache hit + result = await cache.get(bar_key) + assert result is not None + assert result.to_bytes() == b"yyy" + assert 1 == store.counter["get", bar_key] # No additional get call due to cache + assert 2 == cache.hits + assert 2 == cache.misses + + # test 'foo' get(), should have been evicted, cache miss + result = await cache.get(foo_key) + assert result is not None + assert result.to_bytes() == b"xxx" + assert 2 == store.counter["get", foo_key] # Cache miss due to eviction + assert 2 == cache.hits + assert 3 == cache.misses + + # test 'bar' get(), should have been evicted, cache miss + result = await cache.get(bar_key) + assert result is not None + assert result.to_bytes() == b"yyy" + assert 2 == store.counter["get", bar_key] # Cache miss due to eviction + assert 2 == cache.hits + assert 4 == cache.misses + + async def test_cache_value_too_large_warning(self) -> None: + """Test that a warning is emitted when a value is too large to cache.""" + # setup store with small cache + store = self.CountingClass() + foo_key = self.root + "foo" + large_value = b"x" * 1000 # 1000 bytes + small_cache_size = 500 # 500 bytes max cache + + await store.set(foo_key, large_value) + cache = self.LRUStoreClass(store, max_size=small_cache_size) + + # Test that warning is emitted when trying to cache a value that's too large + # This should trigger the warning since 1000 bytes > 500 bytes cache limit + with pytest.warns( + UserWarning, match=r"Value for key.*exceeds cache max_size.*and will not be cached" + ): + result = await cache.get(foo_key) + assert result is not None + assert result.to_bytes() == large_value + + # Verify the value was not actually cached (cache miss on second access) + assert cache.hits == 0 # No hits yet + assert cache.misses == 1 # One miss from the first access + + # Second access should also be a miss since value wasn't cached + # And it will also emit a warning, so we need to catch that too + with warnings.catch_warnings(): + warnings.simplefilter("ignore") # Ignore warnings for this call + result2 = await cache.get(foo_key) + assert result2 is not None + assert result2.to_bytes() == large_value + assert cache.hits == 0 # Still no hits + assert cache.misses == 2 # Two misses total + + # Verify the warning message contains expected information + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + await cache.get(foo_key) # Trigger warning again + + assert len(w) == 1 + warning_message = str(w[0].message) + assert f"Value for key '{foo_key}'" in warning_message + assert "1,000 bytes" in warning_message # Check formatted number + assert "500 bytes" in warning_message # Check cache size + assert "Consider increasing max_size" in warning_message + + async def test_getsize_uses_cache(self) -> None: + """Test that getsize() uses cached values when available.""" + store = self.CountingClass() + cache = self.LRUStoreClass(store, max_size=1000) + + await store.set("key", b"value") + + # Populate cache + await cache.get("key") + assert 1 == store.counter["get", "key"] + + # getsize() should use cached value + size = await cache.getsize("key") + assert size == 5 + assert 1 == store.counter["get", "key"] # No additional store access + assert cache.hits == 1 + + async def test_getsize_exception_handling(self) -> None: + """Test that getsize() handles get() exceptions gracefully.""" + + class FailingStore(CounterStore): + async def get(self, key: str, prototype: Any = None, byte_range: Any = None) -> Any: + if key == "fail": + raise RuntimeError("Simulated failure") + return await super().get(key, prototype, byte_range) + + store = FailingStore() + cache = self.LRUStoreClass(store, max_size=1000) + await store.set("fail", b"x" * 50) # Small value that would be cached + + # getsize() should work despite get() failing + size = await cache.getsize("fail") + assert size == 50 + assert cache.hits == 0 # No successful caching + + async def test_get_partial_values(self) -> None: + """Test get_partial_values method.""" + + # setup store + store = MemoryStore() + foo_key = "foo" + bar_key = "bar" + foo_data = b"hello world" + bar_data = b"goodbye world" + + await store.set(foo_key, Buffer.from_bytes(foo_data)) + await store.set(bar_key, Buffer.from_bytes(bar_data)) + + cache = self.LRUStoreClass(store, max_size=1024 * 1024) + + # Test getting partial values with byte ranges + key_ranges = [ + (foo_key, RangeByteRequest(start=0, end=5)), # "hello" + (bar_key, RangeByteRequest(start=8, end=13)), # "world" + (foo_key, None), # full value + ] + + results = await cache.get_partial_values(buffer_prototype, key_ranges) + + assert len(results) == 3 + assert results[0] is not None + assert results[0].to_bytes() == b"hello" + assert results[1] is not None + assert results[1].to_bytes() == b"world" + assert results[2] is not None + assert results[2].to_bytes() == foo_data + + # Test with non-existent key + key_ranges_with_missing = [ + (foo_key, RangeByteRequest(start=0, end=5)), + ("missing_key", None), + ] + + results = await cache.get_partial_values(buffer_prototype, key_ranges_with_missing) + assert len(results) == 2 + assert results[0] is not None + assert results[0].to_bytes() == b"hello" + assert results[1] is None + + async def test_set_partial_values(self) -> None: + """Test set_partial_values method.""" + # setup store + store = MemoryStore() + cache = self.LRUStoreClass(store, max_size=1024 * 1024) + + key = "test_key" + original_data = b"hello world 123" + + # Set initial data + await cache.set(key, Buffer.from_bytes(original_data)) + + # Test partial value setting + partial_updates = [ + (key, 6, b"WORLD"), # Replace "world" with "WORLD" + (key, 12, b"456"), # Replace "123" with "456" + ] + + # Since MemoryStore doesn't implement set_partial_values, + # it should raise NotImplementedError + with pytest.raises(NotImplementedError): + await cache.set_partial_values(partial_updates) + + # The original data should still be in cache since the operation failed + result = await cache.get(key) + assert result is not None + assert result.to_bytes() == original_data + + # Since the operation failed, there should be at least one cache hit + # (from the initial set and then the get) + assert cache.hits >= 1