Skip to content

Commit 22140eb

Browse files
committed
Add more array types for property tests
1 parent 6a7b932 commit 22140eb

File tree

2 files changed

+101
-42
lines changed

2 files changed

+101
-42
lines changed

tests/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ def assert_equal(a, b, tolerance=None):
113113
else:
114114
a_eager, b_eager = a, b
115115

116-
if a.dtype.kind in "SUMm":
116+
if a.dtype.kind in "SUMmO":
117117
np.testing.assert_equal(a_eager, b_eager)
118118
else:
119119
np.testing.assert_allclose(a_eager, b_eager, equal_nan=True, **tolerance)

tests/test_properties.py

Lines changed: 100 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@
33

44
pytest.importorskip("hypothesis")
55
pytest.importorskip("dask")
6+
pytest.importorskip("cftime")
67

8+
import cftime
79
import dask
810
import hypothesis.extra.numpy as npst
911
import hypothesis.strategies as st
@@ -66,11 +68,55 @@ def supported_dtypes() -> st.SearchStrategy[np.dtype]:
6668
elements={"allow_subnormal": False}, shape=npst.array_shapes(), dtype=supported_dtypes()
6769
)
6870

71+
calendars = st.sampled_from(
72+
[
73+
"standard",
74+
"gregorian",
75+
"proleptic_gregorian",
76+
"noleap",
77+
"365_day",
78+
"360_day",
79+
"julian",
80+
"all_leap",
81+
"366_day",
82+
]
83+
)
84+
85+
86+
@st.composite
87+
def units(draw, *, calendar: str):
88+
choices = ["days", "hours", "minutes", "seconds", "milliseconds", "microseconds"]
89+
if calendar == "360_day":
90+
choices += ["months"]
91+
elif calendar == "noleap":
92+
choices += ["common_years"]
93+
time_units = draw(st.sampled_from(choices))
94+
95+
dt = draw(st.datetimes())
96+
year, month, day = dt.year, dt.month, dt.day
97+
if calendar == "360_day":
98+
month %= 30
99+
return f"{time_units} since {year}-{month}-{day}"
69100

70-
def by_arrays(shape):
71-
return npst.arrays(
72-
dtype=npst.integer_dtypes(endianness="=") | npst.unicode_string_dtypes(endianness="="),
73-
shape=shape,
101+
102+
@st.composite
103+
def cftime_arrays(draw, *, shape, calendars=calendars, elements=None):
104+
if elements is None:
105+
elements = {"min_value": -10_000, "max_value": 10_000}
106+
cal = draw(calendars)
107+
values = draw(npst.arrays(dtype=np.int64, shape=shape, elements=elements))
108+
unit = draw(units(calendar=cal))
109+
return cftime.num2date(values, units=unit, calendar=cal)
110+
111+
112+
def by_arrays(shape, *, elements=None):
113+
return st.one_of(
114+
npst.arrays(
115+
dtype=npst.integer_dtypes(endianness="=") | npst.unicode_string_dtypes(endianness="="),
116+
shape=shape,
117+
elements=elements,
118+
),
119+
cftime_arrays(shape=shape, elements=elements),
74120
)
75121

76122

@@ -87,8 +133,43 @@ def not_overflowing_array(array) -> bool:
87133
return result
88134

89135

90-
@given(array=numeric_arrays, dtype=by_dtype_st, func=func_st)
91-
def test_groupby_reduce(array, dtype, func):
136+
@st.composite
137+
def chunks(draw, *, shape: tuple[int, ...]) -> tuple[tuple[int, ...], ...]:
138+
chunks = []
139+
for size in shape:
140+
if size > 1:
141+
nchunks = draw(st.integers(min_value=1, max_value=size - 1))
142+
dividers = sorted(
143+
set(draw(st.integers(min_value=1, max_value=size - 1)) for _ in range(nchunks - 1))
144+
)
145+
chunks.append(tuple(a - b for a, b in zip(dividers + [size], [0] + dividers)))
146+
else:
147+
chunks.append((1,))
148+
return tuple(chunks)
149+
150+
151+
@st.composite
152+
def chunked_arrays(draw, *, chunks=chunks, arrays=numeric_arrays, from_array=dask.array.from_array):
153+
array = draw(arrays)
154+
chunks = draw(chunks(shape=array.shape))
155+
156+
if array.dtype.kind in "cf":
157+
nan_idx = draw(
158+
st.lists(
159+
st.integers(min_value=0, max_value=array.shape[-1] - 1),
160+
max_size=array.shape[-1] - 1,
161+
unique=True,
162+
)
163+
)
164+
if nan_idx:
165+
array[..., nan_idx] = np.nan
166+
167+
return from_array(array, chunks=chunks)
168+
169+
170+
# TODO: migrate to by_arrays but with constant value
171+
@given(data=st.data(), array=numeric_arrays, func=func_st)
172+
def test_groupby_reduce(data, array, func):
92173
# overflow behaviour differs between bincount and sum (for example)
93174
assume(not_overflowing_array(array))
94175
# TODO: fix var for complex numbers upstream
@@ -97,7 +178,19 @@ def test_groupby_reduce(array, dtype, func):
97178
assume("arg" not in func and not np.any(np.isnan(array).ravel()))
98179

99180
axis = -1
100-
by = np.ones((array.shape[-1],), dtype=dtype)
181+
by = data.draw(
182+
by_arrays(
183+
elements={
184+
"alphabet": st.just("a"),
185+
"min_value": 1,
186+
"max_value": 1,
187+
"min_size": 1,
188+
"max_size": 1,
189+
},
190+
shape=array.shape[-1],
191+
)
192+
)
193+
assert len(np.unique(by)) == 1
101194
kwargs = {"q": 0.8} if "quantile" in func else {}
102195
flox_kwargs = {}
103196
with np.errstate(invalid="ignore", divide="ignore"):
@@ -133,40 +226,6 @@ def test_groupby_reduce(array, dtype, func):
133226
assert_equal(expected, actual, tolerance)
134227

135228

136-
@st.composite
137-
def chunks(draw, *, shape: tuple[int, ...]) -> tuple[tuple[int, ...], ...]:
138-
chunks = []
139-
for size in shape:
140-
if size > 1:
141-
nchunks = draw(st.integers(min_value=1, max_value=size - 1))
142-
dividers = sorted(
143-
set(draw(st.integers(min_value=1, max_value=size - 1)) for _ in range(nchunks - 1))
144-
)
145-
chunks.append(tuple(a - b for a, b in zip(dividers + [size], [0] + dividers)))
146-
else:
147-
chunks.append((1,))
148-
return tuple(chunks)
149-
150-
151-
@st.composite
152-
def chunked_arrays(draw, *, chunks=chunks, arrays=numeric_arrays, from_array=dask.array.from_array):
153-
array = draw(arrays)
154-
chunks = draw(chunks(shape=array.shape))
155-
156-
if array.dtype.kind in "cf":
157-
nan_idx = draw(
158-
st.lists(
159-
st.integers(min_value=0, max_value=array.shape[-1] - 1),
160-
max_size=array.shape[-1] - 1,
161-
unique=True,
162-
)
163-
)
164-
if nan_idx:
165-
array[..., nan_idx] = np.nan
166-
167-
return from_array(array, chunks=chunks)
168-
169-
170229
@given(
171230
data=st.data(),
172231
array=chunked_arrays(),

0 commit comments

Comments
 (0)