@@ -930,7 +930,7 @@ cdef class Basic(object):
930
930
if (len (f) != 1 ):
931
931
raise RuntimeError (" Variable w.r.t should be given" )
932
932
return self ._diff(f.pop())
933
- return diff (self , * args)
933
+ return _diff (self , * args)
934
934
935
935
def subs_dict (Basic self not None , *args ):
936
936
warnings.warn(" subs_dict() is deprecated. Use subs() instead" , DeprecationWarning )
@@ -3687,7 +3687,7 @@ cdef class DenseMatrixBase(MatrixBase):
3687
3687
return R
3688
3688
3689
3689
def diff (self , *args ):
3690
- return diff (self , * args)
3690
+ return _diff (self , * args)
3691
3691
3692
3692
# TODO: implement this in C++
3693
3693
def subs (self , *args ):
@@ -4063,15 +4063,23 @@ def module_cleanup():
4063
4063
import atexit
4064
4064
atexit.register(module_cleanup)
4065
4065
4066
+
4066
4067
def diff (expr , *args ):
4067
- cdef Basic ex = sympify(expr)
4068
+ if isinstance (expr, MatrixBase):
4069
+ # Don't sympify matrices so that mutable matrices
4070
+ # return mutable matrices
4071
+ return _diff(expr, * args)
4072
+ return _diff(sympify(expr), * args)
4073
+
4074
+
4075
+ def _diff (expr , *args ):
4068
4076
cdef Basic prev
4069
4077
cdef Basic b
4070
4078
cdef size_t i
4071
4079
cdef size_t length = len (args)
4072
4080
4073
4081
if not length:
4074
- return ex
4082
+ return expr
4075
4083
4076
4084
cdef size_t l = 0
4077
4085
cdef Basic cur_arg, next_arg
@@ -4083,20 +4091,20 @@ def diff(expr, *args):
4083
4091
4084
4092
if l + 1 == length:
4085
4093
# No next argument, differentiate with no integer argument
4086
- return ex ._diff(cur_arg)
4094
+ return expr ._diff(cur_arg)
4087
4095
4088
4096
next_arg = sympify(args[l + 1 ])
4089
4097
# Check if the next arg was derivative order
4090
4098
if isinstance (next_arg, Integer):
4091
4099
i = int (next_arg)
4092
4100
for _ in range (i):
4093
- ex = ex ._diff(cur_arg)
4101
+ expr = expr ._diff(cur_arg)
4094
4102
l += 2
4095
4103
if l == length:
4096
- return ex
4104
+ return expr
4097
4105
cur_arg = sympify(args[l])
4098
4106
else :
4099
- ex = ex ._diff(cur_arg)
4107
+ expr = expr ._diff(cur_arg)
4100
4108
l += 1
4101
4109
cur_arg = next_arg
4102
4110
0 commit comments