Skip to content

Commit ab87920

Browse files
committed
SciPy LowLevelCallable from Lambdify
1 parent cef2fe9 commit ab87920

File tree

3 files changed

+77
-6
lines changed

3 files changed

+77
-6
lines changed

symengine/lib/symengine.pxd

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -961,17 +961,17 @@ cdef extern from "<symengine/lambda_double.h>" namespace "SymEngine":
961961
cdef cppclass LambdaRealDoubleVisitor:
962962
LambdaRealDoubleVisitor() nogil
963963
void init(const vec_basic &x, const vec_basic &b) nogil except +
964-
double call(double *r, const double *x) nogil except +
964+
void call(double *r, const double *x) nogil
965965
cdef cppclass LambdaComplexDoubleVisitor:
966966
LambdaComplexDoubleVisitor() nogil
967967
void init(const vec_basic &x, const vec_basic &b) nogil except +
968-
double complex call(double complex *r, const double complex *x) nogil except +
968+
void call(double complex *r, const double complex *x) nogil
969969

970970
cdef extern from "<symengine/llvm_double.h>" namespace "SymEngine":
971971
cdef cppclass LLVMDoubleVisitor:
972972
LLVMDoubleVisitor() nogil
973973
void init(const vec_basic &x, const vec_basic &b) nogil except +
974-
double call(double *r, const double *x) nogil except +
974+
void call(double *r, const double *x) nogil
975975

976976
cdef extern from "<symengine/series.h>" namespace "SymEngine":
977977
cdef cppclass SeriesCoeffInterface:

symengine/lib/symengine_wrapper.pyx

Lines changed: 57 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4560,6 +4560,31 @@ cdef class _Lambdify(object):
45604560
return result
45614561

45624562

4563+
cdef double _scipy_callback_lambda_real(int n, double *x, void *user_data):
4564+
cdef symengine.LambdaRealDoubleVisitor* lamb = <symengine.LambdaRealDoubleVisitor *>user_data
4565+
cdef double result
4566+
deref(lamb).call(&result, x)
4567+
return result
4568+
4569+
4570+
IF HAVE_SYMENGINE_LLVM:
4571+
cdef double _scipy_callback_llvm_real(int n, double *x, void *user_data):
4572+
cdef symengine.LLVMRealDoubleVisitor* lamb = <symengine.LLVMDoubleVisitor *>user_data
4573+
cdef double result
4574+
deref(lamb).call(&result, x)
4575+
return result
4576+
4577+
4578+
def create_low_level_callable(lambdify, *args):
4579+
from scipy import LowLevelCallable
4580+
class LambdifyLowLevelCallable(LowLevelCallable):
4581+
def __init__(self, lambdify, *args):
4582+
self.lambdify = lambdify
4583+
def __new__(cls, value, *args, **kwargs):
4584+
return super(LambdifyLowLevelCallable, cls).__new__(cls, *args)
4585+
return LambdifyLowLevelCallable(lambdify, *args)
4586+
4587+
45634588
cdef class LambdaDouble(_Lambdify):
45644589

45654590
cdef vector[symengine.LambdaRealDoubleVisitor] lambda_double
@@ -4579,6 +4604,16 @@ cdef class LambdaDouble(_Lambdify):
45794604
cpdef unsafe_complex(self, double complex[::1] inp, double complex[::1] out, int inp_offset=0, int out_offset=0):
45804605
self.lambda_double_complex[0].call(&out[out_offset], &inp[inp_offset])
45814606

4607+
cpdef as_scipy_low_level_callable(self):
4608+
from scipy import LowLevelCallable
4609+
from ctypes import c_double, c_void_p, c_int, cast, POINTER, CFUNCTYPE
4610+
if not self.real:
4611+
raise RuntimeError("Lambda function has to be real")
4612+
addr1 = cast(<size_t>&_scipy_callback_lambda_real,
4613+
CFUNCTYPE(c_double, c_int, POINTER(c_double), c_void_p))
4614+
addr2 = cast(<size_t>&self.lambda_double[0], c_void_p)
4615+
return create_low_level_callable(self, addr1, addr2)
4616+
45824617

45834618
IF HAVE_SYMENGINE_LLVM:
45844619
cdef class LLVMDouble(_Lambdify):
@@ -4592,8 +4627,18 @@ IF HAVE_SYMENGINE_LLVM:
45924627
cpdef unsafe_real(self, double[::1] inp, double[::1] out, int inp_offset=0, int out_offset=0):
45934628
self.lambda_double[0].call(&out[out_offset], &inp[inp_offset])
45944629

4630+
cpdef as_scipy_low_level_callable(self):
4631+
from scipy import LowLevelCallable
4632+
from ctypes import c_double, c_void_p, c_int, cast, POINTER, CFUNCTYPE
4633+
if not self.real:
4634+
raise RuntimeError("Lambda function has to be real")
4635+
addr1 = cast(<size_t>&_scipy_callback_lambda_real,
4636+
CFUNCTYPE(c_double, c_int, POINTER(c_double), c_void_p))
4637+
addr2 = cast(<size_t>&self.lambda_double[0], c_void_p)
4638+
return create_low_level_callable(self, addr1, addr2)
4639+
45954640

4596-
def Lambdify(args, *exprs, cppbool real=True, backend=None, order='C'):
4641+
def Lambdify(args, *exprs, cppbool real=True, backend=None, order='C', as_scipy=False):
45974642
"""
45984643
Lambdify instances are callbacks that numerically evaluate their symbolic
45994644
expressions from user provided input (real or complex) into (possibly user
@@ -4616,6 +4661,9 @@ def Lambdify(args, *exprs, cppbool real=True, backend=None, order='C'):
46164661
(k, l, 3) (C-contiguous) input will give a (k, l, m, n) shaped output,
46174662
whereas a (3, k, l) (C-contiguous) input will give a (m, n, k, l) shaped
46184663
output. If ``None`` order is taken as ``self.order`` (from initialization).
4664+
as_scipy : bool
4665+
return a SciPy LowLevelCallable which can be used in SciPy's integrate
4666+
methods
46194667
46204668
Returns
46214669
-------
@@ -4637,15 +4685,21 @@ def Lambdify(args, *exprs, cppbool real=True, backend=None, order='C'):
46374685
backend = os.getenv('SYMENGINE_LAMBDIFY_BACKEND', "lambda")
46384686
if backend == "llvm":
46394687
IF HAVE_SYMENGINE_LLVM:
4640-
return LLVMDouble(args, *exprs, real=real, order=order)
4688+
ret = LLVMDouble(args, *exprs, real=real, order=order)
4689+
if as_scipy:
4690+
return ret.as_scipy_low_level_callable()
4691+
return ret
46414692
ELSE:
46424693
raise ValueError("""llvm backend is chosen, but symengine is not compiled
46434694
with llvm support.""")
46444695
elif backend == "lambda":
46454696
pass
46464697
else:
46474698
warnings.warn("Unknown SymEngine backend: %s\nUsing backend='lambda'" % backend)
4648-
return LambdaDouble(args, *exprs, real=real, order=order)
4699+
ret = LambdaDouble(args, *exprs, real=real, order=order)
4700+
if as_scipy:
4701+
return ret.as_scipy_low_level_callable()
4702+
return ret
46494703

46504704

46514705
def LambdifyCSE(args, *exprs, cse=None, order='C', **kwargs):

symengine/tests/test_lambdify.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,13 @@
2525
except ImportError:
2626
have_sympy = False
2727

28+
try:
29+
import scipy
30+
from scipy import LowLevelCallable
31+
have_scipy = True
32+
except ImportError:
33+
have_scipy = False
34+
2835
if have_numpy:
2936
import numpy as np
3037

@@ -785,3 +792,13 @@ def _mtx(_x, _y):
785792
assert np.all(out3b[..., i] == _mtx(*inp3b[2*i:2*(i+1)]))
786793
raises(ValueError, lambda: lmb3(inp3b.reshape((4, 2))))
787794
raises(ValueError, lambda: lmb3(inp3b.reshape((2, 4)).T))
795+
796+
797+
@unittest.skipUnless(have_scipy, "Scipy not installed")
798+
def test_scipy():
799+
from scipy import integrate
800+
import numpy as np
801+
args = t, x = se.symbols('t, x')
802+
lmb = se.Lambdify(args, [se.exp(-x*t)/t**5], as_scipy=True)
803+
res = integrate.nquad(lmb, [[1, np.inf], [0, np.inf]])
804+
assert abs(res[0] - 0.2) < 1e-7

0 commit comments

Comments
 (0)