Skip to content

Commit 1c10b74

Browse files
authored
Add scans (#370)
* Add scans * grouped reduce * Some fixes. * Updates for ffill * Better ffill * Support numpy * cleanup * more tests * Fix ffill * [WIP] expand tests * Fixes. we need two versions of binary_op * Fix ffill again * Disable cumsum for now. * Fixes. * Fix tests: Remove overflowing test cases, proper fill_value * typing * Fix tests * Try and avoid some roundoff error * Skip float32 for cumsum * fix min deps test * Another fix * Silence warnings * Cleanup * Add docs * fix * bfill * Fix test * hypothesis: Better CI profile * Small change. * Add hypothesis to all envs * Generate chunking along all dimensions * lint * more guards * more guards * fix * Fix typing * cleanup * fix mypy * Add comments
1 parent 04338d4 commit 1c10b74

File tree

15 files changed

+688
-49
lines changed

15 files changed

+688
-49
lines changed

.github/workflows/ci.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ jobs:
6262
- name: Run Tests
6363
id: status
6464
run: |
65-
pytest -n auto --cov=./ --cov-report=xml
65+
pytest -n auto --cov=./ --cov-report=xml --hypothesis-profile ci
6666
- name: Upload code coverage to Codecov
6767
uses: codecov/[email protected]
6868
with:

ci/minimal-requirements.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ channels:
33
- conda-forge
44
dependencies:
55
- codecov
6+
- hypothesis
67
- pip
78
- pytest
89
- pytest-cov

ci/no-dask.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ channels:
44
dependencies:
55
- codecov
66
- pandas
7+
- hypothesis
78
- cftime
89
- numpy>=1.22
910
- scipy

ci/no-numba.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ dependencies:
77
- cftime
88
- codecov
99
- dask-core
10+
- hypothesis
1011
- pandas
1112
- numpy>=1.22
1213
- scipy

docs/source/api.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ Functions
1010
:toctree: generated/
1111

1212
groupby_reduce
13+
groupby_scan
1314
xarray.xarray_reduce
1415

1516
Rechunking
@@ -40,5 +41,7 @@ Aggregation Objects
4041
:toctree: generated/
4142

4243
Aggregation
44+
Scan
45+
4346
aggregations.sum_
4447
aggregations.nansum

docs/source/user-stories/hourly-climatology.html

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14033,7 +14033,7 @@
1403314033
"result = _execute_task(task, cache)\n",
1403414034
"return func(*(_execute_task(a, cache) for a in args))\n",
1403514035
"ret = self.first(*args, **kwargs)\n",
14036-
"group_idx, array = _prepare_for_flox(group_idx, array)\n",
14036+
"group_idx, array, _ = _prepare_for_flox(group_idx, array)\n",
1403714037
"return group_idx, found_groups, grp_shape, ngroups, size, props\n",
1403814038
"",
1403914039
"return _wrapfunc(a, 'searchsorted', v, side=side, sorter=sorter)\n",

flox/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
"""Top-level module for flox ."""
44
from . import cache
55
from .aggregations import Aggregation # noqa
6-
from .core import groupby_reduce, rechunk_for_blockwise, rechunk_for_cohorts # noqa
6+
from .core import groupby_reduce, groupby_scan, rechunk_for_blockwise, rechunk_for_cohorts # noqa
77

88

99
def _get_version():

flox/aggregate_flox.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,12 @@ def _prepare_for_flox(group_idx, array):
1414
issorted = (group_idx[:-1] <= group_idx[1:]).all()
1515
if issorted:
1616
ordered_array = array
17+
perm = slice(None)
1718
else:
1819
perm = group_idx.argsort(kind="stable")
1920
group_idx = group_idx[..., perm]
2021
ordered_array = array[..., perm]
21-
return group_idx, ordered_array
22+
return group_idx, ordered_array, perm
2223

2324

2425
def _lerp(a, b, *, t, dtype, out=None):
@@ -226,3 +227,29 @@ def nanmean(group_idx, array, *, axis=-1, size=None, fill_value=None, dtype=None
226227
with np.errstate(invalid="ignore", divide="ignore"):
227228
out /= nanlen(group_idx, array, size=size, axis=axis, fill_value=0)
228229
return out
230+
231+
232+
def ffill(group_idx, array, *, axis, **kwargs):
233+
group_idx, array, perm = _prepare_for_flox(group_idx, array)
234+
shape = array.shape
235+
ndim = array.ndim
236+
assert axis == (ndim - 1), (axis, ndim - 1)
237+
238+
flag = np.concatenate((np.array([True], like=array), group_idx[1:] != group_idx[:-1]))
239+
(group_starts,) = flag.nonzero()
240+
241+
# https://stackoverflow.com/questions/41190852/most-efficient-way-to-forward-fill-nan-values-in-numpy-array
242+
mask = np.isnan(array)
243+
# modified from the SO answer, just reset the index at the start of every group!
244+
mask[..., np.asarray(group_starts)] = False
245+
246+
idx = np.where(mask, 0, np.arange(shape[axis]))
247+
np.maximum.accumulate(idx, axis=axis, out=idx)
248+
slc = [
249+
np.arange(k)[tuple([slice(None) if dim == i else np.newaxis for dim in range(ndim)])]
250+
for i, k in enumerate(shape)
251+
]
252+
slc[axis] = idx
253+
254+
invert_perm = slice(None) if isinstance(perm, slice) else np.argsort(perm, kind="stable")
255+
return array[tuple(slc)][..., invert_perm]

flox/aggregations.py

Lines changed: 179 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,13 @@
33
import copy
44
import logging
55
import warnings
6+
from collections.abc import Sequence
67
from dataclasses import dataclass
78
from functools import cached_property, partial
89
from typing import TYPE_CHECKING, Any, Callable, Literal, TypedDict
910

1011
import numpy as np
12+
import pandas as pd
1113
from numpy.typing import ArrayLike, DTypeLike
1214

1315
from . import aggregate_flox, aggregate_npg, xrutils
@@ -19,6 +21,7 @@
1921

2022

2123
logger = logging.getLogger("flox")
24+
T_ScanBinaryOpMode = Literal["apply_binary_op", "concat_then_scan"]
2225

2326

2427
def _is_arg_reduction(func: str | Aggregation) -> bool:
@@ -63,6 +66,9 @@ def generic_aggregate(
6366
dtype=None,
6467
**kwargs,
6568
):
69+
if func == "identity":
70+
return array
71+
6672
if engine == "flox":
6773
try:
6874
method = getattr(aggregate_flox, func)
@@ -567,7 +573,171 @@ def quantile_new_dims_func(q) -> tuple[Dim]:
567573
mode = Aggregation(name="mode", fill_value=dtypes.NA, chunk=None, combine=None)
568574
nanmode = Aggregation(name="nanmode", fill_value=dtypes.NA, chunk=None, combine=None)
569575

570-
aggregations = {
576+
577+
@dataclass
578+
class Scan:
579+
# This dataclass is separate from Aggregations since there's not much in common
580+
# between reductions and scans
581+
name: str
582+
# binary operation (e.g. np.add)
583+
# Must be None for mode="concat_then_scan"
584+
binary_op: Callable | None
585+
# in-memory grouped scan function (e.g. cumsum)
586+
scan: str
587+
# Grouped reduction that yields the last result of the scan (e.g. sum)
588+
reduction: str
589+
# Identity element
590+
identity: Any
591+
# dtype of result
592+
dtype: Any = None
593+
# "Mode" of applying binary op.
594+
# for np.add we apply the op directly to the `state` array and the `current` array.
595+
# for ffill, bfill we concat `state` to `current` and then run the scan again.
596+
mode: T_ScanBinaryOpMode = "apply_binary_op"
597+
preprocess: Callable | None = None
598+
finalize: Callable | None = None
599+
600+
601+
def concatenate(arrays: Sequence[AlignedArrays], axis=-1, out=None) -> AlignedArrays:
602+
group_idx = np.concatenate([a.group_idx for a in arrays], axis=axis)
603+
array = np.concatenate([a.array for a in arrays], axis=axis)
604+
return AlignedArrays(array=array, group_idx=group_idx)
605+
606+
607+
@dataclass
608+
class AlignedArrays:
609+
"""Simple Xarray DataArray type data class with two aligned arrays."""
610+
611+
array: np.ndarray
612+
group_idx: np.ndarray
613+
614+
def __post_init__(self):
615+
assert self.array.shape[-1] == self.group_idx.size
616+
617+
def last(self) -> AlignedArrays:
618+
from flox.core import chunk_reduce
619+
620+
reduced = chunk_reduce(
621+
self.array,
622+
self.group_idx,
623+
func=("nanlast",),
624+
axis=-1,
625+
# TODO: automate?
626+
engine="flox",
627+
dtype=self.array.dtype,
628+
fill_value=_get_fill_value(self.array.dtype, dtypes.NA),
629+
expected_groups=None,
630+
)
631+
return AlignedArrays(array=reduced["intermediates"][0], group_idx=reduced["groups"])
632+
633+
634+
@dataclass
635+
class ScanState:
636+
"""Dataclass representing intermediates for scan."""
637+
638+
# last value of each group seen so far
639+
state: AlignedArrays | None
640+
# intermediate result
641+
result: AlignedArrays | None
642+
643+
def __post_init__(self):
644+
assert (self.state is not None) or (self.result is not None)
645+
646+
647+
def reverse(a: AlignedArrays) -> AlignedArrays:
648+
a.group_idx = a.group_idx[::-1]
649+
a.array = a.array[::-1]
650+
return a
651+
652+
653+
def scan_binary_op(left_state: ScanState, right_state: ScanState, *, agg: Scan) -> ScanState:
654+
from .core import reindex_
655+
656+
assert left_state.state is not None
657+
left = left_state.state
658+
right = right_state.result if right_state.result is not None else right_state.state
659+
assert right is not None
660+
661+
if agg.mode == "apply_binary_op":
662+
assert agg.binary_op is not None
663+
# Implements groupby binary operation.
664+
reindexed = reindex_(
665+
left.array,
666+
from_=pd.Index(left.group_idx),
667+
# can't use right.group_idx since we need to do the indexing later
668+
to=pd.RangeIndex(right.group_idx.max() + 1),
669+
fill_value=agg.identity,
670+
axis=-1,
671+
)
672+
result = AlignedArrays(
673+
array=agg.binary_op(reindexed[..., right.group_idx], right.array),
674+
group_idx=right.group_idx,
675+
)
676+
677+
elif agg.mode == "concat_then_scan":
678+
# Implements the binary op portion of the scan as a concatenate-then-scan.
679+
# This is useful for `ffill`, and presumably more generalized scans.
680+
assert agg.binary_op is None
681+
concat = concatenate([left, right], axis=-1)
682+
final_value = generic_aggregate(
683+
concat.group_idx,
684+
concat.array,
685+
func=agg.scan,
686+
axis=concat.array.ndim - 1,
687+
engine="flox",
688+
fill_value=agg.identity,
689+
)
690+
result = AlignedArrays(
691+
array=final_value[..., left.group_idx.size :], group_idx=right.group_idx
692+
)
693+
else:
694+
raise ValueError(f"Unknown binary op application mode: {agg.mode!r}")
695+
696+
# This is quite important. We need to update the state seen so far and propagate that.
697+
# So we must account for what we know when entering this function: i.e. `left`
698+
# TODO: this is a bit wasteful since it will sort again, but for now let's focus on
699+
# correctness and DRY
700+
lasts = concatenate([left, result]).last()
701+
702+
return ScanState(
703+
state=lasts,
704+
# The binary op is called on the results of the reduction too when building up the tree.
705+
# We need to be careful and assign those results only to `state` and not the final result.
706+
# Up above, `result` is privileged when it exists.
707+
result=None if right_state.result is None else result,
708+
)
709+
710+
711+
# TODO: numpy_groupies cumsum is a broken when NaNs are present.
712+
# cumsum = Scan("cumsum", binary_op=np.add, reduction="sum", scan="cumsum", identity=0)
713+
nancumsum = Scan("nancumsum", binary_op=np.add, reduction="nansum", scan="nancumsum", identity=0)
714+
# ffill uses the identity for scan, and then at the binary-op state,
715+
# we concatenate the blockwise-reduced values with the original block,
716+
# and then execute the scan
717+
# TODO: consider adding chunk="identity" here, like with reductions as an optimization
718+
ffill = Scan(
719+
"ffill",
720+
binary_op=None,
721+
reduction="nanlast",
722+
scan="ffill",
723+
identity=np.nan,
724+
mode="concat_then_scan",
725+
)
726+
bfill = Scan(
727+
"bfill",
728+
binary_op=None,
729+
reduction="nanlast",
730+
scan="ffill",
731+
identity=np.nan,
732+
mode="concat_then_scan",
733+
preprocess=reverse,
734+
finalize=reverse,
735+
)
736+
# TODO: not implemented in numpy_groupies
737+
# cumprod = Scan("cumprod", binary_op=np.multiply, preop="prod", scan="cumprod")
738+
739+
740+
AGGREGATIONS: dict[str, Aggregation | Scan] = {
571741
"any": any_,
572742
"all": all_,
573743
"count": count,
@@ -599,6 +769,10 @@ def quantile_new_dims_func(q) -> tuple[Dim]:
599769
"nanquantile": nanquantile,
600770
"mode": mode,
601771
"nanmode": nanmode,
772+
# "cumsum": cumsum,
773+
"nancumsum": nancumsum,
774+
"ffill": ffill,
775+
"bfill": bfill,
602776
}
603777

604778

@@ -610,11 +784,14 @@ def _initialize_aggregation(
610784
min_count: int,
611785
finalize_kwargs: dict[Any, Any] | None,
612786
) -> Aggregation:
787+
agg: Aggregation
613788
if not isinstance(func, Aggregation):
614789
try:
615790
# TODO: need better interface
616791
# we set dtype, fillvalue on reduction later. so deepcopy now
617-
agg = copy.deepcopy(aggregations[func])
792+
agg_ = copy.deepcopy(AGGREGATIONS[func])
793+
assert isinstance(agg_, Aggregation)
794+
agg = agg_
618795
except KeyError:
619796
raise NotImplementedError(f"Reduction {func!r} not implemented yet")
620797
elif isinstance(func, Aggregation):

0 commit comments

Comments
 (0)