@@ -370,18 +370,17 @@ def func(x, /, *args, func_str=func_str, **kwargs):
370
370
371
371
setattr (mod , func_str , func )
372
372
373
- # def searchsorted(x1, x2, /, *, side='left', sorter=None):
374
- # if sorter is not None:
375
- # x1 = take(x1, sorter)
376
-
377
- # mask_count = xp.cumulative_sum(xp.astype(x1.mask, xp.int64))
378
- # x1_compressed = x1.data[~x1.mask]
379
- # count = xp.zeros(x1_compressed.size+1, dtype=xp.int64)
380
- # count[:-1] = mask_count[~x1.mask]
381
- # count[-1] = count[-2]
382
- # i = xp.searchsorted(x1_compressed, x2.data, side=side)
383
- # j = i + xp.take(count, i)
384
- # return MArray(j, mask=x2.mask)
373
+ def searchsorted (x1 , x2 , / , * , side = "left" , sorter = None ):
374
+ if sorter is not None :
375
+ x1 = take (x1 , sorter )
376
+
377
+ magnitude_x1 = xp .asarray (x1 .magnitude , copy = True )
378
+ magnitude_x2 = x2 .m_as (x1 .units )
379
+
380
+ magnitude = xp .searchsorted (magnitude_x1 , magnitude_x2 , side = side )
381
+ return ArrayUnitQuantity (magnitude , None )
382
+
383
+ mod .searchsorted = searchsorted
385
384
386
385
# ignore units of condition, convert x2 to units of x1
387
386
def where (condition , x1 , x2 , / ):
@@ -392,7 +391,6 @@ def where(condition, x1, x2, /):
392
391
magnitude = xp .where (condition .magnitude , x1 .magnitude , x2 .m_as (units ))
393
392
return ArrayUnitQuantity (magnitude , units )
394
393
395
- # mod.searchsorted = searchsorted
396
394
mod .where = where
397
395
398
396
# strip_unit_input_output_ufuncs = ["isnan", "isinf", "isfinite", "signbit", "sign"]
0 commit comments