Skip to content

Commit 46d6a08

Browse files
committed
manipulation functions
1 parent 32a6b2f commit 46d6a08

File tree

1 file changed

+48
-44
lines changed

1 file changed

+48
-44
lines changed

src/pint_array/__init__.py

Lines changed: 48 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -231,50 +231,54 @@ def fun(*args, func_str=func_str, units=None, **kwargs):
231231
setattr(mod, func_str, fun)
232232

233233
## Manipulation Functions ##
234-
# first_arg_arrays = {'broadcast_arrays', 'concat', 'stack', 'meshgrid'}
235-
# output_arrays = {'broadcast_arrays', 'unstack', 'meshgrid'}
236-
237-
# def get_manip_fun(name):
238-
# def manip_fun(x, *args, **kwargs):
239-
# x = (asarray(x) if name not in first_arg_arrays
240-
# else [asarray(xi) for xi in x])
241-
# mask = (x.mask if name not in first_arg_arrays
242-
# else [xi.mask for xi in x])
243-
# data = (x.data if name not in first_arg_arrays
244-
# else [xi.data for xi in x])
245-
246-
# fun = getattr(xp, name)
247-
248-
# if name == "repeat":
249-
# args = list(args)
250-
# repeats = args[0]
251-
# if hasattr(repeats, 'mask') and xp.any(repeats.mask):
252-
# message = (
253-
# "Correct behavior when `repeats` is a masked array is "
254-
# "ambiguous, and no convention is supported at this time.")
255-
# raise NotImplementedError(message)
256-
# elif hasattr(repeats, 'mask'):
257-
# repeats = repeats.data
258-
# args[0] = repeats
259-
260-
# if name in {'broadcast_arrays', 'meshgrid'}:
261-
# res = fun(*data, *args, **kwargs)
262-
# mask = fun(*mask, *args, **kwargs)
263-
# else:
264-
# res = fun(data, *args, **kwargs)
265-
# mask = fun(mask, *args, **kwargs)
266-
267-
# out = (MArray(res, mask) if name not in output_arrays
268-
# else tuple(MArray(resi, maski) for resi, maski in zip(res, mask)))
269-
# return out
270-
# return manip_fun
271-
272-
# creation_manip_functions = ['tril', 'triu', 'meshgrid']
273-
# manip_names = ['broadcast_arrays', 'broadcast_to', 'concat', 'expand_dims',
274-
# 'flip', 'moveaxis', 'permute_dims', 'repeat', 'reshape',
275-
# 'roll', 'squeeze', 'stack', 'tile', 'unstack']
276-
# for name in manip_names + creation_manip_functions:
277-
# setattr(mod, name, get_manip_fun(name))
234+
first_arg_arrays = {"broadcast_arrays", "concat", "stack", "meshgrid"}
235+
output_arrays = {"broadcast_arrays", "unstack", "meshgrid"}
236+
237+
def get_manip_fun(func_str):
238+
def manip_fun(x, *args, **kwargs):
239+
xp_func = getattr(xp, func_str)
240+
241+
if func_str not in first_arg_arrays:
242+
x = asarray(x)
243+
magnitude = xp.asarray(x.magnitude, copy=True)
244+
units = x.units
245+
magnitude = xp_func(magnitude, *args, **kwargs)
246+
247+
else:
248+
x = [asarray(x_i) for x_i in x]
249+
units = x[0].units
250+
magnitude = [xp.asarray(x[0].magnitude, copy=True)]
251+
for x_i in x[1:]:
252+
magnitude.append(x_i.m_as(units))
253+
magnitude = xp_func(*magnitude, *args, **kwargs)
254+
255+
if name in output_arrays:
256+
return tuple(
257+
ArrayUnitQuantity(magnitude_i, units) for magnitude_i in magnitude
258+
)
259+
return ArrayUnitQuantity(magnitude, units)
260+
261+
return manip_fun
262+
263+
creation_manip_functions = ["tril", "triu", "meshgrid"]
264+
manip_names = [
265+
"broadcast_arrays",
266+
"broadcast_to",
267+
"concat",
268+
"expand_dims",
269+
"flip",
270+
"moveaxis",
271+
"permute_dims",
272+
"repeat",
273+
"reshape",
274+
"roll",
275+
"squeeze",
276+
"stack",
277+
"tile",
278+
"unstack",
279+
]
280+
for name in manip_names + creation_manip_functions:
281+
setattr(mod, name, get_manip_fun(name))
278282

279283
## Data Type Functions and Data Types ##
280284
dtype_fun_names = ["can_cast", "finfo", "iinfo", "isdtype"]

0 commit comments

Comments
 (0)