@@ -356,6 +356,31 @@ def func(x, /, *args, func_str=func_str, **kwargs):
356
356
357
357
setattr (mod , func_str , func )
358
358
359
+ # def searchsorted(x1, x2, /, *, side='left', sorter=None):
360
+ # if sorter is not None:
361
+ # x1 = take(x1, sorter)
362
+
363
+ # mask_count = xp.cumulative_sum(xp.astype(x1.mask, xp.int64))
364
+ # x1_compressed = x1.data[~x1.mask]
365
+ # count = xp.zeros(x1_compressed.size+1, dtype=xp.int64)
366
+ # count[:-1] = mask_count[~x1.mask]
367
+ # count[-1] = count[-2]
368
+ # i = xp.searchsorted(x1_compressed, x2.data, side=side)
369
+ # j = i + xp.take(count, i)
370
+ # return MArray(j, mask=x2.mask)
371
+
372
+ # ignore units of condition, convert x2 to units of x1
373
+ def where (condition , x1 , x2 , / ):
374
+ condition = asarray (condition )
375
+ x1 = asarray (x1 )
376
+ x2 = asarray (x2 )
377
+ units = x1 .units
378
+ magnitude = xp .where (condition .magnitude , x1 .magnitude , x2 .m_as (units ))
379
+ return ArrayUnitQuantity (magnitude , units )
380
+
381
+ # mod.searchsorted = searchsorted
382
+ mod .where = where
383
+
359
384
# strip_unit_input_output_ufuncs = ["isnan", "isinf", "isfinite", "signbit", "sign"]
360
385
# matching_input_bare_output_ufuncs = [
361
386
# "equal",
0 commit comments