Skip to content

Commit 01b6906

Browse files
committed
Added tests
1 parent e1e66a8 commit 01b6906

File tree

8 files changed

+91
-82
lines changed

8 files changed

+91
-82
lines changed

symengine/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,9 @@
77
MatrixBase, Basic, DictBasic, symarray, series, diff, zeros,
88
eye, diag, ones, Derivative, Subs, add, expand, has_symbol,
99
UndefFunction, Function, FunctionSymbol as AppliedUndef,
10-
have_numpy)
10+
have_numpy, true, false, Equality, Unequality, GreaterThan,
11+
LessThan, StrictGreaterThan, StrictLessThan, Eq, Ne, Ge, Le,
12+
Gt, Lt)
1113
from .utilities import var, symbols
1214
from .functions import *
1315

symengine/lib/symengine.pxd

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ cdef extern from "<symengine/symengine_rcp.h>" namespace "SymEngine":
153153
RCP[const RealMPFR] rcp_static_cast_RealMPFR "SymEngine::rcp_static_cast<const SymEngine::RealMPFR>"(RCP[const Basic] &b) nogil
154154
RCP[const ComplexMPC] rcp_static_cast_ComplexMPC "SymEngine::rcp_static_cast<const SymEngine::ComplexMPC>"(RCP[const Basic] &b) nogil
155155
RCP[const Log] rcp_static_cast_Log "SymEngine::rcp_static_cast<const SymEngine::Log>"(RCP[const Basic] &b) nogil
156-
RCP[const Boolean] rcp_static_cast_Boolean "SymEngine::rcp_static_cast<const SymEngine::Boolean>"(RCP[const Basic] &b) nogil
156+
RCP[const BooleanAtom] rcp_static_cast_BooleanAtom "SymEngine::rcp_static_cast<const SymEngine::BooleanAtom>"(RCP[const Basic] &b) nogil
157157
RCP[const PyNumber] rcp_static_cast_PyNumber "SymEngine::rcp_static_cast<const SymEngine::PyNumber>"(RCP[const Basic] &b) nogil
158158
RCP[const PyFunction] rcp_static_cast_PyFunction "SymEngine::rcp_static_cast<const SymEngine::PyFunction>"(RCP[const Basic] &b) nogil
159159
Ptr[RCP[Basic]] outArg(RCP[const Basic] &arg) nogil
@@ -766,7 +766,7 @@ cdef extern from "<symengine/logic.h>" namespace "SymEngine":
766766
cdef cppclass Boolean(Basic):
767767
pass
768768
cdef cppclass BooleanAtom(Boolean):
769-
pass
769+
bool get_val() nogil
770770
cdef cppclass Relational(Boolean):
771771
pass
772772
cdef cppclass Equality(Relational):
@@ -781,13 +781,13 @@ cdef extern from "<symengine/logic.h>" namespace "SymEngine":
781781
RCP[const Basic] boolTrue
782782
RCP[const Basic] boolFalse
783783
bool is_a_Relational(const Basic &b) nogil
784-
RCP[const Boolean] Eq(const RCP[const Basic] &lhs) nogil
785-
RCP[const Boolean] Eq(const RCP[const Basic] &lhs, const RCP[const Basic] &rhs) nogil
786-
RCP[const Boolean] Ne(const RCP[const Basic] &lhs, const RCP[const Basic] &rhs) nogil
787-
RCP[const Boolean] Ge(const RCP[const Basic] &lhs, const RCP[const Basic] &rhs) nogil
788-
RCP[const Boolean] Gt(const RCP[const Basic] &lhs, const RCP[const Basic] &rhs) nogil
789-
RCP[const Boolean] Le(const RCP[const Basic] &lhs, const RCP[const Basic] &rhs) nogil
790-
RCP[const Boolean] Lt(const RCP[const Basic] &lhs, const RCP[const Basic] &rhs) nogil
784+
cdef RCP[const Boolean] Eq(RCP[const Basic] &lhs) nogil except+
785+
cdef RCP[const Boolean] Eq(RCP[const Basic] &lhs, RCP[const Basic] &rhs) nogil except+
786+
cdef RCP[const Boolean] Ne(RCP[const Basic] &lhs, RCP[const Basic] &rhs) nogil except+
787+
cdef RCP[const Boolean] Ge(RCP[const Basic] &lhs, RCP[const Basic] &rhs) nogil except+
788+
cdef RCP[const Boolean] Gt(RCP[const Basic] &lhs, RCP[const Basic] &rhs) nogil except+
789+
cdef RCP[const Boolean] Le(RCP[const Basic] &lhs, RCP[const Basic] &rhs) nogil except+
790+
cdef RCP[const Boolean] Lt(RCP[const Basic] &lhs, RCP[const Basic] &rhs) nogil except+
791791

792792
cdef extern from "<utility>" namespace "std":
793793
cdef integer_class std_move_mpz "std::move" (integer_class) nogil

symengine/lib/symengine_wrapper.pyx

Lines changed: 34 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,10 @@ cdef c2py(RCP[const symengine.Basic] o):
6868
elif (symengine.is_a_Min(deref(o))):
6969
r = Function.__new__(Min)
7070
elif (symengine.is_a_BooleanAtom(deref(o))):
71-
r = BooleanAtom.__new__(BooleanAtom)
71+
if (deref(symengine.rcp_static_cast_BooleanAtom(o)).get_val()):
72+
r = BooleanAtom.__new__(BooleanTrue)
73+
else:
74+
r = BooleanAtom.__new__(BooleanFalse)
7275
elif (symengine.is_a_Equality(deref(o))):
7376
r = Equality.__new__(Equality)
7477
elif (symengine.is_a_Unequality(deref(o))):
@@ -196,9 +199,9 @@ def sympy2symengine(a, raise_error=False):
196199
elif a is sympy.nan:
197200
return nan
198201
elif a is sympy.S.true:
199-
return BooleanTrue
202+
return true
200203
elif a is sympy.S.false:
201-
return BooleanFalse
204+
return false
202205
elif isinstance(a, sympy.functions.elementary.trigonometric.TrigonometricFunction):
203206
if isinstance(a, sympy.sin):
204207
return sin(a.args[0])
@@ -350,7 +353,7 @@ def _sympify(a, raise_error=True):
350353
if isinstance(a, (Basic, MatrixBase)):
351354
return a
352355
elif isinstance(a, bool):
353-
return (BooleanTrue if a else BooleanFalse)
356+
return (true if a else false)
354357
elif isinstance(a, (int, long)):
355358
return Integer(a)
356359
elif isinstance(a, float):
@@ -876,19 +879,27 @@ class Boolean(Basic):
876879

877880

878881
class BooleanAtom(Boolean):
882+
pass
883+
884+
885+
class BooleanTrue(BooleanAtom):
879886

880887
def _sympy_(self):
881888
import sympy
882-
if self == BooleanTrue:
883-
return sympy.S.true
884-
else:
885-
return sympy.S.false
889+
return sympy.S.true
886890

887891
def _sage_(self):
888-
if self == BooleanTrue:
889-
return True
890-
else:
891-
return False
892+
return True
893+
894+
895+
class BooleanFalse(BooleanAtom):
896+
897+
def _sympy_(self):
898+
import sympy
899+
return sympy.S.false
900+
901+
def _sage_(self):
902+
return False
892903

893904

894905
class Relational(Boolean):
@@ -939,48 +950,6 @@ class Unequality(Relational):
939950
Ne = Unequality
940951

941952

942-
class GreaterThan(Relational):
943-
944-
def __new__(cls, *args):
945-
return ge(*args)
946-
947-
def _sympy_(self):
948-
import sympy
949-
s = self.args_as_sympy()
950-
return sympy.GreaterThan(*s)
951-
952-
def _sage_(self):
953-
import sage.all as sage
954-
s = self.args_as_sage()
955-
return sage.ge(*s)
956-
957-
func = __class__
958-
959-
960-
Ge = GreaterThan
961-
962-
963-
class StrictGreaterThan(Relational):
964-
965-
def __new__(cls, *args):
966-
return gt(*args)
967-
968-
def _sympy_(self):
969-
import sympy
970-
s = self.args_as_sympy()
971-
return sympy.StrictGreaterThan(*s)
972-
973-
def _sage_(self):
974-
import sage.all as sage
975-
s = self.args_as_sage()
976-
return sage.gt(*s)
977-
978-
func = __class__
979-
980-
981-
Gt = StrictGreaterThan
982-
983-
984953
class LessThan(Relational):
985954

986955
def __new__(cls, *args):
@@ -996,8 +965,6 @@ class LessThan(Relational):
996965
s = self.args_as_sage()
997966
return sage.le(*s)
998967

999-
func = __class__
1000-
1001968

1002969
Le = LessThan
1003970

@@ -1017,8 +984,6 @@ class StrictLessThan(Relational):
1017984
s = self.args_as_sage()
1018985
return sage.lt(*s)
1019986

1020-
func = __class__
1021-
1022987

1023988
Lt = StrictLessThan
1024989

@@ -2740,12 +2705,12 @@ pi = c2py(symengine.pi)
27402705
oo = c2py(symengine.Inf)
27412706
zoo = c2py(symengine.ComplexInf)
27422707
nan = c2py(symengine.Nan)
2743-
BooleanTrue = c2py(symengine.boolTrue)
2744-
BooleanFalse = c2py(symengine.boolFalse)
2708+
true = c2py(symengine.boolTrue)
2709+
false = c2py(symengine.boolFalse)
27452710

27462711
def module_cleanup():
2747-
global I, E, pi, oo, zoo, nan, BooleanTrue, BooleanFalse, sympy_module, sage_module
2748-
del I, E, pi, oo, zoo, nan, BooleanTrue, BooleanFalse, sympy_module, sage_module
2712+
global I, E, pi, oo, zoo, nan, true, false, sympy_module, sage_module
2713+
del I, E, pi, oo, zoo, nan, true, false, sympy_module, sage_module
27492714

27502715
import atexit
27512716
atexit.register(module_cleanup)
@@ -2813,11 +2778,15 @@ def ge(lhs, rhs):
28132778
cdef Basic Y = sympify(rhs)
28142779
return c2py(<RCP[const symengine.Basic]>(symengine.Ge(X.thisptr, Y.thisptr)))
28152780

2781+
Ge = GreaterThan = ge
2782+
28162783
def gt(lhs, rhs):
28172784
cdef Basic X = sympify(lhs)
28182785
cdef Basic Y = sympify(rhs)
28192786
return c2py(<RCP[const symengine.Basic]>(symengine.Gt(X.thisptr, Y.thisptr)))
28202787

2788+
Gt = StrictGreaterThan = gt
2789+
28212790
def le(lhs, rhs):
28222791
cdef Basic X = sympify(lhs)
28232792
cdef Basic Y = sympify(rhs)
@@ -3339,7 +3308,7 @@ cdef class _Lambdify(object):
33393308
cdef vector[int] accum_out_sizes
33403309
cdef object numpy_dtype
33413310

3342-
def __cinit__(self, args, *exprs, bool real=True):
3311+
def __cinit__(self, args, *exprs, cppbool real=True):
33433312
cdef vector[int] out_sizes
33443313
self.real = real
33453314
self.numpy_dtype = np.float64 if self.real else np.complex128
@@ -3355,7 +3324,7 @@ cdef class _Lambdify(object):
33553324
for j in range(i):
33563325
self.accum_out_sizes[i] += out_sizes[j]
33573326

3358-
def __init__(self, args, *exprs, bool real=True):
3327+
def __init__(self, args, *exprs, cppbool real=True):
33593328
cdef:
33603329
Basic e_
33613330
size_t ri, ci, nr, nc
@@ -3538,7 +3507,7 @@ IF HAVE_SYMENGINE_LLVM:
35383507
self.lambda_double[0].call(&out[out_offset], &inp[inp_offset])
35393508

35403509

3541-
def Lambdify(args, *exprs, bool real=True, backend=None):
3510+
def Lambdify(args, *exprs, cppbool real=True, backend=None):
35423511
if backend is None:
35433512
backend = os.getenv('SYMENGINE_LAMBDIFY_BACKEND', "lambda")
35443513
if backend == "llvm":

symengine/tests/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,5 +18,6 @@ install(FILES __init__.py
1818
test_var.py
1919
test_lambdify.py
2020
test_sympy_compat.py
21+
test_logic.py
2122
DESTINATION ${PY_PATH}
2223
)

symengine/tests/test_logic.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
from symengine.utilities import raises
2+
from symengine.lib.symengine_wrapper import (true, false, Eq, Ne,
3+
Ge, Gt, Le, Lt, Symbol, I)
4+
5+
x = Symbol("x")
6+
7+
def test_relationals():
8+
assert Eq(x, x) == true
9+
assert Eq(0, 0) == true
10+
assert Eq(1, 0) == false
11+
assert Ne(0, 0) == false
12+
assert Ne(1, 0) == true
13+
assert Lt(0, 1) == true
14+
assert Lt(1, 0) == false
15+
assert Le(0, 1) == true
16+
assert Le(1, 0) == false
17+
assert Le(0, 0) == true
18+
assert Gt(1, 0) == true
19+
assert Gt(0, 1) == false
20+
assert Ge(1, 0) == true
21+
assert Ge(0, 1) == false
22+
assert Ge(1, 1) == true
23+
assert Eq(I, 2) == false
24+
assert Ne(I, 2) == true

symengine/tests/test_sage.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from symengine import (Integer, symbols, sin, cos, pi, E, I, oo, zoo,
2-
nan, Add, function_symbol, DenseMatrix, sympify, log)
2+
nan, true, false, Add, function_symbol, DenseMatrix,
3+
sympify, log)
34
from symengine.lib.symengine_wrapper import (PyNumber, PyFunction,
45
sage_module, wrap_sage_function)
56

@@ -68,13 +69,15 @@ def test_sage_conversions():
6869
# For the following test, sage needs to be modified
6970
# assert sage.sin(x) == sage.sin(x1)
7071

71-
# Constants
72+
# Constants and Booleans
7273
assert pi._sage_() == sage.pi
7374
assert E._sage_() == sage.e
7475
assert I._sage_() == sage.I
7576
assert oo._sage_() == sage.oo
7677
assert zoo._sage_() == sage.unsigned_infinity
7778
assert nan._sage_() == sage.NaN
79+
assert true._sage_() == True
80+
assert false._sage_() == False
7881

7982
assert pi == sympify(sage.pi)
8083
assert E == sympify(sage.e)

symengine/tests/test_sympify.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from symengine.utilities import raises
22

3-
from symengine import Symbol, Integer, sympify, SympifyError
3+
from symengine import Symbol, Integer, sympify, SympifyError, true, false
44
from symengine.lib.symengine_wrapper import _sympify
55

66

@@ -10,6 +10,8 @@ def test_sympify1():
1010
assert sympify(-5) == Integer(-5)
1111
assert sympify(Integer(3)) == Integer(3)
1212
assert sympify("3+5") == Integer(8)
13+
assert true == sympify(True)
14+
assert false == sympify(False)
1315

1416

1517
def test_sympify_error1a():

symengine/tests/test_sympy_conv.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
from symengine import (Symbol, Integer, sympify, SympifyError, log,
2-
function_symbol, I, E, pi, oo, zoo, nan, exp, gamma, have_mpfr,
3-
have_mpc, DenseMatrix, sin, cos, tan, cot, csc, sec, asin, acos,
4-
atan, acot, acsc, asec, sinh, cosh, tanh, coth, asinh, acosh,
5-
atanh, acoth, Add, Mul, Pow, diff)
2+
function_symbol, I, E, pi, oo, zoo, nan, true, false,
3+
exp, gamma, have_mpfr, have_mpc, DenseMatrix, sin, cos, tan, cot,
4+
csc, sec, asin, acos, atan, acot, acsc, asec, sinh, cosh, tanh, coth,
5+
asinh, acosh, atanh, acoth, Add, Mul, Pow, diff)
66
from symengine.lib.symengine_wrapper import (Subs, Derivative, RealMPFR,
77
ComplexMPC, PyNumber, Function)
88
import sympy
@@ -380,6 +380,14 @@ def test_constants():
380380
assert sympy.nan == nan._sympy_()
381381

382382

383+
def test_booleans():
384+
assert sympify(sympy.S.true) == true
385+
assert sympy.S.true == true._sympy_()
386+
387+
assert sympify(sympy.S.false) == false
388+
assert sympy.S.false == false._sympy_()
389+
390+
383391
def test_abs():
384392
x = Symbol("x")
385393
e1 = abs(sympy.Symbol("x"))

0 commit comments

Comments
 (0)