Skip to content

Commit 996ff2a

Browse files
committed
dask support
1 parent 889be0c commit 996ff2a

File tree

2 files changed

+21
-6
lines changed

2 files changed

+21
-6
lines changed

flox/aggregations.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -586,8 +586,8 @@ def topk_new_dims_func(k) -> tuple[Dim]:
586586
topk = Aggregation(
587587
name="topk",
588588
fill_value=dtypes.NINF,
589-
chunk=None,
590-
combine=None,
589+
chunk="topk",
590+
combine=xrutils.topk,
591591
final_dtype=None,
592592
new_dims_func=topk_new_dims_func,
593593
)
@@ -890,10 +890,7 @@ def _initialize_aggregation(
890890
simple_combine: list[Callable | None] = []
891891
for combine in agg.combine:
892892
if isinstance(combine, str):
893-
if combine in ["nanfirst", "nanlast"]:
894-
simple_combine.append(getattr(xrutils, combine))
895-
else:
896-
simple_combine.append(getattr(np, combine))
893+
simple_combine.append(getattr(np, combine))
897894
else:
898895
simple_combine.append(combine)
899896

flox/xrutils.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -378,3 +378,21 @@ def nanlast(values, axis, keepdims=False):
378378
return np.expand_dims(result, axis=axis)
379379
else:
380380
return result
381+
382+
383+
def topk(a, k, axis, keepdims):
384+
"""Chunk and combine function of topk
385+
386+
Extract the k largest elements from a on the given axis.
387+
If k is negative, extract the -k smallest elements instead.
388+
Note that, unlike in the parent function, the returned elements
389+
are not sorted internally.
390+
"""
391+
assert keepdims is True
392+
axis = axis[0]
393+
if abs(k) >= a.shape[axis]:
394+
return a
395+
396+
a = np.partition(a, -k, axis=axis)
397+
k_slice = slice(-k, None) if k > 0 else slice(-k)
398+
return a[tuple(k_slice if i == axis else slice(None) for i in range(a.ndim))]

0 commit comments

Comments
 (0)