@@ -50,36 +50,39 @@ def _lerp(a, b, *, t, dtype, out=None):
50
50
def quantile_ (array , inv_idx , * , q , axis , skipna , group_idx , dtype = None , out = None ):
51
51
inv_idx = np .concatenate ((inv_idx , [array .shape [- 1 ]]))
52
52
53
- array_nanmask = isnull (array )
54
- actual_sizes = np .add .reduceat (~ array_nanmask , inv_idx [:- 1 ], axis = axis )
53
+ array_validmask = notnull (array )
54
+ actual_sizes = np .add .reduceat (array_validmask , inv_idx [:- 1 ], axis = axis )
55
55
newshape = (1 ,) * (array .ndim - 1 ) + (inv_idx .size - 1 ,)
56
56
full_sizes = np .reshape (np .diff (inv_idx ), newshape )
57
57
nanmask = full_sizes != actual_sizes
58
58
59
59
# The approach here is to use (complex_array.partition) because
60
60
# 1. The full np.lexsort((array, labels), axis=-1) is slow and unnecessary
61
61
# 2. Using record_array.partition(..., order=["labels", "array"]) is incredibly slow.
62
- # partition will first sort by real part, then by imaginary part, so it is a two element lex-partition.
63
- # So we set
62
+ # 3. For complex arrays, partition will first sort by real part, then by imaginary part, so it is a two element
63
+ # lex-partition.
64
+ # Therefore we use approach (3) and set
64
65
# complex_array = group_idx + 1j * array
65
- # group_idx is an integer (guaranteed), but array can have NaNs. Now,
66
- # 1 + 1j*NaN = NaN + 1j * NaN
67
- # so we must replace all NaNs with the maximum array value in the group so these NaNs
68
- # get sorted to the end.
69
- # Partly inspired by https://krstn.eu/np.nanpercentile()-there-has-to-be-a-faster-way/
70
- # TODO: Don't know if this array has been copied in _prepare_for_flox. This is potentially wasteful
71
- array = np .where (array_nanmask , - np .inf , array )
72
- maxes = np .maximum .reduceat (array , inv_idx [:- 1 ], axis = axis )
73
- replacement = np .repeat (maxes , np .diff (inv_idx ), axis = axis )
74
- array [array_nanmask ] = replacement [array_nanmask ]
75
-
66
+ # group_idx is an integer (guaranteed), but array can have NaNs.
67
+ # Now the sort order of np.nan is bigger than np.inf
68
+ # >>> c = (np.array([0, 1, 2, np.nan]) + np.array([np.nan, 2, 3, 4]) * 1j)
69
+ # >>> c.partition(2)
70
+ # >>> c
71
+ # array([ 1. +2.j, 2. +3.j, nan +4.j, nan+nanj])
72
+ # So we determine which indices we need using the fact that NaNs get sorted to the end.
73
+ # This *was* partly inspired by https://krstn.eu/np.nanpercentile()-there-has-to-be-a-faster-way/
74
+ # but not any more now that I use partition and avoid replacing NaNs
76
75
qin = q
77
76
q = np .atleast_1d (qin )
78
77
q = np .reshape (q , (len (q ),) + (1 ,) * array .ndim )
79
78
80
79
# This is numpy's method="linear"
81
80
# TODO: could support all the interpolations here
82
- virtual_index = q * (actual_sizes - 1 ) + inv_idx [:- 1 ]
81
+ offset = actual_sizes .cumsum (axis = - 1 )
82
+ actual_sizes -= 1
83
+ virtual_index = q * actual_sizes
84
+ # virtual_index is relative to group starts, so now offset that
85
+ virtual_index [..., 1 :] += offset [..., :- 1 ]
83
86
84
87
is_scalar_q = is_scalar (qin )
85
88
if is_scalar_q :
@@ -103,7 +106,10 @@ def quantile_(array, inv_idx, *, q, axis, skipna, group_idx, dtype=None, out=Non
103
106
# partition the complex array in-place
104
107
labels_broadcast = np .broadcast_to (group_idx , array .shape )
105
108
with np .errstate (invalid = "ignore" ):
106
- cmplx = labels_broadcast + 1j * (array .view (int ) if array .dtype .kind in "Mm" else array )
109
+ cmplx = 1j * (array .view (int ) if array .dtype .kind in "Mm" else array )
110
+ # This is a very intentional way of handling `array` with -inf/+inf values :/
111
+ # a simple (labels + 1j * array) will yield `nan+inf * 1j` instead of `0 + inf * j`
112
+ cmplx .real = labels_broadcast
107
113
cmplx .partition (kth = kth , axis = - 1 )
108
114
if is_scalar_q :
109
115
a_ = cmplx .imag
@@ -145,7 +151,8 @@ def _np_grouped_op(
145
151
(inv_idx ,) = flag .nonzero ()
146
152
147
153
if size is None :
148
- size = np .max (uniques ) + 1
154
+ # This is sorted, so the last value is the largest label
155
+ size = uniques [- 1 ] + 1
149
156
if dtype is None :
150
157
dtype = array .dtype
151
158
0 commit comments