Skip to content

Commit 322f2f4

Browse files
committed
handle errors gracefully
1 parent f096b4e commit 322f2f4

File tree

2 files changed

+18
-2
lines changed

2 files changed

+18
-2
lines changed

symengine/lib/pywrapper.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -275,6 +275,9 @@ inline PyObject* get_pickle_module() {
275275
if (module == NULL) {
276276
module = PyImport_ImportModule("pickle");
277277
}
278+
if (module == NULL) {
279+
throw SymEngineException("error importing pickle module.")
280+
}
278281
return module;
279282
}
280283

@@ -290,6 +293,9 @@ RCP<const Basic> load_basic(cereal::PortableBinaryInputArchive &ar, RCP<const Sy
290293
PyObject *module = get_pickle_module();
291294
PyObject *pickle_bytes = PyBytes_FromStringAndSize(pickle_str.data(), pickle_str.size());
292295
PyObject *obj = PyObject_CallMethod(module, "loads", "O", pickle_bytes);
296+
if (obj == NULL) {
297+
throw SymEngineException("error when loading pickled symbol subclass object");
298+
}
293299
RCP<const Basic> result = make_rcp<PySymbol>(name, obj);
294300
Py_XDECREF(pickle_bytes);
295301
return result;
@@ -307,6 +313,9 @@ void save_basic(cereal::PortableBinaryOutputArchive &ar, const Symbol &b)
307313
RCP<const PySymbol> p = rcp_static_cast<const PySymbol>(b.rcp_from_this());
308314
PyObject *module = get_pickle_module();
309315
PyObject *pickle_bytes = PyObject_CallMethod(module, "dumps", "O", p->get_py_object());
316+
if (pickle_bytes == NULL) {
317+
throw SymEngineException("error when pickling symbol subclass object");
318+
}
310319
Py_ssize_t size;
311320
char* buffer;
312321
PyBytes_AsStringAndSize(pickle_bytes, &buffer, &size);

symengine/tests/test_pickling.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from symengine import symbols, sin, sinh, have_numpy, have_llvm, cos, Symbol
2+
from symengine.utilities import raises
23
import pickle
34
import unittest
45

@@ -11,22 +12,28 @@ def test_basic():
1112
assert expr == expr2
1213

1314

14-
class MySymbol(Symbol):
15+
class MySymbolBase(Symbol):
1516
def __init__(self, name, attr):
1617
super().__init__(name=name)
1718
self.attr = attr
1819

20+
21+
class MySymbol(MySymbolBase):
1922
def __reduce__(self):
2023
return (self.__class__, (self.name, self.attr))
2124

2225

2326
def test_pysymbol():
2427
a = MySymbol("hello", attr=1)
25-
b = pickle.loads(pickle.dumps(a))
28+
b = pickle.loads(pickle.dumps(a + 2)) - 2
2629
assert b.attr == 1
2730
a._unsafe_reset()
2831
b._unsafe_reset()
2932

33+
a = MySymbolBase("hello", attr=1)
34+
raises(NotImplementedError, lambda: pickle.dumps(a + 2))
35+
a._unsafe_reset()
36+
3037

3138
@unittest.skipUnless(have_llvm, "No LLVM support")
3239
@unittest.skipUnless(have_numpy, "Numpy not installed")

0 commit comments

Comments
 (0)