Skip to content

Commit 1ddb1f3

Browse files
authored
Merge pull request #374 from rikardn/immutable
Mul and Add of immutable and dense matrices gives immutable
2 parents 4978e0b + 1c90da1 commit 1ddb1f3

File tree

2 files changed

+24
-3
lines changed

2 files changed

+24
-3
lines changed

symengine/lib/symengine_wrapper.pyx

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3570,19 +3570,31 @@ cdef class DenseMatrixBase(MatrixBase):
35703570

35713571
def add_matrix(self, A):
35723572
cdef MatrixBase A_ = sympify(A)
3573-
cdef DenseMatrixBase result = self.__class__(self.nrows(), self.ncols())
3573+
if isinstance(A, ImmutableDenseMatrix):
3574+
cls = A.__class__
3575+
else:
3576+
cls = self.__class__
3577+
cdef DenseMatrixBase result = cls(self.nrows(), self.ncols())
35743578
deref(self.thisptr).add_matrix(deref(A_.thisptr), deref(result.thisptr))
35753579
return result
35763580

35773581
def mul_matrix(self, A):
35783582
cdef MatrixBase A_ = sympify(A)
3579-
cdef DenseMatrixBase result = self.__class__(self.nrows(), A.ncols())
3583+
if isinstance(A, ImmutableDenseMatrix):
3584+
cls = A.__class__
3585+
else:
3586+
cls = self.__class__
3587+
cdef DenseMatrixBase result = cls(self.nrows(), A.ncols())
35803588
deref(self.thisptr).mul_matrix(deref(A_.thisptr), deref(result.thisptr))
35813589
return result
35823590

35833591
def multiply_elementwise(self, A):
35843592
cdef MatrixBase A_ = sympify(A)
3585-
cdef DenseMatrixBase result = self.__class__(self.nrows(), self.ncols())
3593+
if isinstance(A, ImmutableDenseMatrix):
3594+
cls = A.__class__
3595+
else:
3596+
cls = self.__class__
3597+
cdef DenseMatrixBase result = cls(self.nrows(), self.ncols())
35863598
deref(self.thisptr).elementwise_mul_matrix(deref(A_.thisptr), deref(result.thisptr))
35873599
return result
35883600

symengine/tests/test_matrices.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -708,6 +708,15 @@ def test_immutablematrix():
708708
assert isinstance(Z, ImmutableMatrix)
709709
assert Z == ImmutableMatrix([[1, 2], [3, 4], [5, 6]])
710710

711+
# Operations of one immutable and one mutable matrix should give immutable result
712+
X = ImmutableMatrix([1])
713+
Y = DenseMatrix([1])
714+
assert type(X + Y) == ImmutableMatrix
715+
assert type(Y + X) == ImmutableMatrix
716+
assert type(X * Y) == ImmutableMatrix
717+
assert type(Y * X) == ImmutableMatrix
718+
719+
711720
def test_atoms():
712721
a = Symbol("a")
713722
b = Symbol("b")

0 commit comments

Comments
 (0)