Skip to content

Commit 72a54f3

Browse files
authored
Merge pull request #226 from isuruf/diff
Implement diff(ex, x, n)
2 parents 979dcd4 + a803486 commit 72a54f3

File tree

2 files changed

+39
-15
lines changed

2 files changed

+39
-15
lines changed

symengine/lib/symengine_wrapper.pyx

Lines changed: 32 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -868,14 +868,16 @@ cdef class Basic(object):
868868
def expand(Basic self not None, cppbool deep=True):
869869
return c2py(symengine.expand(self.thisptr, deep))
870870

871-
def diff(Basic self not None, x = None):
872-
if x is None:
871+
def _diff(Basic self not None, Basic x):
872+
return c2py(symengine.diff(self.thisptr, x.thisptr))
873+
874+
def diff(self, *args):
875+
if len(args) == 0:
873876
f = self.free_symbols
874877
if (len(f) != 1):
875878
raise RuntimeError("Variable w.r.t should be given")
876-
return self.diff(f.pop())
877-
cdef Basic s = sympify(x)
878-
return c2py(symengine.diff(self.thisptr, s.thisptr))
879+
return self._diff(f.pop())
880+
return diff(self, *args)
879881

880882
def subs_dict(Basic self not None, *args):
881883
warnings.warn("subs_dict() is deprecated. Use subs() instead", DeprecationWarning)
@@ -3430,13 +3432,15 @@ cdef class DenseMatrixBase(MatrixBase):
34303432
cdef _DictBasic D = get_dict(*args)
34313433
return self.applyfunc(lambda x: x.msubs(D))
34323434

3433-
def diff(self, x):
3434-
cdef Basic x_ = sympify(x)
3435+
def _diff(self, Basic x):
34353436
cdef DenseMatrixBase R = self.__class__(self.rows, self.cols)
34363437
symengine.diff(<const symengine.DenseMatrix &>deref(self.thisptr),
3437-
x_.thisptr, <symengine.DenseMatrix &>deref(R.thisptr))
3438+
x.thisptr, <symengine.DenseMatrix &>deref(R.thisptr))
34383439
return R
34393440

3441+
def diff(self, *args):
3442+
return diff(self, *args)
3443+
34403444
#TODO: implement this in C++
34413445
def subs(self, *args):
34423446
cdef _DictBasic D = get_dict(*args)
@@ -3762,17 +3766,31 @@ cdef class Sieve_iterator:
37623766

37633767

37643768
def module_cleanup():
3765-
global I, E, pi, oo, minus_oo, zoo, nan, true, false, golden_ratio, catalan, eulergamma, sympy_module, sage_module
3769+
global I, E, pi, oo, minus_oo, zoo, nan, true, false, golden_ratio, \
3770+
catalan, eulergamma, sympy_module, sage_module, half, one, \
3771+
minus_one, zero
37663772
funcs.clear()
3767-
del I, E, pi, oo, minus_oo, zoo, nan, true, false, golden_ratio, catalan, eulergamma, sympy_module, sage_module
3773+
del I, E, pi, oo, minus_oo, zoo, nan, true, false, golden_ratio, \
3774+
catalan, eulergamma, sympy_module, sage_module, half, one, \
3775+
minus_one, zero
37683776

37693777
import atexit
37703778
atexit.register(module_cleanup)
37713779

3772-
def diff(ex, *x):
3780+
def diff(ex, *args):
37733781
ex = sympify(ex)
3774-
for i in x:
3775-
ex = ex.diff(i)
3782+
prev = 0
3783+
cdef Basic b
3784+
cdef size_t i
3785+
for x in args:
3786+
b = sympify(x)
3787+
if isinstance(b, Integer):
3788+
i = int(b) - 1
3789+
for j in range(i):
3790+
ex = ex._diff(prev)
3791+
else:
3792+
ex = ex._diff(b)
3793+
prev = b
37763794
return ex
37773795

37783796
def expand(x, deep=True):
@@ -4369,7 +4387,7 @@ cdef class _Lambdify(object):
43694387
cdef size_t args_size, tot_out_size
43704388
cdef list out_shapes
43714389
cdef readonly bint real
4372-
cdef readonly int n_exprs
4390+
cdef readonly size_t n_exprs
43734391
cdef public str order
43744392
cdef vector[int] accum_out_sizes
43754393
cdef object numpy_dtype

symengine/tests/test_functions.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from symengine import (Symbol, sin, cos, sqrt, Add, Mul, function_symbol, Integer, log, E, symbols, I,
2-
Rational, EulerGamma)
2+
Rational, EulerGamma, Function)
33
from symengine.lib.symengine_wrapper import (Subs, Derivative, LambertW, zeta, dirichlet_eta,
44
zoo, pi, KroneckerDelta, LeviCivita, erf, erfc,
55
oo, lowergamma, uppergamma, exp, loggamma, beta,
@@ -78,6 +78,12 @@ def test_derivative():
7878
assert s.expr == function_symbol("f", x)
7979
assert s.variables == (x,)
8080

81+
fxy = Function("f")(x, y)
82+
g = Derivative(Function("f")(x, y), x, 2, y, 1)
83+
assert g == fxy.diff(x, x, y)
84+
assert g == fxy.diff(y, 1, x, 2)
85+
assert g == fxy.diff(y, x, 2)
86+
8187

8288
def test_abs():
8389
x = Symbol("x")

0 commit comments

Comments
 (0)