Skip to content

Commit 8808166

Browse files
authored
Merge pull request #366 from isuruf/matrix-diff
Fix Matrix.diff
2 parents 0529d31 + 95e5321 commit 8808166

File tree

2 files changed

+24
-8
lines changed

2 files changed

+24
-8
lines changed

symengine/lib/symengine_wrapper.pyx

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -930,7 +930,7 @@ cdef class Basic(object):
930930
if (len(f) != 1):
931931
raise RuntimeError("Variable w.r.t should be given")
932932
return self._diff(f.pop())
933-
return diff(self, *args)
933+
return _diff(self, *args)
934934

935935
def subs_dict(Basic self not None, *args):
936936
warnings.warn("subs_dict() is deprecated. Use subs() instead", DeprecationWarning)
@@ -3687,7 +3687,7 @@ cdef class DenseMatrixBase(MatrixBase):
36873687
return R
36883688

36893689
def diff(self, *args):
3690-
return diff(self, *args)
3690+
return _diff(self, *args)
36913691

36923692
#TODO: implement this in C++
36933693
def subs(self, *args):
@@ -4063,15 +4063,23 @@ def module_cleanup():
40634063
import atexit
40644064
atexit.register(module_cleanup)
40654065

4066+
40664067
def diff(expr, *args):
4067-
cdef Basic ex = sympify(expr)
4068+
if isinstance(expr, MatrixBase):
4069+
# Don't sympify matrices so that mutable matrices
4070+
# return mutable matrices
4071+
return _diff(expr, *args)
4072+
return _diff(sympify(expr), *args)
4073+
4074+
4075+
def _diff(expr, *args):
40684076
cdef Basic prev
40694077
cdef Basic b
40704078
cdef size_t i
40714079
cdef size_t length = len(args)
40724080

40734081
if not length:
4074-
return ex
4082+
return expr
40754083

40764084
cdef size_t l = 0
40774085
cdef Basic cur_arg, next_arg
@@ -4083,20 +4091,20 @@ def diff(expr, *args):
40834091

40844092
if l + 1 == length:
40854093
# No next argument, differentiate with no integer argument
4086-
return ex._diff(cur_arg)
4094+
return expr._diff(cur_arg)
40874095

40884096
next_arg = sympify(args[l + 1])
40894097
# Check if the next arg was derivative order
40904098
if isinstance(next_arg, Integer):
40914099
i = int(next_arg)
40924100
for _ in range(i):
4093-
ex = ex._diff(cur_arg)
4101+
expr = expr._diff(cur_arg)
40944102
l += 2
40954103
if l == length:
4096-
return ex
4104+
return expr
40974105
cur_arg = sympify(args[l])
40984106
else:
4099-
ex = ex._diff(cur_arg)
4107+
expr = expr._diff(cur_arg)
41004108
l += 1
41014109
cur_arg = next_arg
41024110

symengine/tests/test_matrices.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -655,6 +655,14 @@ def test_cross():
655655
DenseMatrix(1, 2, [1, 1]).cross(DenseMatrix(1, 2, [1, 1])))
656656

657657

658+
def test_diff():
659+
x = symbols("x")
660+
M = DenseMatrix(1, 2, [x**2, x])
661+
result = M.diff(x)
662+
assert isinstance(result, DenseMatrix)
663+
assert result == DenseMatrix(1, 2, [2*x, 1])
664+
665+
658666
def test_immutablematrix():
659667
A = ImmutableMatrix([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
660668
assert A.shape == (3, 3)

0 commit comments

Comments
 (0)