@@ -43,7 +43,7 @@ def paths(draw: st.DrawFn, *, max_num_nodes: int | None = None) -> str:
4343 return draw (st .just ("/" ) | keys (max_num_nodes = max_num_nodes ))
4444
4545
46- def v3_dtypes () -> st .SearchStrategy [np .dtype [Any ]]:
46+ def dtypes () -> st .SearchStrategy [np .dtype [Any ]]:
4747 return (
4848 npst .boolean_dtypes ()
4949 | npst .integer_dtypes (endianness = "=" )
@@ -57,18 +57,12 @@ def v3_dtypes() -> st.SearchStrategy[np.dtype[Any]]:
5757 )
5858
5959
60+ def v3_dtypes () -> st .SearchStrategy [np .dtype [Any ]]:
61+ return dtypes ()
62+
63+
6064def v2_dtypes () -> st .SearchStrategy [np .dtype [Any ]]:
61- return (
62- npst .boolean_dtypes ()
63- | npst .integer_dtypes (endianness = "=" )
64- | npst .unsigned_integer_dtypes (endianness = "=" )
65- | npst .floating_dtypes (endianness = "=" )
66- | npst .complex_number_dtypes (endianness = "=" )
67- | npst .byte_string_dtypes (endianness = "=" )
68- | npst .unicode_string_dtypes (endianness = "=" )
69- | npst .datetime64_dtypes (endianness = "=" )
70- | npst .timedelta64_dtypes (endianness = "=" )
71- )
65+ return dtypes ()
7266
7367
7468def safe_unicode_for_dtype (dtype : np .dtype [np .str_ ]) -> st .SearchStrategy [str ]:
@@ -144,7 +138,7 @@ def array_metadata(
144138 shape = draw (array_shapes ())
145139 ndim = len (shape )
146140 chunk_shape = draw (array_shapes (min_dims = ndim , max_dims = ndim ))
147- np_dtype = draw (v3_dtypes ())
141+ np_dtype = draw (dtypes ())
148142 dtype = get_data_type_from_native_dtype (np_dtype )
149143 fill_value = draw (npst .from_dtype (np_dtype ))
150144 if zarr_format == 2 :
@@ -179,14 +173,12 @@ def numpy_arrays(
179173 * ,
180174 shapes : st .SearchStrategy [tuple [int , ...]] = array_shapes ,
181175 dtype : np .dtype [Any ] | None = None ,
182- zarr_formats : st .SearchStrategy [ZarrFormat ] = zarr_formats ,
183176) -> npt .NDArray [Any ]:
184177 """
185178 Generate numpy arrays that can be saved in the provided Zarr format.
186179 """
187- zarr_format = draw (zarr_formats )
188180 if dtype is None :
189- dtype = draw (v3_dtypes () if zarr_format == 3 else v2_dtypes ())
181+ dtype = draw (dtypes ())
190182 if np .issubdtype (dtype , np .str_ ):
191183 safe_unicode_strings = safe_unicode_for_dtype (dtype )
192184 return draw (npst .arrays (dtype = dtype , shape = shapes , elements = safe_unicode_strings ))
@@ -255,17 +247,24 @@ def arrays(
255247 attrs : st .SearchStrategy = attrs ,
256248 zarr_formats : st .SearchStrategy = zarr_formats ,
257249) -> Array :
258- store = draw (stores )
259- path = draw (paths )
260- name = draw (array_names )
261- attributes = draw (attrs )
262- zarr_format = draw (zarr_formats )
250+ store = draw (stores , label = "store" )
251+ path = draw (paths , label = "array parent" )
252+ name = draw (array_names , label = "array name" )
253+ attributes = draw (attrs , label = "attributes" )
254+ zarr_format = draw (zarr_formats , label = "zarr format" )
263255 if arrays is None :
264- arrays = numpy_arrays (shapes = shapes , zarr_formats = st .just (zarr_format ))
265- nparray = draw (arrays )
266- chunk_shape = draw (chunk_shapes (shape = nparray .shape ))
256+ arrays = numpy_arrays (shapes = shapes )
257+ nparray = draw (arrays , label = "array data" )
258+ chunk_shape = draw (chunk_shapes (shape = nparray .shape ), label = "chunk shape" )
259+ extra_kwargs = {}
267260 if zarr_format == 3 and all (c > 0 for c in chunk_shape ):
268- shard_shape = draw (st .none () | shard_shapes (shape = nparray .shape , chunk_shape = chunk_shape ))
261+ shard_shape = draw (
262+ st .none () | shard_shapes (shape = nparray .shape , chunk_shape = chunk_shape ),
263+ label = "shard shape" ,
264+ )
265+ extra_kwargs ["dimension_names" ] = draw (
266+ dimension_names (ndim = nparray .ndim ), label = "dimension names"
267+ )
269268 else :
270269 shard_shape = None
271270 # test that None works too.
@@ -286,6 +285,7 @@ def arrays(
286285 attributes = attributes ,
287286 # compressor=compressor, # FIXME
288287 fill_value = fill_value ,
288+ ** extra_kwargs ,
289289 )
290290
291291 assert isinstance (a , Array )
@@ -385,13 +385,19 @@ def orthogonal_indices(
385385 npindexer = []
386386 ndim = len (shape )
387387 for axis , size in enumerate (shape ):
388- val = draw (
389- npst .integer_array_indices (
388+ if size != 0 :
389+ strategy = npst .integer_array_indices (
390390 shape = (size ,), result_shape = npst .array_shapes (min_side = 1 , max_side = size , max_dims = 1 )
391- )
392- | basic_indices (min_dims = 1 , shape = (size ,), allow_ellipsis = False )
393- .map (lambda x : (x ,) if not isinstance (x , tuple ) else x ) # bare ints, slices
394- .filter (bool ) # skip empty tuple
391+ ) | basic_indices (min_dims = 1 , shape = (size ,), allow_ellipsis = False )
392+ else :
393+ strategy = basic_indices (min_dims = 1 , shape = (size ,), allow_ellipsis = False )
394+
395+ val = draw (
396+ strategy
397+ # bare ints, slices
398+ .map (lambda x : (x ,) if not isinstance (x , tuple ) else x )
399+ # skip empty tuple
400+ .filter (bool )
395401 )
396402 (idxr ,) = val
397403 if isinstance (idxr , int ):
0 commit comments