Skip to content

Commit 3528389

Browse files
authored
Merge pull request #401 from isuruf/bool
Raise TypeError for bool(Booleans) except true, false
2 parents 2ae4422 + 59056f8 commit 3528389

File tree

3 files changed

+11
-2
lines changed

3 files changed

+11
-2
lines changed

symengine/lib/symengine_wrapper.pyx

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1419,6 +1419,9 @@ cdef class Boolean(Expr):
14191419
def logical_not(self):
14201420
return c2py(<rcp_const_basic>(deref(symengine.rcp_static_cast_Boolean(self.thisptr)).logical_not()))
14211421

1422+
def __bool__(self):
1423+
raise TypeError("cannot determine truth value of Boolean")
1424+
14221425

14231426
cdef class BooleanAtom(Boolean):
14241427

@@ -1524,6 +1527,10 @@ class Relational(Boolean):
15241527
def is_Relational(self):
15251528
return True
15261529

1530+
def __bool__(self):
1531+
raise TypeError("cannot determine truth value of Relational")
1532+
1533+
15271534
Rel = Relational
15281535

15291536

symengine/tests/test_eval.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ def test_eval_double2():
1616
x = Symbol("x")
1717
e = sin(x)**2 + sqrt(2)
1818
raises(RuntimeError, lambda: e.n(real=True))
19-
assert abs(e.n() - x**2 - 1.414) < 1e-3
19+
assert abs(e.n() - sin(x)**2.0 - 1.414) < 1e-3
2020

2121
def test_n():
2222
x = Symbol("x")

symengine/tests/test_logic.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ def test_And():
4444
assert And(True, False) == false
4545
assert And(False, False) == false
4646
assert And(True, True, True) == true
47+
raises(TypeError, lambda: x < y and y < 1)
4748

4849

4950
def test_Or():
@@ -54,6 +55,7 @@ def test_Or():
5455
assert Or(True, False) == true
5556
assert Or(False, False) == false
5657
assert Or(True, False, False) == true
58+
raises(TypeError, lambda: x < y or y < 1)
5759

5860

5961
def test_Nor():
@@ -116,4 +118,4 @@ def test_Contains():
116118
assert Contains(x, Interval(1, 1)) != false
117119
assert Contains(oo, Interval(-oo, oo)) == false
118120
assert Contains(-oo, Interval(-oo, oo)) == false
119-
121+

0 commit comments

Comments
 (0)