Skip to content

Commit 5564c4f

Browse files
committed
dask support
1 parent 275f574 commit 5564c4f

File tree

2 files changed

+23
-8
lines changed

2 files changed

+23
-8
lines changed

flox/aggregations.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -518,8 +518,8 @@ def _pick_second(*x):
518518

519519
first = Aggregation("first", chunk=None, combine=None, fill_value=0)
520520
last = Aggregation("last", chunk=None, combine=None, fill_value=0)
521-
nanfirst = Aggregation("nanfirst", chunk="nanfirst", combine="nanfirst", fill_value=np.nan)
522-
nanlast = Aggregation("nanlast", chunk="nanlast", combine="nanlast", fill_value=np.nan)
521+
nanfirst = Aggregation("nanfirst", chunk="nanfirst", combine=xrutils.nanfirst, fill_value=np.nan)
522+
nanlast = Aggregation("nanlast", chunk="nanlast", combine=xrutils.nanlast, fill_value=np.nan)
523523

524524
all_ = Aggregation(
525525
"all",
@@ -577,8 +577,8 @@ def topk_new_dims_func(k) -> tuple[Dim]:
577577
topk = Aggregation(
578578
name="topk",
579579
fill_value=dtypes.NINF,
580-
chunk=None,
581-
combine=None,
580+
chunk="topk",
581+
combine=xrutils.topk,
582582
final_dtype=None,
583583
new_dims_func=topk_new_dims_func,
584584
)
@@ -881,10 +881,7 @@ def _initialize_aggregation(
881881
simple_combine: list[Callable | None] = []
882882
for combine in agg.combine:
883883
if isinstance(combine, str):
884-
if combine in ["nanfirst", "nanlast"]:
885-
simple_combine.append(getattr(xrutils, combine))
886-
else:
887-
simple_combine.append(getattr(np, combine))
884+
simple_combine.append(getattr(np, combine))
888885
else:
889886
simple_combine.append(combine)
890887

flox/xrutils.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -378,3 +378,21 @@ def nanlast(values, axis, keepdims=False):
378378
return np.expand_dims(result, axis=axis)
379379
else:
380380
return result
381+
382+
383+
def topk(a, k, axis, keepdims):
384+
"""Chunk and combine function of topk
385+
386+
Extract the k largest elements from a on the given axis.
387+
If k is negative, extract the -k smallest elements instead.
388+
Note that, unlike in the parent function, the returned elements
389+
are not sorted internally.
390+
"""
391+
assert keepdims is True
392+
axis = axis[0]
393+
if abs(k) >= a.shape[axis]:
394+
return a
395+
396+
a = np.partition(a, -k, axis=axis)
397+
k_slice = slice(-k, None) if k > 0 else slice(-k)
398+
return a[tuple(k_slice if i == axis else slice(None) for i in range(a.ndim))]

0 commit comments

Comments
 (0)