Skip to content

Commit c46640b

Browse files
committed
add where
1 parent 029476f commit c46640b

File tree

1 file changed

+25
-0
lines changed

1 file changed

+25
-0
lines changed

src/pint_array/__init__.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -356,6 +356,31 @@ def func(x, /, *args, func_str=func_str, **kwargs):
356356

357357
setattr(mod, func_str, func)
358358

359+
# def searchsorted(x1, x2, /, *, side='left', sorter=None):
360+
# if sorter is not None:
361+
# x1 = take(x1, sorter)
362+
363+
# mask_count = xp.cumulative_sum(xp.astype(x1.mask, xp.int64))
364+
# x1_compressed = x1.data[~x1.mask]
365+
# count = xp.zeros(x1_compressed.size+1, dtype=xp.int64)
366+
# count[:-1] = mask_count[~x1.mask]
367+
# count[-1] = count[-2]
368+
# i = xp.searchsorted(x1_compressed, x2.data, side=side)
369+
# j = i + xp.take(count, i)
370+
# return MArray(j, mask=x2.mask)
371+
372+
# ignore units of condition, convert x2 to units of x1
373+
def where(condition, x1, x2, /):
374+
condition = asarray(condition)
375+
x1 = asarray(x1)
376+
x2 = asarray(x2)
377+
units = x1.units
378+
magnitude = xp.where(condition.magnitude, x1.magnitude, x2.m_as(units))
379+
return ArrayUnitQuantity(magnitude, units)
380+
381+
# mod.searchsorted = searchsorted
382+
mod.where = where
383+
359384
# strip_unit_input_output_ufuncs = ["isnan", "isinf", "isfinite", "signbit", "sign"]
360385
# matching_input_bare_output_ufuncs = [
361386
# "equal",

0 commit comments

Comments
 (0)