Skip to content

Commit d78e384

Browse files
committed
Merge branch 'v3' of https://github.com/zarr-developers/zarr-python into fix/dask-compat
2 parents 0d89912 + dd03ff0 commit d78e384

File tree

7 files changed

+230
-19
lines changed

7 files changed

+230
-19
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ default_language_version:
77
python: python3
88
repos:
99
- repo: https://github.com/astral-sh/ruff-pre-commit
10-
rev: v0.6.4
10+
rev: v0.6.5
1111
hooks:
1212
- id: ruff
1313
args: ["--fix", "--show-fixes"]

src/zarr/core/buffer/core.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -462,8 +462,12 @@ def __len__(self) -> int:
462462
def __repr__(self) -> str:
463463
return f"<NDBuffer shape={self.shape} dtype={self.dtype} {self._data!r}>"
464464

465-
def all_equal(self, other: Any) -> bool:
466-
return bool((self._data == other).all())
465+
def all_equal(self, other: Any, equal_nan: bool = True) -> bool:
466+
"""Compare to `other` using np.array_equal."""
467+
# use array_equal to obtain equal_nan=True functionality
468+
data, other = np.broadcast_arrays(self._data, other)
469+
result = np.array_equal(self._data, other, equal_nan=equal_nan)
470+
return result
467471

468472
def fill(self, value: Any) -> None:
469473
self._data.fill(value)

src/zarr/testing/store.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -251,10 +251,10 @@ async def test_list_dir(self, store: S) -> None:
251251
async def test_set_get(self, store_kwargs: dict[str, Any]) -> None:
252252
kwargs = {**store_kwargs, **{"mode": "w"}}
253253
store = self.store_cls(**kwargs)
254-
await zarr.api.asynchronous.open_array(store=store, path="a", mode="w", shape=(4,))
254+
await zarr.api.asynchronous.open_array(store=store, path="a", shape=(4,))
255255
keys = [x async for x in store.list()]
256256
assert keys == ["a/zarr.json"]
257257

258258
# no errors
259-
await zarr.api.asynchronous.open_array(store=store, path="a", mode="r")
260-
await zarr.api.asynchronous.open_array(store=store, path="a", mode="a")
259+
await zarr.api.asynchronous.open_array(store=store, path="a")
260+
await zarr.api.asynchronous.open_array(store=store, path="a")

tests/v3/test_array.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,25 @@ def test_array_v3_fill_value(store: MemoryStore, fill_value: int, dtype_str: str
138138
assert arr.fill_value.dtype == arr.dtype
139139

140140

141+
@pytest.mark.parametrize("store", ["memory"], indirect=True)
142+
async def test_array_v3_nan_fill_value(store: MemoryStore) -> None:
143+
shape = (10,)
144+
arr = Array.create(
145+
store=store,
146+
shape=shape,
147+
dtype=np.float64,
148+
zarr_format=3,
149+
chunk_shape=shape,
150+
fill_value=np.nan,
151+
)
152+
arr[:] = np.nan
153+
154+
assert np.isnan(arr.fill_value)
155+
assert arr.fill_value.dtype == arr.dtype
156+
# all fill value chunk is an empty chunk, and should not be written
157+
assert len([a async for a in store.list_prefix("/")]) == 0
158+
159+
141160
@pytest.mark.parametrize("store", ("local",), indirect=["store"])
142161
@pytest.mark.parametrize("zarr_format", (2, 3))
143162
async def test_serializable_async_array(

tests/v3/test_group.py

Lines changed: 153 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,15 @@
66
import numpy as np
77
import pytest
88

9+
import zarr
910
from zarr import Array, AsyncArray, AsyncGroup, Group
1011
from zarr.abc.store import Store
1112
from zarr.core.buffer import default_buffer_prototype
1213
from zarr.core.common import JSON, ZarrFormat
1314
from zarr.core.group import GroupMetadata
1415
from zarr.core.sync import sync
1516
from zarr.errors import ContainsArrayError, ContainsGroupError
16-
from zarr.storage import LocalStore, StorePath
17+
from zarr.storage import LocalStore, MemoryStore, StorePath
1718
from zarr.storage.common import make_store_path
1819

1920
from .conftest import parse_store
@@ -699,3 +700,154 @@ def test_serializable_sync_group(store: LocalStore, zarr_format: ZarrFormat) ->
699700
actual = pickle.loads(p)
700701

701702
assert actual == expected
703+
704+
705+
async def test_group_members_async(store: LocalStore | MemoryStore) -> None:
706+
group = AsyncGroup(
707+
GroupMetadata(),
708+
store_path=StorePath(store=store, path="root"),
709+
)
710+
a0 = await group.create_array("a0", shape=(1,))
711+
g0 = await group.create_group("g0")
712+
a1 = await g0.create_array("a1", shape=(1,))
713+
g1 = await g0.create_group("g1")
714+
a2 = await g1.create_array("a2", shape=(1,))
715+
g2 = await g1.create_group("g2")
716+
717+
# immediate children
718+
children = sorted([x async for x in group.members()], key=lambda x: x[0])
719+
assert children == [
720+
("a0", a0),
721+
("g0", g0),
722+
]
723+
724+
nmembers = await group.nmembers()
725+
assert nmembers == 2
726+
727+
# partial
728+
children = sorted([x async for x in group.members(max_depth=1)], key=lambda x: x[0])
729+
expected = [
730+
("a0", a0),
731+
("g0", g0),
732+
("g0/a1", a1),
733+
("g0/g1", g1),
734+
]
735+
assert children == expected
736+
nmembers = await group.nmembers(max_depth=1)
737+
assert nmembers == 4
738+
739+
# all children
740+
all_children = sorted([x async for x in group.members(max_depth=None)], key=lambda x: x[0])
741+
expected = [
742+
("a0", a0),
743+
("g0", g0),
744+
("g0/a1", a1),
745+
("g0/g1", g1),
746+
("g0/g1/a2", a2),
747+
("g0/g1/g2", g2),
748+
]
749+
assert all_children == expected
750+
751+
nmembers = await group.nmembers(max_depth=None)
752+
assert nmembers == 6
753+
754+
with pytest.raises(ValueError, match="max_depth"):
755+
[x async for x in group.members(max_depth=-1)]
756+
757+
758+
async def test_require_group(store: LocalStore | MemoryStore, zarr_format: ZarrFormat) -> None:
759+
root = await AsyncGroup.create(store=store, zarr_format=zarr_format)
760+
761+
# create foo group
762+
_ = await root.create_group("foo", attributes={"foo": 100})
763+
764+
# test that we can get the group using require_group
765+
foo_group = await root.require_group("foo")
766+
assert foo_group.attrs == {"foo": 100}
767+
768+
# test that we can get the group using require_group and overwrite=True
769+
foo_group = await root.require_group("foo", overwrite=True)
770+
771+
_ = await foo_group.create_array(
772+
"bar", shape=(10,), dtype="uint8", chunk_shape=(2,), attributes={"foo": 100}
773+
)
774+
775+
# test that overwriting a group w/ children fails
776+
# TODO: figure out why ensure_no_existing_node is not catching the foo.bar array
777+
#
778+
# with pytest.raises(ContainsArrayError):
779+
# await root.require_group("foo", overwrite=True)
780+
781+
# test that requiring a group where an array is fails
782+
with pytest.raises(TypeError):
783+
await foo_group.require_group("bar")
784+
785+
786+
async def test_require_groups(store: LocalStore | MemoryStore, zarr_format: ZarrFormat) -> None:
787+
root = await AsyncGroup.create(store=store, zarr_format=zarr_format)
788+
# create foo group
789+
_ = await root.create_group("foo", attributes={"foo": 100})
790+
# create bar group
791+
_ = await root.create_group("bar", attributes={"bar": 200})
792+
793+
foo_group, bar_group = await root.require_groups("foo", "bar")
794+
assert foo_group.attrs == {"foo": 100}
795+
assert bar_group.attrs == {"bar": 200}
796+
797+
# get a mix of existing and new groups
798+
foo_group, spam_group = await root.require_groups("foo", "spam")
799+
assert foo_group.attrs == {"foo": 100}
800+
assert spam_group.attrs == {}
801+
802+
# no names
803+
no_group = await root.require_groups()
804+
assert no_group == ()
805+
806+
807+
async def test_create_dataset(store: LocalStore | MemoryStore, zarr_format: ZarrFormat) -> None:
808+
root = await AsyncGroup.create(store=store, zarr_format=zarr_format)
809+
with pytest.warns(DeprecationWarning):
810+
foo = await root.create_dataset("foo", shape=(10,), dtype="uint8")
811+
assert foo.shape == (10,)
812+
813+
with pytest.raises(ContainsArrayError), pytest.warns(DeprecationWarning):
814+
await root.create_dataset("foo", shape=(100,), dtype="int8")
815+
816+
_ = await root.create_group("bar")
817+
with pytest.raises(ContainsGroupError), pytest.warns(DeprecationWarning):
818+
await root.create_dataset("bar", shape=(100,), dtype="int8")
819+
820+
821+
async def test_require_array(store: LocalStore | MemoryStore, zarr_format: ZarrFormat) -> None:
822+
root = await AsyncGroup.create(store=store, zarr_format=zarr_format)
823+
foo1 = await root.require_array("foo", shape=(10,), dtype="i8", attributes={"foo": 101})
824+
assert foo1.attrs == {"foo": 101}
825+
foo2 = await root.require_array("foo", shape=(10,), dtype="i8")
826+
assert foo2.attrs == {"foo": 101}
827+
828+
# exact = False
829+
_ = await root.require_array("foo", shape=10, dtype="f8")
830+
831+
# errors w/ exact True
832+
with pytest.raises(TypeError, match="Incompatible dtype"):
833+
await root.require_array("foo", shape=(10,), dtype="f8", exact=True)
834+
835+
with pytest.raises(TypeError, match="Incompatible shape"):
836+
await root.require_array("foo", shape=(100, 100), dtype="i8")
837+
838+
with pytest.raises(TypeError, match="Incompatible dtype"):
839+
await root.require_array("foo", shape=(10,), dtype="f4")
840+
841+
_ = await root.create_group("bar")
842+
with pytest.raises(TypeError, match="Incompatible object"):
843+
await root.require_array("bar", shape=(10,), dtype="int8")
844+
845+
846+
async def test_open_mutable_mapping():
847+
group = await zarr.api.asynchronous.open_group(store={}, mode="w")
848+
assert isinstance(group.store_path.store, MemoryStore)
849+
850+
851+
def test_open_mutable_mapping_sync():
852+
group = zarr.open_group(store={}, mode="w")
853+
assert isinstance(group.store_path.store, MemoryStore)

tests/v3/test_properties.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,6 @@ def test_roundtrip(data: st.DataObject) -> None:
1818

1919

2020
@given(data=st.data())
21-
# The filter warning here is to silence an occasional warning in NDBuffer.all_equal
22-
# See https://github.com/zarr-developers/zarr-python/pull/2118#issuecomment-2310280899
23-
# Uncomment the next line to reproduce the original failure.
24-
# @reproduce_failure('6.111.2', b'AXicY2FgZGRAB/8/ndR2z7nkDZEDADWpBL4=')
25-
@pytest.mark.filterwarnings("ignore::RuntimeWarning")
2621
def test_basic_indexing(data: st.DataObject) -> None:
2722
zarray = data.draw(arrays())
2823
nparray = zarray[:]
@@ -37,11 +32,6 @@ def test_basic_indexing(data: st.DataObject) -> None:
3732

3833

3934
@given(data=st.data())
40-
# The filter warning here is to silence an occasional warning in NDBuffer.all_equal
41-
# See https://github.com/zarr-developers/zarr-python/pull/2118#issuecomment-2310280899
42-
# Uncomment the next line to reproduce the original failure.
43-
# @reproduce_failure('6.111.2', b'AXicY2FgZGRAB/8/eLmF7qr/C5EDADZUBRM=')
44-
@pytest.mark.filterwarnings("ignore::RuntimeWarning")
4535
def test_vindex(data: st.DataObject) -> None:
4636
zarray = data.draw(arrays())
4737
nparray = zarray[:]

tests/v3/test_store/test_memory.py

Lines changed: 48 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,10 @@
44

55
import pytest
66

7-
from zarr.core.buffer import Buffer, cpu
8-
from zarr.storage.memory import MemoryStore
7+
from zarr.core.buffer import Buffer, cpu, gpu
8+
from zarr.storage.memory import GpuMemoryStore, MemoryStore
99
from zarr.testing.store import StoreTests
10+
from zarr.testing.utils import gpu_test
1011

1112

1213
class TestMemoryStore(StoreTests[MemoryStore, cpu.Buffer]):
@@ -56,3 +57,48 @@ def test_serizalizable_store(self, store: MemoryStore) -> None:
5657

5758
with pytest.raises(NotImplementedError):
5859
pickle.dumps(store)
60+
61+
62+
@gpu_test
63+
class TestGpuMemoryStore(StoreTests[GpuMemoryStore, gpu.Buffer]):
64+
store_cls = GpuMemoryStore
65+
buffer_cls = gpu.Buffer
66+
67+
def set(self, store: GpuMemoryStore, key: str, value: Buffer) -> None:
68+
store._store_dict[key] = value
69+
70+
def get(self, store: MemoryStore, key: str) -> Buffer:
71+
return store._store_dict[key]
72+
73+
@pytest.fixture(scope="function", params=[None, {}])
74+
def store_kwargs(self, request) -> dict[str, str | None | dict[str, Buffer]]:
75+
return {"store_dict": request.param, "mode": "r+"}
76+
77+
@pytest.fixture(scope="function")
78+
def store(self, store_kwargs: str | None | dict[str, gpu.Buffer]) -> GpuMemoryStore:
79+
return self.store_cls(**store_kwargs)
80+
81+
def test_store_repr(self, store: GpuMemoryStore) -> None:
82+
assert str(store) == f"gpumemory://{id(store._store_dict)}"
83+
84+
def test_store_supports_writes(self, store: GpuMemoryStore) -> None:
85+
assert store.supports_writes
86+
87+
def test_store_supports_listing(self, store: GpuMemoryStore) -> None:
88+
assert store.supports_listing
89+
90+
def test_store_supports_partial_writes(self, store: GpuMemoryStore) -> None:
91+
assert store.supports_partial_writes
92+
93+
def test_list_prefix(self, store: GpuMemoryStore) -> None:
94+
assert True
95+
96+
def test_serizalizable_store(self, store: MemoryStore) -> None:
97+
with pytest.raises(NotImplementedError):
98+
store.__getstate__()
99+
100+
with pytest.raises(NotImplementedError):
101+
store.__setstate__({})
102+
103+
with pytest.raises(NotImplementedError):
104+
pickle.dumps(store)

0 commit comments

Comments
 (0)