Skip to content

Commit cfb4f4b

Browse files
authored
Merge pull request #377 from isuruf/pickling
Support pickling of Basic objects
2 parents 1ef1c87 + 1ca9c70 commit cfb4f4b

File tree

6 files changed

+151
-4
lines changed

6 files changed

+151
-4
lines changed

symengine/lib/pywrapper.cpp

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#include "pywrapper.h"
2+
#include <symengine/serialize-cereal.h>
23

34
#if PY_MAJOR_VERSION >= 3
45
#define PyInt_FromLong PyLong_FromLong
@@ -269,4 +270,88 @@ int PyFunction::compare(const Basic &o) const {
269270
return unified_compare(get_vec(), s.get_vec());
270271
}
271272

273+
inline PyObject* get_pickle_module() {
274+
static PyObject *module = NULL;
275+
if (module == NULL) {
276+
module = PyImport_ImportModule("pickle");
277+
}
278+
if (module == NULL) {
279+
throw SymEngineException("error importing pickle module.");
280+
}
281+
return module;
282+
}
283+
284+
RCP<const Basic> load_basic(cereal::PortableBinaryInputArchive &ar, RCP<const Symbol> &)
285+
{
286+
bool is_pysymbol;
287+
std::string name;
288+
ar(is_pysymbol);
289+
ar(name);
290+
if (is_pysymbol) {
291+
std::string pickle_str;
292+
ar(pickle_str);
293+
PyObject *module = get_pickle_module();
294+
PyObject *pickle_bytes = PyBytes_FromStringAndSize(pickle_str.data(), pickle_str.size());
295+
PyObject *obj = PyObject_CallMethod(module, "loads", "O", pickle_bytes);
296+
if (obj == NULL) {
297+
throw SerializationError("error when loading pickled symbol subclass object");
298+
}
299+
RCP<const Basic> result = make_rcp<PySymbol>(name, obj);
300+
Py_XDECREF(pickle_bytes);
301+
return result;
302+
} else {
303+
return symbol(name);
304+
}
305+
}
306+
307+
void save_basic(cereal::PortableBinaryOutputArchive &ar, const Symbol &b)
308+
{
309+
bool is_pysymbol = is_a_sub<PySymbol>(b);
310+
ar(is_pysymbol);
311+
ar(b.__str__());
312+
if (is_pysymbol) {
313+
RCP<const PySymbol> p = rcp_static_cast<const PySymbol>(b.rcp_from_this());
314+
PyObject *module = get_pickle_module();
315+
PyObject *pickle_bytes = PyObject_CallMethod(module, "dumps", "O", p->get_py_object());
316+
if (pickle_bytes == NULL) {
317+
throw SerializationError("error when pickling symbol subclass object");
318+
}
319+
Py_ssize_t size;
320+
char* buffer;
321+
PyBytes_AsStringAndSize(pickle_bytes, &buffer, &size);
322+
std::string pickle_str(buffer, size);
323+
ar(pickle_str);
324+
Py_XDECREF(pickle_bytes);
325+
}
326+
}
327+
328+
std::string wrapper_dumps(const Basic &x)
329+
{
330+
std::ostringstream oss;
331+
unsigned short major = SYMENGINE_MAJOR_VERSION;
332+
unsigned short minor = SYMENGINE_MINOR_VERSION;
333+
cereal::PortableBinaryOutputArchive{oss}(major, minor,
334+
x.rcp_from_this());
335+
return oss.str();
336+
}
337+
338+
RCP<const Basic> wrapper_loads(const std::string &serialized)
339+
{
340+
unsigned short major, minor;
341+
RCP<const Basic> obj;
342+
std::istringstream iss(serialized);
343+
cereal::PortableBinaryInputArchive iarchive{iss};
344+
iarchive(major, minor);
345+
if (major != SYMENGINE_MAJOR_VERSION or minor != SYMENGINE_MINOR_VERSION) {
346+
throw SerializationError(StreamFmt()
347+
<< "SymEngine-" << SYMENGINE_MAJOR_VERSION
348+
<< "." << SYMENGINE_MINOR_VERSION
349+
<< " was asked to deserialize an object "
350+
<< "created using SymEngine-" << major << "."
351+
<< minor << ".");
352+
}
353+
iarchive(obj);
354+
return obj;
355+
}
356+
272357
} // SymEngine

symengine/lib/pywrapper.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,9 @@ class PyFunction : public FunctionWrapper {
195195
virtual hash_t __hash__() const;
196196
};
197197

198+
std::string wrapper_dumps(const Basic &x);
199+
RCP<const Basic> wrapper_loads(const std::string &s);
200+
198201
}
199202

200203
#endif //SYMENGINE_PYWRAPPER_H

symengine/lib/symengine.pxd

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,7 @@ cdef extern from "<symengine/basic.h>" namespace "SymEngine":
183183
unsigned int hash() nogil except +
184184
vec_basic get_args() nogil
185185
int __cmp__(const Basic &o) nogil
186+
186187
ctypedef RCP[const Number] rcp_const_number "SymEngine::RCP<const SymEngine::Number>"
187188
ctypedef unordered_map[int, rcp_const_basic] umap_int_basic "SymEngine::umap_int_basic"
188189
ctypedef unordered_map[int, rcp_const_basic].iterator umap_int_basic_iterator "SymEngine::umap_int_basic::iterator"
@@ -193,7 +194,6 @@ cdef extern from "<symengine/basic.h>" namespace "SymEngine":
193194
bool eq(const Basic &a, const Basic &b) nogil except +
194195
bool neq(const Basic &a, const Basic &b) nogil except +
195196

196-
197197
RCP[const Symbol] rcp_static_cast_Symbol "SymEngine::rcp_static_cast<const SymEngine::Symbol>"(rcp_const_basic &b) nogil
198198
RCP[const PySymbol] rcp_static_cast_PySymbol "SymEngine::rcp_static_cast<const SymEngine::PySymbol>"(rcp_const_basic &b) nogil
199199
RCP[const Integer] rcp_static_cast_Integer "SymEngine::rcp_static_cast<const SymEngine::Integer>"(rcp_const_basic &b) nogil
@@ -370,6 +370,9 @@ cdef extern from "pywrapper.h" namespace "SymEngine":
370370
PySymbol(string name, PyObject* pyobj)
371371
PyObject* get_py_object()
372372

373+
string wrapper_dumps(const Basic &x) nogil except +
374+
rcp_const_basic wrapper_loads(const string &s) nogil except +
375+
373376
cdef extern from "<symengine/integer.h>" namespace "SymEngine":
374377
cdef cppclass Integer(Number):
375378
Integer(int i) nogil

symengine/lib/symengine_wrapper.pyx

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -826,6 +826,10 @@ cdef list vec_pair_to_list(symengine.vec_pair& vec):
826826
return result
827827

828828

829+
def load_basic(bytes s):
830+
return c2py(symengine.wrapper_loads(s))
831+
832+
829833
repr_latex=[False]
830834

831835
cdef class Basic(object):
@@ -836,6 +840,10 @@ cdef class Basic(object):
836840
def __repr__(self):
837841
return self.__str__()
838842

843+
def __reduce__(self):
844+
cdef bytes s = symengine.wrapper_dumps(deref(self.thisptr))
845+
return (load_basic, (s,))
846+
839847
def _repr_latex_(self):
840848
if repr_latex[0]:
841849
return "${}$".format(latex(self))
@@ -1223,6 +1231,12 @@ cdef class Symbol(Expr):
12231231
import sympy
12241232
return sympy.Symbol(str(self))
12251233

1234+
def __reduce__(self):
1235+
if type(self) == Symbol:
1236+
return Basic.__reduce__(self)
1237+
else:
1238+
raise NotImplementedError("pickling for Symbol subclass not implemented")
1239+
12261240
def _sage_(self):
12271241
import sage.all as sage
12281242
return sage.SR.symbol(str(self))

symengine/tests/test_pickling.py

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,50 @@
1-
from symengine import symbols, sin, sinh, have_numpy, have_llvm
1+
from symengine import symbols, sin, sinh, have_numpy, have_llvm, cos, Symbol
2+
from symengine.utilities import raises
23
import pickle
34
import unittest
45

6+
7+
def test_basic():
8+
x, y, z = symbols('x y z')
9+
expr = sin(cos(x + y)/z)**2
10+
s = pickle.dumps(expr)
11+
expr2 = pickle.loads(s)
12+
assert expr == expr2
13+
14+
15+
class MySymbolBase(Symbol):
16+
def __init__(self, name, attr):
17+
super().__init__(name=name)
18+
self.attr = attr
19+
20+
def __eq__(self, other):
21+
if not isinstance(other, MySymbolBase):
22+
return False
23+
return self.name == other.name and self.attr == other.attr
24+
25+
26+
class MySymbol(MySymbolBase):
27+
def __reduce__(self):
28+
return (self.__class__, (self.name, self.attr))
29+
30+
31+
def test_pysymbol():
32+
a = MySymbol("hello", attr=1)
33+
b = pickle.loads(pickle.dumps(a + 2)) - 2
34+
try:
35+
assert a == b
36+
finally:
37+
a._unsafe_reset()
38+
b._unsafe_reset()
39+
40+
a = MySymbolBase("hello", attr=1)
41+
try:
42+
raises(NotImplementedError, lambda: pickle.dumps(a))
43+
raises(NotImplementedError, lambda: pickle.dumps(a + 2))
44+
finally:
45+
a._unsafe_reset()
46+
47+
548
@unittest.skipUnless(have_llvm, "No LLVM support")
649
@unittest.skipUnless(have_numpy, "Numpy not installed")
750
def test_llvm_double():
@@ -14,4 +57,3 @@ def test_llvm_double():
1457
ll = pickle.loads(ss)
1558
inp = [1, 2, 3]
1659
assert np.allclose(l(inp), ll(inp))
17-

symengine_version.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
23abf31763620463500d5fad114d855afd66d011
1+
36ac51d06e248657d828bfa4859cff32ab5f03ba

0 commit comments

Comments
 (0)