1- from typing import Any
1+ from typing import Any , Literal
22
33import hypothesis .extra .numpy as npst
44import hypothesis .strategies as st
1919)
2020
2121
22- def dtypes () -> st .SearchStrategy [np .dtype ]:
22+ def v3_dtypes () -> st .SearchStrategy [np .dtype ]:
2323 return (
2424 npst .boolean_dtypes ()
2525 | npst .integer_dtypes (endianness = "=" )
2626 | npst .unsigned_integer_dtypes (endianness = "=" )
2727 | npst .floating_dtypes (endianness = "=" )
2828 | npst .complex_number_dtypes (endianness = "=" )
29+ # | npst.byte_string_dtypes(endianness="=")
2930 # | npst.unicode_string_dtypes()
3031 # | npst.datetime64_dtypes()
3132 # | npst.timedelta64_dtypes()
3233 )
3334
3435
36+ def v2_dtypes () -> st .SearchStrategy [np .dtype ]:
37+ return (
38+ npst .boolean_dtypes ()
39+ | npst .integer_dtypes (endianness = "=" )
40+ | npst .unsigned_integer_dtypes (endianness = "=" )
41+ | npst .floating_dtypes (endianness = "=" )
42+ | npst .complex_number_dtypes (endianness = "=" )
43+ | npst .byte_string_dtypes (endianness = "=" )
44+ | npst .unicode_string_dtypes (endianness = "=" )
45+ | npst .datetime64_dtypes ()
46+ # | npst.timedelta64_dtypes()
47+ )
48+
49+
3550# From https://zarr-specs.readthedocs.io/en/latest/v3/core/v3.0.html#node-names
3651# 1. must not be the empty string ("")
3752# 2. must not include the character "/"
@@ -46,18 +61,29 @@ def dtypes() -> st.SearchStrategy[np.dtype]:
4661array_names = node_names
4762attrs = st .none () | st .dictionaries (_attr_keys , _attr_values )
4863paths = st .lists (node_names , min_size = 1 ).map (lambda x : "/" .join (x )) | st .just ("/" )
49- np_arrays = npst .arrays (
50- dtype = dtypes (),
51- shape = npst .array_shapes (max_dims = 4 ),
52- )
5364stores = st .builds (MemoryStore , st .just ({}), mode = st .just ("w" ))
5465compressors = st .sampled_from ([None , "default" ])
55- zarr_formats = st .sampled_from ([2 , 3 ])
66+ zarr_formats : st .SearchStrategy [Literal [2 , 3 ]] = st .sampled_from ([2 , 3 ])
67+ array_shapes = npst .array_shapes (max_dims = 4 )
68+
69+
70+ @st .composite # type: ignore[misc]
71+ def numpy_arrays (
72+ draw : st .DrawFn ,
73+ * ,
74+ shapes : st .SearchStrategy [tuple [int , ...]] = array_shapes ,
75+ zarr_formats : st .SearchStrategy [Literal [2 , 3 ]] = zarr_formats ,
76+ ) -> Any :
77+ """
78+ Generate numpy arrays that can be saved in the provided Zarr format.
79+ """
80+ zarr_format = draw (zarr_formats )
81+ return draw (npst .arrays (dtype = v3_dtypes () if zarr_format == 3 else v2_dtypes (), shape = shapes ))
5682
5783
5884@st .composite # type: ignore[misc]
5985def np_array_and_chunks (
60- draw : st .DrawFn , * , arrays : st .SearchStrategy [np .ndarray ] = np_arrays
86+ draw : st .DrawFn , * , arrays : st .SearchStrategy [np .ndarray ] = numpy_arrays
6187) -> tuple [np .ndarray , tuple [int ]]: # type: ignore[type-arg]
6288 """A hypothesis strategy to generate small sized random arrays.
6389
@@ -76,20 +102,23 @@ def np_array_and_chunks(
76102def arrays (
77103 draw : st .DrawFn ,
78104 * ,
105+ shapes : st .SearchStrategy [tuple [int , ...]] = array_shapes ,
79106 compressors : st .SearchStrategy = compressors ,
80107 stores : st .SearchStrategy [StoreLike ] = stores ,
81- arrays : st .SearchStrategy [np .ndarray ] = np_arrays ,
82108 paths : st .SearchStrategy [None | str ] = paths ,
83109 array_names : st .SearchStrategy = array_names ,
110+ arrays : st .SearchStrategy | None = None ,
84111 attrs : st .SearchStrategy = attrs ,
85112 zarr_formats : st .SearchStrategy = zarr_formats ,
86113) -> Array :
87114 store = draw (stores )
88- nparray , chunks = draw (np_array_and_chunks (arrays = arrays ))
89115 path = draw (paths )
90116 name = draw (array_names )
91117 attributes = draw (attrs )
92118 zarr_format = draw (zarr_formats )
119+ if arrays is None :
120+ arrays = numpy_arrays (shapes = shapes , zarr_formats = st .just (zarr_format ))
121+ nparray , chunks = draw (np_array_and_chunks (arrays = arrays ))
93122 # test that None works too.
94123 fill_value = draw (st .one_of ([st .none (), npst .from_dtype (nparray .dtype )]))
95124 # compressor = draw(compressors)
0 commit comments