@@ -46,74 +46,107 @@ def _lerp(a, b, *, t, dtype, out=None):
46
46
return out
47
47
48
48
49
- def quantile_ (array , inv_idx , * , q , axis , skipna , group_idx , dtype = None , out = None ):
50
- inv_idx = np .concatenate ((inv_idx , [array .shape [- 1 ]]))
49
+ def quantile_or_topk (
50
+ array , inv_idx , * , q = None , k = None , axis , skipna , group_idx , dtype = None , out = None
51
+ ):
52
+ assert q or k
51
53
52
- array_nanmask = isnull (array )
53
- actual_sizes = np .add .reduceat (~ array_nanmask , inv_idx [:- 1 ], axis = axis )
54
- newshape = (1 ,) * (array .ndim - 1 ) + (inv_idx .size - 1 ,)
55
- full_sizes = np .reshape (np .diff (inv_idx ), newshape )
56
- nanmask = full_sizes != actual_sizes
54
+ inv_idx = np .concatenate ((inv_idx , [array .shape [- 1 ]]))
57
55
58
- # The approach here is to use (complex_array.partition) because
56
+ # The approach for quantiles and topk, both of which are basically grouped partition,
57
+ # here is to use (complex_array.partition) because
59
58
# 1. The full np.lexsort((array, labels), axis=-1) is slow and unnecessary
60
59
# 2. Using record_array.partition(..., order=["labels", "array"]) is incredibly slow.
61
- # partition will first sort by real part, then by imaginary part, so it is a two element lex-partition.
62
- # So we set
60
+ # partition will first sort by real part, then by imaginary part, so it is a two element
61
+ # lex-partition. Therefore we set
63
62
# complex_array = group_idx + 1j * array
64
63
# group_idx is an integer (guaranteed), but array can have NaNs. Now,
65
64
# 1 + 1j*NaN = NaN + 1j * NaN
66
65
# so we must replace all NaNs with the maximum array value in the group so these NaNs
67
66
# get sorted to the end.
67
+
68
+ # Replace NaNs with the maximum value for each group.
68
69
# Partly inspired by https://krstn.eu/np.nanpercentile()-there-has-to-be-a-faster-way/
69
- # TODO: Don't know if this array has been copied in _prepare_for_flox. This is potentially wasteful
70
+ array_nanmask = isnull (array )
71
+ actual_sizes = np .add .reduceat (~ array_nanmask , inv_idx [:- 1 ], axis = axis )
72
+ newshape = (1 ,) * (array .ndim - 1 ) + (inv_idx .size - 1 ,)
73
+ full_sizes = np .reshape (np .diff (inv_idx ), newshape )
74
+ nanmask = full_sizes != actual_sizes
75
+ # TODO: Don't know if this array has been copied in _prepare_for_flox.
76
+ # This is potentially wasteful
70
77
array = np .where (array_nanmask , - np .inf , array )
71
78
maxes = np .maximum .reduceat (array , inv_idx [:- 1 ], axis = axis )
72
79
replacement = np .repeat (maxes , np .diff (inv_idx ), axis = axis )
73
80
array [array_nanmask ] = replacement [array_nanmask ]
74
81
75
- qin = q
76
- q = np .atleast_1d (qin )
77
- q = np .reshape (q , (len (q ),) + (1 ,) * array .ndim )
78
-
79
- # This is numpy's method="linear"
80
- # TODO: could support all the interpolations here
81
- virtual_index = q * (actual_sizes - 1 ) + inv_idx [:- 1 ]
82
+ param = q or k
83
+ if k is not None :
84
+ assert k > 0
85
+ is_scalar_param = False
86
+ param = np .arange (k )
87
+ else :
88
+ is_scalar_param = is_scalar (q )
89
+ param = np .atleast_1d (param )
90
+ param = np .reshape (param , (param .size ,) + (1 ,) * array .ndim )
82
91
83
- is_scalar_q = is_scalar (qin )
84
- if is_scalar_q :
85
- virtual_index = virtual_index .squeeze (axis = 0 )
92
+ if is_scalar_param :
86
93
idxshape = array .shape [:- 1 ] + (actual_sizes .shape [- 1 ],)
87
94
else :
88
- idxshape = (q .shape [0 ],) + array .shape [:- 1 ] + (actual_sizes .shape [- 1 ],)
95
+ idxshape = (param .shape [0 ],) + array .shape [:- 1 ] + (actual_sizes .shape [- 1 ],)
89
96
90
- lo_ = np .floor (
91
- virtual_index , casting = "unsafe" , out = np .empty (virtual_index .shape , dtype = np .int64 )
92
- )
93
- hi_ = np .ceil (
94
- virtual_index , casting = "unsafe" , out = np .empty (virtual_index .shape , dtype = np .int64 )
95
- )
96
- kth = np .unique (np .concatenate ([lo_ .reshape (- 1 ), hi_ .reshape (- 1 )]))
97
+ if q is not None :
98
+ # This is numpy's method="linear"
99
+ # TODO: could support all the interpolations here
100
+ virtual_index = param * (actual_sizes - 1 ) + inv_idx [:- 1 ]
101
+
102
+ if is_scalar_param :
103
+ virtual_index = virtual_index .squeeze (axis = 0 )
104
+
105
+ lo_ = np .floor (
106
+ virtual_index , casting = "unsafe" , out = np .empty (virtual_index .shape , dtype = np .int64 )
107
+ )
108
+ hi_ = np .ceil (
109
+ virtual_index , casting = "unsafe" , out = np .empty (virtual_index .shape , dtype = np .int64 )
110
+ )
111
+ kth = np .unique (np .concatenate ([lo_ .reshape (- 1 ), hi_ .reshape (- 1 )]))
112
+
113
+ else :
114
+ virtual_index = (actual_sizes - k ) + inv_idx [:- 1 ]
115
+ kth = np .unique (virtual_index )
116
+ kth = kth [kth > 0 ]
117
+ k_offset = np .arange (k ).reshape ((k ,) + (1 ,) * virtual_index .ndim )
118
+ lo_ = k_offset + virtual_index [np .newaxis , ...]
97
119
98
120
# partition the complex array in-place
99
121
labels_broadcast = np .broadcast_to (group_idx , array .shape )
100
122
with np .errstate (invalid = "ignore" ):
101
123
cmplx = labels_broadcast + 1j * array
102
124
cmplx .partition (kth = kth , axis = - 1 )
103
- if is_scalar_q :
125
+
126
+ if is_scalar_param :
104
127
a_ = cmplx .imag
105
128
else :
106
- a_ = np .broadcast_to (cmplx .imag , (q .shape [0 ],) + array .shape )
129
+ a_ = np .broadcast_to (cmplx .imag , (param .shape [0 ],) + array .shape )
107
130
108
- # get bounds, Broadcast to (num quantiles, ..., num labels)
109
131
loval = np .take_along_axis (a_ , np .broadcast_to (lo_ , idxshape ), axis = axis )
110
- hival = np .take_along_axis (a_ , np .broadcast_to (hi_ , idxshape ), axis = axis )
132
+ if q is not None :
133
+ # get bounds, Broadcast to (num quantiles, ..., num labels)
134
+ hival = np .take_along_axis (a_ , np .broadcast_to (hi_ , idxshape ), axis = axis )
135
+
136
+ # TODO: could support all the interpolations here
137
+ gamma = np .broadcast_to (virtual_index , idxshape ) - lo_
138
+ result = _lerp (loval , hival , t = gamma , out = out , dtype = dtype )
139
+ else :
140
+ import ipdb
111
141
112
- # TODO: could support all the interpolations here
113
- gamma = np . broadcast_to ( virtual_index , idxshape ) - lo_
114
- result = _lerp ( loval , hival , t = gamma , out = out , dtype = dtype )
142
+ ipdb . set_trace ()
143
+ result = loval
144
+ result [ lo_ < 0 ] = np . nan
115
145
if not skipna and np .any (nanmask ):
116
146
result [..., nanmask ] = np .nan
147
+ if k is not None :
148
+ result = result .astype (array .dtype , copy = False )
149
+ np .copyto (out , result )
117
150
return result
118
151
119
152
@@ -138,10 +171,11 @@ def _np_grouped_op(
138
171
139
172
if out is None :
140
173
q = kwargs .get ("q" , None )
141
- if q is None :
174
+ k = kwargs .get ("k" , None )
175
+ if not q and not k :
142
176
out = np .full (array .shape [:- 1 ] + (size ,), fill_value = fill_value , dtype = dtype )
143
177
else :
144
- nq = len (np .atleast_1d (q ))
178
+ nq = len (np .atleast_1d (q )) if q is not None else k
145
179
out = np .full ((nq ,) + array .shape [:- 1 ] + (size ,), fill_value = fill_value , dtype = dtype )
146
180
kwargs ["group_idx" ] = group_idx
147
181
@@ -178,10 +212,11 @@ def _nan_grouped_op(group_idx, array, func, fillna, *args, **kwargs):
178
212
nanmax = partial (_nan_grouped_op , func = max , fillna = - np .inf )
179
213
min = partial (_np_grouped_op , op = np .minimum .reduceat )
180
214
nanmin = partial (_nan_grouped_op , func = min , fillna = np .inf )
181
- quantile = partial (_np_grouped_op , op = partial (quantile_ , skipna = False ))
182
- nanquantile = partial (_np_grouped_op , op = partial (quantile_ , skipna = True ))
183
- median = partial (partial (_np_grouped_op , q = 0.5 ), op = partial (quantile_ , skipna = False ))
184
- nanmedian = partial (partial (_np_grouped_op , q = 0.5 ), op = partial (quantile_ , skipna = True ))
215
+ quantile = partial (_np_grouped_op , op = partial (quantile_or_topk , skipna = False ))
216
+ topk = partial (_np_grouped_op , op = partial (quantile_or_topk , skipna = True ))
217
+ nanquantile = partial (_np_grouped_op , op = partial (quantile_or_topk , skipna = True ))
218
+ median = partial (partial (_np_grouped_op , q = 0.5 ), op = partial (quantile_or_topk , skipna = False ))
219
+ nanmedian = partial (partial (_np_grouped_op , q = 0.5 ), op = partial (quantile_or_topk , skipna = True ))
185
220
# TODO: all, any
186
221
187
222
0 commit comments