Skip to content

Commit aed6c57

Browse files
authored
Merge pull request numpy#21262 from ganesh-k13/kron_21257_ma
ENH: Masked Array support for `np.kron`
2 parents b2e7534 + 8092911 commit aed6c57

File tree

3 files changed

+57
-17
lines changed

3 files changed

+57
-17
lines changed
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
``np.kron`` now maintains subclass information
2+
----------------------------------------------
3+
``np.kron`` maintains subclass information now such as masked arrays
4+
while computing the Kronecker product of the inputs
5+
6+
.. code-block:: python
7+
8+
>>> x = ma.array([[1, 2], [3, 4]], mask=[[0, 1], [1, 0]])
9+
>>> np.kron(x,x)
10+
masked_array(
11+
data=[[1, --, --, --],
12+
[--, 4, --, --],
13+
[--, --, 4, --],
14+
[--, --, --, 16]],
15+
mask=[[False, True, True, True],
16+
[ True, False, True, True],
17+
[ True, True, False, True],
18+
[ True, True, True, False]],
19+
fill_value=999999)
20+
21+
.. warning::
22+
``np.kron`` output now follows ``ufunc`` ordering (``multiply``)
23+
to determine the output class type
24+
25+
.. code-block:: python
26+
27+
>>> class myarr(np.ndarray):
28+
>>> __array_priority__ = -1
29+
>>> a = np.ones([2, 2])
30+
>>> ma = myarray(a.shape, a.dtype, a.data)
31+
>>> type(np.kron(a, ma)) == np.ndarray
32+
False # Before it was True
33+
>>> type(np.kron(a, ma)) == myarr
34+
True

numpy/lib/shape_base.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1159,23 +1159,19 @@ def kron(a, b):
11591159
bs = (1,)*max(0, nda-ndb) + bs
11601160

11611161
# Compute the product
1162-
a_arr = _nx.asarray(a).reshape(a.size, 1)
1163-
b_arr = _nx.asarray(b).reshape(1, b.size)
1164-
result = a_arr * b_arr
1162+
a_arr = a.reshape(a.size, 1)
1163+
b_arr = b.reshape(1, b.size)
1164+
is_any_mat = isinstance(a_arr, matrix) or isinstance(b_arr, matrix)
1165+
# In case of `mat`, convert result to `array`
1166+
result = _nx.multiply(a_arr, b_arr, subok=(not is_any_mat))
11651167

11661168
# Reshape back
11671169
result = result.reshape(as_+bs)
11681170
transposer = _nx.arange(nd*2).reshape([2, nd]).ravel(order='f')
11691171
result = result.transpose(transposer)
11701172
result = result.reshape(_nx.multiply(as_, bs))
11711173

1172-
wrapper = get_array_prepare(a, b)
1173-
if wrapper is not None:
1174-
result = wrapper(result)
1175-
wrapper = get_array_wrap(a, b)
1176-
if wrapper is not None:
1177-
result = wrapper(result)
1178-
return result
1174+
return result if not is_any_mat else matrix(result, copy=False)
11791175

11801176

11811177
def _tile_dispatcher(A, reps):

numpy/lib/tests/test_shape_base.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -646,21 +646,31 @@ def test_basic(self):
646646
class TestKron:
647647
def test_return_type(self):
648648
class myarray(np.ndarray):
649-
__array_priority__ = 0.0
649+
__array_priority__ = 1.0
650650

651651
a = np.ones([2, 2])
652652
ma = myarray(a.shape, a.dtype, a.data)
653653
assert_equal(type(kron(a, a)), np.ndarray)
654654
assert_equal(type(kron(ma, ma)), myarray)
655-
assert_equal(type(kron(a, ma)), np.ndarray)
655+
assert_equal(type(kron(a, ma)), myarray)
656656
assert_equal(type(kron(ma, a)), myarray)
657657

658-
def test_kron_smoke(self):
659-
a = np.ones([3, 3])
660-
b = np.ones([3, 3])
661-
k = np.ones([9, 9])
658+
@pytest.mark.parametrize(
659+
"array_class", [np.asarray, np.mat]
660+
)
661+
def test_kron_smoke(self, array_class):
662+
a = array_class(np.ones([3, 3]))
663+
b = array_class(np.ones([3, 3]))
664+
k = array_class(np.ones([9, 9]))
665+
666+
assert_array_equal(np.kron(a, b), k)
667+
668+
def test_kron_ma(self):
669+
x = np.ma.array([[1, 2], [3, 4]], mask=[[0, 1], [1, 0]])
670+
k = np.ma.array(np.diag([1, 4, 4, 16]),
671+
mask=~np.array(np.identity(4), dtype=bool))
662672

663-
assert np.array_equal(np.kron(a, b), k), "Smoke test for kron failed"
673+
assert_array_equal(k, np.kron(x, x))
664674

665675
@pytest.mark.parametrize(
666676
"shape_a,shape_b", [

0 commit comments

Comments
 (0)