Skip to content

Add topk #374

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 59 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 25 commits
Commits
Show all changes
59 commits
Select commit Hold shift + click to select a range
650088b
Add topk
dcherian Jul 27, 2024
889be0c
Negative k
dcherian Jul 28, 2024
996ff2a
dask support
dcherian Jul 28, 2024
776d233
test
dcherian Jul 28, 2024
a5eb7b9
wip
dcherian Jul 28, 2024
4fa9a4c
fix
dcherian Jul 28, 2024
4b04fde
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 28, 2024
93800aa
Handle dtypes.NA properly for datetime/timedelta
dcherian Jul 31, 2024
80c67f4
Fix
dcherian Jul 31, 2024
7056d18
Merge branch 'main' into topk
dcherian Aug 7, 2024
44f5f3f
Merge branch 'main' into topk
dcherian Jan 7, 2025
c924017
Fixes
dcherian Jan 7, 2025
7a794ba
one more fix
dcherian Jan 7, 2025
eec4dd4
fix
dcherian Jan 7, 2025
6ac9a1f
one more fix
dcherian Jan 7, 2025
83594e8
Fixes.
dcherian Jan 7, 2025
740f85f
WIP
dcherian Jan 7, 2025
5d64fd9
Merge branch 'main' into topk
dcherian Jan 7, 2025
e177efd
fixes
dcherian Jan 7, 2025
9393470
fix
dcherian Jan 7, 2025
17eb915
cleanup
dcherian Jan 7, 2025
dc0df3e
works?
dcherian Jan 7, 2025
83ae5d8
fix quantile
dcherian Jan 7, 2025
95d20b8
optimize xrutils.topk
dcherian Jan 7, 2025
0b9fafc
Merge branch 'main' into topk
dcherian Jan 8, 2025
caa98b8
Update tests/test_properties.py
dcherian Jan 8, 2025
820d46c
generalize new_dims_func
dcherian Jan 13, 2025
17a4d5d
Merge branch 'main' into topk
dcherian Jan 13, 2025
6aa923a
Revert "generalize new_dims_func"
dcherian Jan 13, 2025
16b0bac
Merge branch 'main' into topk
dcherian Jan 13, 2025
2c6d486
Support bool
dcherian Jan 13, 2025
0dcd87c
more skipping
dcherian Jan 13, 2025
9b874ea
fix
dcherian Jan 14, 2025
adebbec
more xfail
dcherian Jan 15, 2025
ace2af5
Merge branch 'main' into topk
dcherian Jan 19, 2025
4f35230
cleanup
dcherian Jan 19, 2025
cd2f150
one more xfail
dcherian Jan 19, 2025
70e6f22
typing
dcherian Jan 19, 2025
5d45603
minor docs
dcherian Jan 19, 2025
096f6b9
disable log in CI
dcherian Jan 19, 2025
0277cb9
Fix boolean
dcherian Jan 19, 2025
6c7e84a
bool -> bool_
dcherian Jan 20, 2025
43c3408
update int limits
dcherian Jan 20, 2025
01eabfb
fix rtd
dcherian Jan 20, 2025
6e4ce69
Add note
dcherian Jan 20, 2025
4500c7e
Merge branch 'main' into topk
dcherian Jan 24, 2025
8f60477
Add unit test
dcherian Jan 24, 2025
15fcfa1
WIP
dcherian Jan 24, 2025
a5bcc5b
fix
dcherian Jan 24, 2025
489c843
Merge branch 'main' into topk
dcherian Mar 18, 2025
91e1d07
Switch DUMMY_AXIS to 0
dcherian Mar 18, 2025
2d868fe
More support for edge cases
dcherian Mar 18, 2025
d244d60
minor
dcherian Mar 18, 2025
8319f7f
[WIP] failing test
dcherian Mar 18, 2025
d21eec5
Merge branch 'main' into topk
dcherian Jul 16, 2025
dfb1e88
fix expected
dcherian Mar 26, 2025
8b31f5d
Revert "[WIP] failing test"
dcherian Mar 26, 2025
fce4f2b
[revert] failing test
dcherian Mar 26, 2025
0f7ee05
fix
dcherian Jul 16, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 89 additions & 46 deletions flox/aggregate_flox.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,14 +47,32 @@ def _lerp(a, b, *, t, dtype, out=None):
return out


def quantile_(array, inv_idx, *, q, axis, skipna, group_idx, dtype=None, out=None):
def quantile_or_topk(
array,
inv_idx,
*,
q=None,
k=None,
axis,
skipna,
group_idx,
dtype=None,
out=None,
fill_value=None,
):
assert q is not None or k is not None
assert axis == -1

inv_idx = np.concatenate((inv_idx, [array.shape[-1]]))

array_validmask = notnull(array)
actual_sizes = np.add.reduceat(array_validmask, inv_idx[:-1], axis=axis)
newshape = (1,) * (array.ndim - 1) + (inv_idx.size - 1,)
full_sizes = np.reshape(np.diff(inv_idx), newshape)
nanmask = full_sizes != actual_sizes
if k is not None:
nanmask = actual_sizes < abs(k)
else:
full_sizes = np.reshape(np.diff(inv_idx), newshape)
nanmask = full_sizes != actual_sizes

# The approach here is to use (complex_array.partition) because
# 1. The full np.lexsort((array, labels), axis=-1) is slow and unnecessary
Expand All @@ -72,36 +90,44 @@ def quantile_(array, inv_idx, *, q, axis, skipna, group_idx, dtype=None, out=Non
# So we determine which indices we need using the fact that NaNs get sorted to the end.
# This *was* partly inspired by https://krstn.eu/np.nanpercentile()-there-has-to-be-a-faster-way/
# but not any more now that I use partition and avoid replacing NaNs
qin = q
q = np.atleast_1d(qin)
q = np.reshape(q, (len(q),) + (1,) * array.ndim)
if k is not None:
is_scalar_param = False
param = np.arange(abs(k))
else:
is_scalar_param = is_scalar(q)
param = np.atleast_1d(q)
param = np.reshape(param, (param.size,) + (1,) * array.ndim)

# This is numpy's method="linear"
# TODO: could support all the interpolations here
offset = actual_sizes.cumsum(axis=-1)
actual_sizes -= 1
virtual_index = q * actual_sizes
# virtual_index is relative to group starts, so now offset that
virtual_index[..., 1:] += offset[..., :-1]

is_scalar_q = is_scalar(qin)
if is_scalar_q:
virtual_index = virtual_index.squeeze(axis=0)
idxshape = array.shape[:-1] + (actual_sizes.shape[-1],)
else:
idxshape = (q.shape[0],) + array.shape[:-1] + (actual_sizes.shape[-1],)
# For topk(.., k=+1 or -1), we always return the singleton dimension.
idxshape = (param.shape[0],) + array.shape[:-1] + (actual_sizes.shape[-1],)

lo_ = np.floor(
virtual_index,
casting="unsafe",
out=np.empty(virtual_index.shape, dtype=np.int64),
)
hi_ = np.ceil(
virtual_index,
casting="unsafe",
out=np.empty(virtual_index.shape, dtype=np.int64),
)
kth = np.unique(np.concatenate([lo_.reshape(-1), hi_.reshape(-1)]))
if q is not None:
# This is numpy's method="linear"
# TODO: could support all the interpolations here
actual_sizes -= 1
virtual_index = param * actual_sizes
# virtual_index is relative to group starts, so now offset that
virtual_index[..., 1:] += offset[..., :-1]

if is_scalar_param:
virtual_index = virtual_index.squeeze(axis=0)
idxshape = array.shape[:-1] + (actual_sizes.shape[-1],)

lo_ = np.floor(virtual_index, casting="unsafe", out=np.empty(virtual_index.shape, dtype=np.int64))
hi_ = np.ceil(virtual_index, casting="unsafe", out=np.empty(virtual_index.shape, dtype=np.int64))
kth = np.unique(np.concatenate([lo_.reshape(-1), hi_.reshape(-1)]))

else:
virtual_index = (actual_sizes - k) if k > 0 else (np.zeros_like(actual_sizes) + abs(k) - 1)
# virtual_index is relative to group starts, so now offset that
virtual_index[..., 1:] += offset[..., :-1]
kth = np.unique(virtual_index)
kth = kth[kth >= 0]
k_offset = param.reshape((abs(k),) + (1,) * virtual_index.ndim)
lo_ = k_offset + virtual_index[np.newaxis, ...]

# partition the complex array in-place
labels_broadcast = np.broadcast_to(group_idx, array.shape)
Expand All @@ -111,20 +137,34 @@ def quantile_(array, inv_idx, *, q, axis, skipna, group_idx, dtype=None, out=Non
# a simple (labels + 1j * array) will yield `nan+inf * 1j` instead of `0 + inf * j`
cmplx.real = labels_broadcast
cmplx.partition(kth=kth, axis=-1)
if is_scalar_q:
a_ = cmplx.imag
else:
a_ = np.broadcast_to(cmplx.imag, (q.shape[0],) + array.shape)

# get bounds, Broadcast to (num quantiles, ..., num labels)
loval = np.take_along_axis(a_, np.broadcast_to(lo_, idxshape), axis=axis)
hival = np.take_along_axis(a_, np.broadcast_to(hi_, idxshape), axis=axis)
a_ = cmplx.imag
if not is_scalar_param:
a_ = np.broadcast_to(cmplx.imag, (param.shape[0],) + array.shape)

# TODO: could support all the interpolations here
gamma = np.broadcast_to(virtual_index, idxshape) - lo_
result = _lerp(loval, hival, t=gamma, out=out, dtype=dtype)
if not skipna and np.any(nanmask):
result[..., nanmask] = np.nan
if array.dtype.kind in "Mm":
a_ = a_.view(array.dtype)

loval = np.take_along_axis(a_, np.broadcast_to(lo_, idxshape), axis=axis)
if q is not None:
# get bounds, Broadcast to (num quantiles, ..., num labels)
hival = np.take_along_axis(a_, np.broadcast_to(hi_, idxshape), axis=axis)

# TODO: could support all the interpolations here
gamma = np.broadcast_to(virtual_index, idxshape) - lo_
result = _lerp(loval, hival, t=gamma, out=out, dtype=dtype)
if not skipna and np.any(nanmask):
result[..., nanmask] = fill_value
else:
result = loval
# The first clause is True if numel in group < abs(k)
badmask = np.broadcast_to(lo_ < 0, idxshape) | nanmask
result[badmask] = fill_value

if k is not None:
result = result.astype(dtype, copy=False)
if out is not None:
np.copyto(out, result)
return result


Expand Down Expand Up @@ -158,12 +198,14 @@ def _np_grouped_op(

if out is None:
q = kwargs.get("q", None)
if q is None:
k = kwargs.get("k", None)
if q is None and k is None:
out = np.full(array.shape[:-1] + (size,), fill_value=fill_value, dtype=dtype)
else:
nq = len(np.atleast_1d(q))
nq = len(np.atleast_1d(q)) if q is not None else abs(k)
out = np.full((nq,) + array.shape[:-1] + (size,), fill_value=fill_value, dtype=dtype)
kwargs["group_idx"] = group_idx
kwargs["fill_value"] = fill_value

if (len(uniques) == size) and (uniques == np.arange(size, like=array)).all():
# The previous version of this if condition
Expand Down Expand Up @@ -200,10 +242,11 @@ def _nan_grouped_op(group_idx, array, func, fillna, *args, **kwargs):
nanmax = partial(_nan_grouped_op, func=max, fillna=dtypes.NINF)
min = partial(_np_grouped_op, op=np.minimum.reduceat)
nanmin = partial(_nan_grouped_op, func=min, fillna=dtypes.INF)
quantile = partial(_np_grouped_op, op=partial(quantile_, skipna=False))
nanquantile = partial(_np_grouped_op, op=partial(quantile_, skipna=True))
median = partial(partial(_np_grouped_op, q=0.5), op=partial(quantile_, skipna=False))
nanmedian = partial(partial(_np_grouped_op, q=0.5), op=partial(quantile_, skipna=True))
topk = partial(_np_grouped_op, op=partial(quantile_or_topk, skipna=True))
quantile = partial(_np_grouped_op, op=partial(quantile_or_topk, skipna=False))
nanquantile = partial(_np_grouped_op, op=partial(quantile_or_topk, skipna=True))
median = partial(partial(_np_grouped_op, q=0.5), op=partial(quantile_or_topk, skipna=False))
nanmedian = partial(partial(_np_grouped_op, q=0.5), op=partial(quantile_or_topk, skipna=True))
# TODO: all, any


Expand Down
29 changes: 26 additions & 3 deletions flox/aggregations.py
Original file line number Diff line number Diff line change
Expand Up @@ -551,6 +551,10 @@ def quantile_new_dims_func(q) -> tuple[Dim]:
return (Dim(name="quantile", values=q),)


def topk_new_dims_func(k) -> tuple[Dim]:
return (Dim(name="k", values=np.arange(abs(k))),)


# if the input contains integers or floats smaller than float64,
# the output data-type is float64. Otherwise, the output data-type is the same as that
# of the input.
Expand All @@ -572,6 +576,15 @@ def quantile_new_dims_func(q) -> tuple[Dim]:
)
mode = Aggregation(name="mode", fill_value=dtypes.NA, chunk=None, combine=None, preserves_dtype=True)
nanmode = Aggregation(name="nanmode", fill_value=dtypes.NA, chunk=None, combine=None, preserves_dtype=True)
topk = Aggregation(
name="topk",
fill_value=(dtypes.NINF, 0),
final_fill_value=dtypes.NA,
chunk=("topk", "nanlen"),
combine=(xrutils.topk, "sum"),
new_dims_func=topk_new_dims_func,
preserves_dtype=True,
)


@dataclass
Expand Down Expand Up @@ -769,6 +782,7 @@ def scan_binary_op(left_state: ScanState, right_state: ScanState, *, agg: Scan)
"nanquantile": nanquantile,
"mode": mode,
"nanmode": nanmode,
"topk": topk,
# "cumsum": cumsum,
"nancumsum": nancumsum,
"ffill": ffill,
Expand Down Expand Up @@ -823,6 +837,12 @@ def _initialize_aggregation(
),
}

if finalize_kwargs is not None:
assert isinstance(finalize_kwargs, dict)
agg.finalize_kwargs = finalize_kwargs

if agg.name == "topk" and agg.finalize_kwargs["k"] < 0:
agg.fill_value["intermediate"] = (dtypes.INF, 0)
# Replace sentinel fill values according to dtype
agg.fill_value["user"] = fill_value
agg.fill_value["intermediate"] = tuple(
Expand All @@ -838,9 +858,8 @@ def _initialize_aggregation(
else:
agg.fill_value["numpy"] = (fv,)

if finalize_kwargs is not None:
assert isinstance(finalize_kwargs, dict)
agg.finalize_kwargs = finalize_kwargs
if agg.name == "topk":
min_count = max(min_count or 0, abs(agg.finalize_kwargs["k"]))

# This is needed for the dask pathway.
# Because we use intermediate fill_value since a group could be
Expand Down Expand Up @@ -878,6 +897,10 @@ def _initialize_aggregation(
else:
simple_combine.append(getattr(np, combine))
else:
# TODO: bah, we need to pass `k` to the combine topk function
# this is ugly.
if agg.name == "topk" and not isinstance(combine, str):
combine = partial(combine, **agg.finalize_kwargs)
simple_combine.append(combine)

agg.simple_combine = tuple(simple_combine)
Expand Down
16 changes: 13 additions & 3 deletions flox/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
_initialize_aggregation,
generic_aggregate,
quantile_new_dims_func,
topk_new_dims_func,
)
from .cache import memoize
from .xrutils import (
Expand Down Expand Up @@ -970,7 +971,7 @@ def chunk_reduce(
nfuncs = len(funcs)
dtypes = _atleast_1d(dtype, nfuncs)
fill_values = _atleast_1d(fill_value, nfuncs)
kwargss = _atleast_1d({}, nfuncs) if kwargs is None else kwargs
kwargss = _atleast_1d({} if kwargs is None else kwargs, nfuncs)

if isinstance(axis, Sequence):
axes: T_Axes = axis
Expand Down Expand Up @@ -1091,6 +1092,8 @@ def chunk_reduce(
# TODO: Figure out how to generalize this
if reduction in ("quantile", "nanquantile"):
new_dims_shape = tuple(dim.size for dim in quantile_new_dims_func(**kw) if not dim.is_scalar)
elif reduction == "topk":
new_dims_shape = tuple(dim.size for dim in topk_new_dims_func(**kw) if not dim.is_scalar)
else:
new_dims_shape = tuple()
result = result.reshape(new_dims_shape + final_array_shape[:-1] + found_groups_shape)
Expand Down Expand Up @@ -1147,6 +1150,7 @@ def _finalize_results(
if count_mask.any():
# For one count_mask.any() prevents promoting bool to dtype(fill_value) unless
# necessary
fill_value = fill_value or agg.fill_value[agg.name]
if fill_value is None:
raise ValueError("Filling is required but fill_value is None.")
# This allows us to match xarray's type promotion rules
Expand Down Expand Up @@ -1649,6 +1653,9 @@ def dask_groupby_agg(
# use the "non dask" code path, but applied blockwise
blockwise_method = partial(_reduce_blockwise, agg=agg, fill_value=fill_value, reindex=reindex)
else:
extra = {}
if agg.name == "topk":
extra["kwargs"] = (agg.finalize_kwargs, *(({},) * (len(agg.chunk) - 1)))
# choose `chunk_reduce` or `chunk_argreduce`
blockwise_method = partial(
_get_chunk_reduction(agg.reduction_type),
Expand All @@ -1657,6 +1664,7 @@ def dask_groupby_agg(
dtype=agg.dtype["intermediate"],
reindex=reindex,
user_dtype=agg.dtype["user"],
**extra,
)
if do_simple_combine:
# Add a dummy dimension that then gets reduced over
Expand Down Expand Up @@ -2227,7 +2235,7 @@ def _choose_engine(by, agg: Aggregation):

not_arg_reduce = not _is_arg_reduction(agg)

if agg.name in ["quantile", "nanquantile", "median", "nanmedian"]:
if agg.name in ["quantile", "nanquantile", "median", "nanmedian", "topk"]:
logger.debug(f"_choose_engine: Choosing 'flox' since {agg.name}")
return "flox"

Expand Down Expand Up @@ -2278,7 +2286,7 @@ def groupby_reduce(
equality check are for dimensions of size 1 in `by`.
func : {"all", "any", "count", "sum", "nansum", "mean", "nanmean", \
"max", "nanmax", "min", "nanmin", "argmax", "nanargmax", "argmin", "nanargmin", \
"quantile", "nanquantile", "median", "nanmedian", "mode", "nanmode", \
"quantile", "nanquantile", "median", "nanmedian", "topk", "mode", "nanmode", \
"first", "nanfirst", "last", "nanlast"} or Aggregation
Single function name or an Aggregation instance
expected_groups : (optional) Sequence
Expand Down Expand Up @@ -2387,6 +2395,8 @@ def groupby_reduce(
"Use engine='flox' instead (it is also much faster), "
"or set engine=None to use the default."
)
if func == "topk" and (finalize_kwargs is None or "k" not in finalize_kwargs):
raise ValueError("Please pass `k` in ``finalize_kwargs`` for topk calculations.")

bys: T_Bys = tuple(np.asarray(b) if not is_duck_array(b) else b for b in by)
nby = len(bys)
Expand Down
27 changes: 18 additions & 9 deletions flox/xarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,13 @@
from packaging.version import Version
from xarray.core.duck_array_ops import _datetime_nanmin

from .aggregations import Aggregation, Dim, _atleast_1d, quantile_new_dims_func
from .aggregations import (
Aggregation,
Dim,
_atleast_1d,
quantile_new_dims_func,
topk_new_dims_func,
)
from .core import (
_convert_expected_groups_to_index,
_get_expected_groups,
Expand Down Expand Up @@ -92,7 +98,7 @@ def xarray_reduce(
Variables with which to group by ``obj``
func : {"all", "any", "count", "sum", "nansum", "mean", "nanmean", \
"max", "nanmax", "min", "nanmin", "argmax", "nanargmax", "argmin", "nanargmin", \
"quantile", "nanquantile", "median", "nanmedian", "mode", "nanmode", \
"quantile", "nanquantile", "median", "nanmedian", "topk", "mode", "nanmode", \
"first", "nanfirst", "last", "nanlast"} or Aggregation
Single function name or an Aggregation instance
expected_groups : str or sequence
Expand Down Expand Up @@ -384,17 +390,20 @@ def wrapper(array, *by, func, skipna, core_dims, **kwargs):

result, *groups = groupby_reduce(array, *by, func=func, **kwargs)

# Transpose the new quantile dimension to the end. This is ugly.
# Transpose the new quantile or topk dimension to the end. This is ugly.
# but new core dimensions are expected at the end :/
# but groupby_reduce inserts them at the beginning
if func in ["quantile", "nanquantile"]:
(newdim,) = quantile_new_dims_func(**finalize_kwargs)
if not newdim.is_scalar:
# NOTE: _restore_dim_order will move any new dims to the end anyway.
# This transpose is simply makes it easy to specify output_core_dims
# output dim order: (*broadcast_dims, *group_dims, quantile_dim)
result = np.moveaxis(result, 0, -1)

elif func == "topk":
(newdim,) = topk_new_dims_func(**finalize_kwargs)
else:
newdim = None
if newdim is not None and not newdim.is_scalar:
# NOTE: _restore_dim_order will move any new dims to the end anyway.
# This transpose is simply makes it easy to specify output_core_dims
# output dim order: (*broadcast_dims, *group_dims, quantile_dim)
result = np.moveaxis(result, 0, -1)
# Output of count has an int dtype.
if requires_numeric and func != "count":
if is_npdatetime:
Expand Down
Loading
Loading