@@ -246,7 +246,23 @@ def astype(x, dtype, /, *, copy=True, device=None):
246
246
247
247
mod .astype = astype
248
248
249
- # Handle functions that ignore units on input and output
249
+ # Functions with output units equal to input units
250
+ for func_str in (
251
+ "max" ,
252
+ "min" ,
253
+ "mean" ,
254
+ ):
255
+
256
+ def func (x , / , * args , func_str = func_str , ** kwargs ):
257
+ x = asarray (x )
258
+ magnitude = xp .asarray (x .magnitude , copy = True )
259
+ xp_func = getattr (xp , func_str )
260
+ magnitude = xp_func (magnitude , * args , ** kwargs )
261
+ return ArrayUnitQuantity (magnitude , x .units )
262
+
263
+ setattr (mod , func_str , func )
264
+
265
+ # Functions which ignore units on input and output
250
266
for func_str in (
251
267
"ones_like" ,
252
268
"zeros_like" ,
@@ -261,7 +277,7 @@ def func(x, /, *args, func_str=func_str, **kwargs):
261
277
x = asarray (x )
262
278
magnitude = xp .asarray (x .magnitude , copy = True )
263
279
xp_func = getattr (xp , func_str )
264
- magnitude = xp_func (x , * args , ** kwargs )
280
+ magnitude = xp_func (magnitude , * args , ** kwargs )
265
281
return ArrayUnitQuantity (magnitude , None )
266
282
267
283
setattr (mod , func_str , func )
@@ -281,7 +297,7 @@ def func(x, /, *args, func_str=func_str, **kwargs):
281
297
magnitude = xp .asarray (x .magnitude , copy = True )
282
298
units = x .units
283
299
xp_func = getattr (xp , func_str )
284
- magnitude = xp_func (x , * args , ** kwargs )
300
+ magnitude = xp_func (magnitude , * args , ** kwargs )
285
301
units = (1 * units + 1 * units ).units
286
302
return ArrayUnitQuantity (magnitude , units )
287
303
@@ -290,16 +306,28 @@ def func(x, /, *args, func_str=func_str, **kwargs):
290
306
# output_unit="variance":
291
307
# square of `x.units`,
292
308
# unless non-multiplicative, which raises `OffsetUnitCalculusError`
293
- def var (x , / , * , axis = None , correction = 0.0 , keepdims = False ):
309
+ def var (x , / , * args , ** kwargs ):
294
310
x = asarray (x )
295
311
magnitude = xp .asarray (x .magnitude , copy = True )
296
312
units = x .units
297
- magnitude = xp .var (x , axis = axis , correction = correction , keepdims = keepdims )
313
+ magnitude = xp .var (magnitude , * args , ** kwargs )
298
314
units = ((1 * units + 1 * units ) ** 2 ).units
299
315
return ArrayUnitQuantity (magnitude , units )
300
316
301
317
mod .var = var
302
318
319
+ # Output unit is the product of the input unit with itself along axis,
320
+ # or the input unit to the power of the size of the array for axis=None
321
+ def prod (x , / , * args , axis = None , ** kwargs ):
322
+ x = asarray (x )
323
+ magnitude = xp .asarray (x .magnitude , copy = True )
324
+ exponent = magnitude .shape [axis ] if axis is not None else magnitude .size
325
+ units = x .units ** exponent
326
+ magnitude = xp .prod (magnitude , * args , axis = axis , ** kwargs )
327
+ return ArrayUnitQuantity (magnitude , units )
328
+
329
+ mod .prod = prod
330
+
303
331
# "mul": product of all units in `all_args`
304
332
# - "delta": `first_input_units`, unless non-multiplicative,
305
333
# which uses delta version
0 commit comments