Skip to content

Commit 3630288

Browse files
committed
Use evalf from symengine
1 parent 2ca28c1 commit 3630288

File tree

3 files changed

+27
-41
lines changed

3 files changed

+27
-41
lines changed

symengine/lib/symengine.pxd

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,6 @@ cdef extern from "<symengine/basic.h>" namespace "SymEngine":
182182
unsigned int hash() nogil except +
183183
vec_basic get_args() nogil
184184
int __cmp__(const Basic &o) nogil
185-
ctypedef rcp_const_basic rcp_const_basic "SymEngine::RCP<const SymEngine::Basic>"
186185
ctypedef RCP[const Number] rcp_const_number "SymEngine::RCP<const SymEngine::Number>"
187186
ctypedef unordered_map[int, rcp_const_basic] umap_int_basic "SymEngine::umap_int_basic"
188187
ctypedef unordered_map[int, rcp_const_basic].iterator umap_int_basic_iterator "SymEngine::umap_int_basic::iterator"
@@ -958,6 +957,9 @@ cdef extern from "<utility>" namespace "std":
958957
cdef map_basic_basic std_move_map_basic_basic "std::move" (map_basic_basic) nogil
959958
cdef PiecewiseVec std_move_PiecewiseVec "std::move" (PiecewiseVec) nogil
960959

960+
cdef extern from "<symengine/eval.h>" namespace "SymEngine":
961+
rcp_const_basic evalf(const Basic &b, unsigned long bits, bool real) nogil except +
962+
961963
cdef extern from "<symengine/eval_double.h>" namespace "SymEngine":
962964
double eval_double(const Basic &b) nogil except +
963965
double complex eval_complex_double(const Basic &b) nogil except +

symengine/lib/symengine_wrapper.pyx

Lines changed: 20 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -915,11 +915,8 @@ cdef class Basic(object):
915915
symengine.as_real_imag(self.thisptr, symengine.outArg(_real), symengine.outArg(_imag))
916916
return c2py(<rcp_const_basic>_real), c2py(<rcp_const_basic>_imag)
917917

918-
def n(self, prec = 53, real = False):
919-
if real:
920-
return eval_real(self, prec)
921-
else:
922-
return eval(self, prec)
918+
def n(self, unsigned long prec = 53, cppbool real=False):
919+
return evalf(self, prec, real)
923920

924921
evalf = n
925922

@@ -3994,13 +3991,17 @@ def Xnor(*args):
39943991
v.push_back(symengine.rcp_static_cast_Boolean(e_.thisptr))
39953992
return c2py(<rcp_const_basic>(symengine.logical_xnor(v)))
39963993

3997-
def eval_double(x):
3994+
def evalf(x, unsigned long bits=53, cppbool real=False):
39983995
cdef Basic X = sympify(x)
3999-
return c2py(<rcp_const_basic>(symengine.real_double(symengine.eval_double(deref(X.thisptr)))))
3996+
return c2py(<rcp_const_basic>(symengine.evalf(deref(X.thisptr), bits, real)))
3997+
3998+
def eval_double(x):
3999+
warnings.warn("eval_double is deprecated. Use evalf(..., real=True)", DeprecationWarning)
4000+
return evalf(x, 53, real=True)
40004001

40014002
def eval_complex_double(x):
4002-
cdef Basic X = sympify(x)
4003-
return c2py(<rcp_const_basic>(symengine.complex_double(symengine.eval_complex_double(deref(X.thisptr)))))
4003+
warnings.warn("eval_complex_double is deprecated. Use evalf(..., real=False)", DeprecationWarning)
4004+
return evalf(x, 53, real=False)
40044005

40054006
have_mpfr = False
40064007
have_mpc = False
@@ -4010,19 +4011,15 @@ have_llvm = False
40104011

40114012
IF HAVE_SYMENGINE_MPFR:
40124013
have_mpfr = True
4013-
def eval_mpfr(x, long prec):
4014-
cdef Basic X = sympify(x)
4015-
cdef symengine.mpfr_class a = symengine.mpfr_class(prec)
4016-
symengine.eval_mpfr(a.get_mpfr_t(), deref(X.thisptr), symengine.MPFR_RNDN)
4017-
return c2py(<rcp_const_basic>(symengine.real_mpfr(symengine.std_move_mpfr(a))))
4014+
def eval_mpfr(x, unsigned long prec):
4015+
warnings.warn("eval_mpfr is deprecated. Use evalf(..., real=True)", DeprecationWarning)
4016+
return evalf(x, prec, real=True)
40184017

40194018
IF HAVE_SYMENGINE_MPC:
40204019
have_mpc = True
4021-
def eval_mpc(x, long prec):
4022-
cdef Basic X = sympify(x)
4023-
cdef symengine.mpc_class a = symengine.mpc_class(prec)
4024-
symengine.eval_mpc(a.get_mpc_t(), deref(X.thisptr), symengine.MPFR_RNDN)
4025-
return c2py(<rcp_const_basic>(symengine.complex_mpc(symengine.std_move_mpc(a))))
4020+
def eval_mpc(x, unsigned long prec):
4021+
warnings.warn("eval_mpc is deprecated. Use evalf(..., real=False)", DeprecationWarning)
4022+
return evalf(x, prec, real=True)
40264023

40274024
IF HAVE_SYMENGINE_PIRANHA:
40284025
have_piranha = True
@@ -4038,22 +4035,12 @@ def require(obj, t):
40384035
raise TypeError("{} required. {} is of type {}".format(t, obj, type(obj)))
40394036

40404037
def eval(x, long prec):
4041-
if prec <= 53:
4042-
return eval_complex_double(x)
4043-
else:
4044-
IF HAVE_SYMENGINE_MPC:
4045-
return eval_mpc(x, prec)
4046-
ELSE:
4047-
raise ValueError("Precision %s is only supported with MPC" % prec)
4038+
warnings.warn("eval is deprecated. Use evalf(..., real=False)", DeprecationWarning)
4039+
return evalf(x, prec, real=False)
40484040

40494041
def eval_real(x, long prec):
4050-
if prec <= 53:
4051-
return eval_double(x)
4052-
else:
4053-
IF HAVE_SYMENGINE_MPFR:
4054-
return eval_mpfr(x, prec)
4055-
ELSE:
4056-
raise ValueError("Precision %s is only supported with MPFR" % prec)
4042+
warnings.warn("eval_real is deprecated. Use evalf(..., real=True)", DeprecationWarning)
4043+
return evalf(x, prec, real=True)
40574044

40584045
def probab_prime_p(n, reps = 25):
40594046
cdef Basic _n = sympify(n)

symengine/tests/test_eval.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,24 @@
11
from symengine.utilities import raises
22
from symengine import (Symbol, sin, cos, Integer, Add, I, RealDouble, ComplexDouble, sqrt)
33

4-
from symengine.lib.symengine_wrapper import eval_double
54
from unittest.case import SkipTest
65

76
def test_eval_double1():
87
x = Symbol("x")
98
y = Symbol("y")
109
e = sin(x)**2 + cos(x)**2
1110
e = e.subs(x, 7)
12-
assert abs(eval_double(e) - 1) < 1e-9
11+
assert abs(e.n(real=True) - 1) < 1e-9
1312

1413

1514
def test_eval_double2():
1615
x = Symbol("x")
17-
y = Symbol("y")
18-
e = sin(x)**2 + cos(x)**2
19-
raises(RuntimeError, lambda: (abs(eval_double(e) - 1) < 1e-9))
20-
16+
e = sin(x)**2 + sqrt(2)
17+
assert abs(e.n(real=True) - x**2 - 1.414) < 1e-3
2118

2219
def test_n():
2320
x = Symbol("x")
24-
raises(RuntimeError, lambda: (x.n()))
21+
assert x.n(real=True) == x + 0.0
2522

2623
x = 2 + I
2724
raises(RuntimeError, lambda: (x.n(real=True)))

0 commit comments

Comments
 (0)