Skip to content

Commit 728a2b9

Browse files
committed
Add topk
1 parent 1c10b74 commit 728a2b9

File tree

4 files changed

+113
-53
lines changed

4 files changed

+113
-53
lines changed

flox/aggregate_flox.py

Lines changed: 77 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -46,74 +46,107 @@ def _lerp(a, b, *, t, dtype, out=None):
4646
return out
4747

4848

49-
def quantile_(array, inv_idx, *, q, axis, skipna, group_idx, dtype=None, out=None):
50-
inv_idx = np.concatenate((inv_idx, [array.shape[-1]]))
49+
def quantile_or_topk(
50+
array, inv_idx, *, q=None, k=None, axis, skipna, group_idx, dtype=None, out=None
51+
):
52+
assert q or k
5153

52-
array_nanmask = isnull(array)
53-
actual_sizes = np.add.reduceat(~array_nanmask, inv_idx[:-1], axis=axis)
54-
newshape = (1,) * (array.ndim - 1) + (inv_idx.size - 1,)
55-
full_sizes = np.reshape(np.diff(inv_idx), newshape)
56-
nanmask = full_sizes != actual_sizes
54+
inv_idx = np.concatenate((inv_idx, [array.shape[-1]]))
5755

58-
# The approach here is to use (complex_array.partition) because
56+
# The approach for quantiles and topk, both of which are basically grouped partition,
57+
# here is to use (complex_array.partition) because
5958
# 1. The full np.lexsort((array, labels), axis=-1) is slow and unnecessary
6059
# 2. Using record_array.partition(..., order=["labels", "array"]) is incredibly slow.
61-
# partition will first sort by real part, then by imaginary part, so it is a two element lex-partition.
62-
# So we set
60+
# partition will first sort by real part, then by imaginary part, so it is a two element
61+
# lex-partition. Therefore we set
6362
# complex_array = group_idx + 1j * array
6463
# group_idx is an integer (guaranteed), but array can have NaNs. Now,
6564
# 1 + 1j*NaN = NaN + 1j * NaN
6665
# so we must replace all NaNs with the maximum array value in the group so these NaNs
6766
# get sorted to the end.
67+
68+
# Replace NaNs with the maximum value for each group.
6869
# Partly inspired by https://krstn.eu/np.nanpercentile()-there-has-to-be-a-faster-way/
69-
# TODO: Don't know if this array has been copied in _prepare_for_flox. This is potentially wasteful
70+
array_nanmask = isnull(array)
71+
actual_sizes = np.add.reduceat(~array_nanmask, inv_idx[:-1], axis=axis)
72+
newshape = (1,) * (array.ndim - 1) + (inv_idx.size - 1,)
73+
full_sizes = np.reshape(np.diff(inv_idx), newshape)
74+
nanmask = full_sizes != actual_sizes
75+
# TODO: Don't know if this array has been copied in _prepare_for_flox.
76+
# This is potentially wasteful
7077
array = np.where(array_nanmask, -np.inf, array)
7178
maxes = np.maximum.reduceat(array, inv_idx[:-1], axis=axis)
7279
replacement = np.repeat(maxes, np.diff(inv_idx), axis=axis)
7380
array[array_nanmask] = replacement[array_nanmask]
7481

75-
qin = q
76-
q = np.atleast_1d(qin)
77-
q = np.reshape(q, (len(q),) + (1,) * array.ndim)
78-
79-
# This is numpy's method="linear"
80-
# TODO: could support all the interpolations here
81-
virtual_index = q * (actual_sizes - 1) + inv_idx[:-1]
82+
param = q or k
83+
if k is not None:
84+
assert k > 0
85+
is_scalar_param = False
86+
param = np.arange(k)
87+
else:
88+
is_scalar_param = is_scalar(q)
89+
param = np.atleast_1d(param)
90+
param = np.reshape(param, (param.size,) + (1,) * array.ndim)
8291

83-
is_scalar_q = is_scalar(qin)
84-
if is_scalar_q:
85-
virtual_index = virtual_index.squeeze(axis=0)
92+
if is_scalar_param:
8693
idxshape = array.shape[:-1] + (actual_sizes.shape[-1],)
8794
else:
88-
idxshape = (q.shape[0],) + array.shape[:-1] + (actual_sizes.shape[-1],)
95+
idxshape = (param.shape[0],) + array.shape[:-1] + (actual_sizes.shape[-1],)
8996

90-
lo_ = np.floor(
91-
virtual_index, casting="unsafe", out=np.empty(virtual_index.shape, dtype=np.int64)
92-
)
93-
hi_ = np.ceil(
94-
virtual_index, casting="unsafe", out=np.empty(virtual_index.shape, dtype=np.int64)
95-
)
96-
kth = np.unique(np.concatenate([lo_.reshape(-1), hi_.reshape(-1)]))
97+
if q is not None:
98+
# This is numpy's method="linear"
99+
# TODO: could support all the interpolations here
100+
virtual_index = param * (actual_sizes - 1) + inv_idx[:-1]
101+
102+
if is_scalar_param:
103+
virtual_index = virtual_index.squeeze(axis=0)
104+
105+
lo_ = np.floor(
106+
virtual_index, casting="unsafe", out=np.empty(virtual_index.shape, dtype=np.int64)
107+
)
108+
hi_ = np.ceil(
109+
virtual_index, casting="unsafe", out=np.empty(virtual_index.shape, dtype=np.int64)
110+
)
111+
kth = np.unique(np.concatenate([lo_.reshape(-1), hi_.reshape(-1)]))
112+
113+
else:
114+
virtual_index = (actual_sizes - k) + inv_idx[:-1]
115+
kth = np.unique(virtual_index)
116+
kth = kth[kth > 0]
117+
k_offset = np.arange(k).reshape((k,) + (1,) * virtual_index.ndim)
118+
lo_ = k_offset + virtual_index[np.newaxis, ...]
97119

98120
# partition the complex array in-place
99121
labels_broadcast = np.broadcast_to(group_idx, array.shape)
100122
with np.errstate(invalid="ignore"):
101123
cmplx = labels_broadcast + 1j * array
102124
cmplx.partition(kth=kth, axis=-1)
103-
if is_scalar_q:
125+
126+
if is_scalar_param:
104127
a_ = cmplx.imag
105128
else:
106-
a_ = np.broadcast_to(cmplx.imag, (q.shape[0],) + array.shape)
129+
a_ = np.broadcast_to(cmplx.imag, (param.shape[0],) + array.shape)
107130

108-
# get bounds, Broadcast to (num quantiles, ..., num labels)
109131
loval = np.take_along_axis(a_, np.broadcast_to(lo_, idxshape), axis=axis)
110-
hival = np.take_along_axis(a_, np.broadcast_to(hi_, idxshape), axis=axis)
132+
if q is not None:
133+
# get bounds, Broadcast to (num quantiles, ..., num labels)
134+
hival = np.take_along_axis(a_, np.broadcast_to(hi_, idxshape), axis=axis)
135+
136+
# TODO: could support all the interpolations here
137+
gamma = np.broadcast_to(virtual_index, idxshape) - lo_
138+
result = _lerp(loval, hival, t=gamma, out=out, dtype=dtype)
139+
else:
140+
import ipdb
111141

112-
# TODO: could support all the interpolations here
113-
gamma = np.broadcast_to(virtual_index, idxshape) - lo_
114-
result = _lerp(loval, hival, t=gamma, out=out, dtype=dtype)
142+
ipdb.set_trace()
143+
result = loval
144+
result[lo_ < 0] = np.nan
115145
if not skipna and np.any(nanmask):
116146
result[..., nanmask] = np.nan
147+
if k is not None:
148+
result = result.astype(array.dtype, copy=False)
149+
np.copyto(out, result)
117150
return result
118151

119152

@@ -138,10 +171,11 @@ def _np_grouped_op(
138171

139172
if out is None:
140173
q = kwargs.get("q", None)
141-
if q is None:
174+
k = kwargs.get("k", None)
175+
if not q and not k:
142176
out = np.full(array.shape[:-1] + (size,), fill_value=fill_value, dtype=dtype)
143177
else:
144-
nq = len(np.atleast_1d(q))
178+
nq = len(np.atleast_1d(q)) if q is not None else k
145179
out = np.full((nq,) + array.shape[:-1] + (size,), fill_value=fill_value, dtype=dtype)
146180
kwargs["group_idx"] = group_idx
147181

@@ -178,10 +212,11 @@ def _nan_grouped_op(group_idx, array, func, fillna, *args, **kwargs):
178212
nanmax = partial(_nan_grouped_op, func=max, fillna=-np.inf)
179213
min = partial(_np_grouped_op, op=np.minimum.reduceat)
180214
nanmin = partial(_nan_grouped_op, func=min, fillna=np.inf)
181-
quantile = partial(_np_grouped_op, op=partial(quantile_, skipna=False))
182-
nanquantile = partial(_np_grouped_op, op=partial(quantile_, skipna=True))
183-
median = partial(partial(_np_grouped_op, q=0.5), op=partial(quantile_, skipna=False))
184-
nanmedian = partial(partial(_np_grouped_op, q=0.5), op=partial(quantile_, skipna=True))
215+
quantile = partial(_np_grouped_op, op=partial(quantile_or_topk, skipna=False))
216+
topk = partial(_np_grouped_op, op=partial(quantile_or_topk, skipna=True))
217+
nanquantile = partial(_np_grouped_op, op=partial(quantile_or_topk, skipna=True))
218+
median = partial(partial(_np_grouped_op, q=0.5), op=partial(quantile_or_topk, skipna=False))
219+
nanmedian = partial(partial(_np_grouped_op, q=0.5), op=partial(quantile_or_topk, skipna=True))
185220
# TODO: all, any
186221

187222

flox/aggregations.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -554,6 +554,10 @@ def quantile_new_dims_func(q) -> tuple[Dim]:
554554
return (Dim(name="quantile", values=q),)
555555

556556

557+
def topk_new_dims_func(k) -> tuple[Dim]:
558+
return (Dim(name="k", values=np.arange(k)),)
559+
560+
557561
quantile = Aggregation(
558562
name="quantile",
559563
fill_value=dtypes.NA,
@@ -570,6 +574,14 @@ def quantile_new_dims_func(q) -> tuple[Dim]:
570574
final_dtype=np.floating,
571575
new_dims_func=quantile_new_dims_func,
572576
)
577+
topk = Aggregation(
578+
name="topk",
579+
fill_value=dtypes.NINF,
580+
chunk=None,
581+
combine=None,
582+
final_dtype=None,
583+
new_dims_func=topk_new_dims_func,
584+
)
573585
mode = Aggregation(name="mode", fill_value=dtypes.NA, chunk=None, combine=None)
574586
nanmode = Aggregation(name="nanmode", fill_value=dtypes.NA, chunk=None, combine=None)
575587

@@ -769,6 +781,7 @@ def scan_binary_op(left_state: ScanState, right_state: ScanState, *, agg: Scan)
769781
"nanquantile": nanquantile,
770782
"mode": mode,
771783
"nanmode": nanmode,
784+
"topk": topk,
772785
# "cumsum": cumsum,
773786
"nancumsum": nancumsum,
774787
"ffill": ffill,

flox/core.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
_initialize_aggregation,
4343
generic_aggregate,
4444
quantile_new_dims_func,
45+
topk_new_dims_func,
4546
)
4647
from .cache import memoize
4748
from .xrutils import (
@@ -1081,6 +1082,10 @@ def chunk_reduce(
10811082
new_dims_shape = tuple(
10821083
dim.size for dim in quantile_new_dims_func(**kw) if not dim.is_scalar
10831084
)
1085+
elif reduction == "topk":
1086+
new_dims_shape = tuple(
1087+
dim.size for dim in topk_new_dims_func(**kw) if not dim.is_scalar
1088+
)
10841089
else:
10851090
new_dims_shape = tuple()
10861091
result = result.reshape(new_dims_shape + final_array_shape[:-1] + found_groups_shape)
@@ -2205,7 +2210,7 @@ def _choose_engine(by, agg: Aggregation):
22052210

22062211
not_arg_reduce = not _is_arg_reduction(agg)
22072212

2208-
if agg.name in ["quantile", "nanquantile", "median", "nanmedian"]:
2213+
if agg.name in ["quantile", "nanquantile", "median", "nanmedian", "topk"]:
22092214
logger.debug(f"_choose_engine: Choosing 'flox' since {agg.name}")
22102215
return "flox"
22112216

@@ -2258,7 +2263,7 @@ def groupby_reduce(
22582263
equality check are for dimensions of size 1 in `by`.
22592264
func : {"all", "any", "count", "sum", "nansum", "mean", "nanmean", \
22602265
"max", "nanmax", "min", "nanmin", "argmax", "nanargmax", "argmin", "nanargmin", \
2261-
"quantile", "nanquantile", "median", "nanmedian", "mode", "nanmode", \
2266+
"quantile", "nanquantile", "median", "nanmedian", "topk", "mode", "nanmode", \
22622267
"first", "nanfirst", "last", "nanlast"} or Aggregation
22632268
Single function name or an Aggregation instance
22642269
expected_groups : (optional) Sequence

flox/xarray.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,13 @@
99
from packaging.version import Version
1010
from xarray.core.duck_array_ops import _datetime_nanmin
1111

12-
from .aggregations import Aggregation, Dim, _atleast_1d, quantile_new_dims_func
12+
from .aggregations import (
13+
Aggregation,
14+
Dim,
15+
_atleast_1d,
16+
quantile_new_dims_func,
17+
topk_new_dims_func,
18+
)
1319
from .core import (
1420
_convert_expected_groups_to_index,
1521
_get_expected_groups,
@@ -92,7 +98,7 @@ def xarray_reduce(
9298
Variables with which to group by ``obj``
9399
func : {"all", "any", "count", "sum", "nansum", "mean", "nanmean", \
94100
"max", "nanmax", "min", "nanmin", "argmax", "nanargmax", "argmin", "nanargmin", \
95-
"quantile", "nanquantile", "median", "nanmedian", "mode", "nanmode", \
101+
"quantile", "nanquantile", "median", "nanmedian", "topk", "mode", "nanmode", \
96102
"first", "nanfirst", "last", "nanlast"} or Aggregation
97103
Single function name or an Aggregation instance
98104
expected_groups : str or sequence
@@ -390,17 +396,18 @@ def wrapper(array, *by, func, skipna, core_dims, **kwargs):
390396

391397
result, *groups = groupby_reduce(array, *by, func=func, **kwargs)
392398

393-
# Transpose the new quantile dimension to the end. This is ugly.
399+
# Transpose the new quantile or topk dimension to the end. This is ugly.
394400
# but new core dimensions are expected at the end :/
395401
# but groupby_reduce inserts them at the beginning
396402
if func in ["quantile", "nanquantile"]:
397403
(newdim,) = quantile_new_dims_func(**finalize_kwargs)
398-
if not newdim.is_scalar:
399-
# NOTE: _restore_dim_order will move any new dims to the end anyway.
400-
# This transpose is simply makes it easy to specify output_core_dims
401-
# output dim order: (*broadcast_dims, *group_dims, quantile_dim)
402-
result = np.moveaxis(result, 0, -1)
403-
404+
elif func == "topk":
405+
(newdim,) = topk_new_dims_func(**finalize_kwargs)
406+
if not newdim.is_scalar:
407+
# NOTE: _restore_dim_order will move any new dims to the end anyway.
408+
# This transpose is simply makes it easy to specify output_core_dims
409+
# output dim order: (*broadcast_dims, *group_dims, quantile_dim)
410+
result = np.moveaxis(result, 0, -1)
404411
# Output of count has an int dtype.
405412
if requires_numeric and func != "count":
406413
if is_npdatetime:

0 commit comments

Comments
 (0)