Skip to content

Commit 7e32463

Browse files
committed
Fix test
1 parent 026936a commit 7e32463

File tree

2 files changed

+6
-4
lines changed

2 files changed

+6
-4
lines changed

src/zarr/testing/strategies.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -165,13 +165,15 @@ def numpy_arrays(
165165
draw: st.DrawFn,
166166
*,
167167
shapes: st.SearchStrategy[tuple[int, ...]] = array_shapes,
168-
zarr_formats: st.SearchStrategy[ZarrFormat] = zarr_formats,
168+
dtype: np.dtype[Any] | None = None,
169+
zarr_formats: st.SearchStrategy[ZarrFormat] | None = zarr_formats,
169170
) -> Any:
170171
"""
171172
Generate numpy arrays that can be saved in the provided Zarr format.
172173
"""
173174
zarr_format = draw(zarr_formats)
174-
dtype = draw(v3_dtypes() if zarr_format == 3 else v2_dtypes())
175+
if dtype is None:
176+
dtype = draw(v3_dtypes() if zarr_format == 3 else v2_dtypes())
175177
if np.issubdtype(dtype, np.str_):
176178
safe_unicode_strings = safe_unicode_for_dtype(dtype)
177179
return draw(npst.arrays(dtype=dtype, shape=shapes, elements=safe_unicode_strings))

tests/test_properties.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def test_basic_indexing(data: st.DataObject) -> None:
5858
actual = zarray[indexer]
5959
assert_array_equal(nparray[indexer], actual)
6060

61-
new_data = data.draw(npst.arrays(shape=st.just(actual.shape), dtype=nparray.dtype))
61+
new_data = data.draw(numpy_arrays(shapes=st.just(actual.shape), dtype=nparray.dtype))
6262
zarray[indexer] = new_data
6363
nparray[indexer] = new_data
6464
assert_array_equal(nparray, zarray[:])
@@ -80,7 +80,7 @@ def test_oindex(data: st.DataObject) -> None:
8080
if isinstance(idxr, np.ndarray) and idxr.size != np.unique(idxr).size:
8181
# behaviour of setitem with repeated indices is not guaranteed in practice
8282
assume(False)
83-
new_data = data.draw(npst.arrays(shape=st.just(actual.shape), dtype=nparray.dtype))
83+
new_data = data.draw(numpy_arrays(shapes=st.just(actual.shape), dtype=nparray.dtype))
8484
nparray[npindexer] = new_data
8585
zarray.oindex[zindexer] = new_data
8686
note((new_data, npindexer, nparray, zindexer, zarray[:]))

0 commit comments

Comments
 (0)