Skip to content

Commit 9dcb757

Browse files
isurufShikharJ
authored andcommitted
Wrap logic classes and functions
1 parent 9daa702 commit 9dcb757

File tree

2 files changed

+177
-3
lines changed

2 files changed

+177
-3
lines changed

symengine/lib/symengine.pxd

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,8 @@ cdef extern from "<symengine/symengine_rcp.h>" namespace "SymEngine":
156156
RCP[const BooleanAtom] rcp_static_cast_BooleanAtom "SymEngine::rcp_static_cast<const SymEngine::BooleanAtom>"(RCP[const Basic] &b) nogil
157157
RCP[const PyNumber] rcp_static_cast_PyNumber "SymEngine::rcp_static_cast<const SymEngine::PyNumber>"(RCP[const Basic] &b) nogil
158158
RCP[const PyFunction] rcp_static_cast_PyFunction "SymEngine::rcp_static_cast<const SymEngine::PyFunction>"(RCP[const Basic] &b) nogil
159+
RCP[const Boolean] rcp_static_cast_Boolean "SymEngine::rcp_static_cast<const SymEngine::Boolean>"(RCP[const Basic] &b) nogil
160+
RCP[const Set] rcp_static_cast_Set "SymEngine::rcp_static_cast<const SymEngine::Set>"(RCP[const Basic] &b) nogil
159161
Ptr[RCP[Basic]] outArg(RCP[const Basic] &arg) nogil
160162
Ptr[RCP[Integer]] outArg_Integer "SymEngine::outArg<SymEngine::RCP<const SymEngine::Integer>>"(RCP[const Integer] &arg) nogil
161163

@@ -292,7 +294,9 @@ cdef extern from "<symengine/basic.h>" namespace "SymEngine":
292294
bool is_a_Floor "SymEngine::is_a<SymEngine::Floor>"(const Basic &b) nogil
293295
bool is_a_Ceiling "SymEngine::is_a<SymEngine::Ceiling>"(const Basic &b) nogil
294296
bool is_a_Conjugate "SymEngine::is_a<SymEngine::Conjugate>"(const Basic &b) nogil
295-
297+
bool is_a_Interval "SymEngine::is_a<SymEngine::Interval>"(const Basic &b) nogil
298+
bool is_a_Piecewise "SymEngine::is_a<SymEngine::Piecewise>"(const Basic &b) nogil
299+
bool is_a_Contains "SymEngine::is_a<SymEngine::Contains>"(const Basic &b) nogil
296300
RCP[const Basic] expand(RCP[const Basic] &o) nogil except +
297301

298302
cdef extern from "<symengine/subs.h>" namespace "SymEngine":
@@ -385,7 +389,7 @@ cdef extern from "<symengine/constants.h>" namespace "SymEngine":
385389
RCP[const Basic] Inf
386390
RCP[const Basic] ComplexInf
387391
RCP[const Basic] Nan
388-
392+
389393
cdef extern from "<symengine/infinity.h>" namespace "SymEngine":
390394
cdef cppclass Infty(Number):
391395
pass
@@ -883,6 +887,10 @@ cdef extern from "<symengine/logic.h>" namespace "SymEngine":
883887
pass
884888
cdef cppclass StrictLessThan(Relational):
885889
pass
890+
cdef cppclass Piecewise(Basic):
891+
pass
892+
cdef cppclass Contains(Boolean):
893+
pass
886894

887895
RCP[const Basic] boolTrue
888896
RCP[const Basic] boolFalse
@@ -894,6 +902,11 @@ cdef extern from "<symengine/logic.h>" namespace "SymEngine":
894902
cdef RCP[const Boolean] Gt(RCP[const Basic] &lhs, RCP[const Basic] &rhs) nogil except+
895903
cdef RCP[const Boolean] Le(RCP[const Basic] &lhs, RCP[const Basic] &rhs) nogil except+
896904
cdef RCP[const Boolean] Lt(RCP[const Basic] &lhs, RCP[const Basic] &rhs) nogil except+
905+
ctypedef Boolean const_Boolean "const SymEngine::Boolean"
906+
ctypedef vector[pair[RCP[const_Basic], RCP[const_Boolean]]] PiecewiseVec;
907+
cdef RCP[const Basic] piecewise(PiecewiseVec vec) nogil except +
908+
cdef RCP[const Boolean] contains(RCP[const Basic] &expr,
909+
RCP[const Set] &set) nogil
897910

898911
cdef extern from "<utility>" namespace "std":
899912
cdef integer_class std_move_mpz "std::move" (integer_class) nogil
@@ -902,6 +915,7 @@ cdef extern from "<utility>" namespace "std":
902915
IF HAVE_SYMENGINE_MPC:
903916
cdef mpc_class std_move_mpc "std::move" (mpc_class) nogil
904917
cdef map_basic_basic std_move_map_basic_basic "std::move" (map_basic_basic) nogil
918+
cdef PiecewiseVec std_move_PiecewiseVec "std::move" (PiecewiseVec) nogil
905919

906920
cdef extern from "<symengine/eval_double.h>" namespace "SymEngine":
907921
double eval_double(const Basic &b) nogil except +
@@ -944,3 +958,11 @@ cdef extern from "<symengine/parser.h>" namespace "SymEngine":
944958

945959
cdef extern from "<symengine/codegen.h>" namespace "SymEngine":
946960
string ccode(const Basic &x) nogil except +
961+
962+
cdef extern from "<symengine/sets.h>" namespace "SymEngine":
963+
cdef cppclass Set(Basic):
964+
RCP[const Set] set_intersection(RCP[const Set] &o) nogil except +
965+
RCP[const Set] set_union(RCP[const Set] &o) nogil except +
966+
cdef cppclass Interval(Set):
967+
pass
968+
cdef RCP[const Basic] interval(RCP[const Number] &start, RCP[const Number] &end, bool l, bool r) nogil

symengine/lib/symengine_wrapper.pyx

Lines changed: 153 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,16 @@ cdef c2py(RCP[const symengine.Basic] o):
182182
r = Function.__new__(conjugate)
183183
elif (symengine.is_a_PyNumber(deref(o))):
184184
r = PyNumber.__new__(PyNumber)
185+
elif (symengine.is_a_Infty(deref(o))):
186+
r = Infinity.__new__(Infinity)
187+
elif (symengine.is_a_Piecewise(deref(o))):
188+
r = Piecewise.__new__(Piecewise)
189+
elif (symengine.is_a_Contains(deref(o))):
190+
r = Contains.__new__(Contains)
191+
elif (symengine.is_a_BooleanAtom(deref(o))):
192+
r = BooleanAtom.__new__(BooleanAtom)
193+
elif (symengine.is_a_Interval(deref(o))):
194+
r = Interval.__new__(Interval)
185195
else:
186196
raise Exception("Unsupported SymEngine class.")
187197
r.thisptr = o
@@ -354,6 +364,22 @@ def sympy2symengine(a, raise_error=False):
354364
elif isinstance(a, sympy_AppliedUndef):
355365
name = str(a.func)
356366
return function_symbol(name, *(a.args))
367+
elif a == sympy.S.NegativeInfinity:
368+
return -oo
369+
elif a == sympy.S.Infinity:
370+
return oo
371+
elif a == sympy.S.ComplexInfinity:
372+
return zoo
373+
elif a == sympy.S.true:
374+
return BooleanTrue
375+
elif a == sympy.S.false:
376+
return BooleanFalse
377+
elif isinstance(a, (sympy.Piecewise)):
378+
return piecewise(*(a.args))
379+
elif isinstance(a, (sympy.Interval)):
380+
return interval(*(a.args))
381+
elif isinstance(a, (sympy.Contains)):
382+
return contains(*(a.args))
357383
elif isinstance(a, sympy.Function):
358384
return PyFunction(a, a.args, a.func, sympy_module)
359385
elif isinstance(a, sympy.MatrixBase):
@@ -2409,6 +2435,42 @@ class Subs(Basic):
24092435
return self.__class__
24102436

24112437

2438+
cdef class Piecewise(Basic):
2439+
def _sympy_(self):
2440+
import sympy
2441+
a = self.args
2442+
l = []
2443+
for i in range(0, len(a), 2):
2444+
l.append((a[i]._sympy_(), a[i + 1]._sympy_()))
2445+
return sympy.Piecewise(*l)
2446+
2447+
2448+
cdef class Set(Basic):
2449+
def intersection(self, a):
2450+
cdef Set other = sympify(a)
2451+
cdef RCP[const symengine.Set] other_ = symengine.rcp_static_cast_Set(other.thisptr)
2452+
return c2py(<RCP[const symengine.Basic]>(deref(symengine.rcp_static_cast_Set(self.thisptr))
2453+
.set_intersection(other_)))
2454+
2455+
def union(self, a):
2456+
cdef Set other = sympify(a)
2457+
cdef RCP[const symengine.Set] other_ = symengine.rcp_static_cast_Set(other.thisptr)
2458+
return c2py(<RCP[const symengine.Basic]>(deref(symengine.rcp_static_cast_Set(self.thisptr))
2459+
.set_intersection(other_)))
2460+
2461+
2462+
cdef class Interval(Set):
2463+
def _sympy_(self):
2464+
import sympy
2465+
return sympy.Interval(*[arg._sympy_() for arg in self.args])
2466+
2467+
2468+
cdef class Contains(Boolean):
2469+
def _sympy_(self):
2470+
import sympy
2471+
return sympy.Contains(*[arg._sympy_() for arg in self.args])
2472+
2473+
24122474
cdef class MatrixBase:
24132475

24142476
@property
@@ -3701,7 +3763,6 @@ def powermod(a, b, m):
37013763
cdef RCP[const symengine.Integer] m1 = symengine.rcp_static_cast_Integer(_m.thisptr)
37023764
cdef RCP[const symengine.Number] b1 = symengine.rcp_static_cast_Number(_b.thisptr)
37033765
cdef RCP[const symengine.Integer] root
3704-
37053766
cdef cppbool ret_val = symengine.powermod(symengine.outArg_Integer(root), a1, b1, m1)
37063767
if ret_val == 0:
37073768
return None
@@ -4092,5 +4153,96 @@ def ccode(expr):
40924153
cdef Basic expr_ = sympify(expr)
40934154
return symengine.ccode(deref(expr_.thisptr)).decode("utf-8")
40944155

4156+
4157+
def to_contains(relational):
4158+
from sympy import (And, Or, Not, Intersection, Union, Complement,
4159+
S, GreaterThan, LessThan, StrictLessThan,
4160+
StrictGreaterThan, Interval)
4161+
if isinstance(relational, And):
4162+
l = [to_contains(i) for i in relational.args]
4163+
ex = l[0].args[0]
4164+
if all(elem.args[0] == ex for elem in l):
4165+
sset = l[0].args[1]
4166+
for elem in l:
4167+
sset = sset.intersection(elem.args[1])
4168+
return contains(ex, sset)
4169+
else:
4170+
ValueError('Relational {} cannot be converted to a Contains'.format(relational))
4171+
elif isinstance(relational, Or):
4172+
l = [to_contains(i) for i in relational.args]
4173+
ex = l[0].args[0]
4174+
if all(elem.args[0] == ex for elem in l):
4175+
sset = l[0].args[1]
4176+
for elem in l:
4177+
sset = sset.union(elem.args[1])
4178+
return contains(ex, sset)
4179+
else:
4180+
ValueError('Relational {} cannot be converted to a Contains'.format(relational))
4181+
elif isinstance(relational, Not):
4182+
elem = to_contains(relational.args[0])
4183+
return contains(elem.args[0], Complement(S.Reals, elem.args[1]._sympy_()))
4184+
if relational == S.true or relational == S.false:
4185+
return sympify(relational)
4186+
4187+
if len(relational.args) != 2:
4188+
raise ValueError('Relational must only have two arguments')
4189+
4190+
lhs = relational.args[0]
4191+
rhs = relational.args[1]
4192+
if isinstance(relational, GreaterThan):
4193+
if rhs.is_Number:
4194+
return contains(lhs, interval(rhs, oo, left_open=False))
4195+
else:
4196+
return contains(rhs, interval(-oo, rhs, right_open=False))
4197+
elif isinstance(relational, StrictGreaterThan):
4198+
if rhs.is_Number:
4199+
return contains(lhs, interval(rhs, oo, left_open=True))
4200+
else:
4201+
return contains(rhs, interval(-oo, rhs, right_open=True))
4202+
elif isinstance(relational, LessThan):
4203+
if rhs.is_Number:
4204+
return contains(lhs, interval(rhs, oo, left_open=False))
4205+
else:
4206+
return contains(rhs, interval(-oo, rhs, right_open=False))
4207+
elif isinstance(relational, StrictLessThan):
4208+
if rhs.is_Number:
4209+
return contains(lhs, interval(rhs, oo, left_open=True))
4210+
else:
4211+
return contains(rhs, interval(-oo, rhs, right_open=True))
4212+
else:
4213+
raise ValueError('Unsupported Relational: {}'.format(relational.__class__.__name__))
4214+
4215+
4216+
def piecewise(*v):
4217+
cdef symengine.PiecewiseVec vec
4218+
cdef pair[RCP[symengine.const_Basic], RCP[symengine.const_Boolean]] p
4219+
cdef Basic e
4220+
cdef Boolean b
4221+
for expr, rel in v:
4222+
e = sympify(expr)
4223+
b = sympify(to_contains(rel))
4224+
p.first = <RCP[symengine.const_Basic]>e.thisptr
4225+
p.second = <RCP[symengine.const_Boolean]>symengine.rcp_static_cast_Boolean(b.thisptr)
4226+
vec.push_back(p)
4227+
return c2py(symengine.piecewise(symengine.std_move_PiecewiseVec(vec)))
4228+
4229+
4230+
def interval(start, end, left_open=False, right_open=False):
4231+
cdef Number start_ = sympify(start)
4232+
cdef Number end_ = sympify(end)
4233+
cdef cppbool left_open_ = left_open
4234+
cdef cppbool right_open_ = right_open
4235+
cdef RCP[const symengine.Number] n1 = symengine.rcp_static_cast_Number(start_.thisptr)
4236+
cdef RCP[const symengine.Number] n2 = symengine.rcp_static_cast_Number(end_.thisptr)
4237+
return c2py(symengine.interval(n1, n2, left_open_, right_open_))
4238+
4239+
4240+
def contains(expr, sset):
4241+
cdef Basic expr_ = sympify(expr)
4242+
cdef Set sset_ = sympify(sset)
4243+
cdef RCP[const symengine.Set] s = symengine.rcp_static_cast_Set(sset_.thisptr)
4244+
return c2py(<RCP[const symengine.Basic]>symengine.contains(expr_.thisptr, s))
4245+
4246+
40954247
# Turn on nice stacktraces:
40964248
symengine.print_stack_on_segfault()

0 commit comments

Comments
 (0)