Skip to content

Commit 573dd75

Browse files
committed
Fix ffill, bfill bugs
1 parent 38af8d7 commit 573dd75

File tree

4 files changed

+62
-28
lines changed

4 files changed

+62
-28
lines changed

flox/aggregations.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -654,8 +654,8 @@ def __post_init__(self):
654654

655655

656656
def reverse(a: AlignedArrays) -> AlignedArrays:
657-
a.group_idx = a.group_idx[::-1]
658-
a.array = a.array[::-1]
657+
a.group_idx = a.group_idx[..., ::-1]
658+
a.array = a.array[..., ::-1]
659659
return a
660660

661661

flox/core.py

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2788,20 +2788,29 @@ def groupby_scan(
27882788
if by_.shape[-1] == 1 or by_.shape == grp_shape:
27892789
return array.astype(agg.dtype)
27902790

2791+
# Made a design choice here to have `preprocess` handle both array and group_idx
2792+
# Example: for reversing, we need to reverse the whole array, not just reverse
2793+
# each block independently
2794+
inp = AlignedArrays(array=array, group_idx=by_)
2795+
if agg.preprocess:
2796+
inp = agg.preprocess(inp)
2797+
27912798
if not has_dask:
2792-
final_state = chunk_scan(
2793-
AlignedArrays(array=array, group_idx=by_), axis=single_axis, agg=agg, dtype=agg.dtype
2794-
)
2795-
return extract_array(final_state)
2799+
final_state = chunk_scan(inp, axis=single_axis, agg=agg, dtype=agg.dtype)
2800+
result = _finalize_scan(final_state)
27962801
else:
2797-
return dask_groupby_scan(array, by_, axes=axis_, agg=agg)
2802+
result = dask_groupby_scan(inp.array, inp.group_idx, axes=axis_, agg=agg)
2803+
2804+
# Made a design choice here to have `postprocess` handle both array and group_idx
2805+
out = AlignedArrays(array=result, group_idx=by_)
2806+
if agg.finalize:
2807+
out = agg.finalize(out)
2808+
return out.array
27982809

27992810

28002811
def chunk_scan(inp: AlignedArrays, *, axis: int, agg: Scan, dtype=None, keepdims=None) -> ScanState:
28012812
assert axis == inp.array.ndim - 1
28022813

2803-
if agg.preprocess:
2804-
inp = agg.preprocess(inp)
28052814
# I don't think we need to re-factorize here unless we are grouping by a dask array
28062815
accumulated = generic_aggregate(
28072816
inp.group_idx,
@@ -2813,8 +2822,6 @@ def chunk_scan(inp: AlignedArrays, *, axis: int, agg: Scan, dtype=None, keepdims
28132822
fill_value=agg.identity,
28142823
)
28152824
result = AlignedArrays(array=accumulated, group_idx=inp.group_idx)
2816-
if agg.finalize:
2817-
result = agg.finalize(result)
28182825
return ScanState(result=result, state=None)
28192826

28202827

@@ -2840,10 +2847,9 @@ def _zip(group_idx: np.ndarray, array: np.ndarray) -> AlignedArrays:
28402847
return AlignedArrays(group_idx=group_idx, array=array)
28412848

28422849

2843-
def extract_array(block: ScanState, finalize: Callable | None = None) -> np.ndarray:
2850+
def _finalize_scan(block: ScanState) -> np.ndarray:
28442851
assert block.result is not None
2845-
result = finalize(block.result) if finalize is not None else block.result
2846-
return result.array
2852+
return block.result.array
28472853

28482854

28492855
def dask_groupby_scan(array, by, axes: T_Axes, agg: Scan) -> DaskArray:
@@ -2859,9 +2865,8 @@ def dask_groupby_scan(array, by, axes: T_Axes, agg: Scan) -> DaskArray:
28592865
array, by = _unify_chunks(array, by)
28602866

28612867
# 1. zip together group indices & array
2862-
to_map = _zip if agg.preprocess is None else tlz.compose(agg.preprocess, _zip)
28632868
zipped = map_blocks(
2864-
to_map, by, array, dtype=array.dtype, meta=array._meta, name="groupby-scan-preprocess"
2869+
_zip, by, array, dtype=array.dtype, meta=array._meta, name="groupby-scan-preprocess"
28652870
)
28662871

28672872
scan_ = partial(chunk_scan, agg=agg)
@@ -2882,7 +2887,7 @@ def dask_groupby_scan(array, by, axes: T_Axes, agg: Scan) -> DaskArray:
28822887
)
28832888

28842889
# 3. Unzip and extract the final result array, discard groups
2885-
result = map_blocks(extract_array, accumulated, dtype=agg.dtype, finalize=agg.finalize)
2890+
result = map_blocks(_finalize_scan, accumulated, dtype=agg.dtype)
28862891

28872892
assert result.chunks == array.chunks
28882893

tests/test_core.py

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -187,11 +187,11 @@ def test_groupby_reduce(
187187
assert_equal(expected_result, result)
188188

189189

190-
def gen_array_by(size, func):
190+
def gen_array_by(size, func: str):
191191
by = np.ones(size[-1])
192192
rng = np.random.default_rng(12345)
193193
array = rng.random(tuple(6 if s == 1 else s for s in size))
194-
if "nan" in func and "nanarg" not in func:
194+
if ("nan" in func or "fill" in func) and "nanarg" not in func:
195195
array[[1, 4, 5], ...] = np.nan
196196
elif "nanarg" in func and len(size) > 1:
197197
array[[1, 4, 5], 1] = np.nan
@@ -1810,7 +1810,7 @@ def test_nanlen_string(dtype, engine):
18101810
assert_equal(expected, actual)
18111811

18121812

1813-
def test_scans():
1813+
def test_cumusm():
18141814
array = np.array([1, 1, 1], dtype=np.uint64)
18151815
by = np.array([0] * array.shape[-1])
18161816
kwargs = {"func": "nancumsum", "axis": -1}
@@ -1823,3 +1823,27 @@ def test_scans():
18231823
da = dask.array.from_array(array, chunks=2)
18241824
actual = groupby_scan(da, by, **kwargs)
18251825
assert_equal(expected, actual)
1826+
1827+
1828+
@pytest.mark.parametrize(
1829+
"chunks",
1830+
[
1831+
pytest.param(-1, marks=requires_dask),
1832+
pytest.param(3, marks=requires_dask),
1833+
pytest.param(4, marks=requires_dask),
1834+
],
1835+
)
1836+
@pytest.mark.parametrize("size", ((1, 12), (12,), (12, 9)))
1837+
@pytest.mark.parametrize("add_nan_by", [True, False])
1838+
@pytest.mark.parametrize("func", ["ffill", "bfill"])
1839+
def test_ffill_bfill(chunks, size, add_nan_by, func):
1840+
array, by = gen_array_by(size, func)
1841+
if chunks:
1842+
array = dask.array.from_array(array, chunks=chunks)
1843+
if add_nan_by:
1844+
by[0:3] = np.nan
1845+
by = tuple(by)
1846+
1847+
expected = flox.groupby_scan(array.compute(), by, func=func)
1848+
actual = flox.groupby_scan(array, by, func=func)
1849+
assert_equal(expected, actual)

tests/test_properties.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -159,14 +159,19 @@ def test_ffill_bfill_reverse(data, array: dask.array.Array) -> None:
159159
def reverse(arr):
160160
return arr[..., ::-1]
161161

162-
for a in (array, array.compute()):
163-
forward = groupby_scan(a, by, func="ffill")
164-
backward_reversed = reverse(groupby_scan(reverse(a), reverse(by), func="bfill"))
165-
assert_equal(forward, backward_reversed)
166-
167-
backward = groupby_scan(a, by, func="bfill")
168-
forward_reversed = reverse(groupby_scan(reverse(a), reverse(by), func="ffill"))
169-
assert_equal(forward_reversed, backward)
162+
forward = groupby_scan(array, by, func="ffill")
163+
as_numpy = groupby_scan(array.compute(), by, func="ffill")
164+
assert_equal(forward, as_numpy)
165+
166+
backward = groupby_scan(array, by, func="bfill")
167+
as_numpy = groupby_scan(array.compute(), by, func="bfill")
168+
assert_equal(backward, as_numpy)
169+
170+
backward_reversed = reverse(groupby_scan(reverse(array), reverse(by), func="bfill"))
171+
assert_equal(forward, backward_reversed)
172+
173+
forward_reversed = reverse(groupby_scan(reverse(array), reverse(by), func="ffill"))
174+
assert_equal(forward_reversed, backward)
170175

171176

172177
@given(

0 commit comments

Comments
 (0)