Skip to content

Commit 6c9a264

Browse files
authored
Merge pull request #392 from rikardn/immutablematrix
Better compatibility with sympy ImmutableMatrix (fixes #363)
2 parents 0ca50eb + b586d41 commit 6c9a264

File tree

2 files changed

+28
-10
lines changed

2 files changed

+28
-10
lines changed

symengine/lib/symengine_wrapper.pyx

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3895,7 +3895,7 @@ cdef class DenseMatrixBase(MatrixBase):
38953895
l.append(c2py(A.get(i, j))._sympy_())
38963896
s.append(l)
38973897
import sympy
3898-
return sympy.Matrix(s)
3898+
return sympy.ImmutableMatrix(s)
38993899

39003900
def _sage_(self):
39013901
s = []
@@ -3906,7 +3906,7 @@ cdef class DenseMatrixBase(MatrixBase):
39063906
l.append(c2py(A.get(i, j))._sage_())
39073907
s.append(l)
39083908
import sage.all as sage
3909-
return sage.Matrix(s)
3909+
return sage.Matrix(s, immutable=True)
39103910

39113911
def dump_real(self, double[::1] out):
39123912
cdef size_t ri, ci, nr, nc
@@ -4046,6 +4046,12 @@ cdef class ImmutableDenseMatrix(DenseMatrixBase):
40464046
def __setitem__(self, key, value):
40474047
raise TypeError("Cannot set values of {}".format(self.__class__))
40484048

4049+
def _applyfunc(self, f):
4050+
res = DenseMatrix(self)
4051+
res._applyfunc(f)
4052+
return ImmutableDenseMatrix(res)
4053+
4054+
40494055
ImmutableMatrix = ImmutableDenseMatrix
40504056

40514057

symengine/tests/test_matrices.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,23 @@
33
Rational, function_symbol, I, NonSquareMatrixError, ShapeError, zeros,
44
ones, eye, ImmutableMatrix)
55
from symengine.test_utilities import raises
6+
import unittest
67

78

89
try:
910
import numpy as np
10-
HAVE_NUMPY = True
11+
have_numpy = True
1112
except ImportError:
12-
HAVE_NUMPY = False
13+
have_numpy = False
14+
15+
try:
16+
import sympy
17+
from sympy.core.cache import clear_cache
18+
import atexit
19+
atexit.register(clear_cache)
20+
have_sympy = True
21+
except ImportError:
22+
have_sympy = False
1323

1424

1525
def test_init():
@@ -520,21 +530,18 @@ def test_reshape():
520530
assert C != A
521531

522532

523-
# @pytest.mark.skipif(not HAVE_NUMPY, reason='requires numpy')
533+
@unittest.skipIf(not have_numpy, 'requires numpy')
524534
def test_dump_real():
525-
if not HAVE_NUMPY: # nosetests work-around
526-
return
527535
ref = [1, 2, 3, 4]
528536
A = DenseMatrix(2, 2, ref)
529537
out = np.empty(4)
530538
A.dump_real(out)
531539
assert np.allclose(out, ref)
532540

533541

534-
# @pytest.mark.skipif(not HAVE_NUMPY, reason='requires numpy')
542+
543+
@unittest.skipIf(not have_numpy, 'requires numpy')
535544
def test_dump_complex():
536-
if not HAVE_NUMPY: # nosetests work-around
537-
return
538545
ref = [1j, 2j, 3j, 4j]
539546
A = DenseMatrix(2, 2, ref)
540547
out = np.empty(4, dtype=np.complex128)
@@ -741,3 +748,8 @@ def test_repr_latex():
741748
latex_string = testmat._repr_latex_()
742749
assert isinstance(latex_string, str)
743750
init_printing(False)
751+
752+
@unittest.skipIf(not have_sympy, "SymPy not installed")
753+
def test_simplify():
754+
A = ImmutableMatrix([1])
755+
assert type(A.simplify()) == type(A)

0 commit comments

Comments
 (0)