Skip to content

Commit c68a8b6

Browse files
committed
TST: np.kron tests refinement
* Added `mat` cases to smoke tests * Changed type checks to handle new change which uses ufuncs order for result determination * Added cases for `ma` to check subclass info retention
1 parent 730f315 commit c68a8b6

File tree

1 file changed

+17
-7
lines changed

1 file changed

+17
-7
lines changed

numpy/lib/tests/test_shape_base.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -646,21 +646,31 @@ def test_basic(self):
646646
class TestKron:
647647
def test_return_type(self):
648648
class myarray(np.ndarray):
649-
__array_priority__ = 0.0
649+
__array_priority__ = 1.0
650650

651651
a = np.ones([2, 2])
652652
ma = myarray(a.shape, a.dtype, a.data)
653653
assert_equal(type(kron(a, a)), np.ndarray)
654654
assert_equal(type(kron(ma, ma)), myarray)
655-
assert_equal(type(kron(a, ma)), np.ndarray)
655+
assert_equal(type(kron(a, ma)), myarray)
656656
assert_equal(type(kron(ma, a)), myarray)
657657

658-
def test_kron_smoke(self):
659-
a = np.ones([3, 3])
660-
b = np.ones([3, 3])
661-
k = np.ones([9, 9])
658+
@pytest.mark.parametrize(
659+
"array_class", [np.asarray, np.mat]
660+
)
661+
def test_kron_smoke(self, array_class):
662+
a = array_class(np.ones([3, 3]))
663+
b = array_class(np.ones([3, 3]))
664+
k = array_class(np.ones([9, 9]))
665+
666+
assert_array_equal(np.kron(a, b), k)
667+
668+
def test_kron_ma(self):
669+
x = np.ma.array([[1, 2], [3, 4]], mask=[[0, 1], [1, 0]])
670+
k = np.ma.array(np.diag([1, 4, 4, 16]),
671+
mask=~np.array(np.identity(4), dtype=bool))
662672

663-
assert np.array_equal(np.kron(a, b), k), "Smoke test for kron failed"
673+
assert_array_equal(k, np.kron(x, x))
664674

665675
@pytest.mark.parametrize(
666676
"shape_a,shape_b", [

0 commit comments

Comments
 (0)