Skip to content

Commit 7f98f45

Browse files
authored
Property tests with hypothesis (#348)
* Property tests with hypothesis * skip on minimal env * fix typing * fix test * fix mypy * remove docstring * try again * fix again * more fix * fix tests * Try fix * some debug logging instead of info * try `int8` * Update casting behaviour * More dtypes * Complex fixes * Revert "try `int8`" This reverts commit a9097c2. * fix dtype * skip complex var, std * Start fixing timedelta64 * fix casting * exclude timedelta64, datetime64 * tweak * filter out too_slow * update hypothesis cache * fix * fix more. * update caching strategy * WIP * Skip float16 * Attempt to increase numerical stablity of var, std * update tolerances * fix * update action * fixes * Trim CI * Cast to int64 instead of intp * revert? * [revert] * try again * debug logging * Revert "try again" This reverts commit a02d947. * adapt * Revert "Revert "try again"" This reverts commit 35ff742. * Fix cast * remove prints * Revert "[revert]" This reverts commit d143a98. * info -> debug * Fix quantiles * bring back notes * Small opt * Just skip var, std * Fix mypy * no-redef * try again
1 parent 07ad826 commit 7f98f45

File tree

10 files changed

+233
-57
lines changed

10 files changed

+233
-57
lines changed

.github/workflows/ci.yaml

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@ concurrency:
1515
cancel-in-progress: true
1616

1717
jobs:
18-
build:
19-
name: Build (${{ matrix.python-version }}, ${{ matrix.os }})
18+
test:
19+
name: Test (${{ matrix.python-version }}, ${{ matrix.os }})
2020
runs-on: ${{ matrix.os }}
2121
defaults:
2222
run:
@@ -48,7 +48,19 @@ jobs:
4848
- name: Install flox
4949
run: |
5050
python -m pip install --no-deps -e .
51+
52+
# https://github.com/actions/cache/blob/main/tips-and-workarounds.md#update-a-cache
53+
- name: Restore cached hypothesis directory
54+
id: restore-hypothesis-cache
55+
uses: actions/cache/restore@v4
56+
with:
57+
path: .hypothesis/
58+
key: cache-hypothesis-${{ runner.os }}-${{ matrix.python-version }}-${{ github.run_id }}
59+
restore-keys: |
60+
cache-hypothesis-${{ runner.os }}-${{ matrix.python-version }}-
61+
5162
- name: Run Tests
63+
id: status
5264
run: |
5365
pytest -n auto --cov=./ --cov-report=xml
5466
- name: Upload code coverage to Codecov
@@ -60,6 +72,15 @@ jobs:
6072
name: codecov-umbrella
6173
fail_ci_if_error: false
6274

75+
# explicitly save the cache so it gets updated, also do this even if it fails.
76+
- name: Save cached hypothesis directory
77+
id: save-hypothesis-cache
78+
if: always() && steps.status.outcome != 'skipped'
79+
uses: actions/cache/save@v4
80+
with:
81+
path: .hypothesis/
82+
key: cache-hypothesis-${{ runner.os }}-${{ matrix.python-version }}-${{ github.run_id }}
83+
6384
optional-deps:
6485
name: ${{ matrix.env }}
6586
runs-on: "ubuntu-latest"

ci/environment.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,3 +25,4 @@ dependencies:
2525
- toolz
2626
- numba
2727
- numbagg>=0.3
28+
- hypothesis

flox/aggregate_npg.py

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -109,11 +109,28 @@ def _len(group_idx, array, engine, *, func, axis=-1, size=None, fill_value=None,
109109
nanlen = partial(_len, func="nanlen")
110110

111111

112+
def _var_std_wrapper(group_idx, array, engine, *, axis=-1, **kwargs):
113+
# Attempt to increase numerical stability by subtracting the first element.
114+
# https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance
115+
# Cast any unsigned types first
116+
dtype = np.result_type(array, np.int8(-1) * array[0])
117+
array = array.astype(dtype, copy=False)
118+
first = _get_aggregate(engine).aggregate(group_idx, array, func="nanfirst", axis=axis)
119+
array = array - first[..., group_idx]
120+
return _get_aggregate(engine).aggregate(group_idx, array, axis=axis, **kwargs)
121+
122+
123+
var = partial(_var_std_wrapper, func="var")
124+
nanvar = partial(_var_std_wrapper, func="nanvar")
125+
std = partial(_var_std_wrapper, func="std")
126+
nanstd = partial(_var_std_wrapper, func="nanstd")
127+
128+
112129
def median(group_idx, array, engine, *, axis=-1, size=None, fill_value=None, dtype=None):
113130
return npg.aggregate_numpy.aggregate(
114131
group_idx,
115132
array,
116-
func=partial(_casting_wrapper, np.median, dtype=array.dtype),
133+
func=partial(_casting_wrapper, np.median, dtype=np.result_type(array.dtype)),
117134
axis=axis,
118135
size=size,
119136
fill_value=fill_value,
@@ -125,7 +142,7 @@ def nanmedian(group_idx, array, engine, *, axis=-1, size=None, fill_value=None,
125142
return npg.aggregate_numpy.aggregate(
126143
group_idx,
127144
array,
128-
func=partial(_casting_wrapper, np.nanmedian, dtype=array.dtype),
145+
func=partial(_casting_wrapper, np.nanmedian, dtype=np.result_type(array.dtype)),
129146
axis=axis,
130147
size=size,
131148
fill_value=fill_value,
@@ -137,7 +154,11 @@ def quantile(group_idx, array, engine, *, q, axis=-1, size=None, fill_value=None
137154
return npg.aggregate_numpy.aggregate(
138155
group_idx,
139156
array,
140-
func=partial(_casting_wrapper, partial(np.quantile, q=q), dtype=array.dtype),
157+
func=partial(
158+
_casting_wrapper,
159+
partial(np.quantile, q=q),
160+
dtype=np.result_type(dtype, array.dtype),
161+
),
141162
axis=axis,
142163
size=size,
143164
fill_value=fill_value,
@@ -149,7 +170,11 @@ def nanquantile(group_idx, array, engine, *, q, axis=-1, size=None, fill_value=N
149170
return npg.aggregate_numpy.aggregate(
150171
group_idx,
151172
array,
152-
func=partial(_casting_wrapper, partial(np.nanquantile, q=q), dtype=array.dtype),
173+
func=partial(
174+
_casting_wrapper,
175+
partial(np.nanquantile, q=q),
176+
dtype=np.result_type(dtype, array.dtype),
177+
),
153178
axis=axis,
154179
size=size,
155180
fill_value=fill_value,
@@ -163,7 +188,7 @@ def mode_(array, nan_policy, dtype):
163188
# npg splits `array` into object arrays for each group
164189
# scipy.stats.mode does not like that
165190
# here we cast back
166-
return mode(array.astype(dtype, copy=False), nan_policy=nan_policy, axis=-1).mode
191+
return mode(array.astype(dtype, copy=False), nan_policy=nan_policy, axis=-1, keepdims=True).mode
167192

168193

169194
def mode(group_idx, array, engine, *, axis=-1, size=None, fill_value=None, dtype=None):

flox/aggregations.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -123,14 +123,27 @@ def _normalize_dtype(dtype: DTypeLike, array_dtype: np.dtype, fill_value=None) -
123123
return dtype
124124

125125

126+
def _maybe_promote_int(dtype) -> np.dtype:
127+
# https://numpy.org/doc/stable/reference/generated/numpy.prod.html
128+
# The dtype of a is used by default unless a has an integer dtype of less precision
129+
# than the default platform integer.
130+
if not isinstance(dtype, np.dtype):
131+
dtype = np.dtype(dtype)
132+
if dtype.kind == "i":
133+
dtype = np.result_type(dtype, np.intp)
134+
elif dtype.kind == "u":
135+
dtype = np.result_type(dtype, np.uintp)
136+
return dtype
137+
138+
126139
def _get_fill_value(dtype, fill_value):
127140
"""Returns dtype appropriate infinity. Returns +Inf equivalent for None."""
128141
if fill_value == dtypes.INF or fill_value is None:
129142
return dtypes.get_pos_infinity(dtype, max_for_int=True)
130143
if fill_value == dtypes.NINF:
131144
return dtypes.get_neg_infinity(dtype, min_for_int=True)
132145
if fill_value == dtypes.NA:
133-
if np.issubdtype(dtype, np.floating):
146+
if np.issubdtype(dtype, np.floating) or np.issubdtype(dtype, np.complexfloating):
134147
return np.nan
135148
# This is madness, but npg checks that fill_value is compatible
136149
# with array dtype even if the fill_value is never used.
@@ -524,10 +537,10 @@ def _pick_second(*x):
524537
# Support statistical quantities only blockwise
525538
# The parallel versions will be approximate and are hard to implement!
526539
median = Aggregation(
527-
name="median", fill_value=dtypes.NA, chunk=None, combine=None, final_dtype=np.float64
540+
name="median", fill_value=dtypes.NA, chunk=None, combine=None, final_dtype=np.floating
528541
)
529542
nanmedian = Aggregation(
530-
name="nanmedian", fill_value=dtypes.NA, chunk=None, combine=None, final_dtype=np.float64
543+
name="nanmedian", fill_value=dtypes.NA, chunk=None, combine=None, final_dtype=np.floating
531544
)
532545

533546

@@ -540,15 +553,15 @@ def quantile_new_dims_func(q) -> tuple[Dim]:
540553
fill_value=dtypes.NA,
541554
chunk=None,
542555
combine=None,
543-
final_dtype=np.float64,
556+
final_dtype=np.floating,
544557
new_dims_func=quantile_new_dims_func,
545558
)
546559
nanquantile = Aggregation(
547560
name="nanquantile",
548561
fill_value=dtypes.NA,
549562
chunk=None,
550563
combine=None,
551-
final_dtype=np.float64,
564+
final_dtype=np.floating,
552565
new_dims_func=quantile_new_dims_func,
553566
)
554567
mode = Aggregation(name="mode", fill_value=dtypes.NA, chunk=None, combine=None)
@@ -618,6 +631,8 @@ def _initialize_aggregation(
618631
)
619632

620633
final_dtype = _normalize_dtype(dtype_ or agg.dtype_init["final"], array_dtype, fill_value)
634+
if agg.name not in ["min", "max", "nanmin", "nanmax"]:
635+
final_dtype = _maybe_promote_int(final_dtype)
621636
agg.dtype = {
622637
"user": dtype, # Save to automatically choose an engine
623638
"final": final_dtype,

flox/core.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -403,12 +403,12 @@ def invert(x) -> tuple[np.ndarray, ...]:
403403

404404
# 2. Every group is contained to one block, use blockwise here.
405405
if bitmask.shape[CHUNK_AXIS] == 1 or (chunks_per_label == 1).all():
406-
logger.info("find_group_cohorts: blockwise is preferred.")
406+
logger.debug("find_group_cohorts: blockwise is preferred.")
407407
return "blockwise", chunks_cohorts
408408

409409
# 3. Perfectly chunked so there is only a single cohort
410410
if len(chunks_cohorts) == 1:
411-
logger.info("Only found a single cohort. 'map-reduce' is preferred.")
411+
logger.debug("Only found a single cohort. 'map-reduce' is preferred.")
412412
return "map-reduce", chunks_cohorts if merge else {}
413413

414414
# 4. Our dataset has chunksize one along the axis,
@@ -418,7 +418,7 @@ def invert(x) -> tuple[np.ndarray, ...]:
418418
# 6. Existing cohorts don't overlap, great for time grouping with perfect chunking
419419
no_overlapping_cohorts = (np.bincount(np.concatenate(tuple(chunks_cohorts.keys()))) == 1).all()
420420
if one_group_per_chunk or single_chunks or no_overlapping_cohorts:
421-
logger.info("find_group_cohorts: cohorts is preferred, chunking is perfect.")
421+
logger.debug("find_group_cohorts: cohorts is preferred, chunking is perfect.")
422422
return "cohorts", chunks_cohorts
423423

424424
# We'll use containment to measure degree of overlap between labels.
@@ -451,7 +451,7 @@ def invert(x) -> tuple[np.ndarray, ...]:
451451
# 7. Groups seem fairly randomly distributed, use "map-reduce".
452452
if sparsity > MAX_SPARSITY_FOR_COHORTS:
453453
if not merge:
454-
logger.info(
454+
logger.debug(
455455
"find_group_cohorts: bitmask sparsity={}, merge=False, choosing 'map-reduce'".format( # noqa
456456
sparsity
457457
)
@@ -480,7 +480,7 @@ def invert(x) -> tuple[np.ndarray, ...]:
480480
containment.eliminate_zeros()
481481

482482
# Iterate over labels, beginning with those with most chunks
483-
logger.info("find_group_cohorts: merging cohorts")
483+
logger.debug("find_group_cohorts: merging cohorts")
484484
order = np.argsort(containment.sum(axis=LABEL_AXIS))[::-1]
485485
merged_cohorts = {}
486486
merged_keys = set()
@@ -1957,7 +1957,7 @@ def _validate_reindex(
19571957
any_by_dask: bool,
19581958
is_dask_array: bool,
19591959
) -> bool | None:
1960-
logger.info("Entering _validate_reindex: reindex is {}".format(reindex)) # noqa
1960+
# logger.debug("Entering _validate_reindex: reindex is {}".format(reindex)) # noqa
19611961

19621962
all_numpy = not is_dask_array and not any_by_dask
19631963
if reindex is True and not all_numpy:
@@ -1972,7 +1972,7 @@ def _validate_reindex(
19721972

19731973
if reindex is None:
19741974
if method is None:
1975-
logger.info("Leaving _validate_reindex: method = None, returning None")
1975+
# logger.debug("Leaving _validate_reindex: method = None, returning None")
19761976
return None
19771977

19781978
if all_numpy:
@@ -1999,7 +1999,7 @@ def _validate_reindex(
19991999
reindex = True
20002000

20012001
assert isinstance(reindex, bool)
2002-
logger.info("Leaving _validate_reindex: reindex is {}".format(reindex)) # noqa
2002+
logger.debug("Leaving _validate_reindex: reindex is {}".format(reindex)) # noqa
20032003

20042004
return reindex
20052005

@@ -2165,24 +2165,24 @@ def _choose_method(
21652165
method: T_MethodOpt, preferred_method: T_Method, agg: Aggregation, by, nax: int
21662166
) -> T_Method:
21672167
if method is None:
2168-
logger.info("_choose_method: method is None")
2168+
logger.debug("_choose_method: method is None")
21692169
if agg.chunk == (None,):
21702170
if preferred_method != "blockwise":
21712171
raise ValueError(
21722172
f"Aggregation {agg.name} is only supported for `method='blockwise'`, "
21732173
"but the chunking is not right."
21742174
)
2175-
logger.info("_choose_method: choosing 'blockwise'")
2175+
logger.debug("_choose_method: choosing 'blockwise'")
21762176
return "blockwise"
21772177

21782178
if nax != by.ndim:
2179-
logger.info("_choose_method: choosing 'map-reduce'")
2179+
logger.debug("_choose_method: choosing 'map-reduce'")
21802180
return "map-reduce"
21812181

21822182
if _is_arg_reduction(agg) and preferred_method == "blockwise":
21832183
return "cohorts"
21842184

2185-
logger.info("_choose_method: choosing preferred_method={}".format(preferred_method)) # noqa
2185+
logger.debug(f"_choose_method: choosing preferred_method={preferred_method}") # noqa
21862186
return preferred_method
21872187
else:
21882188
return method
@@ -2194,7 +2194,7 @@ def _choose_engine(by, agg: Aggregation):
21942194
not_arg_reduce = not _is_arg_reduction(agg)
21952195

21962196
if agg.name in ["quantile", "nanquantile", "median", "nanmedian"]:
2197-
logger.info(f"_choose_engine: Choosing 'flox' since {agg.name}")
2197+
logger.debug(f"_choose_engine: Choosing 'flox' since {agg.name}")
21982198
return "flox"
21992199

22002200
# numbagg only supports nan-skipping reductions
@@ -2206,14 +2206,14 @@ def _choose_engine(by, agg: Aggregation):
22062206
if agg.name in ["all", "any"] or (
22072207
not_arg_reduce and has_blockwise_nan_skipping and dtype is None
22082208
):
2209-
logger.info("_choose_engine: Choosing 'numbagg'")
2209+
logger.debug("_choose_engine: Choosing 'numbagg'")
22102210
return "numbagg"
22112211

22122212
if not_arg_reduce and (not is_duck_dask_array(by) and _issorted(by)):
2213-
logger.info("_choose_engine: Choosing 'flox'")
2213+
logger.debug("_choose_engine: Choosing 'flox'")
22142214
return "flox"
22152215
else:
2216-
logger.info("_choose_engine: Choosing 'numpy'")
2216+
logger.debug("_choose_engine: Choosing 'numpy'")
22172217
return "numpy"
22182218

22192219

@@ -2389,7 +2389,7 @@ def groupby_reduce(
23892389
if not is_duck_array(array):
23902390
array = np.asarray(array)
23912391
is_bool_array = np.issubdtype(array.dtype, bool)
2392-
array = array.astype(int) if is_bool_array else array
2392+
array = array.astype(np.intp) if is_bool_array else array
23932393

23942394
isbins = _atleast_1d(isbin, nby)
23952395

flox/xrdtypes.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,10 @@ def get_neg_infinity(dtype, min_for_int=False):
123123
-------
124124
fill_value : positive infinity value corresponding to this dtype.
125125
"""
126+
127+
if np.issubdtype(dtype, (np.timedelta64, np.datetime64)):
128+
return dtype.type(np.iinfo(np.int64).min + 1)
129+
126130
if issubclass(dtype.type, np.floating):
127131
return -np.inf
128132

flox/xrutils.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -10,19 +10,6 @@
1010
import pandas as pd
1111
from packaging.version import Version
1212

13-
try:
14-
import cftime
15-
except ImportError:
16-
cftime = None
17-
18-
19-
try:
20-
import dask.array
21-
22-
dask_array_type = dask.array.Array
23-
except ImportError:
24-
dask_array_type = () # type: ignore[assignment, misc]
25-
2613

2714
def module_available(module: str, minversion: Optional[str] = None) -> bool:
2815
"""Checks whether a module is installed without importing it.
@@ -55,6 +42,20 @@ def module_available(module: str, minversion: Optional[str] = None) -> bool:
5542
from numpy.core.numeric import normalize_axis_index # type: ignore[attr-defined]
5643

5744

45+
try:
46+
import cftime
47+
except ImportError:
48+
cftime = None
49+
50+
51+
try:
52+
import dask.array
53+
54+
dask_array_type = dask.array.Array
55+
except ImportError:
56+
dask_array_type = () # type: ignore[assignment, misc]
57+
58+
5859
def asarray(data, xp=np):
5960
return data if is_duck_array(data) else xp.asarray(data)
6061

0 commit comments

Comments
 (0)