Skip to content

Commit 8407c64

Browse files
committed
feat: add wrapperstore
1 parent 498cb78 commit 8407c64

File tree

3 files changed

+192
-0
lines changed

3 files changed

+192
-0
lines changed

src/zarr/storage/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from zarr.storage.logging import LoggingStore
44
from zarr.storage.memory import MemoryStore
55
from zarr.storage.remote import RemoteStore
6+
from zarr.storage.wrapper import WrapperStore
67
from zarr.storage.zip import ZipStore
78

89
__all__ = [
@@ -12,6 +13,7 @@
1213
"RemoteStore",
1314
"StoreLike",
1415
"StorePath",
16+
"WrapperStore",
1517
"ZipStore",
1618
"make_store_path",
1719
]

src/zarr/storage/wrapper.py

Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
from __future__ import annotations
2+
3+
from typing import TYPE_CHECKING, Generic, TypeVar
4+
5+
if TYPE_CHECKING:
6+
from collections.abc import AsyncGenerator, Iterable
7+
from types import TracebackType
8+
from typing import Any, Self
9+
10+
from zarr.abc.store import ByteRangeRequest
11+
from zarr.core.buffer import Buffer, BufferPrototype
12+
from zarr.core.common import AccessModeLiteral, BytesLike
13+
14+
from zarr.abc.store import AccessMode, Store
15+
16+
T_Wrapped = TypeVar("T_Wrapped", bound=Store)
17+
18+
19+
class WrapperStore(Store, Generic[T_Wrapped]):
20+
"""
21+
A store class that wraps an existing ``Store`` instance.
22+
By default all of the store methods are delegated to the wrapped store instance, which is
23+
accessible via the ``._wrapped`` attribute of this class.
24+
25+
Use this class to modify or extend the behavior of the other store classes.
26+
"""
27+
28+
_wrapped: T_Wrapped
29+
30+
def __init__(self, wrapped: T_Wrapped) -> None:
31+
self._wrapped = wrapped
32+
33+
@classmethod
34+
async def open(
35+
cls: type[Self], wrapped_class: type[T_Wrapped], *args: Any, **kwargs: Any
36+
) -> Self:
37+
wrapped = wrapped_class(*args, **kwargs)
38+
await wrapped._open()
39+
return cls(wrapped=wrapped)
40+
41+
def __enter__(self) -> Self:
42+
return type(self)(self._wrapped.__enter__())
43+
44+
def __exit__(
45+
self,
46+
exc_type: type[BaseException] | None,
47+
exc_value: BaseException | None,
48+
traceback: TracebackType | None,
49+
) -> None:
50+
return self._wrapped.__exit__(exc_type, exc_value, traceback)
51+
52+
async def _open(self) -> None:
53+
await self._wrapped._open()
54+
55+
async def _ensure_open(self) -> None:
56+
await self._wrapped._ensure_open()
57+
58+
async def empty(self) -> bool:
59+
return await self._wrapped.empty()
60+
61+
async def clear(self) -> None:
62+
return await self._wrapped.clear()
63+
64+
def with_mode(self, mode: AccessModeLiteral) -> Self:
65+
return type(self)(wrapped=self._wrapped.with_mode(mode=mode))
66+
67+
@property
68+
def mode(self) -> AccessMode:
69+
return self._wrapped._mode
70+
71+
def _check_writable(self) -> None:
72+
return self._wrapped._check_writable()
73+
74+
def __eq__(self, value: object) -> bool:
75+
return type(self) is type(value) and self._wrapped.__eq__(value)
76+
77+
async def get(
78+
self, key: str, prototype: BufferPrototype, byte_range: ByteRangeRequest | None = None
79+
) -> Buffer | None:
80+
return await self._wrapped.get(key, prototype, byte_range)
81+
82+
async def get_partial_values(
83+
self,
84+
prototype: BufferPrototype,
85+
key_ranges: Iterable[tuple[str, ByteRangeRequest]],
86+
) -> list[Buffer | None]:
87+
return await self._wrapped.get_partial_values(prototype, key_ranges)
88+
89+
async def exists(self, key: str) -> bool:
90+
return await self._wrapped.exists(key)
91+
92+
async def set(self, key: str, value: Buffer) -> None:
93+
await self._wrapped.set(key, value)
94+
95+
async def set_if_not_exists(self, key: str, value: Buffer) -> None:
96+
return await self._wrapped.set_if_not_exists(key, value)
97+
98+
async def _set_many(self, values: Iterable[tuple[str, Buffer]]) -> None:
99+
await self._wrapped._set_many(values)
100+
101+
@property
102+
def supports_writes(self) -> bool:
103+
return self._wrapped.supports_writes
104+
105+
@property
106+
def supports_deletes(self) -> bool:
107+
return self._wrapped.supports_deletes
108+
109+
async def delete(self, key: str) -> None:
110+
await self._wrapped.delete(key)
111+
112+
@property
113+
def supports_partial_writes(self) -> bool:
114+
return self._wrapped.supports_partial_writes
115+
116+
async def set_partial_values(
117+
self, key_start_values: Iterable[tuple[str, int, BytesLike]]
118+
) -> None:
119+
return await self._wrapped.set_partial_values(key_start_values)
120+
121+
@property
122+
def supports_listing(self) -> bool:
123+
return self._wrapped.supports_listing
124+
125+
def list(self) -> AsyncGenerator[str]:
126+
return self._wrapped.list()
127+
128+
def list_prefix(self, prefix: str) -> AsyncGenerator[str]:
129+
return self._wrapped.list_prefix(prefix)
130+
131+
def list_dir(self, prefix: str) -> AsyncGenerator[str]:
132+
return self._wrapped.list_dir(prefix)
133+
134+
async def delete_dir(self, prefix: str) -> None:
135+
return await self._wrapped.delete_dir(prefix)
136+
137+
def close(self) -> None:
138+
self._wrapped.close()
139+
140+
async def _get_many(
141+
self, requests: Iterable[tuple[str, BufferPrototype, ByteRangeRequest | None]]
142+
) -> AsyncGenerator[tuple[str, Buffer | None], None]:
143+
async for req in self._wrapped._get_many(requests):
144+
yield req

tests/test_store/test_wrapper.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
from __future__ import annotations
2+
3+
from typing import TYPE_CHECKING
4+
5+
import pytest
6+
7+
from zarr.core.buffer.cpu import Buffer, buffer_prototype
8+
from zarr.storage.wrapper import WrapperStore
9+
10+
if TYPE_CHECKING:
11+
from zarr.abc.store import Store
12+
from zarr.core.buffer.core import BufferPrototype
13+
14+
15+
@pytest.mark.parametrize("store", ["local", "memory", "zip"], indirect=True)
16+
async def test_wrapped_set(store: Store, capsys: pytest.CaptureFixture[str]) -> None:
17+
# define a class that prints when it sets
18+
class NoisySetter(WrapperStore):
19+
async def set(self, key: str, value: Buffer) -> None:
20+
print(f"setting {key}")
21+
await super().set(key, value)
22+
23+
key = "foo"
24+
value = Buffer.from_bytes(b"bar")
25+
store_wrapped = NoisySetter(store)
26+
await store_wrapped.set(key, value)
27+
captured = capsys.readouterr()
28+
assert f"setting {key}" in captured.out
29+
assert await store_wrapped.get(key, buffer_prototype) == value
30+
31+
32+
@pytest.mark.parametrize("store", ["local", "memory", "zip"], indirect=True)
33+
async def test_wrapped_get(store: Store, capsys: pytest.CaptureFixture[str]) -> None:
34+
# define a class that prints when it sets
35+
class NoisySetter(WrapperStore):
36+
def get(self, key: str, prototype: BufferPrototype) -> None:
37+
print(f"getting {key}")
38+
return super().get(key, prototype=prototype)
39+
40+
key = "foo"
41+
value = Buffer.from_bytes(b"bar")
42+
store_wrapped = NoisySetter(store)
43+
await store_wrapped.set(key, value)
44+
assert await store_wrapped.get(key, buffer_prototype) == value
45+
captured = capsys.readouterr()
46+
assert f"getting {key}" in captured.out

0 commit comments

Comments
 (0)