Skip to content

Commit 67d4521

Browse files
committed
Add array_metadata strategy
1 parent f4278a5 commit 67d4521

File tree

2 files changed

+84
-4
lines changed

2 files changed

+84
-4
lines changed

src/zarr/testing/strategies.py

Lines changed: 61 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import sys
2-
from typing import Any
2+
from typing import Any, Literal
33

44
import hypothesis.extra.numpy as npst
55
import hypothesis.strategies as st
@@ -8,9 +8,10 @@
88
from hypothesis.strategies import SearchStrategy
99

1010
import zarr
11-
from zarr.abc.store import RangeByteRequest
11+
from zarr.abc.store import RangeByteRequest, Store
1212
from zarr.core.array import Array
1313
from zarr.core.common import ZarrFormat
14+
from zarr.core.metadata import ArrayV2Metadata, ArrayV3Metadata
1415
from zarr.core.sync import sync
1516
from zarr.storage import MemoryStore, StoreLike
1617
from zarr.storage._common import _dereference_path
@@ -67,6 +68,11 @@ def safe_unicode_for_dtype(dtype: np.dtype[np.str_]) -> st.SearchStrategy[str]:
6768
)
6869

6970

71+
def clear_store(x: Store) -> Store:
72+
sync(x.clear())
73+
return x
74+
75+
7076
# From https://zarr-specs.readthedocs.io/en/latest/v3/core/v3.0.html#node-names
7177
# 1. must not be the empty string ("")
7278
# 2. must not include the character "/"
@@ -85,12 +91,64 @@ def safe_unicode_for_dtype(dtype: np.dtype[np.str_]) -> st.SearchStrategy[str]:
8591
# st.builds will only call a new store constructor for different keyword arguments
8692
# i.e. stores.examples() will always return the same object per Store class.
8793
# So we map a clear to reset the store.
88-
stores = st.builds(MemoryStore, st.just({})).map(lambda x: sync(x.clear()))
94+
stores = st.builds(MemoryStore, st.just({})).map(clear_store)
8995
compressors = st.sampled_from([None, "default"])
9096
zarr_formats: st.SearchStrategy[ZarrFormat] = st.sampled_from([2, 3])
9197
array_shapes = npst.array_shapes(max_dims=4, min_side=0)
9298

9399

100+
@st.composite # type: ignore[misc]
101+
def dimension_names(draw: st.DrawFn, *, ndim: int | None = None) -> list[None | str] | None:
102+
simple_text = st.text(zarr_key_chars, min_size=0)
103+
return draw(st.none() | st.lists(st.none() | simple_text, min_size=ndim, max_size=ndim)) # type: ignore[no-any-return]
104+
105+
106+
@st.composite # type: ignore[misc]
107+
def array_metadata(
108+
draw: st.DrawFn,
109+
*,
110+
array_shapes: st.SearchStrategy[tuple[int, ...]] = npst.array_shapes,
111+
zarr_formats: st.SearchStrategy[Literal[2, 3]] = zarr_formats,
112+
attributes: st.SearchStrategy[dict[str, Any]] = attrs,
113+
) -> ArrayV2Metadata | ArrayV3Metadata:
114+
from zarr.codecs.bytes import BytesCodec
115+
from zarr.core.chunk_grids import RegularChunkGrid
116+
from zarr.core.chunk_key_encodings import DefaultChunkKeyEncoding
117+
from zarr.core.metadata.v3 import ArrayV3Metadata
118+
119+
zarr_format = draw(zarr_formats)
120+
# separator = draw(st.sampled_from(['/', '\\']))
121+
shape = draw(array_shapes())
122+
ndim = len(shape)
123+
chunk_shape = draw(array_shapes(min_dims=ndim, max_dims=ndim))
124+
dtype = draw(v3_dtypes())
125+
fill_value = draw(npst.from_dtype(dtype))
126+
if zarr_format == 2:
127+
return ArrayV2Metadata(
128+
shape=shape,
129+
chunks=chunk_shape,
130+
dtype=dtype,
131+
fill_value=fill_value,
132+
order=draw(st.sampled_from(["C", "F"])),
133+
attributes=draw(attributes),
134+
dimension_separator=draw(st.sampled_from([".", "/"])),
135+
filters=None,
136+
compressor=None,
137+
)
138+
else:
139+
return ArrayV3Metadata(
140+
shape=shape,
141+
data_type=dtype,
142+
chunk_grid=RegularChunkGrid(chunk_shape=chunk_shape),
143+
fill_value=fill_value,
144+
attributes=draw(attributes),
145+
dimension_names=draw(dimension_names(ndim=ndim)),
146+
chunk_key_encoding=DefaultChunkKeyEncoding(separator="/"), # FIXME
147+
codecs=[BytesCodec()],
148+
storage_transformers=(),
149+
)
150+
151+
94152
@st.composite # type: ignore[misc]
95153
def numpy_arrays(
96154
draw: st.DrawFn,

tests/test_properties.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,24 @@
22
import pytest
33
from numpy.testing import assert_array_equal
44

5+
from zarr.core.buffer import default_buffer_prototype
6+
57
pytest.importorskip("hypothesis")
68

79
import hypothesis.extra.numpy as npst
810
import hypothesis.strategies as st
911
from hypothesis import given
1012

11-
from zarr.testing.strategies import arrays, basic_indices, numpy_arrays, zarr_formats
13+
from zarr.abc.store import Store
14+
from zarr.core.metadata import ArrayV2Metadata, ArrayV3Metadata
15+
from zarr.testing.strategies import (
16+
array_metadata,
17+
arrays,
18+
basic_indices,
19+
numpy_arrays,
20+
stores,
21+
zarr_formats,
22+
)
1223

1324

1425
@given(data=st.data(), zarr_format=zarr_formats)
@@ -47,6 +58,17 @@ def test_vindex(data: st.DataObject) -> None:
4758
assert_array_equal(nparray[indexer], actual)
4859

4960

61+
@given(store=stores, meta=array_metadata()) # type: ignore[misc]
62+
async def test_roundtrip_array_metadata(
63+
store: Store, meta: ArrayV2Metadata | ArrayV3Metadata
64+
) -> None:
65+
asdict = meta.to_buffer_dict(prototype=default_buffer_prototype())
66+
for key, expected in asdict.items():
67+
await store.set(f"0/{key}", expected)
68+
actual = await store.get(f"0/{key}", prototype=default_buffer_prototype())
69+
assert actual == expected
70+
71+
5072
# @st.composite
5173
# def advanced_indices(draw, *, shape):
5274
# basic_idxr = draw(

0 commit comments

Comments
 (0)