Skip to content

Commit ca3ce7e

Browse files
committed
add max, min, mean, prod
1 parent b9fc5c1 commit ca3ce7e

File tree

1 file changed

+33
-5
lines changed

1 file changed

+33
-5
lines changed

src/pint_array/__init__.py

Lines changed: 33 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,23 @@ def astype(x, dtype, /, *, copy=True, device=None):
246246

247247
mod.astype = astype
248248

249-
# Handle functions that ignore units on input and output
249+
# Functions with output units equal to input units
250+
for func_str in (
251+
"max",
252+
"min",
253+
"mean",
254+
):
255+
256+
def func(x, /, *args, func_str=func_str, **kwargs):
257+
x = asarray(x)
258+
magnitude = xp.asarray(x.magnitude, copy=True)
259+
xp_func = getattr(xp, func_str)
260+
magnitude = xp_func(magnitude, *args, **kwargs)
261+
return ArrayUnitQuantity(magnitude, x.units)
262+
263+
setattr(mod, func_str, func)
264+
265+
# Functions which ignore units on input and output
250266
for func_str in (
251267
"ones_like",
252268
"zeros_like",
@@ -261,7 +277,7 @@ def func(x, /, *args, func_str=func_str, **kwargs):
261277
x = asarray(x)
262278
magnitude = xp.asarray(x.magnitude, copy=True)
263279
xp_func = getattr(xp, func_str)
264-
magnitude = xp_func(x, *args, **kwargs)
280+
magnitude = xp_func(magnitude, *args, **kwargs)
265281
return ArrayUnitQuantity(magnitude, None)
266282

267283
setattr(mod, func_str, func)
@@ -281,7 +297,7 @@ def func(x, /, *args, func_str=func_str, **kwargs):
281297
magnitude = xp.asarray(x.magnitude, copy=True)
282298
units = x.units
283299
xp_func = getattr(xp, func_str)
284-
magnitude = xp_func(x, *args, **kwargs)
300+
magnitude = xp_func(magnitude, *args, **kwargs)
285301
units = (1 * units + 1 * units).units
286302
return ArrayUnitQuantity(magnitude, units)
287303

@@ -290,16 +306,28 @@ def func(x, /, *args, func_str=func_str, **kwargs):
290306
# output_unit="variance":
291307
# square of `x.units`,
292308
# unless non-multiplicative, which raises `OffsetUnitCalculusError`
293-
def var(x, /, *, axis=None, correction=0.0, keepdims=False):
309+
def var(x, /, *args, **kwargs):
294310
x = asarray(x)
295311
magnitude = xp.asarray(x.magnitude, copy=True)
296312
units = x.units
297-
magnitude = xp.var(x, axis=axis, correction=correction, keepdims=keepdims)
313+
magnitude = xp.var(magnitude, *args, **kwargs)
298314
units = ((1 * units + 1 * units) ** 2).units
299315
return ArrayUnitQuantity(magnitude, units)
300316

301317
mod.var = var
302318

319+
# Output unit is the product of the input unit with itself along axis,
320+
# or the input unit to the power of the size of the array for axis=None
321+
def prod(x, /, *args, axis=None, **kwargs):
322+
x = asarray(x)
323+
magnitude = xp.asarray(x.magnitude, copy=True)
324+
exponent = magnitude.shape[axis] if axis is not None else magnitude.size
325+
units = x.units**exponent
326+
magnitude = xp.prod(magnitude, *args, axis=axis, **kwargs)
327+
return ArrayUnitQuantity(magnitude, units)
328+
329+
mod.prod = prod
330+
303331
# "mul": product of all units in `all_args`
304332
# - "delta": `first_input_units`, unless non-multiplicative,
305333
# which uses delta version

0 commit comments

Comments
 (0)