Skip to content

Commit 894e7a2

Browse files
authored
Merge branch 'master' into release
2 parents 5d89eb8 + 0680b42 commit 894e7a2

File tree

7 files changed

+141
-43
lines changed

7 files changed

+141
-43
lines changed

symengine/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,11 @@
5555

5656
def lambdify(args, exprs, **kwargs):
5757
return Lambdify(args, *exprs, **kwargs)
58+
else:
59+
def __getattr__(name):
60+
if name == 'lambdify':
61+
raise AttributeError("Cannot import numpy, which is required for `lambdify` to work")
62+
raise AttributeError(f"module 'symengine' has no attribute '{name}'")
5863

5964

6065
__version__ = "0.10.0"

symengine/lib/pywrapper.cpp

Lines changed: 32 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -281,47 +281,61 @@ inline PyObject* get_pickle_module() {
281281
return module;
282282
}
283283

284+
PyObject* pickle_loads(const std::string &pickle_str) {
285+
PyObject *module = get_pickle_module();
286+
PyObject *pickle_bytes = PyBytes_FromStringAndSize(pickle_str.data(), pickle_str.size());
287+
PyObject *obj = PyObject_CallMethod(module, "loads", "O", pickle_bytes);
288+
Py_XDECREF(pickle_bytes);
289+
if (obj == NULL) {
290+
throw SerializationError("error when loading pickled symbol subclass object");
291+
}
292+
return obj;
293+
}
294+
284295
RCP<const Basic> load_basic(cereal::PortableBinaryInputArchive &ar, RCP<const Symbol> &)
285296
{
286297
bool is_pysymbol;
298+
bool store_pickle;
287299
std::string name;
288300
ar(is_pysymbol);
289301
ar(name);
290302
if (is_pysymbol) {
291303
std::string pickle_str;
292304
ar(pickle_str);
293-
PyObject *module = get_pickle_module();
294-
PyObject *pickle_bytes = PyBytes_FromStringAndSize(pickle_str.data(), pickle_str.size());
295-
PyObject *obj = PyObject_CallMethod(module, "loads", "O", pickle_bytes);
296-
if (obj == NULL) {
297-
throw SerializationError("error when loading pickled symbol subclass object");
298-
}
299-
RCP<const Basic> result = make_rcp<PySymbol>(name, obj);
300-
Py_XDECREF(pickle_bytes);
305+
ar(store_pickle);
306+
PyObject *obj = pickle_loads(pickle_str);
307+
RCP<const Basic> result = make_rcp<PySymbol>(name, obj, store_pickle);
308+
Py_XDECREF(obj);
301309
return result;
302310
} else {
303311
return symbol(name);
304312
}
305313
}
306314

315+
std::string pickle_dumps(const PyObject * obj) {
316+
PyObject *module = get_pickle_module();
317+
PyObject *pickle_bytes = PyObject_CallMethod(module, "dumps", "O", obj);
318+
if (pickle_bytes == NULL) {
319+
throw SerializationError("error when pickling symbol subclass object");
320+
}
321+
Py_ssize_t size;
322+
char* buffer;
323+
PyBytes_AsStringAndSize(pickle_bytes, &buffer, &size);
324+
return std::string(buffer, size);
325+
}
326+
307327
void save_basic(cereal::PortableBinaryOutputArchive &ar, const Symbol &b)
308328
{
309329
bool is_pysymbol = is_a_sub<PySymbol>(b);
310330
ar(is_pysymbol);
311331
ar(b.__str__());
312332
if (is_pysymbol) {
313333
RCP<const PySymbol> p = rcp_static_cast<const PySymbol>(b.rcp_from_this());
314-
PyObject *module = get_pickle_module();
315-
PyObject *pickle_bytes = PyObject_CallMethod(module, "dumps", "O", p->get_py_object());
316-
if (pickle_bytes == NULL) {
317-
throw SerializationError("error when pickling symbol subclass object");
318-
}
319-
Py_ssize_t size;
320-
char* buffer;
321-
PyBytes_AsStringAndSize(pickle_bytes, &buffer, &size);
322-
std::string pickle_str(buffer, size);
334+
PyObject *obj = p->get_py_object();
335+
std::string pickle_str = pickle_dumps(obj);
323336
ar(pickle_str);
324-
Py_XDECREF(pickle_bytes);
337+
ar(p->store_pickle);
338+
Py_XDECREF(obj);
325339
}
326340
}
327341

symengine/lib/pywrapper.h

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@
88

99
namespace SymEngine {
1010

11+
std::string pickle_dumps(const PyObject *);
12+
PyObject* pickle_loads(const std::string &);
13+
1114
/*
1215
* PySymbol is a subclass of Symbol that keeps a reference to a Python object.
1316
* When subclassing a Symbol from Python, the information stored in subclassed
@@ -27,16 +30,30 @@ namespace SymEngine {
2730
class PySymbol : public Symbol {
2831
private:
2932
PyObject* obj;
33+
std::string bytes;
3034
public:
31-
PySymbol(const std::string& name, PyObject* obj) : Symbol(name), obj(obj) {
32-
Py_INCREF(obj);
35+
const bool store_pickle;
36+
PySymbol(const std::string& name, PyObject* obj, bool store_pickle) :
37+
Symbol(name), obj(obj), store_pickle(store_pickle) {
38+
if (store_pickle) {
39+
bytes = pickle_dumps(obj);
40+
} else {
41+
Py_INCREF(obj);
42+
}
3343
}
3444
PyObject* get_py_object() const {
35-
return obj;
45+
if (store_pickle) {
46+
return pickle_loads(bytes);
47+
} else {
48+
Py_INCREF(obj);
49+
return obj;
50+
}
3651
}
3752
virtual ~PySymbol() {
38-
// TODO: This is never called because of the cyclic reference.
39-
Py_DECREF(obj);
53+
if (not store_pickle) {
54+
// TODO: This is never called because of the cyclic reference.
55+
Py_DECREF(obj);
56+
}
4057
}
4158
};
4259

symengine/lib/symengine.pxd

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,7 @@ cdef extern from "<symengine/basic.h>" namespace "SymEngine":
197197
bool neq(const Basic &a, const Basic &b) nogil except +
198198

199199
RCP[const Symbol] rcp_static_cast_Symbol "SymEngine::rcp_static_cast<const SymEngine::Symbol>"(rcp_const_basic &b) nogil
200-
RCP[const PySymbol] rcp_static_cast_PySymbol "SymEngine::rcp_static_cast<const SymEngine::PySymbol>"(rcp_const_basic &b) nogil
200+
RCP[const PySymbol] rcp_static_cast_PySymbol "SymEngine::rcp_static_cast<const SymEngine::PySymbol>"(rcp_const_basic &b) nogil except +
201201
RCP[const Integer] rcp_static_cast_Integer "SymEngine::rcp_static_cast<const SymEngine::Integer>"(rcp_const_basic &b) nogil
202202
RCP[const Rational] rcp_static_cast_Rational "SymEngine::rcp_static_cast<const SymEngine::Rational>"(rcp_const_basic &b) nogil
203203
RCP[const Complex] rcp_static_cast_Complex "SymEngine::rcp_static_cast<const SymEngine::Complex>"(rcp_const_basic &b) nogil
@@ -369,8 +369,8 @@ cdef extern from "pywrapper.h" namespace "SymEngine":
369369

370370
cdef extern from "pywrapper.h" namespace "SymEngine":
371371
cdef cppclass PySymbol(Symbol):
372-
PySymbol(string name, PyObject* pyobj)
373-
PyObject* get_py_object()
372+
PySymbol(string name, PyObject* pyobj, bool use_pickle) except +
373+
PyObject* get_py_object() except +
374374

375375
string wrapper_dumps(const Basic &x) nogil except +
376376
rcp_const_basic wrapper_loads(const string &s) nogil except +
@@ -479,7 +479,7 @@ cdef extern from "<symengine/basic.h>" namespace "SymEngine":
479479
rcp_const_basic make_rcp_Symbol "SymEngine::make_rcp<const SymEngine::Symbol>"(string name) nogil
480480
rcp_const_basic make_rcp_Dummy "SymEngine::make_rcp<const SymEngine::Dummy>"() nogil
481481
rcp_const_basic make_rcp_Dummy "SymEngine::make_rcp<const SymEngine::Dummy>"(string name) nogil
482-
rcp_const_basic make_rcp_PySymbol "SymEngine::make_rcp<const SymEngine::PySymbol>"(string name, PyObject * pyobj) nogil
482+
rcp_const_basic make_rcp_PySymbol "SymEngine::make_rcp<const SymEngine::PySymbol>"(string name, PyObject * pyobj, bool use_pickle) except +
483483
rcp_const_basic make_rcp_Constant "SymEngine::make_rcp<const SymEngine::Constant>"(string name) nogil
484484
rcp_const_basic make_rcp_Infty "SymEngine::make_rcp<const SymEngine::Infty>"(RCP[const Number] i) nogil
485485
rcp_const_basic make_rcp_NaN "SymEngine::make_rcp<const SymEngine::NaN>"() nogil

symengine/lib/symengine_wrapper.pyx

Lines changed: 39 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ cpdef void assign_to_capsule(object capsule, object value):
4646

4747
cdef object c2py(rcp_const_basic o):
4848
cdef Basic r
49+
cdef PyObject *obj
4950
if (symengine.is_a_Add(deref(o))):
5051
r = Expr.__new__(Add)
5152
elif (symengine.is_a_Mul(deref(o))):
@@ -74,7 +75,10 @@ cdef object c2py(rcp_const_basic o):
7475
r = Dummy.__new__(Dummy)
7576
elif (symengine.is_a_Symbol(deref(o))):
7677
if (symengine.is_a_PySymbol(deref(o))):
77-
return <object>(deref(symengine.rcp_static_cast_PySymbol(o)).get_py_object())
78+
obj = deref(symengine.rcp_static_cast_PySymbol(o)).get_py_object()
79+
result = <object>(obj)
80+
Py_XDECREF(obj);
81+
return result
7882
r = Symbol.__new__(Symbol)
7983
elif (symengine.is_a_Constant(deref(o))):
8084
r = S.Pi
@@ -1216,16 +1220,26 @@ cdef class Expr(Basic):
12161220

12171221

12181222
cdef class Symbol(Expr):
1219-
12201223
"""
12211224
Symbol is a class to store a symbolic variable with a given name.
1225+
Subclassing Symbol leads to a memory leak due to a cycle in reference counting.
1226+
To avoid this with a performance penalty, set the kwarg store_pickle=True
1227+
in the constructor and support the pickle protocol in the subclass by
1228+
implmenting __reduce__.
12221229
"""
12231230

12241231
def __init__(Basic self, name, *args, **kwargs):
1232+
cdef cppbool store_pickle;
12251233
if type(self) == Symbol:
12261234
self.thisptr = symengine.make_rcp_Symbol(name.encode("utf-8"))
12271235
else:
1228-
self.thisptr = symengine.make_rcp_PySymbol(name.encode("utf-8"), <PyObject*>self)
1236+
store_pickle = kwargs.pop("store_pickle", False)
1237+
if store_pickle:
1238+
# First set the pointer to a regular symbol so that when pickle.dumps
1239+
# is called when the PySymbol is created, methods like name works.
1240+
self.thisptr = symengine.make_rcp_Symbol(name.encode("utf-8"))
1241+
self.thisptr = symengine.make_rcp_PySymbol(name.encode("utf-8"), <PyObject*>self,
1242+
store_pickle)
12291243

12301244
def _sympy_(self):
12311245
import sympy
@@ -2635,6 +2649,14 @@ class atan2(Function):
26352649
cdef Basic Y = sympify(y)
26362650
return c2py(symengine.atan2(X.thisptr, Y.thisptr))
26372651

2652+
def _sympy_(self):
2653+
import sympy
2654+
return sympy.atan2(*self.args_as_sympy())
2655+
2656+
def _sage_(self):
2657+
import sage.all as sage
2658+
return sage.atan2(*self.args_as_sage())
2659+
26382660
# For backwards compatibility
26392661

26402662
Sin = sin
@@ -3895,7 +3917,7 @@ cdef class DenseMatrixBase(MatrixBase):
38953917
l.append(c2py(A.get(i, j))._sympy_())
38963918
s.append(l)
38973919
import sympy
3898-
return sympy.Matrix(s)
3920+
return sympy.ImmutableMatrix(s)
38993921

39003922
def _sage_(self):
39013923
s = []
@@ -3906,7 +3928,7 @@ cdef class DenseMatrixBase(MatrixBase):
39063928
l.append(c2py(A.get(i, j))._sage_())
39073929
s.append(l)
39083930
import sage.all as sage
3909-
return sage.Matrix(s)
3931+
return sage.Matrix(s, immutable=True)
39103932

39113933
def dump_real(self, double[::1] out):
39123934
cdef size_t ri, ci, nr, nc
@@ -4046,6 +4068,12 @@ cdef class ImmutableDenseMatrix(DenseMatrixBase):
40464068
def __setitem__(self, key, value):
40474069
raise TypeError("Cannot set values of {}".format(self.__class__))
40484070

4071+
def _applyfunc(self, f):
4072+
res = DenseMatrix(self)
4073+
res._applyfunc(f)
4074+
return ImmutableDenseMatrix(res)
4075+
4076+
40494077
ImmutableMatrix = ImmutableDenseMatrix
40504078

40514079

@@ -5203,7 +5231,7 @@ def Lambdify(args, *exprs, cppbool real=True, backend=None, order='C',
52035231
Whether datatype is ``double`` (``double complex`` otherwise).
52045232
backend : str
52055233
'llvm' or 'lambda'. When ``None`` the environment variable
5206-
'SYMENGINE_LAMBDIFY_BACKEND' is used (taken as 'lambda' if unset).
5234+
'SYMENGINE_LAMBDIFY_BACKEND' is used (taken as 'llvm' if available, otherwise 'lambda').
52075235
order : 'C' or 'F'
52085236
C- or Fortran-contiguous memory layout. Note that this affects
52095237
broadcasting: e.g. a (m, n) matrix taking 3 arguments and given a
@@ -5235,7 +5263,11 @@ def Lambdify(args, *exprs, cppbool real=True, backend=None, order='C',
52355263
52365264
"""
52375265
if backend is None:
5238-
backend = os.getenv('SYMENGINE_LAMBDIFY_BACKEND', "lambda")
5266+
IF HAVE_SYMENGINE_LLVM:
5267+
backend_default = 'llvm' if real else 'lambda'
5268+
ELSE:
5269+
backend_default = 'lambda'
5270+
backend = os.getenv('SYMENGINE_LAMBDIFY_BACKEND', backend_default)
52395271
if backend == "llvm":
52405272
IF HAVE_SYMENGINE_LLVM:
52415273
if dtype == None:

symengine/tests/test_matrices.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,23 @@
33
Rational, function_symbol, I, NonSquareMatrixError, ShapeError, zeros,
44
ones, eye, ImmutableMatrix)
55
from symengine.test_utilities import raises
6+
import unittest
67

78

89
try:
910
import numpy as np
10-
HAVE_NUMPY = True
11+
have_numpy = True
1112
except ImportError:
12-
HAVE_NUMPY = False
13+
have_numpy = False
14+
15+
try:
16+
import sympy
17+
from sympy.core.cache import clear_cache
18+
import atexit
19+
atexit.register(clear_cache)
20+
have_sympy = True
21+
except ImportError:
22+
have_sympy = False
1323

1424

1525
def test_init():
@@ -520,21 +530,18 @@ def test_reshape():
520530
assert C != A
521531

522532

523-
# @pytest.mark.skipif(not HAVE_NUMPY, reason='requires numpy')
533+
@unittest.skipIf(not have_numpy, 'requires numpy')
524534
def test_dump_real():
525-
if not HAVE_NUMPY: # nosetests work-around
526-
return
527535
ref = [1, 2, 3, 4]
528536
A = DenseMatrix(2, 2, ref)
529537
out = np.empty(4)
530538
A.dump_real(out)
531539
assert np.allclose(out, ref)
532540

533541

534-
# @pytest.mark.skipif(not HAVE_NUMPY, reason='requires numpy')
542+
543+
@unittest.skipIf(not have_numpy, 'requires numpy')
535544
def test_dump_complex():
536-
if not HAVE_NUMPY: # nosetests work-around
537-
return
538545
ref = [1j, 2j, 3j, 4j]
539546
A = DenseMatrix(2, 2, ref)
540547
out = np.empty(4, dtype=np.complex128)
@@ -741,3 +748,8 @@ def test_repr_latex():
741748
latex_string = testmat._repr_latex_()
742749
assert isinstance(latex_string, str)
743750
init_printing(False)
751+
752+
@unittest.skipIf(not have_sympy, "SymPy not installed")
753+
def test_simplify():
754+
A = ImmutableMatrix([1])
755+
assert type(A.simplify()) == type(A)

symengine/tests/test_sympy_conv.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
function_symbol, I, E, pi, oo, zoo, nan, true, false,
33
exp, gamma, have_mpfr, have_mpc, DenseMatrix, sin, cos, tan, cot,
44
csc, sec, asin, acos, atan, acot, acsc, asec, sinh, cosh, tanh, coth,
5-
asinh, acosh, atanh, acoth, Add, Mul, Pow, diff, GoldenRatio,
5+
asinh, acosh, atanh, acoth, atan2, Add, Mul, Pow, diff, GoldenRatio,
66
Catalan, EulerGamma, UnevaluatedExpr, RealDouble)
77
from symengine.lib.symengine_wrapper import (Subs, Derivative, RealMPFR,
88
ComplexMPC, PyNumber, Function, LambertW, zeta, dirichlet_eta,
@@ -171,6 +171,7 @@ def test_conv7():
171171
assert acot(x/3) == acot(sympy.Symbol("x") / 3)
172172
assert acsc(x/3) == acsc(sympy.Symbol("x") / 3)
173173
assert asec(x/3) == asec(sympy.Symbol("x") / 3)
174+
assert atan2(x/3, y) == atan2(sympy.Symbol("x") / 3, sympy.Symbol("y"))
174175

175176
assert sin(x/3)._sympy_() == sympy.sin(sympy.Symbol("x") / 3)
176177
assert sin(x/3)._sympy_() != sympy.cos(sympy.Symbol("x") / 3)
@@ -185,6 +186,22 @@ def test_conv7():
185186
assert acot(x/3)._sympy_() == sympy.acot(sympy.Symbol("x") / 3)
186187
assert acsc(x/3)._sympy_() == sympy.acsc(sympy.Symbol("x") / 3)
187188
assert asec(x/3)._sympy_() == sympy.asec(sympy.Symbol("x") / 3)
189+
assert atan2(x/3, y)._sympy_() == sympy.atan2(sympy.Symbol("x") / 3, sympy.Symbol("y"))
190+
191+
assert sympy.sympify(sin(x/3)) == sympy.sin(sympy.Symbol("x") / 3)
192+
assert sympy.sympify(sin(x/3)) != sympy.cos(sympy.Symbol("x") / 3)
193+
assert sympy.sympify(cos(x/3)) == sympy.cos(sympy.Symbol("x") / 3)
194+
assert sympy.sympify(tan(x/3)) == sympy.tan(sympy.Symbol("x") / 3)
195+
assert sympy.sympify(cot(x/3)) == sympy.cot(sympy.Symbol("x") / 3)
196+
assert sympy.sympify(csc(x/3)) == sympy.csc(sympy.Symbol("x") / 3)
197+
assert sympy.sympify(sec(x/3)) == sympy.sec(sympy.Symbol("x") / 3)
198+
assert sympy.sympify(asin(x/3)) == sympy.asin(sympy.Symbol("x") / 3)
199+
assert sympy.sympify(acos(x/3)) == sympy.acos(sympy.Symbol("x") / 3)
200+
assert sympy.sympify(atan(x/3)) == sympy.atan(sympy.Symbol("x") / 3)
201+
assert sympy.sympify(acot(x/3)) == sympy.acot(sympy.Symbol("x") / 3)
202+
assert sympy.sympify(acsc(x/3)) == sympy.acsc(sympy.Symbol("x") / 3)
203+
assert sympy.sympify(asec(x/3)) == sympy.asec(sympy.Symbol("x") / 3)
204+
assert sympy.sympify(atan2(x/3, y)) == sympy.atan2(sympy.Symbol("x") / 3, sympy.Symbol("y"))
188205

189206

190207
@unittest.skipIf(not have_sympy, "SymPy not installed")
@@ -204,6 +221,7 @@ def test_conv7b():
204221
assert sympify(sympy.acot(x/3)) == acot(Symbol("x") / 3)
205222
assert sympify(sympy.acsc(x/3)) == acsc(Symbol("x") / 3)
206223
assert sympify(sympy.asec(x/3)) == asec(Symbol("x") / 3)
224+
assert sympify(sympy.atan2(x/3, y)) == atan2(Symbol("x") / 3, Symbol("y"))
207225

208226

209227
@unittest.skipIf(not have_sympy, "SymPy not installed")

0 commit comments

Comments
 (0)