|
| 1 | +import math |
1 | 2 | import sys |
2 | 3 | from typing import Any, Literal |
3 | 4 |
|
4 | 5 | import hypothesis.extra.numpy as npst |
5 | 6 | import hypothesis.strategies as st |
6 | 7 | import numpy as np |
7 | | -from hypothesis import given, settings # noqa: F401 |
| 8 | +from hypothesis import event, given, settings # noqa: F401 |
8 | 9 | from hypothesis.strategies import SearchStrategy |
9 | 10 |
|
10 | 11 | import zarr |
|
28 | 29 | ) |
29 | 30 |
|
30 | 31 |
|
| 32 | +@st.composite # type: ignore[misc] |
| 33 | +def keys(draw: st.DrawFn, *, max_num_nodes: int | None = None) -> Any: |
| 34 | + return draw(st.lists(node_names, min_size=1, max_size=max_num_nodes).map("/".join)) |
| 35 | + |
| 36 | + |
| 37 | +@st.composite # type: ignore[misc] |
| 38 | +def paths(draw: st.DrawFn, *, max_num_nodes: int | None = None) -> Any: |
| 39 | + return draw(st.just("/") | keys(max_num_nodes=max_num_nodes)) |
| 40 | + |
| 41 | + |
31 | 42 | def v3_dtypes() -> st.SearchStrategy[np.dtype]: |
32 | 43 | return ( |
33 | 44 | npst.boolean_dtypes() |
@@ -87,17 +98,19 @@ def clear_store(x: Store) -> Store: |
87 | 98 | node_names = st.text(zarr_key_chars, min_size=1).filter( |
88 | 99 | lambda t: t not in (".", "..") and not t.startswith("__") |
89 | 100 | ) |
| 101 | +short_node_names = st.text(zarr_key_chars, max_size=3, min_size=1).filter( |
| 102 | + lambda t: t not in (".", "..") and not t.startswith("__") |
| 103 | +) |
90 | 104 | array_names = node_names |
91 | 105 | attrs = st.none() | st.dictionaries(_attr_keys, _attr_values) |
92 | | -keys = st.lists(node_names, min_size=1).map("/".join) |
93 | | -paths = st.just("/") | keys |
94 | 106 | # st.builds will only call a new store constructor for different keyword arguments |
95 | 107 | # i.e. stores.examples() will always return the same object per Store class. |
96 | 108 | # So we map a clear to reset the store. |
97 | 109 | stores = st.builds(MemoryStore, st.just({})).map(clear_store) |
98 | 110 | compressors = st.sampled_from([None, "default"]) |
99 | 111 | zarr_formats: st.SearchStrategy[ZarrFormat] = st.sampled_from([3, 2]) |
100 | | -array_shapes = npst.array_shapes(max_dims=4, min_side=0) |
| 112 | +# We de-prioritize arrays having dim sizes 0, 1, 2 |
| 113 | +array_shapes = npst.array_shapes(max_dims=4, min_side=3) | npst.array_shapes(max_dims=4, min_side=0) |
101 | 114 |
|
102 | 115 |
|
103 | 116 | @st.composite # type: ignore[misc] |
@@ -174,11 +187,19 @@ def chunk_shapes(draw: st.DrawFn, *, shape: tuple[int, ...]) -> tuple[int, ...]: |
174 | 187 | st.tuples(*[st.integers(min_value=0 if size == 0 else 1, max_value=size) for size in shape]) |
175 | 188 | ) |
176 | 189 | # 2. and now generate the chunks tuple |
177 | | - return tuple( |
| 190 | + chunks = tuple( |
178 | 191 | size // nchunks if nchunks > 0 else 0 |
179 | 192 | for size, nchunks in zip(shape, numchunks, strict=True) |
180 | 193 | ) |
181 | 194 |
|
| 195 | + for c in chunks: |
| 196 | + event("chunk size", c) |
| 197 | + |
| 198 | + if any((c != 0 and s % c != 0) for s, c in zip(shape, chunks, strict=True)): |
| 199 | + event("smaller last chunk") |
| 200 | + |
| 201 | + return chunks |
| 202 | + |
182 | 203 |
|
183 | 204 | @st.composite # type: ignore[misc] |
184 | 205 | def shard_shapes( |
@@ -211,7 +232,7 @@ def arrays( |
211 | 232 | shapes: st.SearchStrategy[tuple[int, ...]] = array_shapes, |
212 | 233 | compressors: st.SearchStrategy = compressors, |
213 | 234 | stores: st.SearchStrategy[StoreLike] = stores, |
214 | | - paths: st.SearchStrategy[str | None] = paths, |
| 235 | + paths: st.SearchStrategy[str | None] = paths(), # noqa: B008 |
215 | 236 | array_names: st.SearchStrategy = array_names, |
216 | 237 | arrays: st.SearchStrategy | None = None, |
217 | 238 | attrs: st.SearchStrategy = attrs, |
@@ -267,23 +288,56 @@ def arrays( |
267 | 288 | return a |
268 | 289 |
|
269 | 290 |
|
| 291 | +@st.composite # type: ignore[misc] |
| 292 | +def simple_arrays( |
| 293 | + draw: st.DrawFn, |
| 294 | + *, |
| 295 | + shapes: st.SearchStrategy[tuple[int, ...]] = array_shapes, |
| 296 | +) -> Any: |
| 297 | + return draw( |
| 298 | + arrays( |
| 299 | + shapes=shapes, |
| 300 | + paths=paths(max_num_nodes=2), |
| 301 | + array_names=short_node_names, |
| 302 | + attrs=st.none(), |
| 303 | + compressors=st.sampled_from([None, "default"]), |
| 304 | + ) |
| 305 | + ) |
| 306 | + |
| 307 | + |
270 | 308 | def is_negative_slice(idx: Any) -> bool: |
271 | 309 | return isinstance(idx, slice) and idx.step is not None and idx.step < 0 |
272 | 310 |
|
273 | 311 |
|
| 312 | +@st.composite # type: ignore[misc] |
| 313 | +def end_slices(draw: st.DrawFn, *, shape: tuple[int]) -> Any: |
| 314 | + """ |
| 315 | + A strategy that slices ranges that include the last chunk. |
| 316 | + This is intended to stress-test handling of a possibly smaller last chunk. |
| 317 | + """ |
| 318 | + slicers = [] |
| 319 | + for size in shape: |
| 320 | + start = draw(st.integers(min_value=size // 2, max_value=size - 1)) |
| 321 | + length = draw(st.integers(min_value=0, max_value=size - start)) |
| 322 | + slicers.append(slice(start, start + length)) |
| 323 | + event("drawing end slice") |
| 324 | + return tuple(slicers) |
| 325 | + |
| 326 | + |
274 | 327 | @st.composite # type: ignore[misc] |
275 | 328 | def basic_indices(draw: st.DrawFn, *, shape: tuple[int], **kwargs: Any) -> Any: |
276 | 329 | """Basic indices without unsupported negative slices.""" |
277 | | - return draw( |
278 | | - npst.basic_indices(shape=shape, **kwargs).filter( |
279 | | - lambda idxr: ( |
280 | | - not ( |
281 | | - is_negative_slice(idxr) |
282 | | - or (isinstance(idxr, tuple) and any(is_negative_slice(idx) for idx in idxr)) |
283 | | - ) |
| 330 | + strategy = npst.basic_indices(shape=shape, **kwargs).filter( |
| 331 | + lambda idxr: ( |
| 332 | + not ( |
| 333 | + is_negative_slice(idxr) |
| 334 | + or (isinstance(idxr, tuple) and any(is_negative_slice(idx) for idx in idxr)) |
284 | 335 | ) |
285 | 336 | ) |
286 | 337 | ) |
| 338 | + if math.prod(shape) >= 3: |
| 339 | + strategy = end_slices(shape=shape) | strategy |
| 340 | + return draw(strategy) |
287 | 341 |
|
288 | 342 |
|
289 | 343 | @st.composite # type: ignore[misc] |
|
0 commit comments