Skip to content

Commit 699d49e

Browse files
committed
Wrapped Sign, Floor, Ceiling and Conjugate classes
1 parent 92bdf4e commit 699d49e

File tree

7 files changed

+283
-9
lines changed

7 files changed

+283
-9
lines changed

symengine/functions.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,9 @@
22
asin, acos, atan, acot, acsc, asec,
33
sinh, cosh, tanh, coth, sech, csch,
44
asinh, acosh, atanh, acoth, asech, acsch,
5-
gamma, log, atan2, sqrt, exp, Abs)
5+
gamma, log, atan2, sqrt, exp, Abs,
6+
LambertW, zeta, dirichlet_eta,
7+
KroneckerDelta, LeviCivita, erf, erfc,
8+
lowergamma, uppergamma, loggamma, beta,
9+
polygamma, sign, floor, ceiling,
10+
conjugate, digamma, trigamma)

symengine/lib/symengine.pxd

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -287,6 +287,10 @@ cdef extern from "<symengine/basic.h>" namespace "SymEngine":
287287
bool is_a_Beta "SymEngine::is_a<SymEngine::Beta>"(const Basic &b) nogil
288288
bool is_a_PolyGamma "SymEngine::is_a<SymEngine::PolyGamma>"(const Basic &b) nogil
289289
bool is_a_PySymbol "SymEngine::is_a_sub<SymEngine::PySymbol>"(const Basic &b) nogil
290+
bool is_a_Sign "SymEngine::is_a<SymEngine::Sign>"(const Basic &b) nogil
291+
bool is_a_Floor "SymEngine::is_a<SymEngine::Floor>"(const Basic &b) nogil
292+
bool is_a_Ceiling "SymEngine::is_a<SymEngine::Ceiling>"(const Basic &b) nogil
293+
bool is_a_Conjugate "SymEngine::is_a<SymEngine::Conjugate>"(const Basic &b) nogil
290294

291295
RCP[const Basic] expand(RCP[const Basic] &o) nogil except +
292296

@@ -491,6 +495,10 @@ cdef extern from "<symengine/functions.h>" namespace "SymEngine":
491495
cdef RCP[const Basic] polygamma(RCP[const Basic] &n, RCP[const Basic] &x) nogil except+
492496
cdef RCP[const Basic] digamma(RCP[const Basic] &x) nogil except+
493497
cdef RCP[const Basic] trigamma(RCP[const Basic] &x) nogil except+
498+
cdef RCP[const Basic] sign(RCP[const Basic] &x) nogil except+
499+
cdef RCP[const Basic] floor(RCP[const Basic] &x) nogil except+
500+
cdef RCP[const Basic] ceiling(RCP[const Basic] &x) nogil except+
501+
cdef RCP[const Basic] conjugate(RCP[const Basic] &x) nogil except+
494502

495503

496504
cdef cppclass Function(Basic):
@@ -647,6 +655,17 @@ cdef extern from "<symengine/functions.h>" namespace "SymEngine":
647655
cdef cppclass PolyGamma(Function):
648656
pass
649657

658+
cdef cppclass Sign(OneArgFunction):
659+
pass
660+
661+
cdef cppclass Floor(OneArgFunction):
662+
pass
663+
664+
cdef cppclass Ceiling(OneArgFunction):
665+
pass
666+
667+
cdef cppclass Conjugate(OneArgFunction):
668+
pass
650669

651670
IF HAVE_SYMENGINE_MPFR:
652671
cdef extern from "mpfr.h":

symengine/lib/symengine_wrapper.pyx

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,14 @@ cdef c2py(RCP[const symengine.Basic] o):
170170
r = Function.__new__(beta)
171171
elif (symengine.is_a_PolyGamma(deref(o))):
172172
r = Function.__new__(polygamma)
173+
elif (symengine.is_a_Sign(deref(o))):
174+
r = Function.__new__(sign)
175+
elif (symengine.is_a_Floor(deref(o))):
176+
r = Function.__new__(floor)
177+
elif (symengine.is_a_Ceiling(deref(o))):
178+
r = Function.__new__(ceiling)
179+
elif (symengine.is_a_Conjugate(deref(o))):
180+
r = Function.__new__(conjugate)
173181
elif (symengine.is_a_PyNumber(deref(o))):
174182
r = PyNumber.__new__(PyNumber)
175183
else:
@@ -325,6 +333,14 @@ def sympy2symengine(a, raise_error=False):
325333
return beta(*a.args)
326334
elif isinstance(a, sympy.polygamma):
327335
return polygamma(*a.args)
336+
elif isinstance(a, sympy.sign):
337+
return sign(a.args[0])
338+
elif isinstance(a, sympy.floor):
339+
return floor(a.args[0])
340+
elif isinstance(a, sympy.ceiling):
341+
return ceiling(a.args[0])
342+
elif isinstance(a, sympy.conjugate):
343+
return conjugate(a.args[0])
328344
elif isinstance(a, sympy.gamma):
329345
return gamma(a.args[0])
330346
elif isinstance(a, sympy.Derivative):
@@ -1729,6 +1745,30 @@ class polygamma(Function):
17291745
import sympy
17301746
return sympy.polygamma(*self.args_as_sympy())
17311747

1748+
class sign(OneArgFunction):
1749+
def __new__(cls, x):
1750+
cdef Basic X = sympify(x)
1751+
return c2py(symengine.sign(X.thisptr))
1752+
1753+
class floor(OneArgFunction):
1754+
def __new__(cls, x):
1755+
cdef Basic X = sympify(x)
1756+
return c2py(symengine.floor(X.thisptr))
1757+
1758+
class ceiling(OneArgFunction):
1759+
def __new__(cls, x):
1760+
cdef Basic X = sympify(x)
1761+
return c2py(symengine.ceiling(X.thisptr))
1762+
1763+
def _sage_(self):
1764+
import sage.all as sage
1765+
return sage.ceil(self.get_arg()._sage_())
1766+
1767+
class conjugate(OneArgFunction):
1768+
def __new__(cls, x):
1769+
cdef Basic X = sympify(x)
1770+
return c2py(symengine.conjugate(X.thisptr))
1771+
17321772
class log(OneArgFunction):
17331773
def __new__(cls, x, y=None):
17341774
cdef Basic X = sympify(x)

symengine/tests/test_functions.py

Lines changed: 172 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
1-
from symengine import Symbol, sin, cos, sqrt, Add, Mul, function_symbol, Integer, log, E, symbols
1+
from symengine import (Symbol, sin, cos, sqrt, Add, Mul, function_symbol, Integer, log, E, symbols, I,
2+
Rational)
23
from symengine.lib.symengine_wrapper import (Subs, Derivative, LambertW, zeta, dirichlet_eta,
34
zoo, pi, KroneckerDelta, LeviCivita, erf, erfc,
45
oo, lowergamma, uppergamma, exp, loggamma, beta,
5-
polygamma, digamma, trigamma, EulerGamma)
6+
polygamma, digamma, trigamma, EulerGamma, sign,
7+
floor, ceiling, conjugate, nan, Float)
68

79

810
def test_sin():
@@ -140,64 +142,230 @@ def test_log():
140142
assert log(x, x) == 1
141143
assert log(x, y) == log(x) / log(y)
142144

145+
143146
def test_lambertw():
144147
assert LambertW(0) == 0
145148
assert LambertW(E) == 1
149+
assert LambertW(-1/E) == -1
150+
assert LambertW(-log(2)/2) == -log(2)
151+
146152

147153
def test_zeta():
148154
x = Symbol("x")
149155
assert zeta(1) == zoo
150156
assert zeta(1, x) == zoo
151157

158+
assert zeta(0) == Rational(-1, 2)
159+
assert zeta(0, x) == Rational(1, 2) - x
160+
161+
assert zeta(1, 2) == zoo
162+
assert zeta(1, -7) == zoo
163+
164+
assert zeta(2, 1) == pi**2/6
165+
166+
assert zeta(2) == pi**2/6
167+
assert zeta(4) == pi**4/90
168+
assert zeta(6) == pi**6/945
169+
170+
assert zeta(2, 2) == pi**2/6 - 1
171+
assert zeta(4, 3) == pi**4/90 - Rational(17, 16)
172+
assert zeta(6, 4) == pi**6/945 - Rational(47449, 46656)
173+
174+
assert zeta(-1) == -Rational(1, 12)
175+
assert zeta(-2) == 0
176+
assert zeta(-3) == Rational(1, 120)
177+
assert zeta(-4) == 0
178+
assert zeta(-5) == -Rational(1, 252)
179+
180+
assert zeta(-1, 3) == -Rational(37, 12)
181+
assert zeta(-1, 7) == -Rational(253, 12)
182+
assert zeta(-1, -4) == Rational(119, 12)
183+
assert zeta(-1, -9) == Rational(539, 12)
184+
185+
assert zeta(-4, 3) == -17
186+
assert zeta(-4, -8) == 8772
187+
188+
assert zeta(0, 1) == -Rational(1, 2)
189+
assert zeta(0, -1) == Rational(3, 2)
190+
191+
assert zeta(0, 2) == -Rational(3, 2)
192+
assert zeta(0, -2) == Rational(5, 2)
193+
194+
152195
def test_dirichlet_eta():
196+
assert dirichlet_eta(0) == Rational(1, 2)
197+
assert dirichlet_eta(-1) == Rational(1, 4)
153198
assert dirichlet_eta(1) == log(2)
154199
assert dirichlet_eta(2) == pi**2/12
200+
assert dirichlet_eta(4) == pi**4*Rational(7, 720)
201+
155202

156203
def test_kronecker_delta():
157204
x = Symbol("x")
205+
y = Symbol("y")
158206
assert KroneckerDelta(1, 1) == 1
159207
assert KroneckerDelta(1, 2) == 0
160208
assert KroneckerDelta(x, x) == 1
209+
assert KroneckerDelta(x**2 - y**2, x**2 - y**2) == 1
210+
assert KroneckerDelta(0, 0) == 1
211+
assert KroneckerDelta(0, 1) == 0
212+
161213

162214
def test_levi_civita():
215+
i = Symbol("i")
216+
j = Symbol("j")
163217
assert LeviCivita(1, 2, 3) == 1
164218
assert LeviCivita(1, 3, 2) == -1
165219
assert LeviCivita(1, 2, 2) == 0
220+
assert LeviCivita(i, j, i) == 0
221+
assert LeviCivita(1, i, i) == 0
222+
assert LeviCivita(1, 2, 3, 1) == 0
223+
assert LeviCivita(4, 5, 1, 2, 3) == 1
224+
assert LeviCivita(4, 5, 2, 1, 3) == -1
225+
166226

167227
def test_erf():
168-
assert erf(0) == 0
228+
x = Symbol("x")
229+
y = Symbol("y")
230+
assert erf(nan) == nan
169231
assert erf(oo) == 1
232+
assert erf(-oo) == -1
233+
assert erf(0) == 0
234+
assert erf(-2) == -erf(2)
235+
assert erf(-x*y) == -erf(x*y)
236+
assert erf(-x - y) == -erf(x + y)
237+
170238

171239
def test_erfc():
172-
assert erfc(0) == 1
240+
assert erfc(nan) == nan
173241
assert erfc(oo) == 0
242+
assert erfc(-oo) == 2
243+
assert erfc(0) == 1
244+
174245

175246
def test_lowergamma():
176247
assert lowergamma(1, 2) == 1 - exp(-2)
177248

249+
178250
def test_uppergamma():
179251
assert uppergamma(1, 2) == exp(-2)
180252
assert uppergamma(4, 0) == 6
181253

254+
182255
def test_loggamma():
183256
assert loggamma(-1) == oo
257+
assert loggamma(-2) == oo
184258
assert loggamma(0) == oo
185259
assert loggamma(1) == 0
260+
assert loggamma(2) == 0
186261
assert loggamma(3) == log(2)
187262

263+
188264
def test_beta():
189265
assert beta(3, 2) == beta(2, 3)
190266

267+
191268
def test_polygamma():
269+
assert polygamma(0, -9) == zoo
270+
assert polygamma(0, -9) == zoo
271+
assert polygamma(0, -1) == zoo
192272
assert polygamma(0, 0) == zoo
273+
assert polygamma(0, 1) == -EulerGamma
274+
assert polygamma(0, 7) == Rational(49, 20) - EulerGamma
275+
assert polygamma(1, 1) == pi**2/6
276+
assert polygamma(1, 2) == pi**2/6 - 1
277+
assert polygamma(1, 3) == pi**2/6 - Rational(5, 4)
278+
assert polygamma(3, 1) == pi**4 / 15
279+
assert polygamma(3, 5) == 6*(Rational(-22369, 20736) + pi**4/90)
280+
assert polygamma(5, 1) == 8 * pi**6 / 63
281+
193282

194283
def test_digamma():
195284
x = Symbol("x")
196285
assert digamma(x) == polygamma(0, x)
197286
assert digamma(0) == zoo
198287
assert digamma(1) == -EulerGamma
199288

289+
200290
def test_trigamma():
201291
x = Symbol("x")
202292
assert trigamma(-2) == zoo
203293
assert trigamma(x) == polygamma(1, x)
294+
295+
296+
def test_sign():
297+
assert sign(1.2) == 1
298+
assert sign(-1.2) == -1
299+
assert sign(3*I) == I
300+
assert sign(-3*I) == -I
301+
assert sign(0) == 0
302+
assert sign(nan) == nan
303+
304+
305+
def test_floor():
306+
x = Symbol("x")
307+
y = Symbol("y")
308+
assert floor(nan) == nan
309+
assert floor(oo) == oo
310+
assert floor(-oo) == -oo
311+
assert floor(0) == 0
312+
assert floor(1) == 1
313+
assert floor(-1) == -1
314+
assert floor(E) == 2
315+
assert floor(pi) == 3
316+
assert floor(Rational(1, 2)) == 0
317+
assert floor(-Rational(1, 2)) == -1
318+
assert floor(Rational(7, 3)) == 2
319+
assert floor(-Rational(7, 3)) == -3
320+
assert floor(Float(17.0)) == 17
321+
assert floor(-Float(17.0)) == -17
322+
assert floor(Float(7.69)) == 7
323+
assert floor(-Float(7.69)) == -8
324+
assert floor(I) == I
325+
assert floor(-I) == -I
326+
assert floor(2*I) == 2*I
327+
assert floor(-2*I) == -2*I
328+
assert floor(E + pi) == floor(E + pi)
329+
assert floor(I + pi) == floor(I + pi)
330+
assert floor(floor(pi)) == 3
331+
assert floor(floor(y)) == floor(y)
332+
assert floor(floor(x)) == floor(floor(x))
333+
assert floor(x) == floor(x)
334+
assert floor(2*x) == floor(2*x)
335+
336+
337+
def test_ceiling():
338+
x = Symbol("x")
339+
y = Symbol("y")
340+
assert ceiling(nan) == nan
341+
assert ceiling(oo) == oo
342+
assert ceiling(-oo) == -oo
343+
assert ceiling(0) == 0
344+
assert ceiling(1) == 1
345+
assert ceiling(-1) == -1
346+
assert ceiling(E) == 3
347+
assert ceiling(pi) == 4
348+
assert ceiling(Rational(1, 2)) == 1
349+
assert ceiling(-Rational(1, 2)) == 0
350+
assert ceiling(Rational(7, 3)) == 3
351+
assert ceiling(-Rational(7, 3)) == -2
352+
assert ceiling(Float(17.0)) == 17
353+
assert ceiling(-Float(17.0)) == -17
354+
assert ceiling(Float(7.69)) == 8
355+
assert ceiling(-Float(7.69)) == -7
356+
assert ceiling(I) == I
357+
assert ceiling(-I) == -I
358+
assert ceiling(2*I) == 2*I
359+
assert ceiling(-2*I) == -2*I
360+
assert ceiling(E + pi) == ceiling(E + pi)
361+
assert ceiling(I + pi) == ceiling(I + pi)
362+
assert ceiling(ceiling(pi)) == 4
363+
assert ceiling(ceiling(y)) == ceiling(y)
364+
assert ceiling(ceiling(x)) == ceiling(ceiling(x))
365+
assert ceiling(x) == ceiling(x)
366+
assert ceiling(2*x) == ceiling(2*x)
367+
368+
369+
def test_conjugate():
370+
assert conjugate(pi) == pi
371+
assert conjugate(I) == -I

symengine/tests/test_sage.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33
sympify, log)
44
from symengine.lib.symengine_wrapper import (PyNumber, PyFunction,
55
sage_module, wrap_sage_function, Catalan, GoldenRatio, EulerGamma,
6-
LambertW, KroneckerDelta, erf, lowergamma,
7-
uppergamma, loggamma, beta)
6+
LambertW, KroneckerDelta, erf, lowergamma, uppergamma, loggamma,
7+
beta, floor, ceiling, conjugate)
88

99

1010
def test_sage_conversions():
@@ -89,6 +89,15 @@ def test_sage_conversions():
8989
assert beta(x1, y1) == beta(x, y)
9090
assert beta(x1, y1)._sage_() == sage.beta(x, y)
9191

92+
assert floor(x1) == floor(x)
93+
assert floor(x1)._sage_() == sage.floor(x)
94+
95+
assert ceiling(x1) == ceiling(x)
96+
assert ceiling(x1)._sage_() == sage.ceil(x)
97+
98+
assert conjugate(x1) == conjugate(x)
99+
assert conjugate(x1)._sage_() == sage.conjugate(x)
100+
92101
# For the following test, sage needs to be modified
93102
# assert sage.sin(x) == sage.sin(x1)
94103

0 commit comments

Comments
 (0)