Skip to content

Commit cdb2417

Browse files
authored
Optimize quantile. (#409)
1 parent 3853101 commit cdb2417

File tree

1 file changed

+25
-18
lines changed

1 file changed

+25
-18
lines changed

flox/aggregate_flox.py

Lines changed: 25 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -50,36 +50,39 @@ def _lerp(a, b, *, t, dtype, out=None):
5050
def quantile_(array, inv_idx, *, q, axis, skipna, group_idx, dtype=None, out=None):
5151
inv_idx = np.concatenate((inv_idx, [array.shape[-1]]))
5252

53-
array_nanmask = isnull(array)
54-
actual_sizes = np.add.reduceat(~array_nanmask, inv_idx[:-1], axis=axis)
53+
array_validmask = notnull(array)
54+
actual_sizes = np.add.reduceat(array_validmask, inv_idx[:-1], axis=axis)
5555
newshape = (1,) * (array.ndim - 1) + (inv_idx.size - 1,)
5656
full_sizes = np.reshape(np.diff(inv_idx), newshape)
5757
nanmask = full_sizes != actual_sizes
5858

5959
# The approach here is to use (complex_array.partition) because
6060
# 1. The full np.lexsort((array, labels), axis=-1) is slow and unnecessary
6161
# 2. Using record_array.partition(..., order=["labels", "array"]) is incredibly slow.
62-
# partition will first sort by real part, then by imaginary part, so it is a two element lex-partition.
63-
# So we set
62+
# 3. For complex arrays, partition will first sort by real part, then by imaginary part, so it is a two element
63+
# lex-partition.
64+
# Therefore we use approach (3) and set
6465
# complex_array = group_idx + 1j * array
65-
# group_idx is an integer (guaranteed), but array can have NaNs. Now,
66-
# 1 + 1j*NaN = NaN + 1j * NaN
67-
# so we must replace all NaNs with the maximum array value in the group so these NaNs
68-
# get sorted to the end.
69-
# Partly inspired by https://krstn.eu/np.nanpercentile()-there-has-to-be-a-faster-way/
70-
# TODO: Don't know if this array has been copied in _prepare_for_flox. This is potentially wasteful
71-
array = np.where(array_nanmask, -np.inf, array)
72-
maxes = np.maximum.reduceat(array, inv_idx[:-1], axis=axis)
73-
replacement = np.repeat(maxes, np.diff(inv_idx), axis=axis)
74-
array[array_nanmask] = replacement[array_nanmask]
75-
66+
# group_idx is an integer (guaranteed), but array can have NaNs.
67+
# Now the sort order of np.nan is bigger than np.inf
68+
# >>> c = (np.array([0, 1, 2, np.nan]) + np.array([np.nan, 2, 3, 4]) * 1j)
69+
# >>> c.partition(2)
70+
# >>> c
71+
# array([ 1. +2.j, 2. +3.j, nan +4.j, nan+nanj])
72+
# So we determine which indices we need using the fact that NaNs get sorted to the end.
73+
# This *was* partly inspired by https://krstn.eu/np.nanpercentile()-there-has-to-be-a-faster-way/
74+
# but not any more now that I use partition and avoid replacing NaNs
7675
qin = q
7776
q = np.atleast_1d(qin)
7877
q = np.reshape(q, (len(q),) + (1,) * array.ndim)
7978

8079
# This is numpy's method="linear"
8180
# TODO: could support all the interpolations here
82-
virtual_index = q * (actual_sizes - 1) + inv_idx[:-1]
81+
offset = actual_sizes.cumsum(axis=-1)
82+
actual_sizes -= 1
83+
virtual_index = q * actual_sizes
84+
# virtual_index is relative to group starts, so now offset that
85+
virtual_index[..., 1:] += offset[..., :-1]
8386

8487
is_scalar_q = is_scalar(qin)
8588
if is_scalar_q:
@@ -103,7 +106,10 @@ def quantile_(array, inv_idx, *, q, axis, skipna, group_idx, dtype=None, out=Non
103106
# partition the complex array in-place
104107
labels_broadcast = np.broadcast_to(group_idx, array.shape)
105108
with np.errstate(invalid="ignore"):
106-
cmplx = labels_broadcast + 1j * (array.view(int) if array.dtype.kind in "Mm" else array)
109+
cmplx = 1j * (array.view(int) if array.dtype.kind in "Mm" else array)
110+
# This is a very intentional way of handling `array` with -inf/+inf values :/
111+
# a simple (labels + 1j * array) will yield `nan+inf * 1j` instead of `0 + inf * j`
112+
cmplx.real = labels_broadcast
107113
cmplx.partition(kth=kth, axis=-1)
108114
if is_scalar_q:
109115
a_ = cmplx.imag
@@ -145,7 +151,8 @@ def _np_grouped_op(
145151
(inv_idx,) = flag.nonzero()
146152

147153
if size is None:
148-
size = np.max(uniques) + 1
154+
# This is sorted, so the last value is the largest label
155+
size = uniques[-1] + 1
149156
if dtype is None:
150157
dtype = array.dtype
151158

0 commit comments

Comments
 (0)