Skip to content

Commit e1e66a8

Browse files
committed
Wrapped Relationals and BooleanAtom
1 parent ec88cc9 commit e1e66a8

File tree

3 files changed

+256
-9
lines changed

3 files changed

+256
-9
lines changed

symengine/lib/symengine.pxd

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,7 @@ cdef extern from "<symengine/symengine_rcp.h>" namespace "SymEngine":
153153
RCP[const RealMPFR] rcp_static_cast_RealMPFR "SymEngine::rcp_static_cast<const SymEngine::RealMPFR>"(RCP[const Basic] &b) nogil
154154
RCP[const ComplexMPC] rcp_static_cast_ComplexMPC "SymEngine::rcp_static_cast<const SymEngine::ComplexMPC>"(RCP[const Basic] &b) nogil
155155
RCP[const Log] rcp_static_cast_Log "SymEngine::rcp_static_cast<const SymEngine::Log>"(RCP[const Basic] &b) nogil
156+
RCP[const Boolean] rcp_static_cast_Boolean "SymEngine::rcp_static_cast<const SymEngine::Boolean>"(RCP[const Basic] &b) nogil
156157
RCP[const PyNumber] rcp_static_cast_PyNumber "SymEngine::rcp_static_cast<const SymEngine::PyNumber>"(RCP[const Basic] &b) nogil
157158
RCP[const PyFunction] rcp_static_cast_PyFunction "SymEngine::rcp_static_cast<const SymEngine::PyFunction>"(RCP[const Basic] &b) nogil
158159
Ptr[RCP[Basic]] outArg(RCP[const Basic] &arg) nogil
@@ -266,6 +267,11 @@ cdef extern from "<symengine/basic.h>" namespace "SymEngine":
266267
bool is_a_RealMPFR "SymEngine::is_a<SymEngine::RealMPFR>"(const Basic &b) nogil
267268
bool is_a_ComplexMPC "SymEngine::is_a<SymEngine::ComplexMPC>"(const Basic &b) nogil
268269
bool is_a_Log "SymEngine::is_a<SymEngine::Log>"(const Basic &b) nogil
270+
bool is_a_BooleanAtom "SymEngine::is_a<SymEngine::BooleanAtom>"(const Basic &b) nogil
271+
bool is_a_Equality "SymEngine::is_a<SymEngine::Equality>"(const Basic &b) nogil
272+
bool is_a_Unequality "SymEngine::is_a<SymEngine::Unequality>"(const Basic &b) nogil
273+
bool is_a_LessThan "SymEngine::is_a<SymEngine::LessThan>"(const Basic &b) nogil
274+
bool is_a_StrictLessThan "SymEngine::is_a<SymEngine::StrictLessThan>"(const Basic &b) nogil
269275
bool is_a_PyNumber "SymEngine::is_a<SymEngine::PyNumber>"(const Basic &b) nogil
270276
bool is_a_ATan2 "SymEngine::is_a<SymEngine::ATan2>"(const Basic &b) nogil
271277
bool is_a_PySymbol "SymEngine::is_a_sub<SymEngine::PySymbol>"(const Basic &b) nogil
@@ -756,6 +762,33 @@ cdef extern from "<symengine/visitor.h>" namespace "SymEngine":
756762
RCP[const Basic] coeff(const Basic &b, const Basic &x, const Basic &n) nogil except +
757763
set_basic free_symbols(const Basic &b) nogil except +
758764

765+
cdef extern from "<symengine/logic.h>" namespace "SymEngine":
766+
cdef cppclass Boolean(Basic):
767+
pass
768+
cdef cppclass BooleanAtom(Boolean):
769+
pass
770+
cdef cppclass Relational(Boolean):
771+
pass
772+
cdef cppclass Equality(Relational):
773+
pass
774+
cdef cppclass Unequality(Relational):
775+
pass
776+
cdef cppclass LessThan(Relational):
777+
pass
778+
cdef cppclass StrictLessThan(Relational):
779+
pass
780+
781+
RCP[const Basic] boolTrue
782+
RCP[const Basic] boolFalse
783+
bool is_a_Relational(const Basic &b) nogil
784+
RCP[const Boolean] Eq(const RCP[const Basic] &lhs) nogil
785+
RCP[const Boolean] Eq(const RCP[const Basic] &lhs, const RCP[const Basic] &rhs) nogil
786+
RCP[const Boolean] Ne(const RCP[const Basic] &lhs, const RCP[const Basic] &rhs) nogil
787+
RCP[const Boolean] Ge(const RCP[const Basic] &lhs, const RCP[const Basic] &rhs) nogil
788+
RCP[const Boolean] Gt(const RCP[const Basic] &lhs, const RCP[const Basic] &rhs) nogil
789+
RCP[const Boolean] Le(const RCP[const Basic] &lhs, const RCP[const Basic] &rhs) nogil
790+
RCP[const Boolean] Lt(const RCP[const Basic] &lhs, const RCP[const Basic] &rhs) nogil
791+
759792
cdef extern from "<utility>" namespace "std":
760793
cdef integer_class std_move_mpz "std::move" (integer_class) nogil
761794
IF HAVE_SYMENGINE_MPFR:

symengine/lib/symengine_wrapper.pyx

Lines changed: 222 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from cython.operator cimport dereference as deref, preincrement as inc
22
cimport symengine
33
from symengine cimport RCP, pair, map_basic_basic, umap_int_basic, umap_int_basic_iterator, umap_basic_num, umap_basic_num_iterator, rcp_const_basic, std_pair_short_rcp_const_basic, rcp_const_seriescoeffinterface
4-
from libcpp cimport bool
4+
from libcpp cimport bool as cppbool
55
from libcpp.string cimport string
66
from libcpp.vector cimport vector
77
from cpython cimport PyObject, Py_XINCREF, Py_XDECREF, \
@@ -67,6 +67,16 @@ cdef c2py(RCP[const symengine.Basic] o):
6767
r = Function.__new__(Max)
6868
elif (symengine.is_a_Min(deref(o))):
6969
r = Function.__new__(Min)
70+
elif (symengine.is_a_BooleanAtom(deref(o))):
71+
r = BooleanAtom.__new__(BooleanAtom)
72+
elif (symengine.is_a_Equality(deref(o))):
73+
r = Equality.__new__(Equality)
74+
elif (symengine.is_a_Unequality(deref(o))):
75+
r = Unequality.__new__(Unequality)
76+
elif (symengine.is_a_LessThan(deref(o))):
77+
r = LessThan.__new__(LessThan)
78+
elif (symengine.is_a_StrictLessThan(deref(o))):
79+
r = StrictLessThan.__new__(StrictLessThan)
7080
elif (symengine.is_a_Gamma(deref(o))):
7181
r = Function.__new__(Gamma)
7282
elif (symengine.is_a_Derivative(deref(o))):
@@ -185,6 +195,10 @@ def sympy2symengine(a, raise_error=False):
185195
return zoo
186196
elif a is sympy.nan:
187197
return nan
198+
elif a is sympy.S.true:
199+
return BooleanTrue
200+
elif a is sympy.S.false:
201+
return BooleanFalse
188202
elif isinstance(a, sympy.functions.elementary.trigonometric.TrigonometricFunction):
189203
if isinstance(a, sympy.sin):
190204
return sin(a.args[0])
@@ -242,6 +256,18 @@ def sympy2symengine(a, raise_error=False):
242256
return _max(*a.args)
243257
elif isinstance(a, sympy.Min):
244258
return _min(*a.args)
259+
elif isinstance(a, sympy.Equality):
260+
return eq(*a.args)
261+
elif isinstance(a, sympy.Unequality):
262+
return ne(*a.args)
263+
elif isinstance(a, sympy.GreaterThan):
264+
return ge(*a.args)
265+
elif isinstance(a, sympy.StrictGreaterThan):
266+
return gt(*a.args)
267+
elif isinstance(a, sympy.LessThan):
268+
return le(*a.args)
269+
elif isinstance(a, sympy.StrictLessThan):
270+
return lt(*a.args)
245271
elif isinstance(a, sympy.gamma):
246272
return gamma(a.args[0])
247273
elif isinstance(a, sympy.Derivative):
@@ -323,6 +349,8 @@ def _sympify(a, raise_error=True):
323349
"""
324350
if isinstance(a, (Basic, MatrixBase)):
325351
return a
352+
elif isinstance(a, bool):
353+
return (BooleanTrue if a else BooleanFalse)
326354
elif isinstance(a, (int, long)):
327355
return Integer(a)
328356
elif isinstance(a, float):
@@ -843,6 +871,158 @@ cdef class Constant(Basic):
843871
raise Exception("Unknown Constant")
844872

845873

874+
class Boolean(Basic):
875+
pass
876+
877+
878+
class BooleanAtom(Boolean):
879+
880+
def _sympy_(self):
881+
import sympy
882+
if self == BooleanTrue:
883+
return sympy.S.true
884+
else:
885+
return sympy.S.false
886+
887+
def _sage_(self):
888+
if self == BooleanTrue:
889+
return True
890+
else:
891+
return False
892+
893+
894+
class Relational(Boolean):
895+
pass
896+
897+
Rel = Relational
898+
899+
900+
class Equality(Relational):
901+
902+
def __new__(cls, *args):
903+
return eq(*args)
904+
905+
def _sympy_(self):
906+
import sympy
907+
s = self.args_as_sympy()
908+
return sympy.Equality(*s)
909+
910+
def _sage_(self):
911+
import sage.all as sage
912+
s = self.args_as_sage()
913+
return sage.eq(*s)
914+
915+
func = __class__
916+
917+
918+
Eq = Equality
919+
920+
921+
class Unequality(Relational):
922+
923+
def __new__(cls, *args):
924+
return ne(*args)
925+
926+
def _sympy_(self):
927+
import sympy
928+
s = self.args_as_sympy()
929+
return sympy.Unequality(*s)
930+
931+
def _sage_(self):
932+
import sage.all as sage
933+
s = self.args_as_sage()
934+
return sage.ne(*s)
935+
936+
func = __class__
937+
938+
939+
Ne = Unequality
940+
941+
942+
class GreaterThan(Relational):
943+
944+
def __new__(cls, *args):
945+
return ge(*args)
946+
947+
def _sympy_(self):
948+
import sympy
949+
s = self.args_as_sympy()
950+
return sympy.GreaterThan(*s)
951+
952+
def _sage_(self):
953+
import sage.all as sage
954+
s = self.args_as_sage()
955+
return sage.ge(*s)
956+
957+
func = __class__
958+
959+
960+
Ge = GreaterThan
961+
962+
963+
class StrictGreaterThan(Relational):
964+
965+
def __new__(cls, *args):
966+
return gt(*args)
967+
968+
def _sympy_(self):
969+
import sympy
970+
s = self.args_as_sympy()
971+
return sympy.StrictGreaterThan(*s)
972+
973+
def _sage_(self):
974+
import sage.all as sage
975+
s = self.args_as_sage()
976+
return sage.gt(*s)
977+
978+
func = __class__
979+
980+
981+
Gt = StrictGreaterThan
982+
983+
984+
class LessThan(Relational):
985+
986+
def __new__(cls, *args):
987+
return le(*args)
988+
989+
def _sympy_(self):
990+
import sympy
991+
s = self.args_as_sympy()
992+
return sympy.LessThan(*s)
993+
994+
def _sage_(self):
995+
import sage.all as sage
996+
s = self.args_as_sage()
997+
return sage.le(*s)
998+
999+
func = __class__
1000+
1001+
1002+
Le = LessThan
1003+
1004+
1005+
class StrictLessThan(Relational):
1006+
1007+
def __new__(cls, *args):
1008+
return lt(*args)
1009+
1010+
def _sympy_(self):
1011+
import sympy
1012+
s = self.args_as_sympy()
1013+
return sympy.StrictLessThan(*s)
1014+
1015+
def _sage_(self):
1016+
import sage.all as sage
1017+
s = self.args_as_sage()
1018+
return sage.lt(*s)
1019+
1020+
func = __class__
1021+
1022+
1023+
Lt = StrictLessThan
1024+
1025+
8461026
cdef class Number(Basic):
8471027
@property
8481028
def is_Atom(self):
@@ -2560,10 +2740,12 @@ pi = c2py(symengine.pi)
25602740
oo = c2py(symengine.Inf)
25612741
zoo = c2py(symengine.ComplexInf)
25622742
nan = c2py(symengine.Nan)
2743+
BooleanTrue = c2py(symengine.boolTrue)
2744+
BooleanFalse = c2py(symengine.boolFalse)
25632745

25642746
def module_cleanup():
2565-
global I, E, pi, oo, zoo, nan, sympy_module, sage_module
2566-
del I, E, pi, oo, zoo, nan, sympy_module, sage_module
2747+
global I, E, pi, oo, zoo, nan, BooleanTrue, BooleanFalse, sympy_module, sage_module
2748+
del I, E, pi, oo, zoo, nan, BooleanTrue, BooleanFalse, sympy_module, sage_module
25672749

25682750
import atexit
25692751
atexit.register(module_cleanup)
@@ -2614,6 +2796,38 @@ def gamma(x):
26142796
cdef Basic X = sympify(x)
26152797
return c2py(symengine.gamma(X.thisptr))
26162798

2799+
def eq(lhs, rhs = None):
2800+
cdef Basic X = sympify(lhs)
2801+
if rhs is None:
2802+
return c2py(<RCP[const symengine.Basic]>(symengine.Eq(X.thisptr)))
2803+
cdef Basic Y = sympify(rhs)
2804+
return c2py(<RCP[const symengine.Basic]>(symengine.Eq(X.thisptr, Y.thisptr)))
2805+
2806+
def ne(lhs, rhs):
2807+
cdef Basic X = sympify(lhs)
2808+
cdef Basic Y = sympify(rhs)
2809+
return c2py(<RCP[const symengine.Basic]>(symengine.Ne(X.thisptr, Y.thisptr)))
2810+
2811+
def ge(lhs, rhs):
2812+
cdef Basic X = sympify(lhs)
2813+
cdef Basic Y = sympify(rhs)
2814+
return c2py(<RCP[const symengine.Basic]>(symengine.Ge(X.thisptr, Y.thisptr)))
2815+
2816+
def gt(lhs, rhs):
2817+
cdef Basic X = sympify(lhs)
2818+
cdef Basic Y = sympify(rhs)
2819+
return c2py(<RCP[const symengine.Basic]>(symengine.Gt(X.thisptr, Y.thisptr)))
2820+
2821+
def le(lhs, rhs):
2822+
cdef Basic X = sympify(lhs)
2823+
cdef Basic Y = sympify(rhs)
2824+
return c2py(<RCP[const symengine.Basic]>(symengine.Le(X.thisptr, Y.thisptr)))
2825+
2826+
def lt(lhs, rhs):
2827+
cdef Basic X = sympify(lhs)
2828+
cdef Basic Y = sympify(rhs)
2829+
return c2py(<RCP[const symengine.Basic]>(symengine.Lt(X.thisptr, Y.thisptr)))
2830+
26172831
def eval_double(x):
26182832
cdef Basic X = sympify(x)
26192833
return c2py(<RCP[const symengine.Basic]>(symengine.real_double(symengine.eval_double(deref(X.thisptr)))))
@@ -2758,7 +2972,7 @@ def mod_inverse(a, b):
27582972
def crt(rem, mod):
27592973
cdef symengine.vec_integer _rem, _mod
27602974
cdef Basic _a
2761-
cdef bool ret_val
2975+
cdef cppbool ret_val
27622976
for i in range(len(rem)):
27632977
_a = sympify(rem[i])
27642978
require(_a, Integer)
@@ -2895,7 +3109,7 @@ def primitive_root(n):
28953109
cdef RCP[const symengine.Integer] g
28963110
cdef Basic _n = sympify(n)
28973111
require(_n, Integer)
2898-
cdef bool ret_val = symengine.primitive_root(symengine.outArg_Integer(g),
3112+
cdef cppbool ret_val = symengine.primitive_root(symengine.outArg_Integer(g),
28993113
deref(symengine.rcp_static_cast_Integer(_n.thisptr)))
29003114
if ret_val == 0:
29013115
return None
@@ -2932,7 +3146,7 @@ def multiplicative_order(a, n):
29323146
cdef RCP[const symengine.Integer] n1 = symengine.rcp_static_cast_Integer(_n.thisptr)
29333147
cdef RCP[const symengine.Integer] a1 = symengine.rcp_static_cast_Integer(_a.thisptr)
29343148
cdef RCP[const symengine.Integer] o
2935-
cdef bool c = symengine.multiplicative_order(symengine.outArg_Integer(o),
3149+
cdef cppbool c = symengine.multiplicative_order(symengine.outArg_Integer(o),
29363150
a1, n1)
29373151
if not c:
29383152
return None
@@ -2973,7 +3187,7 @@ def nthroot_mod(a, n, m):
29733187
cdef RCP[const symengine.Integer] n1 = symengine.rcp_static_cast_Integer(_n.thisptr)
29743188
cdef RCP[const symengine.Integer] a1 = symengine.rcp_static_cast_Integer(_a.thisptr)
29753189
cdef RCP[const symengine.Integer] m1 = symengine.rcp_static_cast_Integer(_m.thisptr)
2976-
cdef bool ret_val = symengine.nthroot_mod(symengine.outArg_Integer(root), a1, n1, m1)
3190+
cdef cppbool ret_val = symengine.nthroot_mod(symengine.outArg_Integer(root), a1, n1, m1)
29773191
if not ret_val:
29783192
return None
29793193
return c2py(<RCP[const symengine.Basic]>root)
@@ -3006,7 +3220,7 @@ def powermod(a, b, m):
30063220
cdef RCP[const symengine.Number] b1 = symengine.rcp_static_cast_Number(_b.thisptr)
30073221
cdef RCP[const symengine.Integer] root
30083222

3009-
cdef bool ret_val = symengine.powermod(symengine.outArg_Integer(root), a1, b1, m1)
3223+
cdef cppbool ret_val = symengine.powermod(symengine.outArg_Integer(root), a1, b1, m1)
30103224
if ret_val == 0:
30113225
return None
30123226
return c2py(<RCP[const symengine.Basic]>root)

symengine_version.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
v0.3.0
1+
d575a3a073607589e7307a143cfdecaef439e71c

0 commit comments

Comments
 (0)