Skip to content

Commit 82a272d

Browse files
committed
Fix test
1 parent 1b7cf9a commit 82a272d

File tree

2 files changed

+6
-4
lines changed

2 files changed

+6
-4
lines changed

src/zarr/testing/strategies.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -165,13 +165,15 @@ def numpy_arrays(
165165
draw: st.DrawFn,
166166
*,
167167
shapes: st.SearchStrategy[tuple[int, ...]] = array_shapes,
168-
zarr_formats: st.SearchStrategy[ZarrFormat] = zarr_formats,
168+
dtype: np.dtype[Any] | None = None,
169+
zarr_formats: st.SearchStrategy[ZarrFormat] | None = zarr_formats,
169170
) -> Any:
170171
"""
171172
Generate numpy arrays that can be saved in the provided Zarr format.
172173
"""
173174
zarr_format = draw(zarr_formats)
174-
dtype = draw(v3_dtypes() if zarr_format == 3 else v2_dtypes())
175+
if dtype is None:
176+
dtype = draw(v3_dtypes() if zarr_format == 3 else v2_dtypes())
175177
if np.issubdtype(dtype, np.str_):
176178
safe_unicode_strings = safe_unicode_for_dtype(dtype)
177179
return draw(npst.arrays(dtype=dtype, shape=shapes, elements=safe_unicode_strings))

tests/test_properties.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def test_basic_indexing(data: st.DataObject) -> None:
5757
actual = zarray[indexer]
5858
assert_array_equal(nparray[indexer], actual)
5959

60-
new_data = data.draw(npst.arrays(shape=st.just(actual.shape), dtype=nparray.dtype))
60+
new_data = data.draw(numpy_arrays(shapes=st.just(actual.shape), dtype=nparray.dtype))
6161
zarray[indexer] = new_data
6262
nparray[indexer] = new_data
6363
assert_array_equal(nparray, zarray[:])
@@ -74,7 +74,7 @@ def test_oindex(data: st.DataObject) -> None:
7474
assert_array_equal(nparray[npindexer], actual)
7575

7676
assume(zarray.shards is None) # GH2834
77-
new_data = data.draw(npst.arrays(shape=st.just(actual.shape), dtype=nparray.dtype))
77+
new_data = data.draw(numpy_arrays(shapes=st.just(actual.shape), dtype=nparray.dtype))
7878
nparray[npindexer] = new_data
7979
zarray.oindex[zindexer] = new_data
8080
assert_array_equal(nparray, zarray[:])

0 commit comments

Comments
 (0)