Skip to content

Commit 93800aa

Browse files
committed
Handle dtypes.NA properly for datetime/timedelta
1 parent 4b04fde commit 93800aa

File tree

2 files changed

+12
-11
lines changed

2 files changed

+12
-11
lines changed

flox/aggregations.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -158,12 +158,12 @@ def _get_fill_value(dtype, fill_value):
158158
return np.nan
159159
# This is madness, but npg checks that fill_value is compatible
160160
# with array dtype even if the fill_value is never used.
161-
elif (
162-
np.issubdtype(dtype, np.integer)
163-
or np.issubdtype(dtype, np.timedelta64)
164-
or np.issubdtype(dtype, np.datetime64)
165-
):
161+
elif np.issubdtype(dtype, np.integer):
166162
return dtypes.get_neg_infinity(dtype, min_for_int=True)
163+
elif np.issubdtype(dtype, np.timedelta64):
164+
return np.timedelta64("NaT")
165+
elif np.issubdtype(dtype, np.datetime64):
166+
return np.datetime64("NaT")
167167
else:
168168
return None
169169
return fill_value
@@ -435,9 +435,9 @@ def _std_finalize(sumsq, sum_, count, ddof=0):
435435

436436

437437
min_ = Aggregation("min", chunk="min", combine="min", fill_value=dtypes.INF)
438-
nanmin = Aggregation("nanmin", chunk="nanmin", combine="nanmin", fill_value=np.nan)
438+
nanmin = Aggregation("nanmin", chunk="nanmin", combine="nanmin", fill_value=dtypes.NA)
439439
max_ = Aggregation("max", chunk="max", combine="max", fill_value=dtypes.NINF)
440-
nanmax = Aggregation("nanmax", chunk="nanmax", combine="nanmax", fill_value=np.nan)
440+
nanmax = Aggregation("nanmax", chunk="nanmax", combine="nanmax", fill_value=dtypes.NA)
441441

442442

443443
def argreduce_preprocess(array, axis):
@@ -741,15 +741,15 @@ def scan_binary_op(left_state: ScanState, right_state: ScanState, *, agg: Scan)
741741
binary_op=None,
742742
reduction="nanlast",
743743
scan="ffill",
744-
identity=np.nan,
744+
identity=dtypes.NA,
745745
mode="concat_then_scan",
746746
)
747747
bfill = Scan(
748748
"bfill",
749749
binary_op=None,
750750
reduction="nanlast",
751751
scan="ffill",
752-
identity=np.nan,
752+
identity=dtypes.NA,
753753
mode="concat_then_scan",
754754
preprocess=reverse,
755755
finalize=reverse,

tests/test_properties.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -218,10 +218,11 @@ def test_first_last(data, array: dask.array.Array, func: str) -> None:
218218
@settings(report_multiple_bugs=False)
219219
@given(data=st.data(), array=chunked_arrays())
220220
def test_topk_max_min(data, array):
221-
"top 1 == max; top -1 == min"
221+
"top 1 == nanmax; top -1 == nanmin"
222222
size = array.shape[-1]
223+
note(array.compute())
223224
by = data.draw(by_arrays(shape=(size,)))
224-
k, npfunc = data.draw(st.sampled_from([(1, "max"), (-1, "min")]))
225+
k, npfunc = data.draw(st.sampled_from([(1, "nanmax"), (-1, "nanmin")]))
225226

226227
for a in (array, array.compute()):
227228
actual, _ = groupby_reduce(a, by, func="topk", finalize_kwargs={"k": k})

0 commit comments

Comments
 (0)