Skip to content

Commit c72bd32

Browse files
authored
Support cumsum (#451)
1 parent 8bbcf95 commit c72bd32

File tree

6 files changed

+90
-20
lines changed

6 files changed

+90
-20
lines changed

.pre-commit-config.yaml

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,6 @@ repos:
1010
args: ["--fix", "--show-fixes"]
1111
- id: ruff-format
1212

13-
- repo: https://github.com/pre-commit/mirrors-prettier
14-
rev: "v4.0.0-alpha.8"
15-
hooks:
16-
- id: prettier
17-
1813
- repo: https://github.com/pre-commit/pre-commit-hooks
1914
rev: v5.0.0
2015
hooks:

flox/aggregate_flox.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -276,3 +276,39 @@ def ffill(group_idx, array, *, axis, **kwargs):
276276

277277
invert_perm = slice(None) if isinstance(perm, slice) else np.argsort(perm, kind="stable")
278278
return array[tuple(slc)][..., invert_perm]
279+
280+
281+
def _np_grouped_scan(group_idx, array, *, axis: int, skipna: bool, **kwargs):
282+
handle_nans = not skipna and array.dtype.kind in "cfO"
283+
284+
group_idx, array, perm = _prepare_for_flox(group_idx, array)
285+
ndim = array.ndim
286+
assert axis == (ndim - 1), (axis, ndim - 1)
287+
288+
flag = np.concatenate((np.asarray([True], like=group_idx), group_idx[1:] != group_idx[:-1]))
289+
(inv_idx,) = flag.nonzero()
290+
segment_lengths = np.add.reduceat(np.ones(group_idx.shape), inv_idx, dtype=np.int64)
291+
292+
# TODO: set dtype to float properly for handle_nans?
293+
accum = np.nancumsum(array, axis=axis)
294+
295+
if len(inv_idx) > 1:
296+
first_group_idx = inv_idx[1]
297+
# extract cumulative sum _before_ start of group
298+
prev_group_cumsum = accum[..., inv_idx[1:] - 1]
299+
accum[..., first_group_idx:] -= np.repeat(prev_group_cumsum, segment_lengths[1:], axis=axis)
300+
301+
if handle_nans:
302+
mask = isnull(array)
303+
accummask = np.cumsum(mask, axis=-1, dtype=np.uint64)
304+
if len(inv_idx) > 1:
305+
prev_group_cumsum = accummask[..., inv_idx[1:] - 1]
306+
accummask[..., first_group_idx:] -= np.repeat(prev_group_cumsum, segment_lengths[1:], axis=axis)
307+
accum[accummask > 0] = np.nan
308+
309+
invert_perm = slice(None) if isinstance(perm, slice) else np.argsort(perm, kind="stable")
310+
return accum[..., invert_perm]
311+
312+
313+
cumsum = partial(_np_grouped_scan, skipna=False)
314+
nancumsum = partial(_np_grouped_scan, skipna=True)

flox/aggregations.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -631,7 +631,7 @@ def last(self) -> AlignedArrays:
631631
reduced = chunk_reduce(
632632
self.array,
633633
self.group_idx,
634-
func=("nanlast",),
634+
func=("last",),
635635
axis=-1,
636636
# TODO: automate?
637637
engine="flox",
@@ -699,6 +699,7 @@ def scan_binary_op(left_state: ScanState, right_state: ScanState, *, agg: Scan)
699699
fill_value=agg.identity,
700700
)
701701
result = AlignedArrays(array=final_value[..., left.group_idx.size :], group_idx=right.group_idx)
702+
702703
else:
703704
raise ValueError(f"Unknown binary op application mode: {agg.mode!r}")
704705

@@ -717,8 +718,7 @@ def scan_binary_op(left_state: ScanState, right_state: ScanState, *, agg: Scan)
717718
)
718719

719720

720-
# TODO: numpy_groupies cumsum is a broken when NaNs are present.
721-
# cumsum = Scan("cumsum", binary_op=np.add, reduction="sum", scan="cumsum", identity=0)
721+
cumsum = Scan("cumsum", binary_op=np.add, reduction="sum", scan="cumsum", identity=0)
722722
nancumsum = Scan("nancumsum", binary_op=np.add, reduction="nansum", scan="nancumsum", identity=0)
723723
# ffill uses the identity for scan, and then at the binary-op state,
724724
# we concatenate the blockwise-reduced values with the original block,
@@ -782,7 +782,7 @@ def scan_binary_op(left_state: ScanState, right_state: ScanState, *, agg: Scan)
782782
"nanquantile": nanquantile,
783783
"mode": mode,
784784
"nanmode": nanmode,
785-
# "cumsum": cumsum,
785+
"cumsum": cumsum,
786786
"nancumsum": nancumsum,
787787
"ffill": ffill,
788788
"bfill": bfill,

flox/xarray.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations
22

33
from collections.abc import Hashable, Iterable, Sequence
4-
from typing import TYPE_CHECKING, Any
4+
from typing import TYPE_CHECKING, Any, cast
55

66
import numpy as np
77
import pandas as pd
@@ -249,7 +249,7 @@ def xarray_reduce(
249249
grouper_dims.append(d)
250250

251251
if isinstance(obj, xr.Dataset):
252-
ds = obj
252+
ds = cast(xr.Dataset, obj)
253253
else:
254254
ds = obj._to_temp_dataset()
255255

@@ -295,7 +295,7 @@ def xarray_reduce(
295295
not set(grouper_dims).issubset(set(variable.dims)) for variable in ds.data_vars.values()
296296
)
297297
if needs_broadcast:
298-
ds_broad = xr.broadcast(ds, *by_da, exclude=exclude_dims)[0]
298+
ds_broad = cast(xr.Dataset, xr.broadcast(ds, *by_da, exclude=exclude_dims)[0])
299299
else:
300300
ds_broad = ds
301301

tests/test_core.py

Lines changed: 44 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1974,18 +1974,56 @@ def test_nanlen_string(dtype, engine) -> None:
19741974
assert_equal(expected, actual)
19751975

19761976

1977-
def test_cumsum() -> None:
1978-
array = np.array([1, 1, 1], dtype=np.uint64)
1977+
@pytest.mark.parametrize(
1978+
"array",
1979+
[
1980+
np.array([1, 1, 1, 2, 3, 4, 5], dtype=np.uint64),
1981+
np.array([1, 1, 1, 2, np.nan, 4, 5], dtype=np.float64),
1982+
],
1983+
)
1984+
@pytest.mark.parametrize("func", ["cumsum", "nancumsum"])
1985+
def test_cumsum_simple(array, func) -> None:
19791986
by = np.array([0] * array.shape[-1])
1980-
expected = np.nancumsum(array, axis=-1)
1987+
expected = getattr(np, func)(array, axis=-1)
19811988

1982-
actual = groupby_scan(array, by, func="nancumsum", axis=-1)
1983-
assert_equal(expected, actual)
1989+
actual = groupby_scan(array, by, func=func, axis=-1)
1990+
assert_equal(actual, expected)
1991+
1992+
if has_dask:
1993+
da = dask.array.from_array(array, chunks=2)
1994+
actual = groupby_scan(da, by, func=func, axis=-1)
1995+
assert_equal(actual, expected)
1996+
1997+
1998+
def test_cumsum() -> None:
1999+
array = np.array(
2000+
[
2001+
[1, 2, np.nan, 4, 5],
2002+
[3, np.nan, 4, 6, 6],
2003+
]
2004+
)
2005+
by = [0, 1, 1, 0, 1]
2006+
2007+
expected = np.array(
2008+
[
2009+
[1, 2, np.nan, 5, np.nan],
2010+
[3, np.nan, np.nan, 9, np.nan],
2011+
]
2012+
)
2013+
actual = groupby_scan(array, by, func="cumsum", axis=-1)
2014+
assert_equal(actual, expected)
2015+
if has_dask:
2016+
da = dask.array.from_array(array, chunks=2)
2017+
actual = groupby_scan(da, by, func="cumsum", axis=-1)
2018+
assert_equal(actual, expected)
19842019

2020+
expected = np.array([[1, 2, 2, 5, 7], [3, 0, 4, 9, 10]], dtype=np.float64)
2021+
actual = groupby_scan(array, by, func="nancumsum", axis=-1)
2022+
assert_equal(actual, expected)
19852023
if has_dask:
19862024
da = dask.array.from_array(array, chunks=2)
19872025
actual = groupby_scan(da, by, func="nancumsum", axis=-1)
1988-
assert_equal(expected, actual)
2026+
assert_equal(actual, expected)
19892027

19902028

19912029
@pytest.mark.parametrize(

tests/test_properties.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,10 +56,11 @@ def bfill(array, axis, dtype=None):
5656

5757

5858
NUMPY_SCAN_FUNCS: dict[str, Callable] = {
59+
"cumsum": np.cumsum,
5960
"nancumsum": np.nancumsum,
6061
"ffill": ffill,
6162
"bfill": bfill,
62-
} # "cumsum": np.cumsum,
63+
}
6364

6465

6566
def not_overflowing_array(array: np.ndarray[Any, Any]) -> bool:
@@ -210,7 +211,7 @@ def test_groupby_reduce_numpy_vs_other(data, array, func: str) -> None:
210211
array=chunked_arrays(arrays=numeric_like_arrays),
211212
func=st.sampled_from(tuple(NUMPY_SCAN_FUNCS)),
212213
)
213-
def test_scans(data, array: dask.array.Array, func: str) -> None:
214+
def test_scans_against_numpy(data, array: dask.array.Array, func: str) -> None:
214215
if "cum" in func:
215216
assume(not_overflowing_array(np.asarray(array)))
216217

0 commit comments

Comments
 (0)