Skip to content

Commit 018ecee

Browse files
committed
Fix Matrix.diff
1 parent 0529d31 commit 018ecee

File tree

2 files changed

+11
-6
lines changed

2 files changed

+11
-6
lines changed

symengine/lib/symengine_wrapper.pyx

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4064,14 +4064,13 @@ import atexit
40644064
atexit.register(module_cleanup)
40654065

40664066
def diff(expr, *args):
4067-
cdef Basic ex = sympify(expr)
40684067
cdef Basic prev
40694068
cdef Basic b
40704069
cdef size_t i
40714070
cdef size_t length = len(args)
40724071

40734072
if not length:
4074-
return ex
4073+
return expr
40754074

40764075
cdef size_t l = 0
40774076
cdef Basic cur_arg, next_arg
@@ -4083,20 +4082,20 @@ def diff(expr, *args):
40834082

40844083
if l + 1 == length:
40854084
# No next argument, differentiate with no integer argument
4086-
return ex._diff(cur_arg)
4085+
return expr._diff(cur_arg)
40874086

40884087
next_arg = sympify(args[l + 1])
40894088
# Check if the next arg was derivative order
40904089
if isinstance(next_arg, Integer):
40914090
i = int(next_arg)
40924091
for _ in range(i):
4093-
ex = ex._diff(cur_arg)
4092+
expr = expr._diff(cur_arg)
40944093
l += 2
40954094
if l == length:
4096-
return ex
4095+
return expr
40974096
cur_arg = sympify(args[l])
40984097
else:
4099-
ex = ex._diff(cur_arg)
4098+
expr = expr._diff(cur_arg)
41004099
l += 1
41014100
cur_arg = next_arg
41024101

symengine/tests/test_matrices.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -655,6 +655,12 @@ def test_cross():
655655
DenseMatrix(1, 2, [1, 1]).cross(DenseMatrix(1, 2, [1, 1])))
656656

657657

658+
def test_diff():
659+
x = symbols("x")
660+
M = DenseMatrix(1, 2, [x**2, x])
661+
assert M.diff(x) == DenseMatrix(1, 2, [2*x, 1])
662+
663+
658664
def test_immutablematrix():
659665
A = ImmutableMatrix([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
660666
assert A.shape == (3, 3)

0 commit comments

Comments
 (0)