Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion symengine/lib/symengine.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ cdef extern from "<symengine/basic.h>" namespace "SymEngine":
RCP[const Rational] rcp_static_cast_Rational "SymEngine::rcp_static_cast<const SymEngine::Rational>"(rcp_const_basic &b) nogil
RCP[const Complex] rcp_static_cast_Complex "SymEngine::rcp_static_cast<const SymEngine::Complex>"(rcp_const_basic &b) nogil
RCP[const Number] rcp_static_cast_Number "SymEngine::rcp_static_cast<const SymEngine::Number>"(rcp_const_basic &b) nogil
RCP[const Dummy] rcp_static_cast_Dummy "SymEngine::rcp_static_cast<const SymEngine::Dummy>"(rcp_const_basic &b) nogil
RCP[const Add] rcp_static_cast_Add "SymEngine::rcp_static_cast<const SymEngine::Add>"(rcp_const_basic &b) nogil
RCP[const Mul] rcp_static_cast_Mul "SymEngine::rcp_static_cast<const SymEngine::Mul>"(rcp_const_basic &b) nogil
RCP[const Pow] rcp_static_cast_Pow "SymEngine::rcp_static_cast<const SymEngine::Pow>"(rcp_const_basic &b) nogil
Expand Down Expand Up @@ -180,7 +181,7 @@ cdef extern from "<symengine/symbol.h>" namespace "SymEngine":
Symbol(string name) nogil
string get_name() nogil
cdef cppclass Dummy(Symbol):
pass
size_t get_index()

cdef extern from "<symengine/number.h>" namespace "SymEngine":
cdef cppclass Number(Basic):
Expand Down Expand Up @@ -322,6 +323,7 @@ cdef extern from "<symengine/basic.h>" namespace "SymEngine":
rcp_const_basic make_rcp_Symbol "SymEngine::make_rcp<const SymEngine::Symbol>"(string name) nogil
rcp_const_basic make_rcp_Dummy "SymEngine::make_rcp<const SymEngine::Dummy>"() nogil
rcp_const_basic make_rcp_Dummy "SymEngine::make_rcp<const SymEngine::Dummy>"(string name) nogil
rcp_const_basic make_rcp_Dummy "SymEngine::make_rcp<const SymEngine::Dummy>"(string &name, size_t index) nogil
rcp_const_basic make_rcp_PySymbol "SymEngine::make_rcp<const SymEngine::PySymbol>"(string name, PyObject * pyobj, bool use_pickle) except +
rcp_const_basic make_rcp_Constant "SymEngine::make_rcp<const SymEngine::Constant>"(string name) nogil
rcp_const_basic make_rcp_Infty "SymEngine::make_rcp<const SymEngine::Infty>"(RCP[const Number] i) nogil
Expand Down
31 changes: 21 additions & 10 deletions symengine/lib/symengine_wrapper.in.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -278,10 +278,10 @@ def sympy2symengine(a, raise_error=False):
"""
import sympy
from sympy.core.function import AppliedUndef as sympy_AppliedUndef
if isinstance(a, sympy.Symbol):
if isinstance(a, sympy.Dummy):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since Dummy is a subclass of Symbol, we need to check for it first.

return Dummy(a.name, a.dummy_index)
elif isinstance(a, sympy.Symbol):
return Symbol(a.name)
elif isinstance(a, sympy.Dummy):
return Dummy(a.name)
elif isinstance(a, sympy.Mul):
return mul(*[sympy2symengine(x, raise_error) for x in a.args])
elif isinstance(a, sympy.Add):
Expand Down Expand Up @@ -1301,10 +1301,10 @@ cdef class Symbol(Expr):
return sympy.Symbol(str(self))

def __reduce__(self):
if type(self) == Symbol:
if type(self) in (Symbol, Dummy):
return Basic.__reduce__(self)
else:
raise NotImplementedError("pickling for Symbol subclass not implemented")
raise NotImplementedError("pickling for subclass of Symbol or Dummy not implemented")

def _sage_(self):
import sage.all as sage
Expand Down Expand Up @@ -1337,15 +1337,20 @@ cdef class Symbol(Expr):

cdef class Dummy(Symbol):

def __init__(Basic self, name=None, *args, **kwargs):
if name is None:
self.thisptr = symengine.make_rcp_Dummy()
def __init__(Basic self, name=None, dummy_index=None, *args, **kwargs):
cdef size_t index
if dummy_index is None:
if name is None:
self.thisptr = symengine.make_rcp_Dummy()
else:
self.thisptr = symengine.make_rcp_Dummy(name.encode("utf-8"))
else:
self.thisptr = symengine.make_rcp_Dummy(name.encode("utf-8"))
index = dummy_index
self.thisptr = symengine.make_rcp_Dummy(name.encode("utf-8"), index)

def _sympy_(self):
import sympy
return sympy.Dummy(str(self)[1:])
return sympy.Dummy(name=self.name, dummy_index=self.dummy_index)

@property
def is_Dummy(self):
Expand All @@ -1355,6 +1360,12 @@ cdef class Dummy(Symbol):
def func(self):
return self.__class__

@property
def dummy_index(self):
cdef RCP[const symengine.Dummy] this = \
symengine.rcp_static_cast_Dummy(self.thisptr)
cdef size_t index = deref(this).get_index()
return index

def symarray(prefix, shape, **kwargs):
""" Creates an nd-array of symbols
Expand Down
18 changes: 17 additions & 1 deletion symengine/tests/test_pickling.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from symengine import symbols, sin, sinh, have_numpy, have_llvm, cos, Symbol
from symengine import symbols, sin, sinh, have_numpy, have_llvm, cos, Symbol, Dummy
from symengine.test_utilities import raises
import pickle
import unittest
Expand Down Expand Up @@ -57,3 +57,19 @@ def test_llvm_double():
ll = pickle.loads(ss)
inp = [1, 2, 3]
assert np.allclose(l(inp), ll(inp))


def _check_pickling_roundtrip(arg):
s2 = pickle.dumps(arg)
arg2 = pickle.loads(s2)
assert arg == arg2
s3 = pickle.dumps(arg2)
arg3 = pickle.loads(s3)
assert arg == arg3


def test_pickling_roundtrip():
x, y, z = symbols('x y z')
_check_pickling_roundtrip(x+y)
_check_pickling_roundtrip(Dummy('d'))
_check_pickling_roundtrip(Dummy('d') - z)
3 changes: 3 additions & 0 deletions symengine/tests/test_symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,9 @@ def test_dummy():
x2 = Symbol('x')
xdummy1 = Dummy('x')
xdummy2 = Dummy('x')
assert xdummy1.dummy_index != xdummy2.dummy_index # maybe test using "less than"?
assert xdummy1.name == 'x'
assert xdummy2.name == 'x'

assert x1 == x2
assert x1 != xdummy1
Expand Down
23 changes: 22 additions & 1 deletion symengine/tests/test_sympy_conv.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from symengine import (Symbol, Integer, sympify, SympifyError, log,
function_symbol, I, E, pi, oo, zoo, nan, true, false,
exp, gamma, have_mpfr, have_mpc, DenseMatrix, sin, cos, tan, cot,
exp, gamma, have_mpfr, have_mpc, DenseMatrix, Dummy, sin, cos, tan, cot,
csc, sec, asin, acos, atan, acot, acsc, asec, sinh, cosh, tanh, coth,
asinh, acosh, atanh, acoth, atan2, Add, Mul, Pow, diff, GoldenRatio,
Catalan, EulerGamma, UnevaluatedExpr, RealDouble)
Expand Down Expand Up @@ -833,3 +833,24 @@ def test_conv_large_integers():
if have_sympy:
c = a._sympy_()
d = sympify(c)


def _check_sympy_roundtrip(arg):
arg_sy1 = sympy.sympify(arg)
arg_se2 = sympify(arg_sy1)
assert arg == arg_se2
arg_sy2 = sympy.sympify(arg_se2)
assert arg_sy2 == arg_sy1
arg_se3 = sympify(arg_sy2)
assert arg_se3 == arg


@unittest.skipIf(not have_sympy, "SymPy not installed")
def test_sympy_roundtrip():
x = Symbol("x")
y = Symbol("y")
d = Dummy("d")
_check_sympy_roundtrip(x)
_check_sympy_roundtrip(x+y)
_check_sympy_roundtrip(x**y)
_check_sympy_roundtrip(d)
2 changes: 1 addition & 1 deletion symengine_version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
c9510fb4b5c30b84adb993573a51f2a9a38a4cfe
c574fa8d7018a850481afa7a59809d30e774d78d
Loading