Skip to content

Commit 39d8dcf

Browse files
committed
Type strategies
1 parent 6fbd0fd commit 39d8dcf

File tree

2 files changed

+52
-24
lines changed

2 files changed

+52
-24
lines changed

tests/strategies.py

Lines changed: 34 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
from __future__ import annotations
2+
3+
from typing import Any, Callable
4+
15
import cftime
26
import dask
37
import hypothesis.extra.numpy as npst
@@ -6,6 +10,8 @@
610

711
from . import ALL_FUNCS, SCIPY_STATS_FUNCS
812

13+
Chunks = tuple[tuple[int, ...], ...]
14+
915

1016
def supported_dtypes() -> st.SearchStrategy[np.dtype]:
1117
return (
@@ -38,7 +44,6 @@ def supported_dtypes() -> st.SearchStrategy[np.dtype]:
3844
elements={"allow_subnormal": False}, shape=npst.array_shapes(), dtype=supported_dtypes()
3945
)
4046

41-
4247
calendars = st.sampled_from(
4348
[
4449
"standard",
@@ -55,7 +60,7 @@ def supported_dtypes() -> st.SearchStrategy[np.dtype]:
5560

5661

5762
@st.composite
58-
def units(draw, *, calendar: str):
63+
def units(draw, *, calendar: str) -> str:
5964
choices = ["days", "hours", "minutes", "seconds", "milliseconds", "microseconds"]
6065
if calendar == "360_day":
6166
choices += ["months"]
@@ -67,20 +72,36 @@ def units(draw, *, calendar: str):
6772
year, month, day = dt.year, dt.month, dt.day
6873
if calendar == "360_day":
6974
day = min(day, 30)
75+
if calendar in ["360_day", "365_day", "noleap"] and month == 2 and day == 29:
76+
day = 28
77+
7078
return f"{time_units} since {year}-{month}-{day}"
7179

7280

7381
@st.composite
74-
def cftime_arrays(draw, *, shape, calendars=calendars, elements=None):
82+
def cftime_arrays(
83+
draw: st.DrawFn,
84+
*,
85+
shape: tuple[int, ...],
86+
calendars: st.SearchStrategy[str] = calendars,
87+
elements: dict[str, Any] | None = None,
88+
) -> np.ndarray[Any, Any]:
7589
if elements is None:
76-
elements = {"min_value": -10_000, "max_value": 10_000}
90+
elements = {}
91+
elements.setdefault("min_value", -10_000)
92+
elements.setdefault("max_value", 10_000)
7793
cal = draw(calendars)
7894
values = draw(npst.arrays(dtype=np.int64, shape=shape, elements=elements))
7995
unit = draw(units(calendar=cal))
8096
return cftime.num2date(values, units=unit, calendar=cal)
8197

8298

83-
def by_arrays(shape, *, elements=None):
99+
def by_arrays(
100+
shape: tuple[int, ...], *, elements: dict[str, Any] | None = None
101+
) -> st.SearchStrategy[np.ndarray[Any, Any]]:
102+
if elements is None:
103+
elements = {}
104+
elements.setdefault("alphabet", st.characters(exclude_categories=("C",)))
84105
return st.one_of(
85106
npst.arrays(
86107
dtype=npst.integer_dtypes(endianness="=") | npst.unicode_string_dtypes(endianness="="),
@@ -92,7 +113,7 @@ def by_arrays(shape, *, elements=None):
92113

93114

94115
@st.composite
95-
def chunks(draw, *, shape: tuple[int, ...]) -> tuple[tuple[int, ...], ...]:
116+
def chunks(draw: st.DrawFn, *, shape: tuple[int, ...]) -> Chunks:
96117
chunks = []
97118
for size in shape:
98119
if size > 1:
@@ -107,7 +128,13 @@ def chunks(draw, *, shape: tuple[int, ...]) -> tuple[tuple[int, ...], ...]:
107128

108129

109130
@st.composite
110-
def chunked_arrays(draw, *, chunks=chunks, arrays=numeric_arrays, from_array=dask.array.from_array):
131+
def chunked_arrays(
132+
draw: st.DrawFn,
133+
*,
134+
chunks: Callable[..., st.SearchStrategy[Chunks]] = chunks,
135+
arrays=all_arrays,
136+
from_array: Callable = dask.array.from_array,
137+
) -> dask.array.Array:
111138
array = draw(arrays)
112139
chunks = draw(chunks(shape=array.shape))
113140

tests/test_properties.py

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from typing import Any, Callable
2+
13
import pandas as pd
24
import pytest
35

@@ -14,7 +16,7 @@
1416
from flox.core import groupby_reduce, groupby_scan
1517

1618
from . import assert_equal
17-
from .strategies import all_arrays, by_arrays, chunked_arrays, func_st, numeric_arrays
19+
from .strategies import by_arrays, chunked_arrays, func_st, numeric_arrays
1820

1921
dask.config.set(scheduler="sync")
2022

@@ -31,14 +33,14 @@ def bfill(array, axis, dtype=None):
3133
)[::-1]
3234

3335

34-
NUMPY_SCAN_FUNCS = {
36+
NUMPY_SCAN_FUNCS: dict[str, Callable] = {
3537
"nancumsum": np.nancumsum,
3638
"ffill": ffill,
3739
"bfill": bfill,
3840
} # "cumsum": np.cumsum,
3941

4042

41-
def not_overflowing_array(array) -> bool:
43+
def not_overflowing_array(array: np.ndarray[Any, Any]) -> bool:
4244
if array.dtype.kind == "f":
4345
info = np.finfo(array.dtype)
4446
elif array.dtype.kind in ["i", "u"]:
@@ -51,9 +53,8 @@ def not_overflowing_array(array) -> bool:
5153
return result
5254

5355

54-
# TODO: migrate to by_arrays but with constant value
5556
@given(data=st.data(), array=numeric_arrays, func=func_st)
56-
def test_groupby_reduce(data, array, func):
57+
def test_groupby_reduce(data, array, func: str) -> None:
5758
# overflow behaviour differs between bincount and sum (for example)
5859
assume(not_overflowing_array(array))
5960
# TODO: fix var for complex numbers upstream
@@ -71,14 +72,14 @@ def test_groupby_reduce(data, array, func):
7172
"min_size": 1,
7273
"max_size": 1,
7374
},
74-
shape=array.shape[-1],
75+
shape=(array.shape[-1],),
7576
)
7677
)
7778
assert len(np.unique(by)) == 1
7879
kwargs = {"q": 0.8} if "quantile" in func else {}
79-
flox_kwargs = {}
80+
flox_kwargs: dict[str, Any] = {}
8081
with np.errstate(invalid="ignore", divide="ignore"):
81-
actual, _ = groupby_reduce(
82+
actual, *_ = groupby_reduce(
8283
array, by, func=func, axis=axis, engine="numpy", **flox_kwargs, finalize_kwargs=kwargs
8384
)
8485

@@ -112,10 +113,10 @@ def test_groupby_reduce(data, array, func):
112113

113114
@given(
114115
data=st.data(),
115-
array=chunked_arrays(),
116+
array=chunked_arrays(arrays=numeric_arrays),
116117
func=st.sampled_from(tuple(NUMPY_SCAN_FUNCS)),
117118
)
118-
def test_scans(data, array, func):
119+
def test_scans(data, array: dask.array.Array, func: str) -> None:
119120
assume(not_overflowing_array(np.asarray(array)))
120121

121122
by = data.draw(by_arrays(shape=(array.shape[-1],)))
@@ -148,7 +149,7 @@ def test_scans(data, array, func):
148149

149150

150151
@given(data=st.data(), array=chunked_arrays())
151-
def test_ffill_bfill_reverse(data, array):
152+
def test_ffill_bfill_reverse(data, array: dask.array.Array) -> None:
152153
# TODO: test NaT and timedelta, datetime
153154
assume(not_overflowing_array(np.asarray(array)))
154155
by = data.draw(by_arrays(shape=(array.shape[-1],)))
@@ -168,10 +169,10 @@ def reverse(arr):
168169

169170
@given(
170171
data=st.data(),
171-
array=chunked_arrays(arrays=all_arrays),
172+
array=chunked_arrays(),
172173
func=st.sampled_from(["first", "last", "nanfirst", "nanlast"]),
173174
)
174-
def test_first_last(data, array, func):
175+
def test_first_last(data, array: dask.array.Array, func: str) -> None:
175176
by = data.draw(by_arrays(shape=(array.shape[-1],)))
176177

177178
INVERSES = {"first": "last", "last": "first", "nanfirst": "nanlast", "nanlast": "nanfirst"}
@@ -183,8 +184,8 @@ def test_first_last(data, array, func):
183184
array = array.rechunk((*array.chunks[:-1], -1))
184185

185186
for arr in [array, array.compute()]:
186-
forward, fg = groupby_reduce(arr, by, func=func, engine="flox")
187-
reverse, rg = groupby_reduce(arr[..., ::-1], by[..., ::-1], func=inverse, engine="flox")
187+
forward, *fg = groupby_reduce(arr, by, func=func, engine="flox")
188+
reverse, *rg = groupby_reduce(arr[..., ::-1], by[..., ::-1], func=inverse, engine="flox")
188189

189190
assert forward.dtype == reverse.dtype
190191
assert forward.dtype == arr.dtype
@@ -196,6 +197,6 @@ def test_first_last(data, array, func):
196197
if mate in ["first", "last"]:
197198
array = array.rechunk((*array.chunks[:-1], -1))
198199

199-
first, _ = groupby_reduce(array, by, func=func, engine="flox")
200-
second, _ = groupby_reduce(array, by, func=mate, engine="flox")
200+
first, *_ = groupby_reduce(array, by, func=func, engine="flox")
201+
second, *_ = groupby_reduce(array, by, func=mate, engine="flox")
201202
assert_equal(first, second)

0 commit comments

Comments
 (0)