Skip to content

Commit 6fbd0fd

Browse files
committed
Refactor hypothesis strategies
1 parent 22140eb commit 6fbd0fd

File tree

2 files changed

+127
-118
lines changed

2 files changed

+127
-118
lines changed

tests/strategies.py

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
import cftime
2+
import dask
3+
import hypothesis.extra.numpy as npst
4+
import hypothesis.strategies as st
5+
import numpy as np
6+
7+
from . import ALL_FUNCS, SCIPY_STATS_FUNCS
8+
9+
10+
def supported_dtypes() -> st.SearchStrategy[np.dtype]:
11+
return (
12+
npst.integer_dtypes(endianness="=")
13+
| npst.unsigned_integer_dtypes(endianness="=")
14+
| npst.floating_dtypes(endianness="=", sizes=(32, 64))
15+
| npst.complex_number_dtypes(endianness="=")
16+
| npst.datetime64_dtypes(endianness="=")
17+
| npst.timedelta64_dtypes(endianness="=")
18+
| npst.unicode_string_dtypes(endianness="=")
19+
)
20+
21+
22+
# TODO: stop excluding everything but U
23+
array_dtype_st = supported_dtypes().filter(lambda x: x.kind not in "cmMU")
24+
by_dtype_st = supported_dtypes()
25+
26+
NON_NUMPY_FUNCS = ["first", "last", "nanfirst", "nanlast", "count", "any", "all"] + list(
27+
SCIPY_STATS_FUNCS
28+
)
29+
SKIPPED_FUNCS = ["var", "std", "nanvar", "nanstd"]
30+
31+
func_st = st.sampled_from(
32+
[f for f in ALL_FUNCS if f not in NON_NUMPY_FUNCS and f not in SKIPPED_FUNCS]
33+
)
34+
numeric_arrays = npst.arrays(
35+
elements={"allow_subnormal": False}, shape=npst.array_shapes(), dtype=array_dtype_st
36+
)
37+
all_arrays = npst.arrays(
38+
elements={"allow_subnormal": False}, shape=npst.array_shapes(), dtype=supported_dtypes()
39+
)
40+
41+
42+
calendars = st.sampled_from(
43+
[
44+
"standard",
45+
"gregorian",
46+
"proleptic_gregorian",
47+
"noleap",
48+
"365_day",
49+
"360_day",
50+
"julian",
51+
"all_leap",
52+
"366_day",
53+
]
54+
)
55+
56+
57+
@st.composite
58+
def units(draw, *, calendar: str):
59+
choices = ["days", "hours", "minutes", "seconds", "milliseconds", "microseconds"]
60+
if calendar == "360_day":
61+
choices += ["months"]
62+
elif calendar == "noleap":
63+
choices += ["common_years"]
64+
time_units = draw(st.sampled_from(choices))
65+
66+
dt = draw(st.datetimes())
67+
year, month, day = dt.year, dt.month, dt.day
68+
if calendar == "360_day":
69+
day = min(day, 30)
70+
return f"{time_units} since {year}-{month}-{day}"
71+
72+
73+
@st.composite
74+
def cftime_arrays(draw, *, shape, calendars=calendars, elements=None):
75+
if elements is None:
76+
elements = {"min_value": -10_000, "max_value": 10_000}
77+
cal = draw(calendars)
78+
values = draw(npst.arrays(dtype=np.int64, shape=shape, elements=elements))
79+
unit = draw(units(calendar=cal))
80+
return cftime.num2date(values, units=unit, calendar=cal)
81+
82+
83+
def by_arrays(shape, *, elements=None):
84+
return st.one_of(
85+
npst.arrays(
86+
dtype=npst.integer_dtypes(endianness="=") | npst.unicode_string_dtypes(endianness="="),
87+
shape=shape,
88+
elements=elements,
89+
),
90+
cftime_arrays(shape=shape, elements=elements),
91+
)
92+
93+
94+
@st.composite
95+
def chunks(draw, *, shape: tuple[int, ...]) -> tuple[tuple[int, ...], ...]:
96+
chunks = []
97+
for size in shape:
98+
if size > 1:
99+
nchunks = draw(st.integers(min_value=1, max_value=size - 1))
100+
dividers = sorted(
101+
set(draw(st.integers(min_value=1, max_value=size - 1)) for _ in range(nchunks - 1))
102+
)
103+
chunks.append(tuple(a - b for a, b in zip(dividers + [size], [0] + dividers)))
104+
else:
105+
chunks.append((1,))
106+
return tuple(chunks)
107+
108+
109+
@st.composite
110+
def chunked_arrays(draw, *, chunks=chunks, arrays=numeric_arrays, from_array=dask.array.from_array):
111+
array = draw(arrays)
112+
chunks = draw(chunks(shape=array.shape))
113+
114+
if array.dtype.kind in "cf":
115+
nan_idx = draw(
116+
st.lists(
117+
st.integers(min_value=0, max_value=array.shape[-1] - 1),
118+
max_size=array.shape[-1] - 1,
119+
unique=True,
120+
)
121+
)
122+
if nan_idx:
123+
array[..., nan_idx] = np.nan
124+
125+
return from_array(array, chunks=chunks)

tests/test_properties.py

Lines changed: 2 additions & 118 deletions
Original file line numberDiff line numberDiff line change
@@ -5,17 +5,16 @@
55
pytest.importorskip("dask")
66
pytest.importorskip("cftime")
77

8-
import cftime
98
import dask
10-
import hypothesis.extra.numpy as npst
119
import hypothesis.strategies as st
1210
import numpy as np
1311
from hypothesis import assume, given, note
1412

1513
import flox
1614
from flox.core import groupby_reduce, groupby_scan
1715

18-
from . import ALL_FUNCS, SCIPY_STATS_FUNCS, assert_equal
16+
from . import assert_equal
17+
from .strategies import all_arrays, by_arrays, chunked_arrays, func_st, numeric_arrays
1918

2019
dask.config.set(scheduler="sync")
2120

@@ -32,94 +31,13 @@ def bfill(array, axis, dtype=None):
3231
)[::-1]
3332

3433

35-
NON_NUMPY_FUNCS = ["first", "last", "nanfirst", "nanlast", "count", "any", "all"] + list(
36-
SCIPY_STATS_FUNCS
37-
)
38-
SKIPPED_FUNCS = ["var", "std", "nanvar", "nanstd"]
3934
NUMPY_SCAN_FUNCS = {
4035
"nancumsum": np.nancumsum,
4136
"ffill": ffill,
4237
"bfill": bfill,
4338
} # "cumsum": np.cumsum,
4439

4540

46-
def supported_dtypes() -> st.SearchStrategy[np.dtype]:
47-
return (
48-
npst.integer_dtypes(endianness="=")
49-
| npst.unsigned_integer_dtypes(endianness="=")
50-
| npst.floating_dtypes(endianness="=", sizes=(32, 64))
51-
| npst.complex_number_dtypes(endianness="=")
52-
| npst.datetime64_dtypes(endianness="=")
53-
| npst.timedelta64_dtypes(endianness="=")
54-
| npst.unicode_string_dtypes(endianness="=")
55-
)
56-
57-
58-
# TODO: stop excluding everything but U
59-
array_dtype_st = supported_dtypes().filter(lambda x: x.kind not in "cmMU")
60-
by_dtype_st = supported_dtypes()
61-
func_st = st.sampled_from(
62-
[f for f in ALL_FUNCS if f not in NON_NUMPY_FUNCS and f not in SKIPPED_FUNCS]
63-
)
64-
numeric_arrays = npst.arrays(
65-
elements={"allow_subnormal": False}, shape=npst.array_shapes(), dtype=array_dtype_st
66-
)
67-
all_arrays = npst.arrays(
68-
elements={"allow_subnormal": False}, shape=npst.array_shapes(), dtype=supported_dtypes()
69-
)
70-
71-
calendars = st.sampled_from(
72-
[
73-
"standard",
74-
"gregorian",
75-
"proleptic_gregorian",
76-
"noleap",
77-
"365_day",
78-
"360_day",
79-
"julian",
80-
"all_leap",
81-
"366_day",
82-
]
83-
)
84-
85-
86-
@st.composite
87-
def units(draw, *, calendar: str):
88-
choices = ["days", "hours", "minutes", "seconds", "milliseconds", "microseconds"]
89-
if calendar == "360_day":
90-
choices += ["months"]
91-
elif calendar == "noleap":
92-
choices += ["common_years"]
93-
time_units = draw(st.sampled_from(choices))
94-
95-
dt = draw(st.datetimes())
96-
year, month, day = dt.year, dt.month, dt.day
97-
if calendar == "360_day":
98-
month %= 30
99-
return f"{time_units} since {year}-{month}-{day}"
100-
101-
102-
@st.composite
103-
def cftime_arrays(draw, *, shape, calendars=calendars, elements=None):
104-
if elements is None:
105-
elements = {"min_value": -10_000, "max_value": 10_000}
106-
cal = draw(calendars)
107-
values = draw(npst.arrays(dtype=np.int64, shape=shape, elements=elements))
108-
unit = draw(units(calendar=cal))
109-
return cftime.num2date(values, units=unit, calendar=cal)
110-
111-
112-
def by_arrays(shape, *, elements=None):
113-
return st.one_of(
114-
npst.arrays(
115-
dtype=npst.integer_dtypes(endianness="=") | npst.unicode_string_dtypes(endianness="="),
116-
shape=shape,
117-
elements=elements,
118-
),
119-
cftime_arrays(shape=shape, elements=elements),
120-
)
121-
122-
12341
def not_overflowing_array(array) -> bool:
12442
if array.dtype.kind == "f":
12543
info = np.finfo(array.dtype)
@@ -133,40 +51,6 @@ def not_overflowing_array(array) -> bool:
13351
return result
13452

13553

136-
@st.composite
137-
def chunks(draw, *, shape: tuple[int, ...]) -> tuple[tuple[int, ...], ...]:
138-
chunks = []
139-
for size in shape:
140-
if size > 1:
141-
nchunks = draw(st.integers(min_value=1, max_value=size - 1))
142-
dividers = sorted(
143-
set(draw(st.integers(min_value=1, max_value=size - 1)) for _ in range(nchunks - 1))
144-
)
145-
chunks.append(tuple(a - b for a, b in zip(dividers + [size], [0] + dividers)))
146-
else:
147-
chunks.append((1,))
148-
return tuple(chunks)
149-
150-
151-
@st.composite
152-
def chunked_arrays(draw, *, chunks=chunks, arrays=numeric_arrays, from_array=dask.array.from_array):
153-
array = draw(arrays)
154-
chunks = draw(chunks(shape=array.shape))
155-
156-
if array.dtype.kind in "cf":
157-
nan_idx = draw(
158-
st.lists(
159-
st.integers(min_value=0, max_value=array.shape[-1] - 1),
160-
max_size=array.shape[-1] - 1,
161-
unique=True,
162-
)
163-
)
164-
if nan_idx:
165-
array[..., nan_idx] = np.nan
166-
167-
return from_array(array, chunks=chunks)
168-
169-
17054
# TODO: migrate to by_arrays but with constant value
17155
@given(data=st.data(), array=numeric_arrays, func=func_st)
17256
def test_groupby_reduce(data, array, func):

0 commit comments

Comments
 (0)