Skip to content

Commit 7dd26e0

Browse files
authored
Merge pull request #300 from ichumuh/master
exposed llvm lambdify opt_level in python api
2 parents ef36dc2 + 602563b commit 7dd26e0

File tree

4 files changed

+31
-7
lines changed

4 files changed

+31
-7
lines changed

symengine/lib/symengine.pxd

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -986,7 +986,7 @@ cdef extern from "<symengine/lambda_double.h>" namespace "SymEngine":
986986
cdef extern from "<symengine/llvm_double.h>" namespace "SymEngine":
987987
cdef cppclass LLVMDoubleVisitor:
988988
LLVMDoubleVisitor() nogil
989-
void init(const vec_basic &x, const vec_basic &b, bool cse) nogil except +
989+
void init(const vec_basic &x, const vec_basic &b, bool cse, int opt_level) nogil except +
990990
void call(double *r, const double *x) nogil
991991
const string& dumps() nogil
992992
void loads(const string&) nogil

symengine/lib/symengine_wrapper.pxd

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ cdef class LambdaDouble(_Lambdify):
5858

5959
IF HAVE_SYMENGINE_LLVM:
6060
cdef class LLVMDouble(_Lambdify):
61+
cdef int opt_level
6162
cdef vector[symengine.LLVMDoubleVisitor] lambda_double
6263
cdef _init(self, symengine.vec_basic& args_, symengine.vec_basic& outs_, cppbool cse)
6364
cdef _load(self, const string &s)

symengine/lib/symengine_wrapper.pyx

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4469,7 +4469,7 @@ def has_symbol(obj, symbol=None):
44694469

44704470

44714471
cdef class _Lambdify(object):
4472-
def __init__(self, args, *exprs, cppbool real=True, order='C', cppbool cse=False, cppbool _load=False):
4472+
def __init__(self, args, *exprs, cppbool real=True, order='C', cppbool cse=False, cppbool _load=False, **kwargs):
44734473
cdef:
44744474
Basic e_
44754475
size_t ri, ci, nr, nc
@@ -4706,6 +4706,10 @@ def create_low_level_callable(lambdify, *args):
47064706

47074707

47084708
cdef class LambdaDouble(_Lambdify):
4709+
def __cinit__(self, args, *exprs, cppbool real=True, order='C', cppbool cse=False, cppbool _load=False):
4710+
# reject additional arguments
4711+
pass
4712+
47094713
cdef _init(self, symengine.vec_basic& args_, symengine.vec_basic& outs_, cppbool cse):
47104714
if self.real:
47114715
self.lambda_double.resize(1)
@@ -4751,9 +4755,12 @@ cdef class LambdaDouble(_Lambdify):
47514755

47524756
IF HAVE_SYMENGINE_LLVM:
47534757
cdef class LLVMDouble(_Lambdify):
4758+
def __cinit__(self, args, *exprs, cppbool real=True, order='C', cppbool cse=False, cppbool _load=False, opt_level=3):
4759+
self.opt_level = opt_level
4760+
47544761
cdef _init(self, symengine.vec_basic& args_, symengine.vec_basic& outs_, cppbool cse):
47554762
self.lambda_double.resize(1)
4756-
self.lambda_double[0].init(args_, outs_, cse)
4763+
self.lambda_double[0].init(args_, outs_, cse, self.opt_level)
47574764

47584765
cdef _load(self, const string &s):
47594766
self.lambda_double.resize(1)
@@ -4801,7 +4808,7 @@ IF HAVE_SYMENGINE_LLVM:
48014808
def llvm_loading_func(*args):
48024809
return LLVMDouble(args, _load=True)
48034810

4804-
def Lambdify(args, *exprs, cppbool real=True, backend=None, order='C', as_scipy=False, cse=False):
4811+
def Lambdify(args, *exprs, cppbool real=True, backend=None, order='C', as_scipy=False, cse=False, **kwargs):
48054812
"""
48064813
Lambdify instances are callbacks that numerically evaluate their symbolic
48074814
expressions from user provided input (real or complex) into (possibly user
@@ -4851,7 +4858,7 @@ def Lambdify(args, *exprs, cppbool real=True, backend=None, order='C', as_scipy=
48514858
backend = os.getenv('SYMENGINE_LAMBDIFY_BACKEND', "lambda")
48524859
if backend == "llvm":
48534860
IF HAVE_SYMENGINE_LLVM:
4854-
ret = LLVMDouble(args, *exprs, real=real, order=order, cse=cse)
4861+
ret = LLVMDouble(args, *exprs, real=real, order=order, cse=cse, **kwargs)
48554862
if as_scipy:
48564863
return ret.as_scipy_low_level_callable()
48574864
return ret
@@ -4862,7 +4869,7 @@ def Lambdify(args, *exprs, cppbool real=True, backend=None, order='C', as_scipy=
48624869
pass
48634870
else:
48644871
warnings.warn("Unknown SymEngine backend: %s\nUsing backend='lambda'" % backend)
4865-
ret = LambdaDouble(args, *exprs, real=real, order=order, cse=cse)
4872+
ret = LambdaDouble(args, *exprs, real=real, order=order, cse=cse, **kwargs)
48664873
if as_scipy:
48674874
return ret.as_scipy_low_level_callable()
48684875
return ret

symengine/tests/test_lambdify.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,10 @@ def test_Lambdify():
7575
assert allclose(L(range(n, n+len(args))),
7676
[3*n+3, n**2, -1/(n+2), n*(n+1)*(n+2)])
7777

78+
@unittest.skipUnless(have_numpy, "Numpy not installed")
79+
def test_Lambdify_with_opt_level():
80+
args = x, y, z = se.symbols('x y z')
81+
raises(TypeError, lambda: se.Lambdify(args, [x+y+z, x**2, (x-y)/z, x*y*z], backend='lambda', opt_level=0))
7882

7983
def _test_Lambdify_Piecewise(Lambdify):
8084
x = se.symbols('x')
@@ -91,7 +95,6 @@ def test_Lambdify_Piecewise():
9195
if se.have_llvm:
9296
_test_Lambdify_Piecewise(lambda *args: se.Lambdify(*args, backend='llvm'))
9397

94-
9598
@unittest.skipUnless(have_numpy, "Numpy not installed")
9699
def test_Lambdify_LLVM():
97100
n = 7
@@ -105,6 +108,19 @@ def test_Lambdify_LLVM():
105108
assert allclose(L(range(n, n+len(args))),
106109
[3*n+3, n**2, -1/(n+2), n*(n+1)*(n+2)])
107110

111+
@unittest.skipUnless(have_numpy, "Numpy not installed")
112+
def test_Lambdify_LLVM_with_opt_level():
113+
for opt_level in range(4):
114+
n = 7
115+
args = x, y, z = se.symbols('x y z')
116+
if not se.have_llvm:
117+
raises(ValueError, lambda: se.Lambdify(args, [x+y+z, x**2,
118+
(x-y)/z, x*y*z],
119+
backend='llvm', opt_level=opt_level))
120+
raise SkipTest("No LLVM support")
121+
L = se.Lambdify(args, [x+y+z, x**2, (x-y)/z, x*y*z], backend='llvm', opt_level=opt_level)
122+
assert allclose(L(range(n, n+len(args))),
123+
[3*n+3, n**2, -1/(n+2), n*(n+1)*(n+2)])
108124

109125
def _get_2_to_2by2():
110126
args = x, y = se.symbols('x y')

0 commit comments

Comments
 (0)