Skip to content

Commit 7cc1bc2

Browse files
committed
Add compressor, codec pipeline strategy
1 parent 870265a commit 7cc1bc2

File tree

1 file changed

+24
-6
lines changed

1 file changed

+24
-6
lines changed

src/zarr/testing/strategies.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import sys
2-
from typing import Any
2+
from typing import Any, Literal
33

44
import hypothesis.extra.numpy as npst
55
import hypothesis.strategies as st
@@ -86,11 +86,30 @@ def safe_unicode_for_dtype(dtype: np.dtype[np.str_]) -> st.SearchStrategy[str]:
8686
# i.e. stores.examples() will always return the same object per Store class.
8787
# So we map a clear to reset the store.
8888
stores = st.builds(MemoryStore, st.just({})).map(lambda x: sync(x.clear()))
89-
compressors = st.sampled_from([None, "default"])
9089
zarr_formats: st.SearchStrategy[ZarrFormat] = st.sampled_from([2, 3])
9190
array_shapes = npst.array_shapes(max_dims=4, min_side=0)
9291

9392

93+
@st.composite # type: ignore[misc]
94+
def codecs(
95+
draw: st.DrawFn,
96+
*,
97+
zarr_formats: st.SearchStrategy[Literal[2, 3]] = zarr_formats,
98+
dtypes: st.SearchStrategy[np.dtype] | None = None,
99+
) -> Any:
100+
zarr_format = draw(zarr_formats)
101+
codec_kwargs = {"filters": draw(st.none() | st.just(()))}
102+
zarr_codecs = st.one_of(
103+
st.builds(zarr.codecs.ZstdCodec, level=st.integers(min_value=0, max_value=9)),
104+
# TODO: other codecs
105+
)
106+
if zarr_format == 2:
107+
codec_kwargs["compressors"] = draw(st.none() | st.just(()))
108+
else:
109+
codec_kwargs["compressors"] = draw(st.none() | st.just(()) | zarr_codecs)
110+
return codec_kwargs
111+
112+
94113
@st.composite # type: ignore[misc]
95114
def numpy_arrays(
96115
draw: st.DrawFn,
@@ -139,12 +158,12 @@ def arrays(
139158
draw: st.DrawFn,
140159
*,
141160
shapes: st.SearchStrategy[tuple[int, ...]] = array_shapes,
142-
compressors: st.SearchStrategy = compressors,
143161
stores: st.SearchStrategy[StoreLike] = stores,
144162
paths: st.SearchStrategy[str | None] = paths,
145163
array_names: st.SearchStrategy = array_names,
146164
arrays: st.SearchStrategy | None = None,
147165
attrs: st.SearchStrategy = attrs,
166+
codecs: st.SearchStrategy = codecs,
148167
zarr_formats: st.SearchStrategy = zarr_formats,
149168
) -> Array:
150169
store = draw(stores)
@@ -157,21 +176,20 @@ def arrays(
157176
nparray, chunks = draw(np_array_and_chunks(arrays=arrays))
158177
# test that None works too.
159178
fill_value = draw(st.one_of([st.none(), npst.from_dtype(nparray.dtype)]))
160-
# compressor = draw(compressors)
161179

162180
expected_attrs = {} if attributes is None else attributes
163181

164182
array_path = _dereference_path(path, name)
165183
root = zarr.open_group(store, mode="w", zarr_format=zarr_format)
166-
184+
codec_kwargs = draw(codecs(zarr_formats=st.just(zarr_format), dtypes=st.just(nparray.dtype)))
167185
a = root.create_array(
168186
array_path,
169187
shape=nparray.shape,
170188
chunks=chunks,
171189
dtype=nparray.dtype,
172190
attributes=attributes,
173-
# compressor=compressor, # FIXME
174191
fill_value=fill_value,
192+
**codec_kwargs,
175193
)
176194

177195
assert isinstance(a, Array)

0 commit comments

Comments
 (0)