Skip to content

Commit e0468f3

Browse files
authored
Merge pull request #117 from mattwala/make-symengine-classes-closer-to-sympy
Make symengine.py classes closer to sympy
2 parents 15b7028 + b38981f commit e0468f3

File tree

4 files changed

+71
-5
lines changed

4 files changed

+71
-5
lines changed

symengine/lib/symengine_wrapper.pyx

Lines changed: 56 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -387,6 +387,13 @@ def get_dict(*args):
387387
return D
388388

389389

390+
cdef tuple vec_basic_to_tuple(symengine.vec_basic& vec):
391+
result = []
392+
for i in range(vec.size()):
393+
result.append(c2py(<RCP[const symengine.Basic]>(vec[i])))
394+
return tuple(result)
395+
396+
390397
cdef class Basic(object):
391398

392399
def __str__(self):
@@ -514,11 +521,8 @@ cdef class Basic(object):
514521

515522
@property
516523
def args(self):
517-
cdef symengine.vec_basic Y = deref(self.thisptr).get_args()
518-
s = []
519-
for i in range(Y.size()):
520-
s.append(c2py(<RCP[const symengine.Basic]>(Y[i])))
521-
return tuple(s)
524+
cdef symengine.vec_basic args = deref(self.thisptr).get_args()
525+
return vec_basic_to_tuple(args)
522526

523527
@property
524528
def free_symbols(self):
@@ -657,9 +661,14 @@ cdef class Symbol(Basic):
657661
import sage.all as sage
658662
return sage.SR.symbol(str(deref(X).get_name().decode("utf-8")))
659663

664+
@property
660665
def name(self):
661666
return self.__str__()
662667

668+
@property
669+
def is_Atom(self):
670+
return True
671+
663672
@property
664673
def is_Symbol(self):
665674
return True
@@ -837,6 +846,14 @@ cdef class Integer(Number):
837846
def __float__(self):
838847
return float(str(self))
839848

849+
@property
850+
def p(self):
851+
return int(self)
852+
853+
@property
854+
def q(self):
855+
return 1
856+
840857

841858
cdef class RealDouble(Number):
842859

@@ -956,6 +973,14 @@ cdef class Rational(Number):
956973
def is_Rational(self):
957974
return True
958975

976+
@property
977+
def p(self):
978+
return self.get_num_den()[0]
979+
980+
@property
981+
def q(self):
982+
return self.get_num_den()[1]
983+
959984
def get_num_den(self):
960985
cdef RCP[const symengine.Integer] _num, _den
961986
symengine.get_num_den(deref(symengine.rcp_static_cast_Rational(self.thisptr)),
@@ -1302,6 +1327,15 @@ cdef class Derivative(Basic):
13021327
def is_Derivative(self):
13031328
return True
13041329

1330+
@property
1331+
def expr(self):
1332+
cdef RCP[const symengine.Derivative] X = symengine.rcp_static_cast_Derivative(self.thisptr)
1333+
return c2py(deref(X).get_arg())
1334+
1335+
@property
1336+
def variables(self):
1337+
return self.args[1:]
1338+
13051339
def __cinit__(self, expr = None, symbols = None):
13061340
if expr is None or symbols is None:
13071341
return
@@ -1349,6 +1383,23 @@ cdef class Subs(Basic):
13491383
m[v_.thisptr] = p_.thisptr
13501384
self.thisptr = symengine.make_rcp_Subs(expr_.thisptr, m)
13511385

1386+
@property
1387+
def expr(self):
1388+
cdef RCP[const symengine.Subs] me = symengine.rcp_static_cast_Subs(self.thisptr)
1389+
return c2py(deref(me).get_arg())
1390+
1391+
@property
1392+
def variables(self):
1393+
cdef RCP[const symengine.Subs] me = symengine.rcp_static_cast_Subs(self.thisptr)
1394+
cdef symengine.vec_basic variables = deref(me).get_variables()
1395+
return vec_basic_to_tuple(variables)
1396+
1397+
@property
1398+
def point(self):
1399+
cdef RCP[const symengine.Subs] me = symengine.rcp_static_cast_Subs(self.thisptr)
1400+
cdef symengine.vec_basic point = deref(me).get_point()
1401+
return vec_basic_to_tuple(point)
1402+
13521403
def _sympy_(self):
13531404
cdef RCP[const symengine.Subs] X = symengine.rcp_static_cast_Subs(self.thisptr)
13541405
arg = c2py(deref(X).get_arg())._sympy_()

symengine/tests/test_functions.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,10 @@ def test_derivative():
5656
assert f.diff(x).diff(y) == function_symbol("f", x, y).diff(x).diff(y)
5757
assert f.diff(Symbol("z")) == 0
5858

59+
s = Derivative(function_symbol("f", x), [x])
60+
assert s.expr == function_symbol("f", x)
61+
assert s.variables == (x,)
62+
5963
def test_abs():
6064
x = Symbol("x")
6165
e = abs(x)
@@ -86,6 +90,12 @@ def test_Subs():
8690
assert Subs(Derivative(function_symbol("f", x, y), [x]), [x, y], [_x, x]) \
8791
== Subs(Derivative(function_symbol("f", x, y), [x]), [y, x], [x, _x])
8892

93+
s = f.diff(x)/2
94+
_xi_1 = Symbol("_xi_1")
95+
assert s.expr == Derivative(function_symbol("f", _xi_1), [_xi_1])
96+
assert s.variables == (_xi_1,)
97+
assert s.point == (2*x,)
98+
8999
def test_FunctionWrapper():
90100
import sympy
91101
n, m, theta, phi = sympy.symbols("n, m, theta, phi")

symengine/tests/test_symbol.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
def test_symbol():
66
x = Symbol("x")
7+
assert x.name == "x"
78
assert str(x) == "x"
89
assert str(x) != "y"
910
assert repr(x) == str(x)

symengine/tests/test_sympy_compat.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,15 @@ def test_Integer():
66
assert isinstance(i, Integer)
77
assert isinstance(i, Rational)
88
assert isinstance(i, Basic)
9+
assert i.p == 5
10+
assert i.q == 1
911

1012
def test_Rational():
1113
i = S(1)/2
1214
assert isinstance(i, Rational)
1315
assert isinstance(i, Basic)
16+
assert i.p == 1
17+
assert i.q == 2
1418

1519
def test_Add():
1620
x, y = symbols("x y")

0 commit comments

Comments
 (0)