Skip to content

Commit 889be0c

Browse files
committed
Negative k
1 parent 650088b commit 889be0c

File tree

2 files changed

+22
-13
lines changed

2 files changed

+22
-13
lines changed

flox/aggregate_flox.py

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,17 @@ def _lerp(a, b, *, t, dtype, out=None):
4747

4848

4949
def quantile_or_topk(
50-
array, inv_idx, *, q=None, k=None, axis, skipna, group_idx, dtype=None, out=None
50+
array,
51+
inv_idx,
52+
*,
53+
q=None,
54+
k=None,
55+
axis,
56+
skipna,
57+
group_idx,
58+
dtype=None,
59+
out=None,
60+
fill_value=None,
5161
):
5262
assert q or k
5363

@@ -81,9 +91,8 @@ def quantile_or_topk(
8191

8292
param = q or k
8393
if k is not None:
84-
assert k > 0
8594
is_scalar_param = False
86-
param = np.arange(k)
95+
param = np.arange(abs(k))
8796
else:
8897
is_scalar_param = is_scalar(q)
8998
param = np.atleast_1d(param)
@@ -111,10 +120,10 @@ def quantile_or_topk(
111120
kth = np.unique(np.concatenate([lo_.reshape(-1), hi_.reshape(-1)]))
112121

113122
else:
114-
virtual_index = (actual_sizes - k) + inv_idx[:-1]
123+
virtual_index = inv_idx[:-1] + ((actual_sizes - k) if k > 0 else abs(k) - 1)
115124
kth = np.unique(virtual_index)
116125
kth = kth[kth > 0]
117-
k_offset = np.arange(k).reshape((k,) + (1,) * virtual_index.ndim)
126+
k_offset = param.reshape((abs(k),) + (1,) * virtual_index.ndim)
118127
lo_ = k_offset + virtual_index[np.newaxis, ...]
119128

120129
# partition the complex array in-place
@@ -137,15 +146,12 @@ def quantile_or_topk(
137146
gamma = np.broadcast_to(virtual_index, idxshape) - lo_
138147
result = _lerp(loval, hival, t=gamma, out=out, dtype=dtype)
139148
else:
140-
import ipdb
141-
142-
ipdb.set_trace()
143149
result = loval
144-
result[lo_ < 0] = np.nan
150+
result[lo_ < 0] = fill_value
145151
if not skipna and np.any(nanmask):
146-
result[..., nanmask] = np.nan
152+
result[..., nanmask] = fill_value
147153
if k is not None:
148-
result = result.astype(array.dtype, copy=False)
154+
result = result.astype(dtype, copy=False)
149155
np.copyto(out, result)
150156
return result
151157

@@ -175,9 +181,10 @@ def _np_grouped_op(
175181
if not q and not k:
176182
out = np.full(array.shape[:-1] + (size,), fill_value=fill_value, dtype=dtype)
177183
else:
178-
nq = len(np.atleast_1d(q)) if q is not None else k
184+
nq = len(np.atleast_1d(q)) if q is not None else abs(k)
179185
out = np.full((nq,) + array.shape[:-1] + (size,), fill_value=fill_value, dtype=dtype)
180186
kwargs["group_idx"] = group_idx
187+
kwargs["fill_value"] = fill_value
181188

182189
if (len(uniques) == size) and (uniques == np.arange(size, like=array)).all():
183190
# The previous version of this if condition

flox/aggregations.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -564,7 +564,7 @@ def quantile_new_dims_func(q) -> tuple[Dim]:
564564

565565

566566
def topk_new_dims_func(k) -> tuple[Dim]:
567-
return (Dim(name="k", values=np.arange(k)),)
567+
return (Dim(name="k", values=np.arange(abs(k))),)
568568

569569

570570
quantile = Aggregation(
@@ -848,6 +848,8 @@ def _initialize_aggregation(
848848
),
849849
}
850850

851+
if agg.name == "topk" and finalize_kwargs["k"] < 0:
852+
agg.fill_value["intermediate"] = (dtypes.INF,)
851853
# Replace sentinel fill values according to dtype
852854
agg.fill_value["user"] = fill_value
853855
agg.fill_value["intermediate"] = tuple(

0 commit comments

Comments
 (0)