Skip to content

Commit b85a0ba

Browse files
committed
Support pickling of Basic objects
1 parent 1ef1c87 commit b85a0ba

File tree

4 files changed

+22
-2
lines changed

4 files changed

+22
-2
lines changed

symengine/lib/symengine.pxd

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,8 @@ cdef extern from "<symengine/basic.h>" namespace "SymEngine":
183183
unsigned int hash() nogil except +
184184
vec_basic get_args() nogil
185185
int __cmp__(const Basic &o) nogil
186+
string dumps() nogil except +
187+
186188
ctypedef RCP[const Number] rcp_const_number "SymEngine::RCP<const SymEngine::Number>"
187189
ctypedef unordered_map[int, rcp_const_basic] umap_int_basic "SymEngine::umap_int_basic"
188190
ctypedef unordered_map[int, rcp_const_basic].iterator umap_int_basic_iterator "SymEngine::umap_int_basic::iterator"
@@ -193,6 +195,7 @@ cdef extern from "<symengine/basic.h>" namespace "SymEngine":
193195
bool eq(const Basic &a, const Basic &b) nogil except +
194196
bool neq(const Basic &a, const Basic &b) nogil except +
195197

198+
RCP[const Basic] loads "SymEngine::Basic::loads"(const string &) nogil except +
196199

197200
RCP[const Symbol] rcp_static_cast_Symbol "SymEngine::rcp_static_cast<const SymEngine::Symbol>"(rcp_const_basic &b) nogil
198201
RCP[const PySymbol] rcp_static_cast_PySymbol "SymEngine::rcp_static_cast<const SymEngine::PySymbol>"(rcp_const_basic &b) nogil

symengine/lib/symengine_wrapper.pyx

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -826,6 +826,10 @@ cdef list vec_pair_to_list(symengine.vec_pair& vec):
826826
return result
827827

828828

829+
def load_basic(bytes s):
830+
return c2py(symengine.loads(s))
831+
832+
829833
repr_latex=[False]
830834

831835
cdef class Basic(object):
@@ -836,6 +840,10 @@ cdef class Basic(object):
836840
def __repr__(self):
837841
return self.__str__()
838842

843+
def __reduce__(self):
844+
cdef bytes s = deref(self.thisptr).dumps()
845+
return (load_basic, (s,))
846+
839847
def _repr_latex_(self):
840848
if repr_latex[0]:
841849
return "${}$".format(latex(self))

symengine/tests/test_pickling.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,16 @@
1-
from symengine import symbols, sin, sinh, have_numpy, have_llvm
1+
from symengine import symbols, sin, sinh, have_numpy, have_llvm, cos
22
import pickle
33
import unittest
44

5+
6+
def test_basic():
7+
x, y, z = symbols('x y z')
8+
expr = sin(cos(x + y)/z)**2
9+
s = pickle.dumps(expr)
10+
expr2 = pickle.loads(s)
11+
assert expr == expr2
12+
13+
514
@unittest.skipUnless(have_llvm, "No LLVM support")
615
@unittest.skipUnless(have_numpy, "Numpy not installed")
716
def test_llvm_double():

symengine_version.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
23abf31763620463500d5fad114d855afd66d011
1+
4b841d144bbc1ecd4367e4bd7dd4e7c6be8fac05

0 commit comments

Comments
 (0)