Skip to content

Commit f096b4e

Browse files
committed
support serializing and deserializing pysymbol
1 parent b85a0ba commit f096b4e

File tree

5 files changed

+109
-7
lines changed

5 files changed

+109
-7
lines changed

symengine/lib/pywrapper.cpp

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#include "pywrapper.h"
2+
#include <symengine/serialize-cereal.h>
23

34
#if PY_MAJOR_VERSION >= 3
45
#define PyInt_FromLong PyLong_FromLong
@@ -269,4 +270,80 @@ int PyFunction::compare(const Basic &o) const {
269270
return unified_compare(get_vec(), s.get_vec());
270271
}
271272

273+
inline PyObject* get_pickle_module() {
274+
static PyObject *module = NULL;
275+
if (module == NULL) {
276+
module = PyImport_ImportModule("pickle");
277+
}
278+
return module;
279+
}
280+
281+
RCP<const Basic> load_basic(cereal::PortableBinaryInputArchive &ar, RCP<const Symbol> &)
282+
{
283+
bool is_pysymbol;
284+
std::string name;
285+
ar(is_pysymbol);
286+
ar(name);
287+
if (is_pysymbol) {
288+
std::string pickle_str;
289+
ar(pickle_str);
290+
PyObject *module = get_pickle_module();
291+
PyObject *pickle_bytes = PyBytes_FromStringAndSize(pickle_str.data(), pickle_str.size());
292+
PyObject *obj = PyObject_CallMethod(module, "loads", "O", pickle_bytes);
293+
RCP<const Basic> result = make_rcp<PySymbol>(name, obj);
294+
Py_XDECREF(pickle_bytes);
295+
return result;
296+
} else {
297+
return symbol(name);
298+
}
299+
}
300+
301+
void save_basic(cereal::PortableBinaryOutputArchive &ar, const Symbol &b)
302+
{
303+
bool is_pysymbol = is_a_sub<PySymbol>(b);
304+
ar(is_pysymbol);
305+
ar(b.__str__());
306+
if (is_pysymbol) {
307+
RCP<const PySymbol> p = rcp_static_cast<const PySymbol>(b.rcp_from_this());
308+
PyObject *module = get_pickle_module();
309+
PyObject *pickle_bytes = PyObject_CallMethod(module, "dumps", "O", p->get_py_object());
310+
Py_ssize_t size;
311+
char* buffer;
312+
PyBytes_AsStringAndSize(pickle_bytes, &buffer, &size);
313+
std::string pickle_str(buffer, size);
314+
ar(pickle_str);
315+
Py_XDECREF(pickle_bytes);
316+
}
317+
}
318+
319+
std::string wrapper_dumps(const Basic &x)
320+
{
321+
std::cout << "qwe" << std::endl;
322+
std::ostringstream oss;
323+
unsigned short major = SYMENGINE_MAJOR_VERSION;
324+
unsigned short minor = SYMENGINE_MINOR_VERSION;
325+
cereal::PortableBinaryOutputArchive{oss}(major, minor,
326+
x.rcp_from_this());
327+
return oss.str();
328+
}
329+
330+
RCP<const Basic> wrapper_loads(const std::string &serialized)
331+
{
332+
unsigned short major, minor;
333+
RCP<const Basic> obj;
334+
std::istringstream iss(serialized);
335+
cereal::PortableBinaryInputArchive iarchive{iss};
336+
iarchive(major, minor);
337+
if (major != SYMENGINE_MAJOR_VERSION or minor != SYMENGINE_MINOR_VERSION) {
338+
throw SerializationError(StreamFmt()
339+
<< "SymEngine-" << SYMENGINE_MAJOR_VERSION
340+
<< "." << SYMENGINE_MINOR_VERSION
341+
<< " was asked to deserialize an object "
342+
<< "created using SymEngine-" << major << "."
343+
<< minor << ".");
344+
}
345+
iarchive(obj);
346+
return obj;
347+
}
348+
272349
} // SymEngine

symengine/lib/pywrapper.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,9 @@ class PyFunction : public FunctionWrapper {
195195
virtual hash_t __hash__() const;
196196
};
197197

198+
std::string wrapper_dumps(const Basic &x);
199+
RCP<const Basic> wrapper_loads(const std::string &s);
200+
198201
}
199202

200203
#endif //SYMENGINE_PYWRAPPER_H

symengine/lib/symengine.pxd

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,6 @@ cdef extern from "<symengine/basic.h>" namespace "SymEngine":
183183
unsigned int hash() nogil except +
184184
vec_basic get_args() nogil
185185
int __cmp__(const Basic &o) nogil
186-
string dumps() nogil except +
187186

188187
ctypedef RCP[const Number] rcp_const_number "SymEngine::RCP<const SymEngine::Number>"
189188
ctypedef unordered_map[int, rcp_const_basic] umap_int_basic "SymEngine::umap_int_basic"
@@ -195,8 +194,6 @@ cdef extern from "<symengine/basic.h>" namespace "SymEngine":
195194
bool eq(const Basic &a, const Basic &b) nogil except +
196195
bool neq(const Basic &a, const Basic &b) nogil except +
197196

198-
RCP[const Basic] loads "SymEngine::Basic::loads"(const string &) nogil except +
199-
200197
RCP[const Symbol] rcp_static_cast_Symbol "SymEngine::rcp_static_cast<const SymEngine::Symbol>"(rcp_const_basic &b) nogil
201198
RCP[const PySymbol] rcp_static_cast_PySymbol "SymEngine::rcp_static_cast<const SymEngine::PySymbol>"(rcp_const_basic &b) nogil
202199
RCP[const Integer] rcp_static_cast_Integer "SymEngine::rcp_static_cast<const SymEngine::Integer>"(rcp_const_basic &b) nogil
@@ -373,6 +370,9 @@ cdef extern from "pywrapper.h" namespace "SymEngine":
373370
PySymbol(string name, PyObject* pyobj)
374371
PyObject* get_py_object()
375372

373+
string wrapper_dumps(const Basic &x) nogil except +
374+
rcp_const_basic wrapper_loads(const string &s) nogil except +
375+
376376
cdef extern from "<symengine/integer.h>" namespace "SymEngine":
377377
cdef cppclass Integer(Number):
378378
Integer(int i) nogil

symengine/lib/symengine_wrapper.pyx

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -827,7 +827,7 @@ cdef list vec_pair_to_list(symengine.vec_pair& vec):
827827

828828

829829
def load_basic(bytes s):
830-
return c2py(symengine.loads(s))
830+
return c2py(symengine.wrapper_loads(s))
831831

832832

833833
repr_latex=[False]
@@ -841,7 +841,7 @@ cdef class Basic(object):
841841
return self.__str__()
842842

843843
def __reduce__(self):
844-
cdef bytes s = deref(self.thisptr).dumps()
844+
cdef bytes s = symengine.wrapper_dumps(deref(self.thisptr))
845845
return (load_basic, (s,))
846846

847847
def _repr_latex_(self):
@@ -1231,6 +1231,12 @@ cdef class Symbol(Expr):
12311231
import sympy
12321232
return sympy.Symbol(str(self))
12331233

1234+
def __reduce__(self):
1235+
if type(self) == Symbol:
1236+
return Basic.__reduce__(self)
1237+
else:
1238+
raise NotImplementedError("pickling for Symbol subclass not implemented")
1239+
12341240
def _sage_(self):
12351241
import sage.all as sage
12361242
return sage.SR.symbol(str(self))

symengine/tests/test_pickling.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from symengine import symbols, sin, sinh, have_numpy, have_llvm, cos
1+
from symengine import symbols, sin, sinh, have_numpy, have_llvm, cos, Symbol
22
import pickle
33
import unittest
44

@@ -11,6 +11,23 @@ def test_basic():
1111
assert expr == expr2
1212

1313

14+
class MySymbol(Symbol):
15+
def __init__(self, name, attr):
16+
super().__init__(name=name)
17+
self.attr = attr
18+
19+
def __reduce__(self):
20+
return (self.__class__, (self.name, self.attr))
21+
22+
23+
def test_pysymbol():
24+
a = MySymbol("hello", attr=1)
25+
b = pickle.loads(pickle.dumps(a))
26+
assert b.attr == 1
27+
a._unsafe_reset()
28+
b._unsafe_reset()
29+
30+
1431
@unittest.skipUnless(have_llvm, "No LLVM support")
1532
@unittest.skipUnless(have_numpy, "Numpy not installed")
1633
def test_llvm_double():
@@ -23,4 +40,3 @@ def test_llvm_double():
2340
ll = pickle.loads(ss)
2441
inp = [1, 2, 3]
2542
assert np.allclose(l(inp), ll(inp))
26-

0 commit comments

Comments
 (0)