Skip to content

Commit 63eb1ea

Browse files
committed
Update strategy priorities:
1. Emphasize arrays of side > 1, 2. Emphasize indexing the last chunk for both setitem & getitem
1 parent 1a351a4 commit 63eb1ea

File tree

2 files changed

+56
-13
lines changed

2 files changed

+56
-13
lines changed

src/zarr/testing/strategies.py

Lines changed: 52 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
1+
import math
12
import sys
23
from typing import Any, Literal
34

45
import hypothesis.extra.numpy as npst
56
import hypothesis.strategies as st
67
import numpy as np
7-
from hypothesis import given, settings # noqa: F401
8+
from hypothesis import event, given, settings # noqa: F401
89
from hypothesis.strategies import SearchStrategy
910

1011
import zarr
@@ -97,7 +98,8 @@ def clear_store(x: Store) -> Store:
9798
stores = st.builds(MemoryStore, st.just({})).map(clear_store)
9899
compressors = st.sampled_from([None, "default"])
99100
zarr_formats: st.SearchStrategy[ZarrFormat] = st.sampled_from([3, 2])
100-
array_shapes = npst.array_shapes(max_dims=4, min_side=0)
101+
# We de-prioritize arrays having dim sizes 0, 1, 2
102+
array_shapes = npst.array_shapes(max_dims=4, min_side=3) | npst.array_shapes(max_dims=4, min_side=0)
101103

102104

103105
@st.composite # type: ignore[misc]
@@ -174,11 +176,19 @@ def chunk_shapes(draw: st.DrawFn, *, shape: tuple[int, ...]) -> tuple[int, ...]:
174176
st.tuples(*[st.integers(min_value=0 if size == 0 else 1, max_value=size) for size in shape])
175177
)
176178
# 2. and now generate the chunks tuple
177-
return tuple(
179+
chunks = tuple(
178180
size // nchunks if nchunks > 0 else 0
179181
for size, nchunks in zip(shape, numchunks, strict=True)
180182
)
181183

184+
for c in chunks:
185+
event("chunk size", c)
186+
187+
if any((c != 0 and s % c != 0) for s, c in zip(shape, chunks, strict=True)):
188+
event("smaller last chunk")
189+
190+
return chunks
191+
182192

183193
@st.composite # type: ignore[misc]
184194
def shard_shapes(
@@ -267,23 +277,55 @@ def arrays(
267277
return a
268278

269279

280+
@st.composite # type: ignore[misc]
281+
def simple_arrays(
282+
draw: st.DrawFn,
283+
*,
284+
shapes: st.SearchStrategy[tuple[int, ...]] = array_shapes,
285+
) -> Any:
286+
return draw(
287+
arrays(
288+
shapes=shapes,
289+
attrs=st.none(),
290+
paths=paths(max_num_nodes=2),
291+
compressors=st.sampled_from([None, "default"]),
292+
)
293+
)
294+
295+
270296
def is_negative_slice(idx: Any) -> bool:
271297
return isinstance(idx, slice) and idx.step is not None and idx.step < 0
272298

273299

300+
@st.composite # type: ignore[misc]
301+
def end_slices(draw: st.DrawFn, *, shape: tuple[int]) -> Any:
302+
"""
303+
A strategy that slices ranges that include the last chunk.
304+
This is intended to stress-test handling of a possibly smaller last chunk.
305+
"""
306+
slicers = []
307+
for size in shape:
308+
start = draw(st.integers(min_value=size // 2, max_value=size - 1))
309+
length = draw(st.integers(min_value=0, max_value=size - start))
310+
slicers.append(slice(start, start + length))
311+
event("drawing end slice")
312+
return tuple(slicers)
313+
314+
274315
@st.composite # type: ignore[misc]
275316
def basic_indices(draw: st.DrawFn, *, shape: tuple[int], **kwargs: Any) -> Any:
276317
"""Basic indices without unsupported negative slices."""
277-
return draw(
278-
npst.basic_indices(shape=shape, **kwargs).filter(
279-
lambda idxr: (
280-
not (
281-
is_negative_slice(idxr)
282-
or (isinstance(idxr, tuple) and any(is_negative_slice(idx) for idx in idxr))
283-
)
318+
strategy = npst.basic_indices(shape=shape, **kwargs).filter(
319+
lambda idxr: (
320+
not (
321+
is_negative_slice(idxr)
322+
or (isinstance(idxr, tuple) and any(is_negative_slice(idx) for idx in idxr))
284323
)
285324
)
286325
)
326+
if math.prod(shape) >= 3:
327+
strategy = end_slices(shape=shape) | strategy
328+
return draw(strategy)
287329

288330

289331
@st.composite # type: ignore[misc]

tests/test_properties.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
basic_indices,
1919
numpy_arrays,
2020
orthogonal_indices,
21+
simple_arrays,
2122
stores,
2223
zarr_formats,
2324
)
@@ -50,7 +51,7 @@ def test_array_creates_implicit_groups(array):
5051

5152
@given(data=st.data())
5253
def test_basic_indexing(data: st.DataObject) -> None:
53-
zarray = data.draw(arrays())
54+
zarray = data.draw(simple_arrays())
5455
nparray = zarray[:]
5556
indexer = data.draw(basic_indices(shape=nparray.shape))
5657
actual = zarray[indexer]
@@ -65,7 +66,7 @@ def test_basic_indexing(data: st.DataObject) -> None:
6566
@given(data=st.data())
6667
def test_oindex(data: st.DataObject) -> None:
6768
# integer_array_indices can't handle 0-size dimensions.
68-
zarray = data.draw(arrays(shapes=npst.array_shapes(max_dims=4, min_side=1)))
69+
zarray = data.draw(simple_arrays(shapes=npst.array_shapes(max_dims=4, min_side=1)))
6970
nparray = zarray[:]
7071

7172
zindexer, npindexer = data.draw(orthogonal_indices(shape=nparray.shape))
@@ -82,7 +83,7 @@ def test_oindex(data: st.DataObject) -> None:
8283
@given(data=st.data())
8384
def test_vindex(data: st.DataObject) -> None:
8485
# integer_array_indices can't handle 0-size dimensions.
85-
zarray = data.draw(arrays(shapes=npst.array_shapes(max_dims=4, min_side=1)))
86+
zarray = data.draw(simple_arrays(shapes=npst.array_shapes(max_dims=4, min_side=1)))
8687
nparray = zarray[:]
8788

8889
indexer = data.draw(

0 commit comments

Comments
 (0)