Skip to content

Commit 3a85bdf

Browse files
committed
add oindex property test
1 parent 74d0995 commit 3a85bdf

File tree

2 files changed

+48
-2
lines changed

2 files changed

+48
-2
lines changed

src/zarr/testing/strategies.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,34 @@ def basic_indices(draw: st.DrawFn, *, shape: tuple[int], **kwargs: Any) -> Any:
188188
)
189189

190190

191+
@st.composite # type: ignore[misc]
192+
def orthogonal_indices(
193+
draw: st.DrawFn, *, shape: tuple[int]
194+
) -> tuple[tuple[np.ndarray[Any, Any], ...], tuple[np.ndarray[Any, Any], ...]]:
195+
"""
196+
Strategy that returns
197+
(1) a tuple of integer arrays used for orthogonal indexing of Zarr arrays.
198+
(2) an tuple of integer arrays that can be used for equivalent indexing of numpy arrays
199+
"""
200+
zindexer = []
201+
npindexer = []
202+
ndim = len(shape)
203+
for axis, size in enumerate(shape):
204+
(idxr,) = draw(
205+
npst.integer_array_indices(
206+
shape=(size,), result_shape=npst.array_shapes(min_side=1, max_side=size, max_dims=1)
207+
)
208+
# | npst.basic_indices(shape=(size,), allow_ellipsis=False)
209+
)
210+
zindexer.append(idxr)
211+
if isinstance(idxr, np.ndarray):
212+
newshape = [1] * ndim
213+
newshape[axis] = idxr.size
214+
idxr = idxr.reshape(newshape)
215+
npindexer.append(idxr)
216+
return tuple(zindexer), np.broadcast_arrays(*npindexer)
217+
218+
191219
def key_ranges(
192220
keys: SearchStrategy = node_names, max_size: int | None = None
193221
) -> SearchStrategy[list[int]]:

tests/test_properties.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,15 @@
66

77
import hypothesis.extra.numpy as npst
88
import hypothesis.strategies as st
9-
from hypothesis import given
9+
from hypothesis import given, settings
1010

11-
from zarr.testing.strategies import arrays, basic_indices, numpy_arrays, zarr_formats
11+
from zarr.testing.strategies import (
12+
arrays,
13+
basic_indices,
14+
numpy_arrays,
15+
orthogonal_indices,
16+
zarr_formats,
17+
)
1218

1319

1420
@given(data=st.data(), zarr_format=zarr_formats)
@@ -32,6 +38,18 @@ def test_basic_indexing(data: st.DataObject) -> None:
3238
assert_array_equal(nparray, zarray[:])
3339

3440

41+
@settings(report_multiple_bugs=False)
42+
@given(data=st.data())
43+
def test_oindex(data: st.DataObject) -> None:
44+
# integer_array_indices can't handle 0-size dimensions.
45+
zarray = data.draw(arrays(shapes=npst.array_shapes(max_dims=4, min_side=1)))
46+
nparray = zarray[:]
47+
48+
zindexer, npindexer = data.draw(orthogonal_indices(shape=nparray.shape))
49+
actual = zarray.oindex[zindexer]
50+
assert_array_equal(nparray[npindexer], actual)
51+
52+
3553
@given(data=st.data())
3654
def test_vindex(data: st.DataObject) -> None:
3755
# integer_array_indices can't handle 0-size dimensions.

0 commit comments

Comments
 (0)