@@ -47,7 +47,17 @@ def _lerp(a, b, *, t, dtype, out=None):
47
47
48
48
49
49
def quantile_or_topk (
50
- array , inv_idx , * , q = None , k = None , axis , skipna , group_idx , dtype = None , out = None
50
+ array ,
51
+ inv_idx ,
52
+ * ,
53
+ q = None ,
54
+ k = None ,
55
+ axis ,
56
+ skipna ,
57
+ group_idx ,
58
+ dtype = None ,
59
+ out = None ,
60
+ fill_value = None ,
51
61
):
52
62
assert q or k
53
63
@@ -81,9 +91,8 @@ def quantile_or_topk(
81
91
82
92
param = q or k
83
93
if k is not None :
84
- assert k > 0
85
94
is_scalar_param = False
86
- param = np .arange (k )
95
+ param = np .arange (abs ( k ) )
87
96
else :
88
97
is_scalar_param = is_scalar (q )
89
98
param = np .atleast_1d (param )
@@ -111,10 +120,10 @@ def quantile_or_topk(
111
120
kth = np .unique (np .concatenate ([lo_ .reshape (- 1 ), hi_ .reshape (- 1 )]))
112
121
113
122
else :
114
- virtual_index = ( actual_sizes - k ) + inv_idx [: - 1 ]
123
+ virtual_index = inv_idx [: - 1 ] + (( actual_sizes - k ) if k > 0 else abs ( k ) - 1 )
115
124
kth = np .unique (virtual_index )
116
125
kth = kth [kth > 0 ]
117
- k_offset = np . arange ( k ). reshape ((k ,) + (1 ,) * virtual_index .ndim )
126
+ k_offset = param . reshape ((abs ( k ) ,) + (1 ,) * virtual_index .ndim )
118
127
lo_ = k_offset + virtual_index [np .newaxis , ...]
119
128
120
129
# partition the complex array in-place
@@ -137,15 +146,12 @@ def quantile_or_topk(
137
146
gamma = np .broadcast_to (virtual_index , idxshape ) - lo_
138
147
result = _lerp (loval , hival , t = gamma , out = out , dtype = dtype )
139
148
else :
140
- import ipdb
141
-
142
- ipdb .set_trace ()
143
149
result = loval
144
- result [lo_ < 0 ] = np . nan
150
+ result [lo_ < 0 ] = fill_value
145
151
if not skipna and np .any (nanmask ):
146
- result [..., nanmask ] = np . nan
152
+ result [..., nanmask ] = fill_value
147
153
if k is not None :
148
- result = result .astype (array . dtype , copy = False )
154
+ result = result .astype (dtype , copy = False )
149
155
np .copyto (out , result )
150
156
return result
151
157
@@ -175,9 +181,10 @@ def _np_grouped_op(
175
181
if not q and not k :
176
182
out = np .full (array .shape [:- 1 ] + (size ,), fill_value = fill_value , dtype = dtype )
177
183
else :
178
- nq = len (np .atleast_1d (q )) if q is not None else k
184
+ nq = len (np .atleast_1d (q )) if q is not None else abs ( k )
179
185
out = np .full ((nq ,) + array .shape [:- 1 ] + (size ,), fill_value = fill_value , dtype = dtype )
180
186
kwargs ["group_idx" ] = group_idx
187
+ kwargs ["fill_value" ] = fill_value
181
188
182
189
if (len (uniques ) == size ) and (uniques == np .arange (size , like = array )).all ():
183
190
# The previous version of this if condition
0 commit comments