7
7
import pandas as pd
8
8
import xarray as xr
9
9
from packaging .version import Version
10
+ from xarray .core .duck_array_ops import _datetime_nanmin
10
11
11
12
from .aggregations import Aggregation , Dim , _atleast_1d , quantile_new_dims_func
12
13
from .core import (
17
18
)
18
19
from .core import rechunk_for_blockwise as rechunk_array_for_blockwise
19
20
from .core import rechunk_for_cohorts as rechunk_array_for_cohorts
21
+ from .xrutils import _contains_cftime_datetimes , _to_pytimedelta , datetime_to_numeric
20
22
21
23
if TYPE_CHECKING :
22
24
from xarray .core .types import T_DataArray , T_Dataset
@@ -364,6 +366,22 @@ def wrapper(array, *by, func, skipna, core_dims, **kwargs):
364
366
if "nan" not in func and func not in ["all" , "any" , "count" ]:
365
367
func = f"nan{ func } "
366
368
369
+ # Flox's count works with non-numeric and its faster than converting.
370
+ requires_numeric = func not in ["count" , "any" , "all" ] or (
371
+ func == "count" and kwargs ["engine" ] != "flox"
372
+ )
373
+ if requires_numeric :
374
+ is_npdatetime = array .dtype .kind in "Mm"
375
+ is_cftime = _contains_cftime_datetimes (array )
376
+ if is_npdatetime :
377
+ offset = _datetime_nanmin (array )
378
+ # xarray always uses np.datetime64[ns] for np.datetime64 data
379
+ dtype = "timedelta64[ns]"
380
+ array = datetime_to_numeric (array , offset )
381
+ elif is_cftime :
382
+ offset = array .min ()
383
+ array = datetime_to_numeric (array , offset , datetime_unit = "us" )
384
+
367
385
result , * groups = groupby_reduce (array , * by , func = func , ** kwargs )
368
386
369
387
# Transpose the new quantile dimension to the end. This is ugly.
@@ -377,6 +395,13 @@ def wrapper(array, *by, func, skipna, core_dims, **kwargs):
377
395
# output dim order: (*broadcast_dims, *group_dims, quantile_dim)
378
396
result = np .moveaxis (result , 0 , - 1 )
379
397
398
+ # Output of count has an int dtype.
399
+ if requires_numeric and func != "count" :
400
+ if is_npdatetime :
401
+ return result .astype (dtype ) + offset
402
+ elif is_cftime :
403
+ return _to_pytimedelta (result , unit = "us" ) + offset
404
+
380
405
return result
381
406
382
407
# These data variables do not have any of the core dimension,
0 commit comments