Skip to content

Commit a581a49

Browse files
committed
Refactor property-based tests; round trip ArrayV2Metadata
1 parent 9048d79 commit a581a49

File tree

9 files changed

+567
-117
lines changed

9 files changed

+567
-117
lines changed

src/zarr/core/metadata/v2.py

Lines changed: 40 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
import json
2222
from dataclasses import dataclass, field, fields, replace
23+
import numbers
2324

2425
import numcodecs
2526
import numpy as np
@@ -144,18 +145,7 @@ def _json_convert(
144145
return o.name
145146
raise TypeError
146147

147-
def _sanitize_fill_value(fv: Any):
148-
if isinstance(fv, (float, np.floating)):
149-
if np.isnan(fv):
150-
fv = "NaN"
151-
elif np.isinf(fv):
152-
fv = "Infinity" if fv > 0 else "-Infinity"
153-
elif isinstance(fv, (np.complex64, np.complexfloating)):
154-
fv = [_sanitize_fill_value(fv.real), _sanitize_fill_value(fv.imag)]
155-
return fv
156-
157148
zarray_dict = self.to_dict()
158-
zarray_dict["fill_value"] = _sanitize_fill_value(zarray_dict["fill_value"])
159149
zattrs_dict = zarray_dict.pop("attributes", {})
160150
json_indent = config.get("json_indent")
161151
return {
@@ -167,11 +157,12 @@ def _sanitize_fill_value(fv: Any):
167157
),
168158
}
169159

160+
170161
@classmethod
171162
def from_dict(cls, data: dict[str, Any]) -> ArrayV2Metadata:
172-
# make a copy to protect the original from modification
163+
# Make a copy to protect the original from modification.
173164
_data = data.copy()
174-
# check that the zarr_format attribute is correct
165+
# Check that the zarr_format attribute is correct.
175166
_ = parse_zarr_format(_data.pop("zarr_format"))
176167
dtype = parse_dtype(_data["dtype"])
177168

@@ -180,20 +171,48 @@ def from_dict(cls, data: dict[str, Any]) -> ArrayV2Metadata:
180171
if fill_value_encoded is not None:
181172
fill_value = base64.standard_b64decode(fill_value_encoded)
182173
_data["fill_value"] = fill_value
183-
184-
# zarr v2 allowed arbitrary keys here.
185-
# We don't want the ArrayV2Metadata constructor to fail just because someone put an
186-
# extra key in the metadata.
174+
else:
175+
fill_value = _data.get("fill_value")
176+
if fill_value is not None:
177+
if np.issubdtype(dtype, np.datetime64):
178+
if fill_value == "NaT":
179+
_data["fill_value"] = np.array("NaT", dtype=dtype)[()]
180+
else:
181+
_data["fill_value"] = np.array(fill_value, dtype=dtype)[()]
182+
elif dtype.kind == "c" and isinstance(fill_value, list):
183+
if len(fill_value) == 2:
184+
val = complex(float(fill_value[0]), float(fill_value[1]))
185+
_data["fill_value"] = np.array(val, dtype=dtype)[()]
186+
elif dtype.kind in "f" and isinstance(fill_value, str):
187+
if fill_value in {"NaN", "Infinity", "-Infinity"}:
188+
_data["fill_value"] = np.array(fill_value, dtype=dtype)[()]
189+
# zarr v2 allowed arbitrary keys in the metadata.
190+
# Filter the keys to only those expected by the constructor.
187191
expected = {x.name for x in fields(cls)}
188-
# https://github.com/zarr-developers/zarr-python/issues/2269
189-
# handle the renames
190192
expected |= {"dtype", "chunks"}
191-
192193
_data = {k: v for k, v in _data.items() if k in expected}
193194

194195
return cls(**_data)
195196

197+
198+
196199
def to_dict(self) -> dict[str, JSON]:
200+
def _sanitize_fill_value(fv: Any):
201+
if fv is None:
202+
return fv
203+
elif isinstance(fv, np.datetime64):
204+
if np.isnat(fv):
205+
return "NaT"
206+
return np.datetime_as_string(fv)
207+
elif isinstance(fv, numbers.Real):
208+
if np.isnan(fv):
209+
fv = "NaN"
210+
elif np.isinf(fv):
211+
fv = "Infinity" if fv > 0 else "-Infinity"
212+
elif isinstance(fv, numbers.Complex):
213+
fv = [_sanitize_fill_value(fv.real), _sanitize_fill_value(fv.imag)]
214+
return fv
215+
197216
zarray_dict = super().to_dict()
198217

199218
if self.dtype.kind in "SV" and self.fill_value is not None:
@@ -203,6 +222,7 @@ def to_dict(self) -> dict[str, JSON]:
203222
fill_value = base64.standard_b64encode(cast(bytes, self.fill_value)).decode("ascii")
204223
zarray_dict["fill_value"] = fill_value
205224

225+
zarray_dict["fill_value"] = _sanitize_fill_value(zarray_dict["fill_value"])
206226
_ = zarray_dict.pop("dtype")
207227
dtype_json: JSON
208228
# In the case of zarr v2, the simplest i.e., '|VXX' dtype is represented as a string

src/zarr/testing/stateful.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,7 @@
2020
from zarr.core.buffer import Buffer, BufferPrototype, cpu, default_buffer_prototype
2121
from zarr.core.sync import SyncMixin
2222
from zarr.storage import LocalStore, MemoryStore
23-
from zarr.testing.strategies import key_ranges, node_names, np_array_and_chunks, numpy_arrays
24-
from zarr.testing.strategies import keys as zarr_keys
23+
from zarr.testing.strategies import key_ranges, node_names, np_array_and_chunks, numpy_arrays, keys as zarr_keys
2524

2625
MAX_BINARY_SIZE = 100
2726

@@ -82,7 +81,7 @@ def add_group(self, name: str, data: DataObject) -> None:
8281
@rule(
8382
data=st.data(),
8483
name=node_names,
85-
array_and_chunks=np_array_and_chunks(arrays=numpy_arrays(zarr_formats=st.just(3))),
84+
array_and_chunks=np_array_and_chunks(nparrays=numpy_arrays(zarr_formats=st.just(3))),
8685
)
8786
def add_array(
8887
self,
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
from .arrays import keys, key_ranges, node_names, np_array_and_chunks, basic_indices, numpy_arrays, zarr_arrays, keys, zarr_formats
2+
from .array_metadata_generators import array_metadata_v2, array_metadata_v3
3+
4+
5+
6+
__all__ = [
7+
"array_metadata_v2",
8+
"array_metadata_v3",
9+
"keys",
10+
"node_names",
11+
"np_array_and_chunks",
12+
"key_ranges",
13+
"zarr_arrays",
14+
"basic_indices",
15+
"numpy_arrays",
16+
"zarr_arrays",
17+
"zarr_formats"
18+
]
Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
from typing import Any, Iterable, Literal, Tuple, Dict
2+
import numpy as np
3+
import numpy.typing as npt
4+
import numcodecs
5+
from hypothesis import strategies as st
6+
import hypothesis.extra.numpy as npst
7+
from hypothesis import assume
8+
from dataclasses import dataclass, field
9+
10+
from zarr.codecs.bytes import BytesCodec
11+
from zarr.core.chunk_grids import RegularChunkGrid, ChunkGrid
12+
from zarr.core.chunk_key_encodings import DefaultChunkKeyEncoding, ChunkKeyEncoding
13+
from zarr.core.metadata.v2 import ArrayV2Metadata
14+
from zarr.core.metadata.v3 import ArrayV3Metadata
15+
from zarr.core.chunk_key_encodings import ChunkKeyEncoding, ChunkKeyEncodingLike
16+
17+
18+
from .dtypes import v2_dtypes, v3_dtypes
19+
20+
def simple_text():
21+
"""A strategy for generating simple text strings."""
22+
return st.text(st.characters(min_codepoint=32, max_codepoint=126), min_size=1, max_size=10)
23+
24+
25+
def simple_attrs():
26+
"""A strategy for generating simple attribute dictionaries."""
27+
return st.dictionaries(
28+
simple_text(),
29+
st.one_of(st.integers(),
30+
st.floats(allow_nan=False, allow_infinity=False),
31+
st.booleans(),
32+
simple_text()))
33+
34+
35+
def array_shapes(min_dims=1, max_dims=3, max_len=100):
36+
"""A strategy for generating array shapes."""
37+
return st.lists(st.integers(min_value=1, max_value=max_len), min_size=min_dims, max_size=max_dims)
38+
39+
40+
# def zarr_compressors():
41+
# """A strategy for generating Zarr compressors."""
42+
# return st.sampled_from([None, Blosc(), GZip(), Zstd(), LZ4()])
43+
44+
45+
# def zarr_codecs():
46+
# """A strategy for generating Zarr codecs."""
47+
# return st.sampled_from([BytesCodec(), Blosc(), GZip(), Zstd(), LZ4()])
48+
49+
50+
def zarr_filters():
51+
"""A strategy for generating Zarr filters."""
52+
return st.lists(st.just(numcodecs.Delta(dtype='i4')), min_size=0, max_size=2) # Example filter, expand as needed
53+
54+
55+
def zarr_storage_transformers():
56+
"""A strategy for generating Zarr storage transformers."""
57+
return st.lists(st.dictionaries(simple_text(), st.one_of(st.integers(), st.floats(), st.booleans(), simple_text())), min_size=0, max_size=2)
58+
59+
60+
@st.composite
61+
def array_metadata_v2(draw: st.DrawFn) -> ArrayV2Metadata:
62+
"""Generates valid ArrayV2Metadata objects for property-based testing."""
63+
dims = draw(st.integers(min_value=1, max_value=3)) # Limit dimensions for complexity
64+
shape = tuple(draw(array_shapes(min_dims=dims, max_dims=dims, max_len=100)))
65+
max_chunk_len = max(shape) if shape else 100
66+
chunks = tuple(draw(st.lists(st.integers(min_value=1, max_value=max_chunk_len), min_size=dims, max_size=dims)))
67+
68+
# Validate shape and chunks relationship
69+
assume(all(c <= s for s, c in zip(shape, chunks))) # Chunk size must be <= shape
70+
71+
dtype = draw(v2_dtypes())
72+
fill_value = draw(st.one_of([st.none(), npst.from_dtype(dtype)]))
73+
order = draw(st.sampled_from(["C", "F"]))
74+
dimension_separator = draw(st.sampled_from([".", "/"]))
75+
#compressor = draw(zarr_compressors())
76+
filters = tuple(draw(zarr_filters())) if draw(st.booleans()) else None
77+
attributes = draw(simple_attrs())
78+
79+
# Construct the metadata object. Type hints are crucial here for correctness.
80+
return ArrayV2Metadata(
81+
shape=shape,
82+
dtype=dtype,
83+
chunks=chunks,
84+
fill_value=fill_value,
85+
order=order,
86+
dimension_separator=dimension_separator,
87+
# compressor=compressor,
88+
filters=filters,
89+
attributes=attributes,
90+
)
91+
92+
93+
@st.composite
94+
def array_metadata_v3(draw: st.DrawFn) -> ArrayV3Metadata:
95+
"""Generates valid ArrayV3Metadata objects for property-based testing."""
96+
dims = draw(st.integers(min_value=1, max_value=3))
97+
shape = tuple(draw(array_shapes(min_dims=dims, max_dims=dims, max_len=100)))
98+
max_chunk_len = max(shape) if shape else 100
99+
chunks = tuple(draw(st.lists(st.integers(min_value=1, max_value=max_chunk_len), min_size=dims, max_size=dims)))
100+
assume(all(c <= s for s, c in zip(shape, chunks)))
101+
102+
dtype = draw(v3_dtypes())
103+
fill_value = draw(npst.from_dtype(dtype))
104+
chunk_grid = RegularChunkGrid(chunks) # Ensure chunks is passed as tuple.
105+
chunk_key_encoding = DefaultChunkKeyEncoding(separator="/") # Or st.sampled_from(["/", "."])
106+
#codecs = tuple(draw(st.lists(zarr_codecs(), min_size=0, max_size=3)))
107+
attributes = draw(simple_attrs())
108+
dimension_names = tuple(draw(st.lists(st.one_of(st.none(), simple_text()), min_size=dims, max_size=dims))) if draw(st.booleans()) else None
109+
storage_transformers = tuple(draw(zarr_storage_transformers()))
110+
111+
return ArrayV3Metadata(
112+
shape=shape,
113+
data_type=dtype,
114+
chunk_grid=chunk_grid,
115+
chunk_key_encoding=chunk_key_encoding,
116+
fill_value=fill_value,
117+
# codecs=codecs,
118+
attributes=attributes,
119+
dimension_names=dimension_names,
120+
storage_transformers=storage_transformers,
121+
)

0 commit comments

Comments
 (0)