diff --git a/changes/3638.feature.md b/changes/3638.feature.md new file mode 100644 index 0000000000..ad2276fd51 --- /dev/null +++ b/changes/3638.feature.md @@ -0,0 +1 @@ +Add methods for reading stored objects as bytes and JSON-decoded bytes to store classes. \ No newline at end of file diff --git a/src/zarr/abc/store.py b/src/zarr/abc/store.py index 4b3edf78d1..25aaba4aa9 100644 --- a/src/zarr/abc/store.py +++ b/src/zarr/abc/store.py @@ -1,19 +1,24 @@ from __future__ import annotations +import asyncio +import json from abc import ABC, abstractmethod -from asyncio import gather from dataclasses import dataclass from itertools import starmap from typing import TYPE_CHECKING, Literal, Protocol, runtime_checkable +from zarr.core.buffer import Buffer, BufferPrototype +from zarr.core.buffer.core import default_buffer_prototype +from zarr.core.sync import sync + if TYPE_CHECKING: from collections.abc import AsyncGenerator, AsyncIterator, Iterable from types import TracebackType from typing import Any, Self, TypeAlias - from zarr.core.buffer import Buffer, BufferPrototype +__all__ = ["BufferLike", "ByteGetter", "ByteSetter", "Store", "set_or_delete"] -__all__ = ["ByteGetter", "ByteSetter", "Store", "set_or_delete"] +BufferLike = type[Buffer] | BufferPrototype @dataclass @@ -180,11 +185,17 @@ def __eq__(self, value: object) -> bool: """Equality comparison.""" ... + def _get_default_buffer_class(self) -> type[Buffer]: + """ + Get the default buffer class. + """ + return default_buffer_prototype().buffer + @abstractmethod async def get( self, key: str, - prototype: BufferPrototype, + prototype: BufferLike | None = None, byte_range: ByteRequest | None = None, ) -> Buffer | None: """Retrieve the value associated with a given key. @@ -192,8 +203,12 @@ async def get( Parameters ---------- key : str - prototype : BufferPrototype - The prototype of the output buffer. Stores may support a default buffer prototype. + prototype : BufferLike | None, optional + The prototype of the output buffer. + Can be either a Buffer class or an instance of `BufferPrototype`, in which the + `buffer` attribute will be used. + If `None`, the default buffer class for this store will be retrieved via the + ``_get_default_buffer_class`` method. byte_range : ByteRequest, optional ByteRequest may be one of the following. If not provided, all data associated with the key is retrieved. - RangeByteRequest(int, int): Request a specific range of bytes in the form (start, end). The end is exclusive. If the given range is zero-length or starts after the end of the object, an error will be returned. Additionally, if the range ends after the end of the object, the entire remainder of the object will be returned. Otherwise, the exact requested range will be returned. @@ -206,18 +221,259 @@ async def get( """ ... + async def get_bytes( + self, + key: str, + *, + prototype: BufferLike | None = None, + byte_range: ByteRequest | None = None, + ) -> bytes: + """ + Retrieve raw bytes from the store asynchronously. + + This is a convenience method that wraps ``get()`` and converts the result + to bytes. Use this when you need the raw byte content of a stored value. + + Parameters + ---------- + key : str + The key identifying the data to retrieve. + prototype : BufferLike | None, optional + The prototype of the output buffer. + Can be either a Buffer class or an instance of `BufferPrototype`, in which the + `buffer` attribute will be used. + If `None`, the default buffer prototype for this store will be retrieved via the + ``_get_default_buffer_class`` method. + byte_range : ByteRequest, optional + If specified, only retrieve a portion of the stored data. + Can be a ``RangeByteRequest``, ``OffsetByteRequest``, or ``SuffixByteRequest``. + + Returns + ------- + bytes + The raw bytes stored at the given key. + + Raises + ------ + FileNotFoundError + If the key does not exist in the store. + + See Also + -------- + get : Lower-level method that returns a Buffer object. + get_bytes : Synchronous version of this method. + get_json : Asynchronous method for retrieving and parsing JSON data. + + Examples + -------- + >>> store = await MemoryStore.open() + >>> await store.set("data", Buffer.from_bytes(b"hello world")) + >>> data = await store.get_bytes("data", prototype=default_buffer_prototype()) + >>> print(data) + b'hello world' + """ + buffer = await self.get(key, prototype, byte_range) + if buffer is None: + raise FileNotFoundError(key) + return buffer.to_bytes() + + def get_bytes_sync( + self, + key: str = "", + *, + prototype: BufferLike | None = None, + byte_range: ByteRequest | None = None, + ) -> bytes: + """ + Retrieve raw bytes from the store synchronously. + + This is a synchronous wrapper around ``get_bytes()``. It should only + be called from non-async code. For async contexts, use ``get_bytes()`` + instead. + + Parameters + ---------- + key : str, optional + The key identifying the data to retrieve. Defaults to an empty string. + prototype : BufferLike | None, optional + The prototype of the output buffer. + Can be either a Buffer class or an instance of `BufferPrototype`, in which the + `buffer` attribute will be used. + If `None`, the default buffer prototype for this store will be retrieved via the + ``_get_default_buffer_class`` method. + byte_range : ByteRequest, optional + If specified, only retrieve a portion of the stored data. + Can be a ``RangeByteRequest``, ``OffsetByteRequest``, or ``SuffixByteRequest``. + + Returns + ------- + bytes + The raw bytes stored at the given key. + + Raises + ------ + FileNotFoundError + If the key does not exist in the store. + + Warnings + -------- + Do not call this method from async functions. Use ``get_bytes()`` instead + to avoid blocking the event loop. + + See Also + -------- + get_bytes : Asynchronous version of this method. + get_json_sync : Synchronous method for retrieving and parsing JSON data. + + Examples + -------- + >>> store = MemoryStore() + >>> await store.set("data", Buffer.from_bytes(b"hello world")) + >>> data = store.get_bytes_sync("data", prototype=default_buffer_prototype()) + >>> print(data) + b'hello world' + """ + + return sync(self.get_bytes(key, prototype=prototype, byte_range=byte_range)) + + async def get_json( + self, + key: str, + *, + prototype: BufferLike | None = None, + byte_range: ByteRequest | None = None, + ) -> Any: + """ + Retrieve and parse JSON data from the store asynchronously. + + This is a convenience method that retrieves bytes from the store and + parses them as JSON. + + Parameters + ---------- + key : str + The key identifying the JSON data to retrieve. + prototype : BufferLike | None, optional + The prototype of the output buffer. + Can be either a Buffer class or an instance of `BufferPrototype`, in which the + `buffer` attribute will be used. + If `None`, the default buffer prototype for this store will be retrieved via the + ``_get_default_buffer_class`` method. + byte_range : ByteRequest, optional + If specified, only retrieve a portion of the stored data. + Can be a ``RangeByteRequest``, ``OffsetByteRequest``, or ``SuffixByteRequest``. + Note: Using byte ranges with JSON may result in invalid JSON. + + Returns + ------- + Any + The parsed JSON data. This follows the behavior of ``json.loads()`` and + can be any JSON-serializable type: dict, list, str, int, float, bool, or None. + + Raises + ------ + FileNotFoundError + If the key does not exist in the store. + json.JSONDecodeError + If the stored data is not valid JSON. + + See Also + -------- + get_bytes : Method for retrieving raw bytes. + get_json_sync : Synchronous version of this method. + + Examples + -------- + >>> store = await MemoryStore.open() + >>> metadata = {"zarr_format": 3, "node_type": "array"} + >>> await store.set("zarr.json", Buffer.from_bytes(json.dumps(metadata).encode())) + >>> data = await store.get_json("zarr.json", prototype=default_buffer_prototype()) + >>> print(data) + {'zarr_format': 3, 'node_type': 'array'} + """ + + return json.loads(await self.get_bytes(key, prototype=prototype, byte_range=byte_range)) + + def get_json_sync( + self, + key: str = "", + *, + prototype: BufferLike | None = None, + byte_range: ByteRequest | None = None, + ) -> Any: + """ + Retrieve and parse JSON data from the store synchronously. + + This is a synchronous wrapper around ``get_json()``. It should only + be called from non-async code. For async contexts, use ``get_json()`` + instead. + + Parameters + ---------- + key : str, optional + The key identifying the JSON data to retrieve. Defaults to an empty string. + prototype : BufferLike | None, optional + The prototype of the output buffer. + Can be either a Buffer class or an instance of `BufferPrototype`, in which the + `buffer` attribute will be used. + If `None`, the default buffer prototype for this store will be retrieved via the + ``_get_default_buffer_class`` method. + byte_range : ByteRequest, optional + If specified, only retrieve a portion of the stored data. + Can be a ``RangeByteRequest``, ``OffsetByteRequest``, or ``SuffixByteRequest``. + Note: Using byte ranges with JSON may result in invalid JSON. + + Returns + ------- + Any + The parsed JSON data. This follows the behavior of ``json.loads()`` and + can be any JSON-serializable type: dict, list, str, int, float, bool, or None. + + Raises + ------ + FileNotFoundError + If the key does not exist in the store. + json.JSONDecodeError + If the stored data is not valid JSON. + + Warnings + -------- + Do not call this method from async functions. Use ``get_json()`` instead + to avoid blocking the event loop. + + See Also + -------- + get_json : Asynchronous version of this method. + get_bytes_sync : Synchronous method for retrieving raw bytes without parsing. + + Examples + -------- + >>> store = MemoryStore() + >>> metadata = {"zarr_format": 3, "node_type": "array"} + >>> store.set("zarr.json", Buffer.from_bytes(json.dumps(metadata).encode())) + >>> data = store.get_json_sync("zarr.json", prototype=default_buffer_prototype()) + >>> print(data) + {'zarr_format': 3, 'node_type': 'array'} + """ + + return sync(self.get_json(key, prototype=prototype, byte_range=byte_range)) + @abstractmethod async def get_partial_values( self, - prototype: BufferPrototype, + prototype: BufferLike | None, key_ranges: Iterable[tuple[str, ByteRequest | None]], ) -> list[Buffer | None]: """Retrieve possibly partial values from given key_ranges. Parameters ---------- - prototype : BufferPrototype - The prototype of the output buffer. Stores may support a default buffer prototype. + prototype : BufferLike | None + The prototype of the output buffer. + Can be either a Buffer class or an instance of `BufferPrototype`, in which the + `buffer` attribute will be used. + If `None`, the default buffer prototype for this store will be retrieved via the + ``_get_default_buffer_class`` method. key_ranges : Iterable[tuple[str, tuple[int | None, int | None]]] Ordered set of key, range pairs, a key may occur multiple times with different ranges @@ -278,7 +534,7 @@ async def _set_many(self, values: Iterable[tuple[str, Buffer]]) -> None: """ Insert multiple (key, value) pairs into storage. """ - await gather(*starmap(self.set, values)) + await asyncio.gather(*starmap(self.set, values)) @property def supports_consolidated_metadata(self) -> bool: @@ -389,7 +645,7 @@ def close(self) -> None: self._is_open = False async def _get_many( - self, requests: Iterable[tuple[str, BufferPrototype, ByteRequest | None]] + self, requests: Iterable[tuple[str, BufferLike | None, ByteRequest | None]] ) -> AsyncGenerator[tuple[str, Buffer | None], None]: """ Retrieve a collection of objects from storage. In general this method does not guarantee diff --git a/src/zarr/experimental/cache_store.py b/src/zarr/experimental/cache_store.py index 3456c94320..e696e0eb0f 100644 --- a/src/zarr/experimental/cache_store.py +++ b/src/zarr/experimental/cache_store.py @@ -6,13 +6,13 @@ from collections import OrderedDict from typing import TYPE_CHECKING, Any, Literal -from zarr.abc.store import ByteRequest, Store +from zarr.abc.store import BufferLike, ByteRequest, Store from zarr.storage._wrapper import WrapperStore logger = logging.getLogger(__name__) if TYPE_CHECKING: - from zarr.core.buffer.core import Buffer, BufferPrototype + from zarr.core.buffer.core import Buffer class CacheStore(WrapperStore[Store]): @@ -218,7 +218,7 @@ def _remove_from_tracking(self, key: str) -> None: self._key_sizes.pop(key, None) async def _get_try_cache( - self, key: str, prototype: BufferPrototype, byte_range: ByteRequest | None = None + self, key: str, prototype: BufferLike | None, byte_range: ByteRequest | None = None ) -> Buffer | None: """Try to get data from cache first, falling back to source store.""" maybe_cached_result = await self._cache.get(key, prototype, byte_range) @@ -246,7 +246,7 @@ async def _get_try_cache( return maybe_fresh_result async def _get_no_cache( - self, key: str, prototype: BufferPrototype, byte_range: ByteRequest | None = None + self, key: str, prototype: BufferLike | None, byte_range: ByteRequest | None = None ) -> Buffer | None: """Get data directly from source store and update cache.""" self._misses += 1 @@ -265,7 +265,7 @@ async def _get_no_cache( async def get( self, key: str, - prototype: BufferPrototype, + prototype: BufferLike | None = None, byte_range: ByteRequest | None = None, ) -> Buffer | None: """ @@ -275,8 +275,12 @@ async def get( ---------- key : str The key to retrieve - prototype : BufferPrototype - Buffer prototype for creating the result buffer + prototype : BufferLike | None, optional + The prototype of the output buffer. + Can be either a Buffer class or an instance of `BufferPrototype`, in which the + `buffer` attribute will be used. + If `None`, the default buffer class for this store will be retrieved via the + ``_get_default_buffer_class`` method. byte_range : ByteRequest, optional Byte range to retrieve diff --git a/src/zarr/storage/_common.py b/src/zarr/storage/_common.py index 4bea04f024..e381c65839 100644 --- a/src/zarr/storage/_common.py +++ b/src/zarr/storage/_common.py @@ -3,10 +3,10 @@ import importlib.util import json from pathlib import Path -from typing import TYPE_CHECKING, Any, Literal, Self, TypeAlias +from typing import Any, Literal, Self, TypeAlias -from zarr.abc.store import ByteRequest, Store -from zarr.core.buffer import Buffer, default_buffer_prototype +from zarr.abc.store import BufferLike, ByteRequest, Store +from zarr.core.buffer import Buffer from zarr.core.common import ( ANY_ACCESS_MODE, ZARR_JSON, @@ -26,9 +26,6 @@ else: FSMap = None -if TYPE_CHECKING: - from zarr.core.buffer import BufferPrototype - def _dereference_path(root: str, path: str) -> str: if not isinstance(root, str): @@ -145,7 +142,7 @@ async def open(cls, store: Store, path: str, mode: AccessModeLiteral | None = No async def get( self, - prototype: BufferPrototype | None = None, + prototype: BufferLike | None = None, byte_range: ByteRequest | None = None, ) -> Buffer | None: """ @@ -153,8 +150,12 @@ async def get( Parameters ---------- - prototype : BufferPrototype, optional - The buffer prototype to use when reading the bytes. + prototype : BufferLike | None, optional + The prototype of the output buffer. + Can be either a Buffer class or an instance of `BufferPrototype`, in which the + `buffer` attribute will be used. + If `None`, the default buffer class for this store will be retrieved via the + store's ``_get_default_buffer_class`` method. byte_range : ByteRequest, optional The range of bytes to read. @@ -164,7 +165,7 @@ async def get( The read bytes, or None if the key does not exist. """ if prototype is None: - prototype = default_buffer_prototype() + prototype = self.store._get_default_buffer_class() return await self.store.get(self.path, prototype=prototype, byte_range=byte_range) async def set(self, value: Buffer) -> None: diff --git a/src/zarr/storage/_fsspec.py b/src/zarr/storage/_fsspec.py index f9e4ed375d..b16712c786 100644 --- a/src/zarr/storage/_fsspec.py +++ b/src/zarr/storage/_fsspec.py @@ -8,13 +8,14 @@ from packaging.version import parse as parse_version from zarr.abc.store import ( + BufferLike, ByteRequest, OffsetByteRequest, RangeByteRequest, Store, SuffixByteRequest, ) -from zarr.core.buffer import Buffer +from zarr.core.buffer import Buffer, BufferPrototype from zarr.errors import ZarrUserWarning from zarr.storage._common import _dereference_path @@ -25,8 +26,6 @@ from fsspec.asyn import AsyncFileSystem from fsspec.mapping import FSMap - from zarr.core.buffer import BufferPrototype - ALLOWED_EXCEPTIONS: tuple[type[Exception], ...] = ( FileNotFoundError, @@ -276,19 +275,27 @@ def __eq__(self, other: object) -> bool: async def get( self, key: str, - prototype: BufferPrototype, + prototype: BufferLike | None = None, byte_range: ByteRequest | None = None, ) -> Buffer | None: # docstring inherited if not self._is_open: await self._open() + if prototype is None: + prototype = self._get_default_buffer_class() + # Extract buffer class from BufferLike + if isinstance(prototype, BufferPrototype): + buffer_cls = prototype.buffer + else: + buffer_cls = prototype + path = _dereference_path(self.path, key) try: if byte_range is None: - value = prototype.buffer.from_bytes(await self.fs._cat_file(path)) + value = buffer_cls.from_bytes(await self.fs._cat_file(path)) elif isinstance(byte_range, RangeByteRequest): - value = prototype.buffer.from_bytes( + value = buffer_cls.from_bytes( await self.fs._cat_file( path, start=byte_range.start, @@ -296,11 +303,11 @@ async def get( ) ) elif isinstance(byte_range, OffsetByteRequest): - value = prototype.buffer.from_bytes( + value = buffer_cls.from_bytes( await self.fs._cat_file(path, start=byte_range.offset, end=None) ) elif isinstance(byte_range, SuffixByteRequest): - value = prototype.buffer.from_bytes( + value = buffer_cls.from_bytes( await self.fs._cat_file(path, start=-byte_range.suffix, end=None) ) else: @@ -310,7 +317,7 @@ async def get( except OSError as e: if "not satisfiable" in str(e): # this is an s3-specific condition we probably don't want to leak - return prototype.buffer.from_bytes(b"") + return buffer_cls.from_bytes(b"") raise else: return value @@ -367,10 +374,18 @@ async def exists(self, key: str) -> bool: async def get_partial_values( self, - prototype: BufferPrototype, + prototype: BufferLike | None, key_ranges: Iterable[tuple[str, ByteRequest | None]], ) -> list[Buffer | None]: # docstring inherited + if prototype is None: + prototype = self._get_default_buffer_class() + # Extract buffer class from BufferLike + if isinstance(prototype, BufferPrototype): + buffer_cls = prototype.buffer + else: + buffer_cls = prototype + if key_ranges: # _cat_ranges expects a list of paths, start, and end ranges, so we need to reformat each ByteRequest. key_ranges = list(key_ranges) @@ -403,7 +418,7 @@ async def get_partial_values( if isinstance(r, Exception) and not isinstance(r, self.allowed_exceptions): raise r - return [None if isinstance(r, Exception) else prototype.buffer.from_bytes(r) for r in res] + return [None if isinstance(r, Exception) else buffer_cls.from_bytes(r) for r in res] async def list(self) -> AsyncIterator[str]: # docstring inherited diff --git a/src/zarr/storage/_local.py b/src/zarr/storage/_local.py index f64da71bb4..f991765723 100644 --- a/src/zarr/storage/_local.py +++ b/src/zarr/storage/_local.py @@ -8,40 +8,44 @@ import sys import uuid from pathlib import Path -from typing import TYPE_CHECKING, BinaryIO, Literal, Self +from typing import TYPE_CHECKING, Any, BinaryIO, Literal, Self from zarr.abc.store import ( + BufferLike, ByteRequest, OffsetByteRequest, RangeByteRequest, Store, SuffixByteRequest, ) -from zarr.core.buffer import Buffer -from zarr.core.buffer.core import default_buffer_prototype +from zarr.core.buffer import Buffer, BufferPrototype from zarr.core.common import AccessModeLiteral, concurrent_map if TYPE_CHECKING: from collections.abc import AsyncIterator, Iterable, Iterator - from zarr.core.buffer import BufferPrototype +def _get(path: Path, prototype: BufferLike, byte_range: ByteRequest | None) -> Buffer: + # Extract buffer class from BufferLike + if isinstance(prototype, BufferPrototype): + buffer_cls = prototype.buffer + else: + buffer_cls = prototype -def _get(path: Path, prototype: BufferPrototype, byte_range: ByteRequest | None) -> Buffer: if byte_range is None: - return prototype.buffer.from_bytes(path.read_bytes()) + return buffer_cls.from_bytes(path.read_bytes()) with path.open("rb") as f: size = f.seek(0, io.SEEK_END) if isinstance(byte_range, RangeByteRequest): f.seek(byte_range.start) - return prototype.buffer.from_bytes(f.read(byte_range.end - f.tell())) + return buffer_cls.from_bytes(f.read(byte_range.end - f.tell())) elif isinstance(byte_range, OffsetByteRequest): f.seek(byte_range.offset) elif isinstance(byte_range, SuffixByteRequest): f.seek(max(0, size - byte_range.suffix)) else: raise TypeError(f"Unexpected byte_range, got {byte_range}.") - return prototype.buffer.from_bytes(f.read()) + return buffer_cls.from_bytes(f.read()) if sys.platform == "win32": @@ -190,12 +194,12 @@ def __eq__(self, other: object) -> bool: async def get( self, key: str, - prototype: BufferPrototype | None = None, + prototype: BufferLike | None = None, byte_range: ByteRequest | None = None, ) -> Buffer | None: # docstring inherited if prototype is None: - prototype = default_buffer_prototype() + prototype = self._get_default_buffer_class() if not self._is_open: await self._open() assert isinstance(key, str) @@ -208,10 +212,12 @@ async def get( async def get_partial_values( self, - prototype: BufferPrototype, + prototype: BufferLike | None, key_ranges: Iterable[tuple[str, ByteRequest | None]], ) -> list[Buffer | None]: # docstring inherited + if prototype is None: + prototype = self._get_default_buffer_class() args = [] for key, byte_range in key_ranges: assert isinstance(key, str) @@ -306,6 +312,236 @@ async def list_dir(self, prefix: str) -> AsyncIterator[str]: except (FileNotFoundError, NotADirectoryError): pass + async def get_bytes( + self, + key: str = "", + *, + prototype: BufferLike | None = None, + byte_range: ByteRequest | None = None, + ) -> bytes: + """ + Retrieve raw bytes from the local store asynchronously. + + This is a convenience override that makes the ``prototype`` parameter optional + by defaulting to the standard buffer prototype. See the base ``Store.get_bytes`` + for full documentation. + + Parameters + ---------- + key : str, optional + The key identifying the data to retrieve. Defaults to an empty string. + prototype : BufferPrototype, optional + The buffer prototype to use for reading the data. If None, uses + ``default_buffer_prototype()``. + byte_range : ByteRequest, optional + If specified, only retrieve a portion of the stored data. + + Returns + ------- + bytes + The raw bytes stored at the given key. + + Raises + ------ + FileNotFoundError + If the key does not exist in the store. + + See Also + -------- + Store.get_bytes : Base implementation with full documentation. + get_bytes_sync : Synchronous version of this method. + + Examples + -------- + >>> store = await LocalStore.open("data") + >>> await store.set("data", Buffer.from_bytes(b"hello")) + >>> # No need to specify prototype for LocalStore + >>> data = await store.get_bytes("data") + >>> print(data) + b'hello' + """ + if prototype is None: + prototype = self._get_default_buffer_class() + return await super().get_bytes(key, prototype=prototype, byte_range=byte_range) + + def get_bytes_sync( + self, + key: str = "", + *, + prototype: BufferLike | None = None, + byte_range: ByteRequest | None = None, + ) -> bytes: + """ + Retrieve raw bytes from the local store synchronously. + + This is a convenience override that makes the ``prototype`` parameter optional + by defaulting to the standard buffer prototype. See the base ``Store.get_bytes`` + for full documentation. + + Parameters + ---------- + key : str, optional + The key identifying the data to retrieve. Defaults to an empty string. + prototype : BufferPrototype, optional + The buffer prototype to use for reading the data. If None, uses + ``default_buffer_prototype()``. + byte_range : ByteRequest, optional + If specified, only retrieve a portion of the stored data. + + Returns + ------- + bytes + The raw bytes stored at the given key. + + Raises + ------ + FileNotFoundError + If the key does not exist in the store. + + Warnings + -------- + Do not call this method from async functions. Use ``get_bytes()`` instead. + + See Also + -------- + Store.get_bytes_sync : Base implementation with full documentation. + get_bytes : Asynchronous version of this method. + + Examples + -------- + >>> store = LocalStore("data") + >>> store.set("data", Buffer.from_bytes(b"hello")) + >>> # No need to specify prototype for LocalStore + >>> data = store.get_bytes("data") + >>> print(data) + b'hello' + """ + if prototype is None: + prototype = self._get_default_buffer_class() + return super().get_bytes_sync(key, prototype=prototype, byte_range=byte_range) + + async def get_json( + self, + key: str = "", + *, + prototype: BufferLike | None = None, + byte_range: ByteRequest | None = None, + ) -> Any: + """ + Retrieve and parse JSON data from the local store asynchronously. + + This is a convenience override that makes the ``prototype`` parameter optional + by defaulting to the standard buffer prototype. See the base ``Store.get_json`` + for full documentation. + + Parameters + ---------- + key : str, optional + The key identifying the JSON data to retrieve. Defaults to an empty string. + prototype : BufferPrototype, optional + The buffer prototype to use for reading the data. If None, uses + ``default_buffer_prototype()``. + byte_range : ByteRequest, optional + If specified, only retrieve a portion of the stored data. + Note: Using byte ranges with JSON may result in invalid JSON. + + Returns + ------- + Any + The parsed JSON data. This follows the behavior of ``json.loads()`` and + can be any JSON-serializable type: dict, list, str, int, float, bool, or None. + + Raises + ------ + FileNotFoundError + If the key does not exist in the store. + json.JSONDecodeError + If the stored data is not valid JSON. + + See Also + -------- + Store.get_json : Base implementation with full documentation. + get_json_sync : Synchronous version of this method. + get_bytes : Method for retrieving raw bytes without parsing. + + Examples + -------- + >>> store = await LocalStore.open("data") + >>> import json + >>> metadata = {"zarr_format": 3, "node_type": "array"} + >>> await store.set("zarr.json", Buffer.from_bytes(json.dumps(metadata).encode())) + >>> # No need to specify prototype for LocalStore + >>> data = await store.get_json("zarr.json") + >>> print(data) + {'zarr_format': 3, 'node_type': 'array'} + """ + if prototype is None: + prototype = self._get_default_buffer_class() + return await super().get_json(key, prototype=prototype, byte_range=byte_range) + + def get_json_sync( + self, + key: str = "", + *, + prototype: BufferLike | None = None, + byte_range: ByteRequest | None = None, + ) -> Any: + """ + Retrieve and parse JSON data from the local store synchronously. + + This is a convenience override that makes the ``prototype`` parameter optional + by defaulting to the standard buffer prototype. See the base ``Store.get_json`` + for full documentation. + + Parameters + ---------- + key : str, optional + The key identifying the JSON data to retrieve. Defaults to an empty string. + prototype : BufferPrototype, optional + The buffer prototype to use for reading the data. If None, uses + ``default_buffer_prototype()``. + byte_range : ByteRequest, optional + If specified, only retrieve a portion of the stored data. + Note: Using byte ranges with JSON may result in invalid JSON. + + Returns + ------- + Any + The parsed JSON data. This follows the behavior of ``json.loads()`` and + can be any JSON-serializable type: dict, list, str, int, float, bool, or None. + + Raises + ------ + FileNotFoundError + If the key does not exist in the store. + json.JSONDecodeError + If the stored data is not valid JSON. + + Warnings + -------- + Do not call this method from async functions. Use ``get_json()`` instead. + + See Also + -------- + Store.get_json_sync : Base implementation with full documentation. + get_json : Asynchronous version of this method. + get_bytes_sync : Method for retrieving raw bytes without parsing. + + Examples + -------- + >>> store = LocalStore("data") + >>> import json + >>> metadata = {"zarr_format": 3, "node_type": "array"} + >>> store.set("zarr.json", Buffer.from_bytes(json.dumps(metadata).encode())) + >>> # No need to specify prototype for LocalStore + >>> data = store.get_json("zarr.json") + >>> print(data) + {'zarr_format': 3, 'node_type': 'array'} + """ + if prototype is None: + prototype = self._get_default_buffer_class() + return super().get_json_sync(key, prototype=prototype, byte_range=byte_range) + async def move(self, dest_root: Path | str) -> None: """ Move the store to another path. The old root directory is deleted. diff --git a/src/zarr/storage/_logging.py b/src/zarr/storage/_logging.py index dd20d49ae5..7d82dac948 100644 --- a/src/zarr/storage/_logging.py +++ b/src/zarr/storage/_logging.py @@ -8,14 +8,14 @@ from contextlib import contextmanager from typing import TYPE_CHECKING, Any, Self, TypeVar -from zarr.abc.store import Store +from zarr.abc.store import BufferLike, Store from zarr.storage._wrapper import WrapperStore if TYPE_CHECKING: from collections.abc import AsyncGenerator, Generator, Iterable from zarr.abc.store import ByteRequest - from zarr.core.buffer import Buffer, BufferPrototype + from zarr.core.buffer import Buffer counter: defaultdict[str, int] @@ -165,7 +165,7 @@ def __eq__(self, other: object) -> bool: async def get( self, key: str, - prototype: BufferPrototype, + prototype: BufferLike | None = None, byte_range: ByteRequest | None = None, ) -> Buffer | None: # docstring inherited @@ -174,7 +174,7 @@ async def get( async def get_partial_values( self, - prototype: BufferPrototype, + prototype: BufferLike | None, key_ranges: Iterable[tuple[str, ByteRequest | None]], ) -> list[Buffer | None]: # docstring inherited diff --git a/src/zarr/storage/_memory.py b/src/zarr/storage/_memory.py index 904be922d7..c28dc910b4 100644 --- a/src/zarr/storage/_memory.py +++ b/src/zarr/storage/_memory.py @@ -1,19 +1,16 @@ from __future__ import annotations from logging import getLogger -from typing import TYPE_CHECKING, Self +from typing import TYPE_CHECKING, Any, Self -from zarr.abc.store import ByteRequest, Store -from zarr.core.buffer import Buffer, gpu -from zarr.core.buffer.core import default_buffer_prototype +from zarr.abc.store import BufferLike, ByteRequest, Store +from zarr.core.buffer import Buffer, BufferPrototype, gpu from zarr.core.common import concurrent_map from zarr.storage._utils import _normalize_byte_range_index if TYPE_CHECKING: from collections.abc import AsyncIterator, Iterable, MutableMapping - from zarr.core.buffer import BufferPrototype - logger = getLogger(__name__) @@ -80,25 +77,30 @@ def __eq__(self, other: object) -> bool: async def get( self, key: str, - prototype: BufferPrototype | None = None, + prototype: BufferLike | None = None, byte_range: ByteRequest | None = None, ) -> Buffer | None: # docstring inherited if prototype is None: - prototype = default_buffer_prototype() + prototype = self._get_default_buffer_class() + # Extract buffer class from BufferLike + if isinstance(prototype, BufferPrototype): + buffer_cls = prototype.buffer + else: + buffer_cls = prototype if not self._is_open: await self._open() assert isinstance(key, str) try: value = self._store_dict[key] start, stop = _normalize_byte_range_index(value, byte_range) - return prototype.buffer.from_buffer(value[start:stop]) + return buffer_cls.from_buffer(value[start:stop]) except KeyError: return None async def get_partial_values( self, - prototype: BufferPrototype, + prototype: BufferLike | None, key_ranges: Iterable[tuple[str, ByteRequest | None]], ) -> list[Buffer | None]: # docstring inherited @@ -175,6 +177,236 @@ async def list_dir(self, prefix: str) -> AsyncIterator[str]: for key in keys_unique: yield key + async def get_bytes( + self, + key: str = "", + *, + prototype: BufferLike | None = None, + byte_range: ByteRequest | None = None, + ) -> bytes: + """ + Retrieve raw bytes from the memory store asynchronously. + + This is a convenience override that makes the ``prototype`` parameter optional + by defaulting to the standard buffer prototype. See the base ``Store.get_bytes`` + for full documentation. + + Parameters + ---------- + key : str, optional + The key identifying the data to retrieve. Defaults to an empty string. + prototype : BufferPrototype, optional + The buffer prototype to use for reading the data. If None, uses + ``default_buffer_prototype()``. + byte_range : ByteRequest, optional + If specified, only retrieve a portion of the stored data. + + Returns + ------- + bytes + The raw bytes stored at the given key. + + Raises + ------ + FileNotFoundError + If the key does not exist in the store. + + See Also + -------- + Store.get_bytes : Base implementation with full documentation. + get_bytes_sync : Synchronous version of this method. + + Examples + -------- + >>> store = await MemoryStore.open() + >>> await store.set("data", Buffer.from_bytes(b"hello")) + >>> # No need to specify prototype for MemoryStore + >>> data = await store.get_bytes("data") + >>> print(data) + b'hello' + """ + if prototype is None: + prototype = self._get_default_buffer_class() + return await super().get_bytes(key, prototype=prototype, byte_range=byte_range) + + def get_bytes_sync( + self, + key: str = "", + *, + prototype: BufferLike | None = None, + byte_range: ByteRequest | None = None, + ) -> bytes: + """ + Retrieve raw bytes from the memory store synchronously. + + This is a convenience override that makes the ``prototype`` parameter optional + by defaulting to the standard buffer prototype. See the base ``Store.get_bytes`` + for full documentation. + + Parameters + ---------- + key : str, optional + The key identifying the data to retrieve. Defaults to an empty string. + prototype : BufferPrototype, optional + The buffer prototype to use for reading the data. If None, uses + ``default_buffer_prototype()``. + byte_range : ByteRequest, optional + If specified, only retrieve a portion of the stored data. + + Returns + ------- + bytes + The raw bytes stored at the given key. + + Raises + ------ + FileNotFoundError + If the key does not exist in the store. + + Warnings + -------- + Do not call this method from async functions. Use ``get_bytes()`` instead. + + See Also + -------- + Store.get_bytes_sync : Base implementation with full documentation. + get_bytes : Asynchronous version of this method. + + Examples + -------- + >>> store = MemoryStore() + >>> store.set("data", Buffer.from_bytes(b"hello")) + >>> # No need to specify prototype for MemoryStore + >>> data = store.get_bytes("data") + >>> print(data) + b'hello' + """ + if prototype is None: + prototype = self._get_default_buffer_class() + return super().get_bytes_sync(key, prototype=prototype, byte_range=byte_range) + + async def get_json( + self, + key: str = "", + *, + prototype: BufferLike | None = None, + byte_range: ByteRequest | None = None, + ) -> Any: + """ + Retrieve and parse JSON data from the memory store asynchronously. + + This is a convenience override that makes the ``prototype`` parameter optional + by defaulting to the standard buffer prototype. See the base ``Store.get_json`` + for full documentation. + + Parameters + ---------- + key : str, optional + The key identifying the JSON data to retrieve. Defaults to an empty string. + prototype : BufferPrototype, optional + The buffer prototype to use for reading the data. If None, uses + ``default_buffer_prototype()``. + byte_range : ByteRequest, optional + If specified, only retrieve a portion of the stored data. + Note: Using byte ranges with JSON may result in invalid JSON. + + Returns + ------- + Any + The parsed JSON data. This follows the behavior of ``json.loads()`` and + can be any JSON-serializable type: dict, list, str, int, float, bool, or None. + + Raises + ------ + FileNotFoundError + If the key does not exist in the store. + json.JSONDecodeError + If the stored data is not valid JSON. + + See Also + -------- + Store.get_json : Base implementation with full documentation. + get_json_sync : Synchronous version of this method. + get_bytes : Method for retrieving raw bytes without parsing. + + Examples + -------- + >>> store = await MemoryStore.open() + >>> import json + >>> metadata = {"zarr_format": 3, "node_type": "array"} + >>> await store.set("zarr.json", Buffer.from_bytes(json.dumps(metadata).encode())) + >>> # No need to specify prototype for MemoryStore + >>> data = await store.get_json("zarr.json") + >>> print(data) + {'zarr_format': 3, 'node_type': 'array'} + """ + if prototype is None: + prototype = self._get_default_buffer_class() + return await super().get_json(key, prototype=prototype, byte_range=byte_range) + + def get_json_sync( + self, + key: str = "", + *, + prototype: BufferLike | None = None, + byte_range: ByteRequest | None = None, + ) -> Any: + """ + Retrieve and parse JSON data from the memory store synchronously. + + This is a convenience override that makes the ``prototype`` parameter optional + by defaulting to the standard buffer prototype. See the base ``Store.get_json`` + for full documentation. + + Parameters + ---------- + key : str, optional + The key identifying the JSON data to retrieve. Defaults to an empty string. + prototype : BufferPrototype, optional + The buffer prototype to use for reading the data. If None, uses + ``default_buffer_prototype()``. + byte_range : ByteRequest, optional + If specified, only retrieve a portion of the stored data. + Note: Using byte ranges with JSON may result in invalid JSON. + + Returns + ------- + Any + The parsed JSON data. This follows the behavior of ``json.loads()`` and + can be any JSON-serializable type: dict, list, str, int, float, bool, or None. + + Raises + ------ + FileNotFoundError + If the key does not exist in the store. + json.JSONDecodeError + If the stored data is not valid JSON. + + Warnings + -------- + Do not call this method from async functions. Use ``get_json()`` instead. + + See Also + -------- + Store.get_json_sync : Base implementation with full documentation. + get_json : Asynchronous version of this method. + get_bytes_sync : Method for retrieving raw bytes without parsing. + + Examples + -------- + >>> store = MemoryStore() + >>> import json + >>> metadata = {"zarr_format": 3, "node_type": "array"} + >>> store.set("zarr.json", Buffer.from_bytes(json.dumps(metadata).encode())) + >>> # No need to specify prototype for MemoryStore + >>> data = store.get_json("zarr.json") + >>> print(data) + {'zarr_format': 3, 'node_type': 'array'} + """ + if prototype is None: + prototype = self._get_default_buffer_class() + return super().get_json_sync(key, prototype=prototype, byte_range=byte_range) + class GpuMemoryStore(MemoryStore): """ diff --git a/src/zarr/storage/_obstore.py b/src/zarr/storage/_obstore.py index 5c2197ecf6..aff000afe9 100644 --- a/src/zarr/storage/_obstore.py +++ b/src/zarr/storage/_obstore.py @@ -7,12 +7,15 @@ from typing import TYPE_CHECKING, Generic, Self, TypedDict, TypeVar from zarr.abc.store import ( + BufferLike, ByteRequest, OffsetByteRequest, RangeByteRequest, Store, SuffixByteRequest, ) +from zarr.core.buffer import BufferPrototype +from zarr.core.buffer.core import default_buffer_prototype from zarr.core.common import concurrent_map from zarr.core.config import config @@ -23,7 +26,7 @@ from obstore import ListResult, ListStream, ObjectMeta, OffsetRange, SuffixRange from obstore.store import ObjectStore as _UpstreamObjectStore - from zarr.core.buffer import Buffer, BufferPrototype + from zarr.core.buffer import Buffer __all__ = ["ObjectStore"] @@ -95,25 +98,33 @@ def __setstate__(self, state: dict[Any, Any]) -> None: self.__dict__.update(state) async def get( - self, key: str, prototype: BufferPrototype, byte_range: ByteRequest | None = None + self, key: str, prototype: BufferLike | None = None, byte_range: ByteRequest | None = None ) -> Buffer | None: # docstring inherited import obstore as obs + if prototype is None: + prototype = self._get_default_buffer_class() + # Extract buffer class from BufferLike + if isinstance(prototype, BufferPrototype): + buffer_cls = prototype.buffer + else: + buffer_cls = prototype + try: if byte_range is None: resp = await obs.get_async(self.store, key) - return prototype.buffer.from_bytes(await resp.bytes_async()) # type: ignore[arg-type] + return buffer_cls.from_bytes(await resp.bytes_async()) # type: ignore[arg-type] elif isinstance(byte_range, RangeByteRequest): bytes = await obs.get_range_async( self.store, key, start=byte_range.start, end=byte_range.end ) - return prototype.buffer.from_bytes(bytes) # type: ignore[arg-type] + return buffer_cls.from_bytes(bytes) # type: ignore[arg-type] elif isinstance(byte_range, OffsetByteRequest): resp = await obs.get_async( self.store, key, options={"range": {"offset": byte_range.offset}} ) - return prototype.buffer.from_bytes(await resp.bytes_async()) # type: ignore[arg-type] + return buffer_cls.from_bytes(await resp.bytes_async()) # type: ignore[arg-type] elif isinstance(byte_range, SuffixByteRequest): # some object stores (Azure) don't support suffix requests. In this # case, our workaround is to first get the length of the object and then @@ -122,7 +133,7 @@ async def get( resp = await obs.get_async( self.store, key, options={"range": {"suffix": byte_range.suffix}} ) - return prototype.buffer.from_bytes(await resp.bytes_async()) # type: ignore[arg-type] + return buffer_cls.from_bytes(await resp.bytes_async()) # type: ignore[arg-type] except obs.exceptions.NotSupportedError: head_resp = await obs.head_async(self.store, key) file_size = head_resp["size"] @@ -133,7 +144,7 @@ async def get( start=file_size - suffix_len, length=suffix_len, ) - return prototype.buffer.from_bytes(buffer) # type: ignore[arg-type] + return buffer_cls.from_bytes(buffer) # type: ignore[arg-type] else: raise ValueError(f"Unexpected byte_range, got {byte_range}") except _ALLOWED_EXCEPTIONS: @@ -141,10 +152,16 @@ async def get( async def get_partial_values( self, - prototype: BufferPrototype, + prototype: BufferLike | None, key_ranges: Iterable[tuple[str, ByteRequest | None]], ) -> list[Buffer | None]: # docstring inherited + if prototype is None: + prototype = self._get_default_buffer_class() + # Extract buffer class from BufferLike - _get_partial_values expects BufferPrototype + if not isinstance(prototype, BufferPrototype): + # Convert raw buffer class to BufferPrototype + prototype = default_buffer_prototype() return await _get_partial_values(self.store, prototype=prototype, key_ranges=key_ranges) async def exists(self, key: str) -> bool: diff --git a/src/zarr/storage/_wrapper.py b/src/zarr/storage/_wrapper.py index 64a5b2d83c..ca3609009e 100644 --- a/src/zarr/storage/_wrapper.py +++ b/src/zarr/storage/_wrapper.py @@ -9,9 +9,8 @@ from zarr.abc.buffer import Buffer from zarr.abc.store import ByteRequest - from zarr.core.buffer import BufferPrototype -from zarr.abc.store import Store +from zarr.abc.store import BufferLike, Store T_Store = TypeVar("T_Store", bound=Store) @@ -85,13 +84,13 @@ def __repr__(self) -> str: return f"WrapperStore({self._store.__class__.__name__}, '{self._store}')" async def get( - self, key: str, prototype: BufferPrototype, byte_range: ByteRequest | None = None + self, key: str, prototype: BufferLike | None = None, byte_range: ByteRequest | None = None ) -> Buffer | None: return await self._store.get(key, prototype, byte_range) async def get_partial_values( self, - prototype: BufferPrototype, + prototype: BufferLike | None, key_ranges: Iterable[tuple[str, ByteRequest | None]], ) -> list[Buffer | None]: return await self._store.get_partial_values(prototype, key_ranges) @@ -139,7 +138,7 @@ def close(self) -> None: self._store.close() async def _get_many( - self, requests: Iterable[tuple[str, BufferPrototype, ByteRequest | None]] + self, requests: Iterable[tuple[str, BufferLike | None, ByteRequest | None]] ) -> AsyncGenerator[tuple[str, Buffer | None], None]: async for req in self._store._get_many(requests): yield req diff --git a/src/zarr/storage/_zip.py b/src/zarr/storage/_zip.py index 72bf9e335a..0348eeedd8 100644 --- a/src/zarr/storage/_zip.py +++ b/src/zarr/storage/_zip.py @@ -9,6 +9,7 @@ from typing import TYPE_CHECKING, Any, Literal from zarr.abc.store import ( + BufferLike, ByteRequest, OffsetByteRequest, RangeByteRequest, @@ -146,19 +147,24 @@ def __eq__(self, other: object) -> bool: def _get( self, key: str, - prototype: BufferPrototype, + prototype: BufferLike, byte_range: ByteRequest | None = None, ) -> Buffer | None: if not self._is_open: self._sync_open() + # Extract buffer class from BufferLike + if isinstance(prototype, BufferPrototype): + buffer_cls = prototype.buffer + else: + buffer_cls = prototype # docstring inherited try: with self._zf.open(key) as f: # will raise KeyError if byte_range is None: - return prototype.buffer.from_bytes(f.read()) + return buffer_cls.from_bytes(f.read()) elif isinstance(byte_range, RangeByteRequest): f.seek(byte_range.start) - return prototype.buffer.from_bytes(f.read(byte_range.end - f.tell())) + return buffer_cls.from_bytes(f.read(byte_range.end - f.tell())) size = f.seek(0, os.SEEK_END) if isinstance(byte_range, OffsetByteRequest): f.seek(byte_range.offset) @@ -166,17 +172,19 @@ def _get( f.seek(max(0, size - byte_range.suffix)) else: raise TypeError(f"Unexpected byte_range, got {byte_range}.") - return prototype.buffer.from_bytes(f.read()) + return buffer_cls.from_bytes(f.read()) except KeyError: return None async def get( self, key: str, - prototype: BufferPrototype, + prototype: BufferLike | None = None, byte_range: ByteRequest | None = None, ) -> Buffer | None: # docstring inherited + if prototype is None: + prototype = self._get_default_buffer_class() assert isinstance(key, str) with self._lock: @@ -184,10 +192,12 @@ async def get( async def get_partial_values( self, - prototype: BufferPrototype, + prototype: BufferLike | None, key_ranges: Iterable[tuple[str, ByteRequest | None]], ) -> list[Buffer | None]: # docstring inherited + if prototype is None: + prototype = self._get_default_buffer_class() out = [] with self._lock: for key, byte_range in key_ranges: diff --git a/src/zarr/testing/store.py b/src/zarr/testing/store.py index ad3b80da41..55e3687f20 100644 --- a/src/zarr/testing/store.py +++ b/src/zarr/testing/store.py @@ -1,6 +1,7 @@ from __future__ import annotations import asyncio +import json import pickle from abc import abstractmethod from typing import TYPE_CHECKING, Generic, TypeVar @@ -11,19 +12,19 @@ from typing import Any from zarr.abc.store import ByteRequest - from zarr.core.buffer.core import BufferPrototype import pytest from zarr.abc.store import ( + BufferLike, ByteRequest, OffsetByteRequest, RangeByteRequest, Store, SuffixByteRequest, ) -from zarr.core.buffer import Buffer, default_buffer_prototype -from zarr.core.sync import _collect_aiterator +from zarr.core.buffer import Buffer, cpu, default_buffer_prototype +from zarr.core.sync import _collect_aiterator, sync from zarr.storage._utils import _normalize_byte_range_index from zarr.testing.utils import assert_bytes_equal @@ -201,6 +202,15 @@ async def test_with_read_only_store(self, open_kwargs: dict[str, Any]) -> None: ): await reader.delete("foo") + @pytest.mark.parametrize( + "prototype", + [ + None, # Should use store's default buffer class + default_buffer_prototype(), # BufferPrototype instance + default_buffer_prototype().buffer, # Raw Buffer class (cpu.Buffer) + ], + ids=["prototype=None", "prototype=BufferPrototype", "prototype=Buffer"], + ) @pytest.mark.parametrize("key", ["c/0", "foo/c/0.0", "foo/0/0"]) @pytest.mark.parametrize( ("data", "byte_range"), @@ -212,13 +222,15 @@ async def test_with_read_only_store(self, open_kwargs: dict[str, Any]) -> None: (b"", None), ], ) - async def test_get(self, store: S, key: str, data: bytes, byte_range: ByteRequest) -> None: + async def test_get( + self, store: S, key: str, data: bytes, byte_range: ByteRequest, prototype: BufferLike | None + ) -> None: """ Ensure that data can be read from the store using the store.get method. """ data_buf = self.buffer_cls.from_bytes(data) await self.set(store, key, data_buf) - observed = await store.get(key, prototype=default_buffer_prototype(), byte_range=byte_range) + observed = await store.get(key, prototype=prototype, byte_range=byte_range) start, stop = _normalize_byte_range_index(data_buf, byte_range=byte_range) expected = data_buf[start:stop] assert_bytes_equal(observed, expected) @@ -331,6 +343,15 @@ async def test_set_many(self, store: S) -> None: for k, v in store_dict.items(): assert (await self.get(store, k)).to_bytes() == v.to_bytes() + @pytest.mark.parametrize( + "prototype", + [ + None, # Should use store's default buffer class + default_buffer_prototype(), # BufferPrototype instance + default_buffer_prototype().buffer, # Raw Buffer class (cpu.Buffer) + ], + ids=["prototype=None", "prototype=BufferPrototype", "prototype=Buffer"], + ) @pytest.mark.parametrize( "key_ranges", [ @@ -345,16 +366,14 @@ async def test_set_many(self, store: S) -> None: ], ) async def test_get_partial_values( - self, store: S, key_ranges: list[tuple[str, ByteRequest]] + self, store: S, key_ranges: list[tuple[str, ByteRequest]], prototype: BufferLike | None ) -> None: # put all of the data for key, _ in key_ranges: await self.set(store, key, self.buffer_cls.from_bytes(bytes(key, encoding="utf-8"))) # read back just part of it - observed_maybe = await store.get_partial_values( - prototype=default_buffer_prototype(), key_ranges=key_ranges - ) + observed_maybe = await store.get_partial_values(prototype=prototype, key_ranges=key_ranges) observed: list[Buffer] = [] expected: list[Buffer] = [] @@ -365,9 +384,7 @@ async def test_get_partial_values( for idx in range(len(observed)): key, byte_range = key_ranges[idx] - result = await store.get( - key, prototype=default_buffer_prototype(), byte_range=byte_range - ) + result = await store.get(key, prototype=cpu.Buffer, byte_range=byte_range) assert result is not None expected.append(result) @@ -526,6 +543,46 @@ async def test_set_if_not_exists(self, store: S) -> None: result = await store.get("k2", default_buffer_prototype()) assert result == new + async def test_get_bytes(self, store: S) -> None: + """ + Test that the get_bytes method reads bytes. + """ + data = b"hello world" + key = "zarr.json" + await self.set(store, key, self.buffer_cls.from_bytes(data)) + assert await store.get_bytes(key, prototype=default_buffer_prototype()) == data + with pytest.raises(FileNotFoundError): + await store.get_bytes("nonexistent_key", prototype=default_buffer_prototype()) + + def test_get_bytes_sync(self, store: S) -> None: + """ + Test that the get_bytes_sync method reads bytes. + """ + data = b"hello world" + key = "zarr.json" + sync(self.set(store, key, self.buffer_cls.from_bytes(data))) + assert store.get_bytes_sync(key, prototype=default_buffer_prototype()) == data + + async def test_get_json(self, store: S) -> None: + """ + Test that the get_json method reads json. + """ + data = {"foo": "bar"} + data_bytes = json.dumps(data).encode("utf-8") + key = "zarr.json" + await self.set(store, key, self.buffer_cls.from_bytes(data_bytes)) + assert await store.get_json(key, prototype=default_buffer_prototype()) == data + + def test_get_json_sync(self, store: S) -> None: + """ + Test that the get_json method reads json. + """ + data = {"foo": "bar"} + data_bytes = json.dumps(data).encode("utf-8") + key = "zarr.json" + sync(self.set(store, key, self.buffer_cls.from_bytes(data_bytes))) + assert store.get_json_sync(key, prototype=default_buffer_prototype()) == data + class LatencyStore(WrapperStore[Store]): """ @@ -563,7 +620,7 @@ async def set(self, key: str, value: Buffer) -> None: await self._store.set(key, value) async def get( - self, key: str, prototype: BufferPrototype, byte_range: ByteRequest | None = None + self, key: str, prototype: BufferLike | None = None, byte_range: ByteRequest | None = None ) -> Buffer | None: """ Add latency to the ``get`` method. @@ -574,8 +631,12 @@ async def get( ---------- key : str The key to get - prototype : BufferPrototype - The BufferPrototype to use. + prototype : BufferLike | None, optional + The prototype of the output buffer. + Can be either a Buffer class or an instance of `BufferPrototype`, in which the + `buffer` attribute will be used. + If `None`, the default buffer class for this store will be retrieved via the + ``_get_default_buffer_class`` method. byte_range : ByteRequest, optional An optional byte range. diff --git a/tests/test_store/test_local.py b/tests/test_store/test_local.py index 6756bc83d9..fa4bc7cfc0 100644 --- a/tests/test_store/test_local.py +++ b/tests/test_store/test_local.py @@ -1,7 +1,9 @@ from __future__ import annotations +import json import pathlib import re +from typing import TYPE_CHECKING import numpy as np import pytest @@ -9,11 +11,15 @@ import zarr from zarr import create_array from zarr.core.buffer import Buffer, cpu +from zarr.core.sync import sync from zarr.storage import LocalStore from zarr.storage._local import _atomic_write from zarr.testing.store import StoreTests from zarr.testing.utils import assert_bytes_equal +if TYPE_CHECKING: + from zarr.core.buffer import BufferPrototype + class TestLocalStore(StoreTests[LocalStore, cpu.Buffer]): store_cls = LocalStore @@ -108,6 +114,54 @@ async def test_move( ): await store2.move(destination) + @pytest.mark.parametrize("buffer_cls", [None, cpu.buffer_prototype]) + async def test_get_bytes_with_prototype_none( + self, store: LocalStore, buffer_cls: None | BufferPrototype + ) -> None: + """Test that get_bytes works with prototype=None.""" + data = b"hello world" + key = "test_key" + await self.set(store, key, self.buffer_cls.from_bytes(data)) + + result = await store.get_bytes(key, prototype=buffer_cls) + assert result == data + + @pytest.mark.parametrize("buffer_cls", [None, cpu.buffer_prototype]) + def test_get_bytes_sync_with_prototype_none( + self, store: LocalStore, buffer_cls: None | BufferPrototype + ) -> None: + """Test that get_bytes_sync works with prototype=None.""" + data = b"hello world" + key = "test_key" + sync(self.set(store, key, self.buffer_cls.from_bytes(data))) + + result = store.get_bytes_sync(key, prototype=buffer_cls) + assert result == data + + @pytest.mark.parametrize("buffer_cls", [None, cpu.buffer_prototype]) + async def test_get_json_with_prototype_none( + self, store: LocalStore, buffer_cls: None | BufferPrototype + ) -> None: + """Test that get_json works with prototype=None.""" + data = {"foo": "bar", "number": 42} + key = "test.json" + await self.set(store, key, self.buffer_cls.from_bytes(json.dumps(data).encode())) + + result = await store.get_json(key, prototype=buffer_cls) + assert result == data + + @pytest.mark.parametrize("buffer_cls", [None, cpu.buffer_prototype]) + def test_get_json_sync_with_prototype_none( + self, store: LocalStore, buffer_cls: None | BufferPrototype + ) -> None: + """Test that get_json_sync works with prototype=None.""" + data = {"foo": "bar", "number": 42} + key = "test.json" + sync(self.set(store, key, self.buffer_cls.from_bytes(json.dumps(data).encode()))) + + result = store.get_json_sync(key, prototype=buffer_cls) + assert result == data + @pytest.mark.parametrize("exclusive", [True, False]) def test_atomic_write_successful(tmp_path: pathlib.Path, exclusive: bool) -> None: diff --git a/tests/test_store/test_memory.py b/tests/test_store/test_memory.py index 29fa9b2964..96b7fe9845 100644 --- a/tests/test_store/test_memory.py +++ b/tests/test_store/test_memory.py @@ -1,5 +1,6 @@ from __future__ import annotations +import json import re from typing import TYPE_CHECKING, Any @@ -9,12 +10,14 @@ import zarr from zarr.core.buffer import Buffer, cpu, gpu +from zarr.core.sync import sync from zarr.errors import ZarrUserWarning from zarr.storage import GpuMemoryStore, MemoryStore from zarr.testing.store import StoreTests from zarr.testing.utils import gpu_test if TYPE_CHECKING: + from zarr.core.buffer import BufferPrototype from zarr.core.common import ZarrFormat @@ -76,6 +79,54 @@ async def test_deterministic_size( np.testing.assert_array_equal(a[:3], 1) np.testing.assert_array_equal(a[3:], 0) + @pytest.mark.parametrize("buffer_cls", [None, cpu.buffer_prototype]) + async def test_get_bytes_with_prototype_none( + self, store: MemoryStore, buffer_cls: None | BufferPrototype + ) -> None: + """Test that get_bytes works with prototype=None.""" + data = b"hello world" + key = "test_key" + await self.set(store, key, self.buffer_cls.from_bytes(data)) + + result = await store.get_bytes(key, prototype=buffer_cls) + assert result == data + + @pytest.mark.parametrize("buffer_cls", [None, cpu.buffer_prototype]) + def test_get_bytes_sync_with_prototype_none( + self, store: MemoryStore, buffer_cls: None | BufferPrototype + ) -> None: + """Test that get_bytes_sync works with prototype=None.""" + data = b"hello world" + key = "test_key" + sync(self.set(store, key, self.buffer_cls.from_bytes(data))) + + result = store.get_bytes_sync(key, prototype=buffer_cls) + assert result == data + + @pytest.mark.parametrize("buffer_cls", [None, cpu.buffer_prototype]) + async def test_get_json_with_prototype_none( + self, store: MemoryStore, buffer_cls: None | BufferPrototype + ) -> None: + """Test that get_json works with prototype=None.""" + data = {"foo": "bar", "number": 42} + key = "test.json" + await self.set(store, key, self.buffer_cls.from_bytes(json.dumps(data).encode())) + + result = await store.get_json(key, prototype=buffer_cls) + assert result == data + + @pytest.mark.parametrize("buffer_cls", [None, cpu.buffer_prototype]) + def test_get_json_sync_with_prototype_none( + self, store: MemoryStore, buffer_cls: None | BufferPrototype + ) -> None: + """Test that get_json_sync works with prototype=None.""" + data = {"foo": "bar", "number": 42} + key = "test.json" + sync(self.set(store, key, self.buffer_cls.from_bytes(json.dumps(data).encode()))) + + result = store.get_json_sync(key, prototype=buffer_cls) + assert result == data + # TODO: fix this warning @pytest.mark.filterwarnings("ignore:Unclosed client session:ResourceWarning") diff --git a/tests/test_store/test_wrapper.py b/tests/test_store/test_wrapper.py index b34a63d5d0..c5f2240297 100644 --- a/tests/test_store/test_wrapper.py +++ b/tests/test_store/test_wrapper.py @@ -4,7 +4,7 @@ import pytest -from zarr.abc.store import ByteRequest, Store +from zarr.abc.store import BufferLike, ByteRequest, Store from zarr.core.buffer import Buffer from zarr.core.buffer.cpu import Buffer as CPUBuffer from zarr.core.buffer.cpu import buffer_prototype @@ -14,8 +14,6 @@ if TYPE_CHECKING: from pathlib import Path - from zarr.core.buffer.core import BufferPrototype - class StoreKwargs(TypedDict): store: LocalStore @@ -111,10 +109,13 @@ async def test_wrapped_get(store: Store, capsys: pytest.CaptureFixture[str]) -> # define a class that prints when it sets class NoisyGetter(WrapperStore[Any]): async def get( - self, key: str, prototype: BufferPrototype, byte_range: ByteRequest | None = None - ) -> None: + self, + key: str, + prototype: BufferLike | None = None, + byte_range: ByteRequest | None = None, + ) -> Buffer | None: print(f"getting {key}") - await super().get(key, prototype=prototype, byte_range=byte_range) + return await super().get(key, prototype=prototype, byte_range=byte_range) key = "foo" value = CPUBuffer.from_bytes(b"bar")