Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions src/zarr/abc/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,31 @@ async def empty(self) -> bool: ...
@abstractmethod
async def clear(self) -> None: ...

@abstractmethod
def with_mode(self, mode: AccessModeLiteral) -> Self:
"""
Return a new store of the same type pointing to the same location with a new mode.

The returned Store is not automatically opened. Call :meth:`Store.open` before
using.

Parameters
----------
mode: AccessModeLiteral
The new mode to use.

Returns
-------
store:
A new store of the same type with the new mode.

Examples
--------
>>> writer = zarr.store.MemoryStore(mode="w")
>>> reader = writer.with_mode("r")
"""
...

@property
def mode(self) -> AccessMode:
"""Access mode of the store."""
Expand Down
2 changes: 1 addition & 1 deletion src/zarr/store/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ async def make_store_path(
result = store_like
elif isinstance(store_like, Store):
if mode is not None:
assert AccessMode.from_literal(mode) == store_like.mode
store_like = store_like.with_mode(mode)
await store_like._ensure_open()
result = StorePath(store_like)
elif store_like is None:
Expand Down
5 changes: 5 additions & 0 deletions src/zarr/store/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from pathlib import Path
from typing import TYPE_CHECKING

from typing_extensions import Self

from zarr.abc.store import Store
from zarr.core.buffer import Buffer
from zarr.core.common import concurrent_map, to_thread
Expand Down Expand Up @@ -103,6 +105,9 @@ async def empty(self) -> bool:
else:
return True

def with_mode(self, mode: AccessModeLiteral) -> Self:
return type(self)(root=self.root, mode=mode)

def __str__(self) -> str:
return f"file://{self.root}"

Expand Down
23 changes: 19 additions & 4 deletions src/zarr/store/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from collections.abc import AsyncGenerator, MutableMapping
from typing import TYPE_CHECKING

from typing_extensions import Self

from zarr.abc.store import Store
from zarr.core.buffer import Buffer, gpu
from zarr.core.common import concurrent_map
Expand All @@ -15,6 +17,12 @@
from zarr.core.common import AccessModeLiteral


# T = TypeVar("T", bound=Buffer | gpu.Buffer)


# class _MemoryStore


# TODO: this store could easily be extended to wrap any MutableMapping store from v2
# When that is done, the `MemoryStore` will just be a store that wraps a dict.
class MemoryStore(Store):
Expand Down Expand Up @@ -42,6 +50,9 @@ async def empty(self) -> bool:
async def clear(self) -> None:
self._store_dict.clear()

def with_mode(self, mode: AccessModeLiteral) -> Self:
return type(self)(store_dict=self._store_dict, mode=mode)

def __str__(self) -> str:
return f"memory://{id(self._store_dict)}"

Expand Down Expand Up @@ -156,19 +167,23 @@ class GpuMemoryStore(MemoryStore):
of the original location. This guarantees that chunks will always be in GPU
memory for downstream processing. For location agnostic use cases, it would
be better to use `MemoryStore` instead.

Parameters
----------
store_dict: MutableMapping, optional
A mutable mapping with string keys and :class:`zarr.core.buffer.gpu.Buffer`
values.
"""

_store_dict: MutableMapping[str, Buffer]

def __init__(
self,
store_dict: MutableMapping[str, Buffer] | None = None,
store_dict: MutableMapping[str, gpu.Buffer] | None = None,
*,
mode: AccessModeLiteral = "r",
) -> None:
super().__init__(mode=mode)
if store_dict:
self._store_dict = {k: gpu.Buffer.from_buffer(store_dict[k]) for k in iter(store_dict)}
super().__init__(store_dict=store_dict, mode=mode) # type: ignore[arg-type]

def __str__(self) -> str:
return f"gpumemory://{id(self._store_dict)}"
Expand Down
9 changes: 9 additions & 0 deletions src/zarr/store/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import TYPE_CHECKING, Any

import fsspec
from typing_extensions import Self

from zarr.abc.store import Store
from zarr.store.common import _dereference_path
Expand Down Expand Up @@ -95,6 +96,14 @@ async def clear(self) -> None:
async def empty(self) -> bool:
return not await self.fs._find(self.path, withdirs=True)

def with_mode(self, mode: AccessModeLiteral) -> Self:
return type(self)(
fs=self.fs,
mode=mode,
path=self.path,
allowed_exceptions=self.allowed_exceptions,
)

def __repr__(self) -> str:
return f"<RemoteStore({type(self.fs).__name__}, {self.path})>"

Expand Down
5 changes: 5 additions & 0 deletions src/zarr/store/zip.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from pathlib import Path
from typing import TYPE_CHECKING, Any, Literal

from typing_extensions import Self

from zarr.abc.store import Store
from zarr.core.buffer import Buffer, BufferPrototype

Expand Down Expand Up @@ -115,6 +117,9 @@ async def empty(self) -> bool:
else:
return True

def with_mode(self, mode: ZipStoreAccessModeLiteral) -> Self: # type: ignore[override]
raise NotImplementedError("ZipStore cannot be reopened with a new mode.")

def __str__(self) -> str:
return f"zip://{self.path}"

Expand Down
28 changes: 27 additions & 1 deletion src/zarr/testing/store.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import pickle
from typing import Any, Generic, TypeVar
from typing import Any, Generic, TypeVar, cast

import pytest

from zarr.abc.store import AccessMode, Store
from zarr.core.buffer import Buffer, default_buffer_prototype
from zarr.core.common import AccessModeLiteral
from zarr.core.sync import _collect_aiterator
from zarr.store._utils import _normalize_interval_index
from zarr.testing.utils import assert_bytes_equal
Expand Down Expand Up @@ -251,3 +252,28 @@ async def test_list_dir(self, store: S) -> None:

keys_observed = await _collect_aiterator(store.list_dir(root + "/"))
assert sorted(keys_expected) == sorted(keys_observed)

async def test_with_mode(self, store: S) -> None:
data = b"0000"
self.set(store, "key", self.buffer_cls.from_bytes(data))
assert self.get(store, "key").to_bytes() == data

for mode in ["r", "a"]:
mode = cast(AccessModeLiteral, mode)
clone = store.with_mode(mode)
# await store.close()
await clone._ensure_open()
assert clone.mode == AccessMode.from_literal(mode)
assert isinstance(clone, type(store))

# earlier writes are visible
assert self.get(clone, "key").to_bytes() == data

# writes to original after with_mode is visible
self.set(store, "key-2", self.buffer_cls.from_bytes(data))
assert self.get(clone, "key-2").to_bytes() == data

if mode == "w":
# writes to clone is visible in the original
self.set(store, "key-3", self.buffer_cls.from_bytes(data))
assert self.get(clone, "key-3").to_bytes() == data
16 changes: 13 additions & 3 deletions tests/v3/test_store/test_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,14 @@ def set(self, store: GpuMemoryStore, key: str, value: Buffer) -> None:
def get(self, store: MemoryStore, key: str) -> Buffer:
return store._store_dict[key]

@pytest.fixture(scope="function", params=[None, {}])
def store_kwargs(self, request) -> dict[str, str | None | dict[str, Buffer]]:
return {"store_dict": request.param, "mode": "r+"}
@pytest.fixture(scope="function", params=[None, True])
def store_kwargs(
self, request: pytest.FixtureRequest
) -> dict[str, str | None | dict[str, Buffer]]:
kwargs = {"store_dict": None, "mode": "r+"}
if request.param is True:
kwargs["store_dict"] = {}
return kwargs

@pytest.fixture(scope="function")
def store(self, store_kwargs: str | None | dict[str, gpu.Buffer]) -> GpuMemoryStore:
Expand All @@ -80,3 +85,8 @@ def test_store_supports_partial_writes(self, store: GpuMemoryStore) -> None:

def test_list_prefix(self, store: GpuMemoryStore) -> None:
assert True

def test_dict_reference(self, store: GpuMemoryStore) -> None:
store_dict = {}
result = GpuMemoryStore(store_dict=store_dict)
assert result._store_dict is store_dict
4 changes: 4 additions & 0 deletions tests/v3/test_store/test_zip.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,3 +96,7 @@ def test_api_integration(self, store: ZipStore) -> None:
del root["bar"]

store.close()

async def test_with_mode(self, store: ZipStore) -> None:
with pytest.raises(NotImplementedError, match="new mode"):
await super().test_with_mode(store)