Skip to content

Commit e2a348e

Browse files
authored
Merge pull request #124 from isuruf/subclass
Allow subclasses of Symbol
2 parents 78800c4 + 755b6d7 commit e2a348e

File tree

5 files changed

+92
-5
lines changed

5 files changed

+92
-5
lines changed

symengine/lib/symengine.pxd

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,7 @@ cdef extern from "<symengine/symengine_rcp.h>" namespace "SymEngine":
130130
T& operator*() nogil except +
131131

132132
RCP[const Symbol] rcp_static_cast_Symbol "SymEngine::rcp_static_cast<const SymEngine::Symbol>"(RCP[const Basic] &b) nogil
133+
RCP[const PySymbol] rcp_static_cast_PySymbol "SymEngine::rcp_static_cast<const SymEngine::PySymbol>"(RCP[const Basic] &b) nogil
133134
RCP[const Integer] rcp_static_cast_Integer "SymEngine::rcp_static_cast<const SymEngine::Integer>"(RCP[const Basic] &b) nogil
134135
RCP[const Rational] rcp_static_cast_Rational "SymEngine::rcp_static_cast<const SymEngine::Rational>"(RCP[const Basic] &b) nogil
135136
RCP[const Complex] rcp_static_cast_Complex "SymEngine::rcp_static_cast<const SymEngine::Complex>"(RCP[const Basic] &b) nogil
@@ -257,6 +258,7 @@ cdef extern from "<symengine/basic.h>" namespace "SymEngine":
257258
bool is_a_Log "SymEngine::is_a<SymEngine::Log>"(const Basic &b) nogil
258259
bool is_a_PyNumber "SymEngine::is_a<SymEngine::PyNumber>"(const Basic &b) nogil
259260
bool is_a_ATan2 "SymEngine::is_a<SymEngine::ATan2>"(const Basic &b) nogil
261+
bool is_a_PySymbol "SymEngine::is_a_sub<SymEngine::PySymbol>"(const Basic &b) nogil
260262

261263
RCP[const Basic] expand(RCP[const Basic] &o) nogil except +
262264

@@ -292,6 +294,11 @@ cdef extern from "<symengine/pywrapper.h>" namespace "SymEngine":
292294
cdef cppclass PyFunction:
293295
PyObject* get_py_object()
294296

297+
cdef extern from "<symengine/pywrapper.h>" namespace "SymEngine":
298+
cdef cppclass PySymbol(Symbol):
299+
PySymbol(string name, PyObject* pyobj)
300+
PyObject* get_py_object()
301+
295302
cdef extern from "<symengine/integer.h>" namespace "SymEngine":
296303
cdef cppclass Integer(Number):
297304
Integer(int i) nogil
@@ -376,6 +383,7 @@ cdef extern from "<symengine/pow.h>" namespace "SymEngine":
376383
cdef extern from "<symengine/basic.h>" namespace "SymEngine":
377384
# We need to specialize these for our classes:
378385
RCP[const Basic] make_rcp_Symbol "SymEngine::make_rcp<const SymEngine::Symbol>"(string name) nogil
386+
RCP[const Basic] make_rcp_PySymbol "SymEngine::make_rcp<const SymEngine::PySymbol>"(string name, PyObject * pyobj) nogil
379387
RCP[const Basic] make_rcp_Constant "SymEngine::make_rcp<const SymEngine::Constant>"(string name) nogil
380388
RCP[const Basic] make_rcp_Integer "SymEngine::make_rcp<const SymEngine::Integer>"(int i) nogil
381389
RCP[const Basic] make_rcp_Integer "SymEngine::make_rcp<const SymEngine::Integer>"(integer_class i) nogil

symengine/lib/symengine/pywrapper.h

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,38 @@
88

99
namespace SymEngine {
1010

11+
/*
12+
* PySymbol is a subclass of Symbol that keeps a reference to a Python object.
13+
* When subclassing a Symbol from Python, the information stored in subclassed
14+
* object is lost because all the arithmetic and function evaluations happen on
15+
* the C++ side. The object returned by `(x + 1) - 1` is wrapped in the Python
16+
* class Symbol and therefore the fact that `x` is a subclass of Symbol is lost.
17+
*
18+
* By subclassing in the C++ side and keeping a python object reference, the
19+
* subclassed python object can be returned instead of wrapping in a Python
20+
* class Symbol.
21+
*
22+
* TODO: Python object and C++ object both keep a reference to each other as one
23+
* must be alive when the other is alive. This creates a cyclic reference and
24+
* should be fixed.
25+
*/
26+
27+
class PySymbol : public Symbol {
28+
private:
29+
PyObject* obj;
30+
public:
31+
PySymbol(const std::string& name, PyObject* obj) : Symbol(name), obj(obj) {
32+
Py_INCREF(obj);
33+
}
34+
PyObject* get_py_object() const {
35+
return obj;
36+
}
37+
virtual ~PySymbol() {
38+
// TODO: This is never called because of the cyclic reference.
39+
Py_DECREF(obj);
40+
}
41+
};
42+
1143
/*
1244
* This module provides classes to wrap Python objects defined in SymPy
1345
* or Sage into SymEngine.

symengine/lib/symengine_wrapper.pyx

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@ cdef c2py(RCP[const symengine.Basic] o):
3434
elif (symengine.is_a_Complex(deref(o))):
3535
r = Complex.__new__(Complex)
3636
elif (symengine.is_a_Symbol(deref(o))):
37+
if (symengine.is_a_PySymbol(deref(o))):
38+
return <object>(deref(symengine.rcp_static_cast_PySymbol(o)).get_py_object())
3739
r = Symbol.__new__(Symbol)
3840
elif (symengine.is_a_Constant(deref(o))):
3941
r = Constant.__new__(Constant)
@@ -653,11 +655,20 @@ def series(ex, x=None, x0=0, n=6, method='sympy', removeO=False):
653655

654656
cdef class Symbol(Basic):
655657

658+
"""
659+
Symbol is a class to store a symbolic variable with a given name.
660+
661+
Note: Subclassing `Symbol` will not work properly. Use `PySymbol`
662+
which is a subclass of `Symbol` for subclassing.
663+
"""
656664
def __cinit__(self, name = None):
657665
if name is None:
658666
return
659667
self.thisptr = symengine.make_rcp_Symbol(name.encode("utf-8"))
660668

669+
def __init__(self, name = None):
670+
return
671+
661672
def _sympy_(self):
662673
cdef RCP[const symengine.Symbol] X = symengine.rcp_static_cast_Symbol(self.thisptr)
663674
import sympy
@@ -680,6 +691,15 @@ cdef class Symbol(Basic):
680691
def is_Symbol(self):
681692
return True
682693

694+
695+
cdef class PySymbol(Symbol):
696+
def __init__(self, name, *args, **kwargs):
697+
super(PySymbol, self).__init__(name)
698+
if name is None:
699+
return
700+
self.thisptr = symengine.make_rcp_PySymbol(name.encode("utf-8"), <PyObject*>self)
701+
702+
683703
def symarray(prefix, shape, **kwargs):
684704
""" Creates an nd-array of symbols
685705

symengine/sympy_compat.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from .lib import symengine_wrapper as symengine
22
from .utilities import var, symbols
33
from .compatibility import with_metaclass
4-
from .lib.symengine_wrapper import (Symbol, sympify, sympify as S,
4+
from .lib.symengine_wrapper import (sympify, sympify as S,
55
SympifyError, sqrt, I, E, pi, Matrix, Derivative, exp,
66
Lambdify as lambdify, symarray, diff, zeros, eye, diag, ones, zeros,
77
expand, Subs, FunctionSymbol as AppliedUndef)
@@ -18,19 +18,23 @@ class Basic(with_metaclass(BasicMeta, object)):
1818

1919

2020
class Number(Basic):
21-
_classes = (symengine.Number,) + Basic._classes
21+
_classes = (symengine.Number,)
22+
pass
23+
24+
class Symbol(symengine.PySymbol, Basic):
25+
_classes = (symengine.Symbol,)
2226
pass
2327

2428

2529
class Rational(Number):
26-
_classes = (symengine.Rational,) + Number._classes
30+
_classes = (symengine.Rational, symengine.Integer)
2731

2832
def __new__(cls, num, den = 1):
2933
return symengine.Integer(num) / den
3034

3135

3236
class Integer(Rational):
33-
_classes = (symengine.Integer,) + Rational._classes
37+
_classes = (symengine.Integer,)
3438

3539
def __new__(cls, i):
3640
return symengine.Integer(i)

symengine/tests/test_sympy_compat.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,25 @@
11
from symengine.sympy_compat import (Integer, Rational, S, Basic, Add, Mul,
2-
Pow, symbols, Symbol, log, sin, zeros, atan2)
2+
Pow, symbols, Symbol, log, sin, zeros, atan2, Number)
33

44
def test_Integer():
55
i = Integer(5)
66
assert isinstance(i, Integer)
77
assert isinstance(i, Rational)
8+
assert isinstance(i, Number)
89
assert isinstance(i, Basic)
910
assert i.p == 5
1011
assert i.q == 1
1112

1213
def test_Rational():
1314
i = S(1)/2
1415
assert isinstance(i, Rational)
16+
assert isinstance(i, Number)
1517
assert isinstance(i, Basic)
1618
assert i.p == 1
1719
assert i.q == 2
20+
x = symbols("x")
21+
assert not isinstance(x, Rational)
22+
assert not isinstance(x, Number)
1823

1924
def test_Add():
2025
x, y = symbols("x y")
@@ -67,3 +72,21 @@ def test_zeros():
6772
def test_has_functions_module():
6873
import symengine.sympy_compat as sp
6974
assert sp.functions.sin(0) == 0
75+
76+
def test_subclass_symbol():
77+
# Subclass of Symbol with an extra attribute
78+
class Wrapper(Symbol):
79+
def __new__(cls, name, extra_attribute):
80+
return Symbol.__new__(cls, name)
81+
82+
def __init__(self, name, extra_attribute):
83+
super(Wrapper, self).__init__(name)
84+
self.extra_attribute = extra_attribute
85+
86+
# Instantiate the subclass
87+
x = Wrapper("x", extra_attribute=3)
88+
assert x.extra_attribute == 3
89+
two_x = 2 * x
90+
# Check that after arithmetic, same subclass is returned
91+
assert two_x.args[1] is x
92+

0 commit comments

Comments
 (0)