diff --git a/src/zarr/testing/stateful.py b/src/zarr/testing/stateful.py index f83d942549..6faf5dc6d3 100644 --- a/src/zarr/testing/stateful.py +++ b/src/zarr/testing/stateful.py @@ -1,5 +1,7 @@ import builtins -from typing import Any +import functools +from collections.abc import Callable +from typing import Any, TypeVar, cast import hypothesis.extra.numpy as npst import hypothesis.strategies as st @@ -24,15 +26,43 @@ from zarr.testing.strategies import ( basic_indices, chunk_paths, + dimension_names, key_ranges, node_names, np_array_and_chunks, - numpy_arrays, + orthogonal_indices, ) from zarr.testing.strategies import keys as zarr_keys MAX_BINARY_SIZE = 100 +F = TypeVar("F", bound=Callable[..., Any]) + + +def with_frequency(frequency: float) -> Callable[[F], F]: + """This needs to be deterministic for hypothesis replaying""" + + def decorator(func: F) -> F: + counter_attr = f"__{func.__name__}_counter" + + @functools.wraps(func) + def wrapper(*args: Any, **kwargs: Any) -> Any: + return func(*args, **kwargs) + + @precondition + def frequency_check(f: Any) -> Any: + if not hasattr(f, counter_attr): + setattr(f, counter_attr, 0) + + current_count = getattr(f, counter_attr) + 1 + setattr(f, counter_attr, current_count) + + return (current_count * frequency) % 1.0 >= (1.0 - frequency) + + return cast(F, frequency_check(wrapper)) + + return decorator + def split_prefix_name(path: str) -> tuple[str, str]: split = path.rsplit("/", maxsplit=1) @@ -90,11 +120,7 @@ def add_group(self, name: str, data: DataObject) -> None: zarr.group(store=self.store, path=path) zarr.group(store=self.model, path=path) - @rule( - data=st.data(), - name=node_names, - array_and_chunks=np_array_and_chunks(arrays=numpy_arrays(zarr_formats=st.just(3))), - ) + @rule(data=st.data(), name=node_names, array_and_chunks=np_array_and_chunks()) def add_array( self, data: DataObject, @@ -122,12 +148,17 @@ def add_array( path=path, store=store, fill_value=fill_value, + zarr_format=3, + dimension_names=data.draw( + dimension_names(ndim=array.ndim), label="dimension names" + ), # Chose bytes codec to avoid wasting time compressing the data being written codecs=[BytesCodec()], ) self.all_arrays.add(path) @rule() + @with_frequency(0.25) def clear(self) -> None: note("clearing") import zarr @@ -192,6 +223,14 @@ def delete_chunk(self, data: DataObject) -> None: self._sync(self.model.delete(path)) self._sync(self.store.delete(path)) + @precondition(lambda self: bool(self.all_arrays)) + @rule(data=st.data()) + def check_array(self, data: DataObject) -> None: + path = data.draw(st.sampled_from(sorted(self.all_arrays))) + actual = zarr.open_array(self.store, path=path)[:] + expected = zarr.open_array(self.model, path=path)[:] + np.testing.assert_equal(actual, expected) + @precondition(lambda self: bool(self.all_arrays)) @rule(data=st.data()) def overwrite_array_basic_indexing(self, data: DataObject) -> None: @@ -206,6 +245,20 @@ def overwrite_array_basic_indexing(self, data: DataObject) -> None: model_array[slicer] = new_data store_array[slicer] = new_data + @precondition(lambda self: bool(self.all_arrays)) + @rule(data=st.data()) + def overwrite_array_orthogonal_indexing(self, data: DataObject) -> None: + array = data.draw(st.sampled_from(sorted(self.all_arrays))) + model_array = zarr.open_array(path=array, store=self.model) + store_array = zarr.open_array(path=array, store=self.store) + indexer, _ = data.draw(orthogonal_indices(shape=model_array.shape)) + note(f"overwriting array orthogonal {indexer=}") + new_data = data.draw( + npst.arrays(shape=model_array.oindex[indexer].shape, dtype=model_array.dtype) # type: ignore[union-attr] + ) + model_array.oindex[indexer] = new_data + store_array.oindex[indexer] = new_data + @precondition(lambda self: bool(self.all_arrays)) @rule(data=st.data()) def resize_array(self, data: DataObject) -> None: diff --git a/src/zarr/testing/strategies.py b/src/zarr/testing/strategies.py index 2acf7b944c..27f648826d 100644 --- a/src/zarr/testing/strategies.py +++ b/src/zarr/testing/strategies.py @@ -43,7 +43,7 @@ def paths(draw: st.DrawFn, *, max_num_nodes: int | None = None) -> str: return draw(st.just("/") | keys(max_num_nodes=max_num_nodes)) -def v3_dtypes() -> st.SearchStrategy[np.dtype[Any]]: +def dtypes() -> st.SearchStrategy[np.dtype[Any]]: return ( npst.boolean_dtypes() | npst.integer_dtypes(endianness="=") @@ -57,18 +57,12 @@ def v3_dtypes() -> st.SearchStrategy[np.dtype[Any]]: ) +def v3_dtypes() -> st.SearchStrategy[np.dtype[Any]]: + return dtypes() + + def v2_dtypes() -> st.SearchStrategy[np.dtype[Any]]: - return ( - npst.boolean_dtypes() - | npst.integer_dtypes(endianness="=") - | npst.unsigned_integer_dtypes(endianness="=") - | npst.floating_dtypes(endianness="=") - | npst.complex_number_dtypes(endianness="=") - | npst.byte_string_dtypes(endianness="=") - | npst.unicode_string_dtypes(endianness="=") - | npst.datetime64_dtypes(endianness="=") - | npst.timedelta64_dtypes(endianness="=") - ) + return dtypes() def safe_unicode_for_dtype(dtype: np.dtype[np.str_]) -> st.SearchStrategy[str]: @@ -144,7 +138,7 @@ def array_metadata( shape = draw(array_shapes()) ndim = len(shape) chunk_shape = draw(array_shapes(min_dims=ndim, max_dims=ndim)) - np_dtype = draw(v3_dtypes()) + np_dtype = draw(dtypes()) dtype = get_data_type_from_native_dtype(np_dtype) fill_value = draw(npst.from_dtype(np_dtype)) if zarr_format == 2: @@ -179,14 +173,12 @@ def numpy_arrays( *, shapes: st.SearchStrategy[tuple[int, ...]] = array_shapes, dtype: np.dtype[Any] | None = None, - zarr_formats: st.SearchStrategy[ZarrFormat] = zarr_formats, ) -> npt.NDArray[Any]: """ Generate numpy arrays that can be saved in the provided Zarr format. """ - zarr_format = draw(zarr_formats) if dtype is None: - dtype = draw(v3_dtypes() if zarr_format == 3 else v2_dtypes()) + dtype = draw(dtypes()) if np.issubdtype(dtype, np.str_): safe_unicode_strings = safe_unicode_for_dtype(dtype) return draw(npst.arrays(dtype=dtype, shape=shapes, elements=safe_unicode_strings)) @@ -255,17 +247,24 @@ def arrays( attrs: st.SearchStrategy = attrs, zarr_formats: st.SearchStrategy = zarr_formats, ) -> Array: - store = draw(stores) - path = draw(paths) - name = draw(array_names) - attributes = draw(attrs) - zarr_format = draw(zarr_formats) + store = draw(stores, label="store") + path = draw(paths, label="array parent") + name = draw(array_names, label="array name") + attributes = draw(attrs, label="attributes") + zarr_format = draw(zarr_formats, label="zarr format") if arrays is None: - arrays = numpy_arrays(shapes=shapes, zarr_formats=st.just(zarr_format)) - nparray = draw(arrays) - chunk_shape = draw(chunk_shapes(shape=nparray.shape)) + arrays = numpy_arrays(shapes=shapes) + nparray = draw(arrays, label="array data") + chunk_shape = draw(chunk_shapes(shape=nparray.shape), label="chunk shape") + extra_kwargs = {} if zarr_format == 3 and all(c > 0 for c in chunk_shape): - shard_shape = draw(st.none() | shard_shapes(shape=nparray.shape, chunk_shape=chunk_shape)) + shard_shape = draw( + st.none() | shard_shapes(shape=nparray.shape, chunk_shape=chunk_shape), + label="shard shape", + ) + extra_kwargs["dimension_names"] = draw( + dimension_names(ndim=nparray.ndim), label="dimension names" + ) else: shard_shape = None # test that None works too. @@ -286,6 +285,7 @@ def arrays( attributes=attributes, # compressor=compressor, # FIXME fill_value=fill_value, + **extra_kwargs, ) assert isinstance(a, Array) @@ -385,13 +385,19 @@ def orthogonal_indices( npindexer = [] ndim = len(shape) for axis, size in enumerate(shape): - val = draw( - npst.integer_array_indices( + if size != 0: + strategy = npst.integer_array_indices( shape=(size,), result_shape=npst.array_shapes(min_side=1, max_side=size, max_dims=1) - ) - | basic_indices(min_dims=1, shape=(size,), allow_ellipsis=False) - .map(lambda x: (x,) if not isinstance(x, tuple) else x) # bare ints, slices - .filter(bool) # skip empty tuple + ) | basic_indices(min_dims=1, shape=(size,), allow_ellipsis=False) + else: + strategy = basic_indices(min_dims=1, shape=(size,), allow_ellipsis=False) + + val = draw( + strategy + # bare ints, slices + .map(lambda x: (x,) if not isinstance(x, tuple) else x) + # skip empty tuple + .filter(bool) ) (idxr,) = val if isinstance(idxr, int): diff --git a/tests/test_properties.py b/tests/test_properties.py index b8d50ef0b1..27f847fa69 100644 --- a/tests/test_properties.py +++ b/tests/test_properties.py @@ -76,10 +76,10 @@ def deep_equal(a: Any, b: Any) -> bool: @pytest.mark.filterwarnings("ignore::zarr.core.dtype.common.UnstableSpecificationWarning") -@given(data=st.data(), zarr_format=zarr_formats) -def test_array_roundtrip(data: st.DataObject, zarr_format: int) -> None: - nparray = data.draw(numpy_arrays(zarr_formats=st.just(zarr_format))) - zarray = data.draw(arrays(arrays=st.just(nparray), zarr_formats=st.just(zarr_format))) +@given(data=st.data()) +def test_array_roundtrip(data: st.DataObject) -> None: + nparray = data.draw(numpy_arrays()) + zarray = data.draw(arrays(arrays=st.just(nparray))) assert_array_equal(nparray, zarray[:])