Skip to content

Commit b24a300

Browse files
committed
add searchsorted
1 parent 74fdc68 commit b24a300

File tree

1 file changed

+11
-13
lines changed

1 file changed

+11
-13
lines changed

src/pint_array/__init__.py

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -370,18 +370,17 @@ def func(x, /, *args, func_str=func_str, **kwargs):
370370

371371
setattr(mod, func_str, func)
372372

373-
# def searchsorted(x1, x2, /, *, side='left', sorter=None):
374-
# if sorter is not None:
375-
# x1 = take(x1, sorter)
376-
377-
# mask_count = xp.cumulative_sum(xp.astype(x1.mask, xp.int64))
378-
# x1_compressed = x1.data[~x1.mask]
379-
# count = xp.zeros(x1_compressed.size+1, dtype=xp.int64)
380-
# count[:-1] = mask_count[~x1.mask]
381-
# count[-1] = count[-2]
382-
# i = xp.searchsorted(x1_compressed, x2.data, side=side)
383-
# j = i + xp.take(count, i)
384-
# return MArray(j, mask=x2.mask)
373+
def searchsorted(x1, x2, /, *, side="left", sorter=None):
374+
if sorter is not None:
375+
x1 = take(x1, sorter)
376+
377+
magnitude_x1 = xp.asarray(x1.magnitude, copy=True)
378+
magnitude_x2 = x2.m_as(x1.units)
379+
380+
magnitude = xp.searchsorted(magnitude_x1, magnitude_x2, side=side)
381+
return ArrayUnitQuantity(magnitude, None)
382+
383+
mod.searchsorted = searchsorted
385384

386385
# ignore units of condition, convert x2 to units of x1
387386
def where(condition, x1, x2, /):
@@ -392,7 +391,6 @@ def where(condition, x1, x2, /):
392391
magnitude = xp.where(condition.magnitude, x1.magnitude, x2.m_as(units))
393392
return ArrayUnitQuantity(magnitude, units)
394393

395-
# mod.searchsorted = searchsorted
396394
mod.where = where
397395

398396
# strip_unit_input_output_ufuncs = ["isnan", "isinf", "isfinite", "signbit", "sign"]

0 commit comments

Comments
 (0)