Skip to content

Commit 02d1204

Browse files
authored
Merge pull request numpy#20020 from joukewitteveen/ma-ndenumerate
ENH: add ndenumerate specialization for masked arrays
2 parents 42dc653 + 4f1d95a commit 02d1204

File tree

6 files changed

+120
-7
lines changed

6 files changed

+120
-7
lines changed
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
`ndenumerate` specialization for masked arrays
2+
----------------------------------------------
3+
The masked array module now provides the `numpy.ma.ndenumerate` function,
4+
an alternative to `numpy.ndenumerate` that skips masked values by default.

doc/source/reference/routines.ma.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,7 @@ Finding masked data
190190
.. autosummary::
191191
:toctree: generated/
192192

193+
ma.ndenumerate
193194
ma.flatnotmasked_contiguous
194195
ma.flatnotmasked_edges
195196
ma.notmasked_contiguous

numpy/ma/__init__.pyi

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,7 @@ from numpy.ma.extras import (
216216
masked_all_like as masked_all_like,
217217
median as median,
218218
mr_ as mr_,
219+
ndenumerate as ndenumerate,
219220
notmasked_contiguous as notmasked_contiguous,
220221
notmasked_edges as notmasked_edges,
221222
polyfit as polyfit,

numpy/ma/extras.py

Lines changed: 74 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,12 @@
1010
"""
1111
__all__ = [
1212
'apply_along_axis', 'apply_over_axes', 'atleast_1d', 'atleast_2d',
13-
'atleast_3d', 'average', 'clump_masked', 'clump_unmasked',
14-
'column_stack', 'compress_cols', 'compress_nd', 'compress_rowcols',
15-
'compress_rows', 'count_masked', 'corrcoef', 'cov', 'diagflat', 'dot',
16-
'dstack', 'ediff1d', 'flatnotmasked_contiguous', 'flatnotmasked_edges',
17-
'hsplit', 'hstack', 'isin', 'in1d', 'intersect1d', 'mask_cols', 'mask_rowcols',
18-
'mask_rows', 'masked_all', 'masked_all_like', 'median', 'mr_',
13+
'atleast_3d', 'average', 'clump_masked', 'clump_unmasked', 'column_stack',
14+
'compress_cols', 'compress_nd', 'compress_rowcols', 'compress_rows',
15+
'count_masked', 'corrcoef', 'cov', 'diagflat', 'dot', 'dstack', 'ediff1d',
16+
'flatnotmasked_contiguous', 'flatnotmasked_edges', 'hsplit', 'hstack',
17+
'isin', 'in1d', 'intersect1d', 'mask_cols', 'mask_rowcols', 'mask_rows',
18+
'masked_all', 'masked_all_like', 'median', 'mr_', 'ndenumerate',
1919
'notmasked_contiguous', 'notmasked_edges', 'polyfit', 'row_stack',
2020
'setdiff1d', 'setxor1d', 'stack', 'unique', 'union1d', 'vander', 'vstack',
2121
]
@@ -1552,6 +1552,74 @@ def __init__(self):
15521552
#---- Find unmasked data ---
15531553
#####--------------------------------------------------------------------------
15541554

1555+
def ndenumerate(a, compressed=True):
1556+
"""
1557+
Multidimensional index iterator.
1558+
1559+
Return an iterator yielding pairs of array coordinates and values,
1560+
skipping elements that are masked. With `compressed=False`,
1561+
`ma.masked` is yielded as the value of masked elements. This
1562+
behavior differs from that of `numpy.ndenumerate`, which yields the
1563+
value of the underlying data array.
1564+
1565+
Notes
1566+
-----
1567+
.. versionadded:: 1.23.0
1568+
1569+
Parameters
1570+
----------
1571+
a : array_like
1572+
An array with (possibly) masked elements.
1573+
compressed : bool, optional
1574+
If True (default), masked elements are skipped.
1575+
1576+
See Also
1577+
--------
1578+
numpy.ndenumerate : Equivalent function ignoring any mask.
1579+
1580+
Examples
1581+
--------
1582+
>>> a = np.ma.arange(9).reshape((3, 3))
1583+
>>> a[1, 0] = np.ma.masked
1584+
>>> a[1, 2] = np.ma.masked
1585+
>>> a[2, 1] = np.ma.masked
1586+
>>> a
1587+
masked_array(
1588+
data=[[0, 1, 2],
1589+
[--, 4, --],
1590+
[6, --, 8]],
1591+
mask=[[False, False, False],
1592+
[ True, False, True],
1593+
[False, True, False]],
1594+
fill_value=999999)
1595+
>>> for index, x in np.ma.ndenumerate(a):
1596+
... print(index, x)
1597+
(0, 0) 0
1598+
(0, 1) 1
1599+
(0, 2) 2
1600+
(1, 1) 4
1601+
(2, 0) 6
1602+
(2, 2) 8
1603+
1604+
>>> for index, x in np.ma.ndenumerate(a, compressed=False):
1605+
... print(index, x)
1606+
(0, 0) 0
1607+
(0, 1) 1
1608+
(0, 2) 2
1609+
(1, 0) --
1610+
(1, 1) 4
1611+
(1, 2) --
1612+
(2, 0) 6
1613+
(2, 1) --
1614+
(2, 2) 8
1615+
"""
1616+
for it, mask in zip(np.ndenumerate(a), getmaskarray(a).flat):
1617+
if not mask:
1618+
yield it
1619+
elif not compressed:
1620+
yield it[0], masked
1621+
1622+
15551623
def flatnotmasked_edges(a):
15561624
"""
15571625
Find the indices of the first and last unmasked values.

numpy/ma/extras.pyi

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ class mr_class(MAxisConcatenator):
7474

7575
mr_: mr_class
7676

77+
def ndenumerate(a, compressed=...): ...
7778
def flatnotmasked_edges(a): ...
7879
def notmasked_edges(a, axis=...): ...
7980
def flatnotmasked_contiguous(a): ...

numpy/ma/tests/test_extras.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
ediff1d, apply_over_axes, apply_along_axis, compress_nd, compress_rowcols,
2929
mask_rowcols, clump_masked, clump_unmasked, flatnotmasked_contiguous,
3030
notmasked_contiguous, notmasked_edges, masked_all, masked_all_like, isin,
31-
diagflat, stack, vstack
31+
diagflat, ndenumerate, stack, vstack
3232
)
3333

3434

@@ -1671,6 +1671,44 @@ def test_shape_scalar(self):
16711671
assert_equal(b.mask.shape, b.data.shape)
16721672

16731673

1674+
class TestNDEnumerate:
1675+
1676+
def test_ndenumerate_nomasked(self):
1677+
ordinary = np.ndarray(6).reshape((1, 3, 2))
1678+
empty_mask = np.zeros_like(ordinary, dtype=bool)
1679+
with_mask = masked_array(ordinary, mask=empty_mask)
1680+
assert_equal(list(np.ndenumerate(ordinary)),
1681+
list(ndenumerate(ordinary)))
1682+
assert_equal(list(ndenumerate(ordinary)),
1683+
list(ndenumerate(with_mask)))
1684+
assert_equal(list(ndenumerate(with_mask)),
1685+
list(ndenumerate(with_mask, compressed=False)))
1686+
1687+
def test_ndenumerate_allmasked(self):
1688+
a = masked_all(())
1689+
b = masked_all((100,))
1690+
c = masked_all((2, 3, 4))
1691+
assert_equal(list(ndenumerate(a)), [])
1692+
assert_equal(list(ndenumerate(b)), [])
1693+
assert_equal(list(ndenumerate(b, compressed=False)),
1694+
list(zip(np.ndindex((100,)), 100 * [masked])))
1695+
assert_equal(list(ndenumerate(c)), [])
1696+
assert_equal(list(ndenumerate(c, compressed=False)),
1697+
list(zip(np.ndindex((2, 3, 4)), 2 * 3 * 4 * [masked])))
1698+
1699+
def test_ndenumerate_mixedmasked(self):
1700+
a = masked_array(np.arange(12).reshape((3, 4)),
1701+
mask=[[1, 1, 1, 1],
1702+
[1, 1, 0, 1],
1703+
[0, 0, 0, 0]])
1704+
items = [((1, 2), 6),
1705+
((2, 0), 8), ((2, 1), 9), ((2, 2), 10), ((2, 3), 11)]
1706+
assert_equal(list(ndenumerate(a)), items)
1707+
assert_equal(len(list(ndenumerate(a, compressed=False))), a.size)
1708+
for coordinate, value in ndenumerate(a, compressed=False):
1709+
assert_equal(a[coordinate], value)
1710+
1711+
16741712
class TestStack:
16751713

16761714
def test_stack_1d(self):

0 commit comments

Comments
 (0)