Skip to content

Commit 9bd682c

Browse files
authored
Fix nanmax, nanmin bug (#411)
* Add numpy vs dask property test * Fix nanmin, nanmax bug
1 parent 0c4b19f commit 9bd682c

File tree

3 files changed

+51
-6
lines changed

3 files changed

+51
-6
lines changed

flox/aggregations.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -393,15 +393,17 @@ def _std_finalize(sumsq, sum_, count, ddof=0):
393393
"nanmin",
394394
chunk="nanmin",
395395
combine="nanmin",
396-
fill_value=dtypes.NA,
396+
fill_value=dtypes.INF,
397+
final_fill_value=dtypes.NA,
397398
preserves_dtype=True,
398399
)
399400
max_ = Aggregation("max", chunk="max", combine="max", fill_value=dtypes.NINF, preserves_dtype=True)
400401
nanmax = Aggregation(
401402
"nanmax",
402403
chunk="nanmax",
403404
combine="nanmax",
404-
fill_value=dtypes.NA,
405+
fill_value=dtypes.NINF,
406+
final_fill_value=dtypes.NA,
405407
preserves_dtype=True,
406408
)
407409

@@ -845,6 +847,16 @@ def _initialize_aggregation(
845847
# absent in one block, but present in another block
846848
# We set it for numpy to get nansum, nanprod tests to pass
847849
# where the identity element is 0, 1
850+
# Also needed for nanmin, nanmax where intermediate fill_value is +-np.inf,
851+
# but final_fill_value is dtypes.NA
852+
if (
853+
# TODO: this is a total hack, setting a default fill_value
854+
# even though numpy doesn't define identity for nanmin, nanmax
855+
agg.name in ["nanmin", "nanmax"] and min_count == 0
856+
):
857+
min_count = 1
858+
agg.fill_value["user"] = agg.fill_value["user"] or agg.fill_value[agg.name]
859+
848860
if min_count > 0:
849861
agg.min_count = min_count
850862
agg.numpy += ("nanlen",)

flox/core.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1119,7 +1119,6 @@ def _finalize_results(
11191119
agg: Aggregation,
11201120
axis: T_Axes,
11211121
expected_groups: pd.Index | None,
1122-
fill_value: Any,
11231122
reindex: bool,
11241123
) -> FinalResultsDict:
11251124
"""Finalize results by
@@ -1142,6 +1141,7 @@ def _finalize_results(
11421141
else:
11431142
finalized[agg.name] = agg.finalize(*squeezed["intermediates"], **agg.finalize_kwargs)
11441143

1144+
fill_value = agg.fill_value["user"]
11451145
if min_count > 0:
11461146
count_mask = counts < min_count
11471147
if count_mask.any():
@@ -1183,7 +1183,7 @@ def _aggregate(
11831183
) -> FinalResultsDict:
11841184
"""Final aggregation step of tree reduction"""
11851185
results = combine(x_chunk, agg, axis, keepdims, is_aggregate=True)
1186-
return _finalize_results(results, agg, axis, expected_groups, fill_value, reindex)
1186+
return _finalize_results(results, agg, axis, expected_groups, reindex)
11871187

11881188

11891189
def _expand_dims(results: IntermediateDict) -> IntermediateDict:
@@ -1449,7 +1449,7 @@ def _reduce_blockwise(
14491449
if _is_arg_reduction(agg):
14501450
results["intermediates"][0] = np.unravel_index(results["intermediates"][0], array.shape)[-1]
14511451

1452-
result = _finalize_results(results, agg, axis, expected_groups, fill_value=fill_value, reindex=reindex)
1452+
result = _finalize_results(results, agg, axis, expected_groups, reindex=reindex)
14531453
return result
14541454

14551455

@@ -1926,7 +1926,7 @@ def _groupby_combine(a, axis, dummy_axis, dtype, keepdims):
19261926
def _groupby_aggregate(a):
19271927
# Convert cubed dict to one that _finalize_results works with
19281928
results = {"groups": expected_groups, "intermediates": a.values()}
1929-
out = _finalize_results(results, agg, axis, expected_groups, fill_value, reindex)
1929+
out = _finalize_results(results, agg, axis, expected_groups, reindex)
19301930
return out[agg.name]
19311931

19321932
# convert list of dtypes to a structured dtype for cubed

tests/test_properties.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,39 @@ def test_groupby_reduce(data, array, func: str) -> None:
128128
assert_equal(expected, actual, tolerance)
129129

130130

131+
@given(
132+
data=st.data(),
133+
array=chunked_arrays(arrays=numeric_arrays),
134+
func=func_st,
135+
)
136+
def test_groupby_reduce_numpy_vs_dask(data, array, func: str) -> None:
137+
numpy_array = array.compute()
138+
# overflow behaviour differs between bincount and sum (for example)
139+
assume(not_overflowing_array(numpy_array))
140+
# TODO: fix var for complex numbers upstream
141+
assume(not (("quantile" in func or "var" in func or "std" in func) and array.dtype.kind == "c"))
142+
# # arg* with nans in array are weird
143+
assume("arg" not in func and not np.any(np.isnan(numpy_array.ravel())))
144+
if func in ["nanmedian", "nanquantile", "median", "quantile"]:
145+
array = array.rechunk({-1: -1})
146+
147+
axis = -1
148+
by = data.draw(by_arrays(shape=(array.shape[-1],)))
149+
kwargs = {"q": 0.8} if "quantile" in func else {}
150+
flox_kwargs: dict[str, Any] = {}
151+
152+
kwargs = dict(
153+
func=func,
154+
axis=axis,
155+
engine="numpy",
156+
**flox_kwargs,
157+
finalize_kwargs=kwargs,
158+
)
159+
result_dask, *_ = groupby_reduce(array, by, **kwargs)
160+
result_numpy, *_ = groupby_reduce(numpy_array, by, **kwargs)
161+
assert_equal(result_numpy, result_dask)
162+
163+
131164
@given(
132165
data=st.data(),
133166
array=chunked_arrays(arrays=numeric_arrays),

0 commit comments

Comments
 (0)