Skip to content

Commit 9d1f0ba

Browse files
committed
Derivative/Subs: Add expr, variables attributes.
Subs: Add point attribute.
1 parent 5b24aae commit 9d1f0ba

File tree

2 files changed

+40
-5
lines changed

2 files changed

+40
-5
lines changed

symengine/lib/symengine_wrapper.pyx

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -387,6 +387,13 @@ def get_dict(*args):
387387
return D
388388

389389

390+
cdef tuple vec_basic_to_tuple(symengine.vec_basic vec):
391+
result = []
392+
for i in range(vec.size()):
393+
result.append(c2py(<RCP[const symengine.Basic]>(vec[i])))
394+
return tuple(result)
395+
396+
390397
cdef class Basic(object):
391398

392399
def __str__(self):
@@ -514,11 +521,8 @@ cdef class Basic(object):
514521

515522
@property
516523
def args(self):
517-
cdef symengine.vec_basic Y = deref(self.thisptr).get_args()
518-
s = []
519-
for i in range(Y.size()):
520-
s.append(c2py(<RCP[const symengine.Basic]>(Y[i])))
521-
return tuple(s)
524+
cdef symengine.vec_basic args = deref(self.thisptr).get_args()
525+
return vec_basic_to_tuple(args)
522526

523527
@property
524528
def free_symbols(self):
@@ -1310,6 +1314,14 @@ cdef class Derivative(Basic):
13101314
def is_Derivative(self):
13111315
return True
13121316

1317+
@property
1318+
def expr(self):
1319+
return self.args[0]
1320+
1321+
@property
1322+
def variables(self):
1323+
return self.args[1:]
1324+
13131325
def __cinit__(self, expr = None, symbols = None):
13141326
if expr is None or symbols is None:
13151327
return
@@ -1357,6 +1369,20 @@ cdef class Subs(Basic):
13571369
m[v_.thisptr] = p_.thisptr
13581370
self.thisptr = symengine.make_rcp_Subs(expr_.thisptr, m)
13591371

1372+
@property
1373+
def expr(self):
1374+
return self.args[0]
1375+
1376+
@property
1377+
def variables(self):
1378+
cdef RCP[const symengine.Subs] me = symengine.rcp_static_cast_Subs(self.thisptr)
1379+
return vec_basic_to_tuple(deref(me).get_variables())
1380+
1381+
@property
1382+
def point(self):
1383+
cdef RCP[const symengine.Subs] me = symengine.rcp_static_cast_Subs(self.thisptr)
1384+
return vec_basic_to_tuple(deref(me).get_point())
1385+
13601386
def _sympy_(self):
13611387
cdef RCP[const symengine.Subs] X = symengine.rcp_static_cast_Subs(self.thisptr)
13621388
arg = c2py(deref(X).get_arg())._sympy_()

symengine/tests/test_functions.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,10 @@ def test_derivative():
5656
assert f.diff(x).diff(y) == function_symbol("f", x, y).diff(x).diff(y)
5757
assert f.diff(Symbol("z")) == 0
5858

59+
s = Derivative(function_symbol("f", x), [x])
60+
assert s.expr == function_symbol("f", x)
61+
assert s.variables == (x,)
62+
5963
def test_abs():
6064
x = Symbol("x")
6165
e = abs(x)
@@ -86,6 +90,11 @@ def test_Subs():
8690
assert Subs(Derivative(function_symbol("f", x, y), [x]), [x, y], [_x, x]) \
8791
== Subs(Derivative(function_symbol("f", x, y), [x]), [y, x], [x, _x])
8892

93+
s = Subs(function_symbol("f", _x), [_x], [x])
94+
assert s.expr == function_symbol("f", _x)
95+
assert s.variables == (_x,)
96+
assert s.point == (x,)
97+
8998
def test_FunctionWrapper():
9099
import sympy
91100
n, m, theta, phi = sympy.symbols("n, m, theta, phi")

0 commit comments

Comments
 (0)