Skip to content

Commit e49478c

Browse files
authored
Merge pull request numpy#21485 from WarrenWeckesser/average-keepdims
ENH: Add 'keepdims' to 'average()' and 'ma.average()'.
2 parents 369a677 + b89939b commit e49478c

File tree

5 files changed

+122
-11
lines changed

5 files changed

+122
-11
lines changed
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
``keepdims`` parameter for ``average``
2+
--------------------------------------
3+
The parameter ``keepdims`` was added to the functions `numpy.average`
4+
and `numpy.ma.average`. The parameter has the same meaning as it
5+
does in reduction functions such as `numpy.sum` or `numpy.mean`.

numpy/lib/function_base.py

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -388,12 +388,14 @@ def iterable(y):
388388
return True
389389

390390

391-
def _average_dispatcher(a, axis=None, weights=None, returned=None):
391+
def _average_dispatcher(a, axis=None, weights=None, returned=None, *,
392+
keepdims=None):
392393
return (a, weights)
393394

394395

395396
@array_function_dispatch(_average_dispatcher)
396-
def average(a, axis=None, weights=None, returned=False):
397+
def average(a, axis=None, weights=None, returned=False, *,
398+
keepdims=np._NoValue):
397399
"""
398400
Compute the weighted average along the specified axis.
399401
@@ -428,6 +430,14 @@ def average(a, axis=None, weights=None, returned=False):
428430
is returned, otherwise only the average is returned.
429431
If `weights=None`, `sum_of_weights` is equivalent to the number of
430432
elements over which the average is taken.
433+
keepdims : bool, optional
434+
If this is set to True, the axes which are reduced are left
435+
in the result as dimensions with size one. With this option,
436+
the result will broadcast correctly against the original `a`.
437+
*Note:* `keepdims` will not work with instances of `numpy.matrix`
438+
or other classes whose methods do not support `keepdims`.
439+
440+
.. versionadded:: 1.23.0
431441
432442
Returns
433443
-------
@@ -471,7 +481,7 @@ def average(a, axis=None, weights=None, returned=False):
471481
>>> np.average(np.arange(1, 11), weights=np.arange(10, 0, -1))
472482
4.0
473483
474-
>>> data = np.arange(6).reshape((3,2))
484+
>>> data = np.arange(6).reshape((3, 2))
475485
>>> data
476486
array([[0, 1],
477487
[2, 3],
@@ -488,11 +498,24 @@ def average(a, axis=None, weights=None, returned=False):
488498
>>> avg = np.average(a, weights=w)
489499
>>> print(avg.dtype)
490500
complex256
501+
502+
With ``keepdims=True``, the following result has shape (3, 1).
503+
504+
>>> np.average(data, axis=1, keepdims=True)
505+
array([[0.5],
506+
[2.5],
507+
[4.5]])
491508
"""
492509
a = np.asanyarray(a)
493510

511+
if keepdims is np._NoValue:
512+
# Don't pass on the keepdims argument if one wasn't given.
513+
keepdims_kw = {}
514+
else:
515+
keepdims_kw = {'keepdims': keepdims}
516+
494517
if weights is None:
495-
avg = a.mean(axis)
518+
avg = a.mean(axis, **keepdims_kw)
496519
scl = avg.dtype.type(a.size/avg.size)
497520
else:
498521
wgt = np.asanyarray(weights)
@@ -524,7 +547,8 @@ def average(a, axis=None, weights=None, returned=False):
524547
raise ZeroDivisionError(
525548
"Weights sum to zero, can't be normalized")
526549

527-
avg = np.multiply(a, wgt, dtype=result_dtype).sum(axis)/scl
550+
avg = np.multiply(a, wgt,
551+
dtype=result_dtype).sum(axis, **keepdims_kw) / scl
528552

529553
if returned:
530554
if scl.shape != avg.shape:

numpy/lib/tests/test_function_base.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -305,6 +305,29 @@ def test_basic(self):
305305
assert_almost_equal(y5.mean(0), average(y5, 0))
306306
assert_almost_equal(y5.mean(1), average(y5, 1))
307307

308+
@pytest.mark.parametrize(
309+
'x, axis, expected_avg, weights, expected_wavg, expected_wsum',
310+
[([1, 2, 3], None, [2.0], [3, 4, 1], [1.75], [8.0]),
311+
([[1, 2, 5], [1, 6, 11]], 0, [[1.0, 4.0, 8.0]],
312+
[1, 3], [[1.0, 5.0, 9.5]], [[4, 4, 4]])],
313+
)
314+
def test_basic_keepdims(self, x, axis, expected_avg,
315+
weights, expected_wavg, expected_wsum):
316+
avg = np.average(x, axis=axis, keepdims=True)
317+
assert avg.shape == np.shape(expected_avg)
318+
assert_array_equal(avg, expected_avg)
319+
320+
wavg = np.average(x, axis=axis, weights=weights, keepdims=True)
321+
assert wavg.shape == np.shape(expected_wavg)
322+
assert_array_equal(wavg, expected_wavg)
323+
324+
wavg, wsum = np.average(x, axis=axis, weights=weights, returned=True,
325+
keepdims=True)
326+
assert wavg.shape == np.shape(expected_wavg)
327+
assert_array_equal(wavg, expected_wavg)
328+
assert wsum.shape == np.shape(expected_wsum)
329+
assert_array_equal(wsum, expected_wsum)
330+
308331
def test_weights(self):
309332
y = np.arange(10)
310333
w = np.arange(10)
@@ -1242,11 +1265,11 @@ def test_no_trim(self):
12421265
res = trim_zeros(arr)
12431266
assert_array_equal(arr, res)
12441267

1245-
12461268
def test_list_to_list(self):
12471269
res = trim_zeros(self.a.tolist())
12481270
assert isinstance(res, list)
12491271

1272+
12501273
class TestExtins:
12511274

12521275
def test_basic(self):
@@ -1759,6 +1782,7 @@ def test_frompyfunc_leaks(self, name, incr):
17591782
finally:
17601783
gc.enable()
17611784

1785+
17621786
class TestDigitize:
17631787

17641788
def test_forward(self):
@@ -2339,6 +2363,7 @@ def test_complex(self):
23392363
with pytest.raises(TypeError, match="i0 not supported for complex values"):
23402364
res = i0(a)
23412365

2366+
23422367
class TestKaiser:
23432368

23442369
def test_simple(self):
@@ -3474,6 +3499,7 @@ def test_quantile_scalar_nan(self):
34743499
assert np.isscalar(actual)
34753500
assert_equal(np.quantile(a, 0.5), np.nan)
34763501

3502+
34773503
class TestLerp:
34783504
@hypothesis.given(t0=st.floats(allow_nan=False, allow_infinity=False,
34793505
min_value=0, max_value=1),

numpy/ma/extras.py

Lines changed: 36 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -475,6 +475,7 @@ def apply_over_axes(func, a, axes):
475475
"an array of the correct shape")
476476
return val
477477

478+
478479
if apply_over_axes.__doc__ is not None:
479480
apply_over_axes.__doc__ = np.apply_over_axes.__doc__[
480481
:np.apply_over_axes.__doc__.find('Notes')].rstrip() + \
@@ -524,7 +525,8 @@ def apply_over_axes(func, a, axes):
524525
"""
525526

526527

527-
def average(a, axis=None, weights=None, returned=False):
528+
def average(a, axis=None, weights=None, returned=False, *,
529+
keepdims=np._NoValue):
528530
"""
529531
Return the weighted average of array over the given axis.
530532
@@ -550,6 +552,14 @@ def average(a, axis=None, weights=None, returned=False):
550552
Flag indicating whether a tuple ``(result, sum of weights)``
551553
should be returned as output (True), or just the result (False).
552554
Default is False.
555+
keepdims : bool, optional
556+
If this is set to True, the axes which are reduced are left
557+
in the result as dimensions with size one. With this option,
558+
the result will broadcast correctly against the original `a`.
559+
*Note:* `keepdims` will not work with instances of `numpy.matrix`
560+
or other classes whose methods do not support `keepdims`.
561+
562+
.. versionadded:: 1.23.0
553563
554564
Returns
555565
-------
@@ -582,14 +592,29 @@ def average(a, axis=None, weights=None, returned=False):
582592
mask=[False, False],
583593
fill_value=1e+20)
584594
595+
With ``keepdims=True``, the following result has shape (3, 1).
596+
597+
>>> np.ma.average(x, axis=1, keepdims=True)
598+
masked_array(
599+
data=[[0.5],
600+
[2.5],
601+
[4.5]],
602+
mask=False,
603+
fill_value=1e+20)
585604
"""
586605
a = asarray(a)
587606
m = getmask(a)
588607

589608
# inspired by 'average' in numpy/lib/function_base.py
590609

610+
if keepdims is np._NoValue:
611+
# Don't pass on the keepdims argument if one wasn't given.
612+
keepdims_kw = {}
613+
else:
614+
keepdims_kw = {'keepdims': keepdims}
615+
591616
if weights is None:
592-
avg = a.mean(axis)
617+
avg = a.mean(axis, **keepdims_kw)
593618
scl = avg.dtype.type(a.count(axis))
594619
else:
595620
wgt = asarray(weights)
@@ -621,7 +646,8 @@ def average(a, axis=None, weights=None, returned=False):
621646
wgt.mask |= a.mask
622647

623648
scl = wgt.sum(axis=axis, dtype=result_dtype)
624-
avg = np.multiply(a, wgt, dtype=result_dtype).sum(axis)/scl
649+
avg = np.multiply(a, wgt,
650+
dtype=result_dtype).sum(axis, **keepdims_kw) / scl
625651

626652
if returned:
627653
if scl.shape != avg.shape:
@@ -713,6 +739,7 @@ def median(a, axis=None, out=None, overwrite_input=False, keepdims=False):
713739
else:
714740
return r
715741

742+
716743
def _median(a, axis=None, out=None, overwrite_input=False):
717744
# when an unmasked NaN is present return it, so we need to sort the NaN
718745
# values behind the mask
@@ -840,6 +867,7 @@ def compress_nd(x, axis=None):
840867
data = data[(slice(None),)*ax + (~m.any(axis=axes),)]
841868
return data
842869

870+
843871
def compress_rowcols(x, axis=None):
844872
"""
845873
Suppress the rows and/or columns of a 2-D array that contain
@@ -912,6 +940,7 @@ def compress_rows(a):
912940
raise NotImplementedError("compress_rows works for 2D arrays only.")
913941
return compress_rowcols(a, 0)
914942

943+
915944
def compress_cols(a):
916945
"""
917946
Suppress whole columns of a 2-D array that contain masked values.
@@ -929,6 +958,7 @@ def compress_cols(a):
929958
raise NotImplementedError("compress_cols works for 2D arrays only.")
930959
return compress_rowcols(a, 1)
931960

961+
932962
def mask_rows(a, axis=np._NoValue):
933963
"""
934964
Mask rows of a 2D array that contain masked values.
@@ -979,6 +1009,7 @@ def mask_rows(a, axis=np._NoValue):
9791009
"will raise TypeError", DeprecationWarning, stacklevel=2)
9801010
return mask_rowcols(a, 0)
9811011

1012+
9821013
def mask_cols(a, axis=np._NoValue):
9831014
"""
9841015
Mask columns of a 2D array that contain masked values.
@@ -1516,6 +1547,7 @@ def __init__(self):
15161547

15171548
mr_ = mr_class()
15181549

1550+
15191551
#####--------------------------------------------------------------------------
15201552
#---- Find unmasked data ---
15211553
#####--------------------------------------------------------------------------
@@ -1682,6 +1714,7 @@ def flatnotmasked_contiguous(a):
16821714
i += n
16831715
return result
16841716

1717+
16851718
def notmasked_contiguous(a, axis=None):
16861719
"""
16871720
Find contiguous unmasked data in a masked array along the given axis.

numpy/ma/tests/test_extras.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def test_masked_all_with_object_nested(self):
7575
assert_equal(len(masked_arr['b']['c']), 1)
7676
assert_equal(masked_arr['b']['c'].shape, (1, 1))
7777
assert_equal(masked_arr['b']['c']._fill_value.shape, ())
78-
78+
7979
def test_masked_all_with_object(self):
8080
# same as above except that the array is not nested
8181
my_dtype = np.dtype([('b', (object, (1,)))])
@@ -292,6 +292,29 @@ def test_complex(self):
292292
assert_almost_equal(wav1.real, expected1.real)
293293
assert_almost_equal(wav1.imag, expected1.imag)
294294

295+
@pytest.mark.parametrize(
296+
'x, axis, expected_avg, weights, expected_wavg, expected_wsum',
297+
[([1, 2, 3], None, [2.0], [3, 4, 1], [1.75], [8.0]),
298+
([[1, 2, 5], [1, 6, 11]], 0, [[1.0, 4.0, 8.0]],
299+
[1, 3], [[1.0, 5.0, 9.5]], [[4, 4, 4]])],
300+
)
301+
def test_basic_keepdims(self, x, axis, expected_avg,
302+
weights, expected_wavg, expected_wsum):
303+
avg = np.ma.average(x, axis=axis, keepdims=True)
304+
assert avg.shape == np.shape(expected_avg)
305+
assert_array_equal(avg, expected_avg)
306+
307+
wavg = np.ma.average(x, axis=axis, weights=weights, keepdims=True)
308+
assert wavg.shape == np.shape(expected_wavg)
309+
assert_array_equal(wavg, expected_wavg)
310+
311+
wavg, wsum = np.ma.average(x, axis=axis, weights=weights,
312+
returned=True, keepdims=True)
313+
assert wavg.shape == np.shape(expected_wavg)
314+
assert_array_equal(wavg, expected_wavg)
315+
assert wsum.shape == np.shape(expected_wsum)
316+
assert_array_equal(wsum, expected_wsum)
317+
295318
def test_masked_weights(self):
296319
# Test with masked weights.
297320
# (Regression test for https://github.com/numpy/numpy/issues/10438)
@@ -335,6 +358,7 @@ def test_masked_weights(self):
335358
assert_almost_equal(avg_masked, avg_expected)
336359
assert_equal(avg_masked.mask, avg_expected.mask)
337360

361+
338362
class TestConcatenator:
339363
# Tests for mr_, the equivalent of r_ for masked arrays.
340364

@@ -1642,7 +1666,6 @@ def test_shape_scalar(self):
16421666
assert_equal(a.mask.shape, a.shape)
16431667
assert_equal(a.data.shape, a.shape)
16441668

1645-
16461669
b = diagflat(1.0)
16471670
assert_equal(b.shape, (1, 1))
16481671
assert_equal(b.mask.shape, b.data.shape)

0 commit comments

Comments
 (0)