Skip to content

Commit 4f1d95a

Browse files
ENH: Add compressed= argument to ma.ndenumerate
1 parent ff3a9da commit 4f1d95a

File tree

4 files changed

+39
-8
lines changed

4 files changed

+39
-8
lines changed
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
`ndenumerate` specialization for masked arrays
22
----------------------------------------------
33
The masked array module now provides the `numpy.ma.ndenumerate` function,
4-
an alternative to `numpy.ndenumerate` that skips masked values.
4+
an alternative to `numpy.ndenumerate` that skips masked values by default.

numpy/ma/extras.py

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1520,18 +1520,26 @@ def __init__(self):
15201520
#---- Find unmasked data ---
15211521
#####--------------------------------------------------------------------------
15221522

1523-
def ndenumerate(a):
1523+
def ndenumerate(a, compressed=True):
15241524
"""
15251525
Multidimensional index iterator.
15261526
1527-
Return an iterator yielding pairs of array coordinates and values of
1528-
elements that are not masked.
1527+
Return an iterator yielding pairs of array coordinates and values,
1528+
skipping elements that are masked. With `compressed=False`,
1529+
`ma.masked` is yielded as the value of masked elements. This
1530+
behavior differs from that of `numpy.ndenumerate`, which yields the
1531+
value of the underlying data array.
1532+
1533+
Notes
1534+
-----
15291535
.. versionadded:: 1.23.0
1530-
1536+
15311537
Parameters
15321538
----------
15331539
a : array_like
15341540
An array with (possibly) masked elements.
1541+
compressed : bool, optional
1542+
If True (default), masked elements are skipped.
15351543
15361544
See Also
15371545
--------
@@ -1560,10 +1568,24 @@ def ndenumerate(a):
15601568
(1, 1) 4
15611569
(2, 0) 6
15621570
(2, 2) 8
1571+
1572+
>>> for index, x in np.ma.ndenumerate(a, compressed=False):
1573+
... print(index, x)
1574+
(0, 0) 0
1575+
(0, 1) 1
1576+
(0, 2) 2
1577+
(1, 0) --
1578+
(1, 1) 4
1579+
(1, 2) --
1580+
(2, 0) 6
1581+
(2, 1) --
1582+
(2, 2) 8
15631583
"""
1564-
for it, masked in zip(np.ndenumerate(a), getmaskarray(a).flat):
1565-
if not masked:
1584+
for it, mask in zip(np.ndenumerate(a), getmaskarray(a).flat):
1585+
if not mask:
15661586
yield it
1587+
elif not compressed:
1588+
yield it[0], masked
15671589

15681590

15691591
def flatnotmasked_edges(a):

numpy/ma/extras.pyi

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ class mr_class(MAxisConcatenator):
7474

7575
mr_: mr_class
7676

77-
def ndenumerate(a): ...
77+
def ndenumerate(a, compressed=...): ...
7878
def flatnotmasked_edges(a): ...
7979
def notmasked_edges(a, axis=...): ...
8080
def flatnotmasked_contiguous(a): ...

numpy/ma/tests/test_extras.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1658,14 +1658,20 @@ def test_ndenumerate_nomasked(self):
16581658
list(ndenumerate(ordinary)))
16591659
assert_equal(list(ndenumerate(ordinary)),
16601660
list(ndenumerate(with_mask)))
1661+
assert_equal(list(ndenumerate(with_mask)),
1662+
list(ndenumerate(with_mask, compressed=False)))
16611663

16621664
def test_ndenumerate_allmasked(self):
16631665
a = masked_all(())
16641666
b = masked_all((100,))
16651667
c = masked_all((2, 3, 4))
16661668
assert_equal(list(ndenumerate(a)), [])
16671669
assert_equal(list(ndenumerate(b)), [])
1670+
assert_equal(list(ndenumerate(b, compressed=False)),
1671+
list(zip(np.ndindex((100,)), 100 * [masked])))
16681672
assert_equal(list(ndenumerate(c)), [])
1673+
assert_equal(list(ndenumerate(c, compressed=False)),
1674+
list(zip(np.ndindex((2, 3, 4)), 2 * 3 * 4 * [masked])))
16691675

16701676
def test_ndenumerate_mixedmasked(self):
16711677
a = masked_array(np.arange(12).reshape((3, 4)),
@@ -1675,6 +1681,9 @@ def test_ndenumerate_mixedmasked(self):
16751681
items = [((1, 2), 6),
16761682
((2, 0), 8), ((2, 1), 9), ((2, 2), 10), ((2, 3), 11)]
16771683
assert_equal(list(ndenumerate(a)), items)
1684+
assert_equal(len(list(ndenumerate(a, compressed=False))), a.size)
1685+
for coordinate, value in ndenumerate(a, compressed=False):
1686+
assert_equal(a[coordinate], value)
16781687

16791688

16801689
class TestStack:

0 commit comments

Comments
 (0)