Skip to content

Commit eeaa4f1

Browse files
committed
set functions
1 parent 4e3c56b commit eeaa4f1

File tree

1 file changed

+29
-0
lines changed

1 file changed

+29
-0
lines changed

src/pint_array/__init__.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
Pint interoperability with array API standard arrays.
66
"""
77

8+
import collections
89
import importlib
910
import sys
1011
import textwrap
@@ -555,6 +556,34 @@ def matrix_transpose(x):
555556

556557
mod.matrix_transpose = matrix_transpose
557558

559+
## Set Functions ##
560+
def get_set_fun(func_str):
561+
def set_fun(x, /):
562+
x = asarray(x)
563+
units = x.units
564+
magnitude = xp.asarray(x.magnitude, copy=True)
565+
566+
xp_func = getattr(xp, func_str)
567+
res = xp_func(magnitude)
568+
if func_str == "unique_values":
569+
return ArrayUnitQuantity(res, units)
570+
571+
fields = res._fields
572+
name_tuple = res.__class__.__name__
573+
result_class = collections.namedtuple(name_tuple, fields)
574+
575+
result_list = []
576+
for res_i, field_i in zip(res, fields, strict=False):
577+
units_i = units if field_i == "values" else None
578+
result_list.append(ArrayUnitQuantity(res_i, units_i))
579+
return result_class(*result_list)
580+
581+
return set_fun
582+
583+
unique_names = ["unique_values", "unique_counts", "unique_inverse", "unique_all"]
584+
for name in unique_names:
585+
setattr(mod, name, get_set_fun(name))
586+
558587
# Handle functions with output unit defined by operation
559588

560589
# output_unit="sum":

0 commit comments

Comments
 (0)