86
86
DUMMY_AXIS = - 2
87
87
88
88
89
+ def _postprocess_numbagg (result , * , func , fill_value , size , seen_groups ):
90
+ """Account for numbagg not providing a fill_value kwarg."""
91
+ from .aggregate_numbagg import DEFAULT_FILL_VALUE
92
+
93
+ if not isinstance (func , str ) or func not in DEFAULT_FILL_VALUE :
94
+ return result
95
+ # The condition needs to be
96
+ # len(found_groups) < size; if so we mask with fill_value (?)
97
+ default_fv = DEFAULT_FILL_VALUE [func ]
98
+ needs_masking = fill_value is not None and not np .array_equal (
99
+ fill_value , default_fv , equal_nan = True
100
+ )
101
+ groups = np .arange (size )
102
+ if needs_masking :
103
+ mask = np .isin (groups , seen_groups , assume_unique = True , invert = True )
104
+ if mask .any ():
105
+ result [..., groups [mask ]] = fill_value
106
+ return result
107
+
108
+
89
109
def _issorted (arr : np .ndarray ) -> bool :
90
110
return bool ((arr [:- 1 ] <= arr [1 :]).all ())
91
111
@@ -780,7 +800,11 @@ def chunk_reduce(
780
800
group_idx , grps , found_groups_shape , _ , size , props = factorize_ (
781
801
(by ,), axes , expected_groups = (expected_groups ,), reindex = reindex , sort = sort
782
802
)
783
- groups = grps [0 ]
803
+ (groups ,) = grps
804
+
805
+ # do this *before* possible broadcasting below.
806
+ # factorize_ has already taken care of offsetting
807
+ seen_groups = _unique (group_idx )
784
808
785
809
order = "C"
786
810
if nax > 1 :
@@ -850,6 +874,16 @@ def chunk_reduce(
850
874
result = generic_aggregate (
851
875
group_idx , array , axis = - 1 , engine = engine , func = reduction , ** kw_func
852
876
).astype (dt , copy = False )
877
+ if engine == "numbagg" :
878
+ result = _postprocess_numbagg (
879
+ result ,
880
+ func = reduction ,
881
+ size = size ,
882
+ fill_value = fv ,
883
+ # Unfortunately, we cannot reuse found_groups, it has not
884
+ # been "offset" and is really expected_groups in nearly all cases
885
+ seen_groups = seen_groups ,
886
+ )
853
887
if np .any (props .nanmask ):
854
888
# remove NaN group label which should be last
855
889
result = result [..., :- 1 ]
@@ -1053,6 +1087,8 @@ def _grouped_combine(
1053
1087
"""Combine intermediates step of tree reduction."""
1054
1088
from dask .utils import deepmap
1055
1089
1090
+ combine = agg .combine
1091
+
1056
1092
if isinstance (x_chunk , dict ):
1057
1093
# Only one block at final step; skip one extra groupby
1058
1094
return x_chunk
@@ -1093,7 +1129,8 @@ def _grouped_combine(
1093
1129
results = chunk_argreduce (
1094
1130
array_idx ,
1095
1131
groups ,
1096
- func = agg .combine [slicer ], # count gets treated specially next
1132
+ # count gets treated specially next
1133
+ func = combine [slicer ], # type: ignore[arg-type]
1097
1134
axis = axis ,
1098
1135
expected_groups = None ,
1099
1136
fill_value = agg .fill_value ["intermediate" ][slicer ],
@@ -1127,9 +1164,10 @@ def _grouped_combine(
1127
1164
elif agg .reduction_type == "reduce" :
1128
1165
# Here we reduce the intermediates individually
1129
1166
results = {"groups" : None , "intermediates" : []}
1130
- for idx , (combine , fv , dtype ) in enumerate (
1131
- zip (agg . combine , agg .fill_value ["intermediate" ], agg .dtype ["intermediate" ])
1167
+ for idx , (combine_ , fv , dtype ) in enumerate (
1168
+ zip (combine , agg .fill_value ["intermediate" ], agg .dtype ["intermediate" ])
1132
1169
):
1170
+ assert combine_ is not None
1133
1171
array = _conc2 (x_chunk , key1 = "intermediates" , key2 = idx , axis = axis )
1134
1172
if array .shape [- 1 ] == 0 :
1135
1173
# all empty when combined
@@ -1143,7 +1181,7 @@ def _grouped_combine(
1143
1181
_results = chunk_reduce (
1144
1182
array ,
1145
1183
groups ,
1146
- func = combine ,
1184
+ func = combine_ ,
1147
1185
axis = axis ,
1148
1186
expected_groups = None ,
1149
1187
fill_value = (fv ,),
@@ -1788,8 +1826,13 @@ def _choose_engine(by, agg: Aggregation):
1788
1826
1789
1827
# numbagg only supports nan-skipping reductions
1790
1828
# without dtype specified
1791
- if HAS_NUMBAGG and "nan" in agg .name :
1792
- if not_arg_reduce and dtype is None :
1829
+ has_blockwise_nan_skipping = (agg .chunk [0 ] is None and "nan" in agg .name ) or any (
1830
+ (isinstance (func , str ) and "nan" in func ) for func in agg .chunk
1831
+ )
1832
+ if HAS_NUMBAGG :
1833
+ if agg .name in ["all" , "any" ] or (
1834
+ not_arg_reduce and has_blockwise_nan_skipping and dtype is None
1835
+ ):
1793
1836
return "numbagg"
1794
1837
1795
1838
if not_arg_reduce and (not is_duck_dask_array (by ) and _issorted (by )):
@@ -2050,7 +2093,7 @@ def groupby_reduce(
2050
2093
nax = len (axis_ )
2051
2094
2052
2095
# When axis is a subset of possible values; then npg will
2053
- # apply it to groups that don't exist along a particular axis (for e.g.)
2096
+ # apply the fill_value to groups that don't exist along a particular axis (for e.g.)
2054
2097
# since these count as a group that is absent. thoo!
2055
2098
# fill_value applies to all-NaN groups as well as labels in expected_groups that are not found.
2056
2099
# The only way to do this consistently is mask out using min_count
@@ -2090,8 +2133,7 @@ def groupby_reduce(
2090
2133
# TODO: How else to narrow that array.chunks is there?
2091
2134
assert isinstance (array , DaskArray )
2092
2135
2093
- # TODO: fix typing of FuncTuple in Aggregation
2094
- if agg .chunk [0 ] is None and method != "blockwise" : # type: ignore[unreachable]
2136
+ if agg .chunk [0 ] is None and method != "blockwise" :
2095
2137
raise NotImplementedError (
2096
2138
f"Aggregation { agg .name !r} is only implemented for dask arrays when method='blockwise'."
2097
2139
f"Received method={ method !r} "
0 commit comments