@@ -2788,20 +2788,29 @@ def groupby_scan(
2788
2788
if by_ .shape [- 1 ] == 1 or by_ .shape == grp_shape :
2789
2789
return array .astype (agg .dtype )
2790
2790
2791
+ # Made a design choice here to have `preprocess` handle both array and group_idx
2792
+ # Example: for reversing, we need to reverse the whole array, not just reverse
2793
+ # each block independently
2794
+ inp = AlignedArrays (array = array , group_idx = by_ )
2795
+ if agg .preprocess :
2796
+ inp = agg .preprocess (inp )
2797
+
2791
2798
if not has_dask :
2792
- final_state = chunk_scan (
2793
- AlignedArrays (array = array , group_idx = by_ ), axis = single_axis , agg = agg , dtype = agg .dtype
2794
- )
2795
- return extract_array (final_state )
2799
+ final_state = chunk_scan (inp , axis = single_axis , agg = agg , dtype = agg .dtype )
2800
+ result = _finalize_scan (final_state )
2796
2801
else :
2797
- return dask_groupby_scan (array , by_ , axes = axis_ , agg = agg )
2802
+ result = dask_groupby_scan (inp .array , inp .group_idx , axes = axis_ , agg = agg )
2803
+
2804
+ # Made a design choice here to have `postprocess` handle both array and group_idx
2805
+ out = AlignedArrays (array = result , group_idx = by_ )
2806
+ if agg .finalize :
2807
+ out = agg .finalize (out )
2808
+ return out .array
2798
2809
2799
2810
2800
2811
def chunk_scan (inp : AlignedArrays , * , axis : int , agg : Scan , dtype = None , keepdims = None ) -> ScanState :
2801
2812
assert axis == inp .array .ndim - 1
2802
2813
2803
- if agg .preprocess :
2804
- inp = agg .preprocess (inp )
2805
2814
# I don't think we need to re-factorize here unless we are grouping by a dask array
2806
2815
accumulated = generic_aggregate (
2807
2816
inp .group_idx ,
@@ -2813,8 +2822,6 @@ def chunk_scan(inp: AlignedArrays, *, axis: int, agg: Scan, dtype=None, keepdims
2813
2822
fill_value = agg .identity ,
2814
2823
)
2815
2824
result = AlignedArrays (array = accumulated , group_idx = inp .group_idx )
2816
- if agg .finalize :
2817
- result = agg .finalize (result )
2818
2825
return ScanState (result = result , state = None )
2819
2826
2820
2827
@@ -2840,10 +2847,9 @@ def _zip(group_idx: np.ndarray, array: np.ndarray) -> AlignedArrays:
2840
2847
return AlignedArrays (group_idx = group_idx , array = array )
2841
2848
2842
2849
2843
- def extract_array (block : ScanState , finalize : Callable | None = None ) -> np .ndarray :
2850
+ def _finalize_scan (block : ScanState ) -> np .ndarray :
2844
2851
assert block .result is not None
2845
- result = finalize (block .result ) if finalize is not None else block .result
2846
- return result .array
2852
+ return block .result .array
2847
2853
2848
2854
2849
2855
def dask_groupby_scan (array , by , axes : T_Axes , agg : Scan ) -> DaskArray :
@@ -2859,9 +2865,8 @@ def dask_groupby_scan(array, by, axes: T_Axes, agg: Scan) -> DaskArray:
2859
2865
array , by = _unify_chunks (array , by )
2860
2866
2861
2867
# 1. zip together group indices & array
2862
- to_map = _zip if agg .preprocess is None else tlz .compose (agg .preprocess , _zip )
2863
2868
zipped = map_blocks (
2864
- to_map , by , array , dtype = array .dtype , meta = array ._meta , name = "groupby-scan-preprocess"
2869
+ _zip , by , array , dtype = array .dtype , meta = array ._meta , name = "groupby-scan-preprocess"
2865
2870
)
2866
2871
2867
2872
scan_ = partial (chunk_scan , agg = agg )
@@ -2882,7 +2887,7 @@ def dask_groupby_scan(array, by, axes: T_Axes, agg: Scan) -> DaskArray:
2882
2887
)
2883
2888
2884
2889
# 3. Unzip and extract the final result array, discard groups
2885
- result = map_blocks (extract_array , accumulated , dtype = agg .dtype , finalize = agg . finalize )
2890
+ result = map_blocks (_finalize_scan , accumulated , dtype = agg .dtype )
2886
2891
2887
2892
assert result .chunks == array .chunks
2888
2893
0 commit comments