Skip to content

Commit a12f736

Browse files
authored
Merge pull request #201 from isuruf/scipy
SciPy LowLevelCallable from Lambdify
2 parents cef2fe9 + 30f640b commit a12f736

File tree

7 files changed

+89
-10
lines changed

7 files changed

+89
-10
lines changed

.travis.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ matrix:
5252
packages:
5353
- binutils-dev
5454
- g++-4.8
55-
- env: BUILD_TYPE="Debug" WITH_BFD="yes" PYTHON_VERSION="3.5" WITH_LLVM="yes"
55+
- env: BUILD_TYPE="Debug" WITH_BFD="yes" PYTHON_VERSION="3.5" WITH_LLVM="yes" WITH_SCIPY="yes"
5656
compiler: clang
5757
os: linux
5858
addons:
@@ -78,7 +78,7 @@ matrix:
7878
- env: BUILD_TYPE="Debug" PYTHON_VERSION="2.7" WITH_NUMPY="no"
7979
compiler: gcc
8080
os: osx
81-
- env: BUILD_TYPE="Release" PYTHON_VERSION="3.5" INTEGER_CLASS="flint" WITH_FLINT="yes"
81+
- env: BUILD_TYPE="Release" PYTHON_VERSION="3.5"
8282
compiler: gcc
8383
os: osx
8484
- env: BUILD_TYPE="Release" WITH_SAGE="yes" WITH_MPC="yes" PYTHON_VERSION="2.7"

bin/install_travis.sh

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,10 @@ if [[ "${WITH_NUMPY}" != "no" ]]; then
1313
export conda_pkgs="${conda_pkgs} numpy";
1414
fi
1515

16+
if [[ "${WITH_SCIPY}" == "yes" ]]; then
17+
export conda_pkgs="${conda_pkgs} scipy";
18+
fi
19+
1620
if [[ "${WITH_SAGE}" == "yes" ]]; then
1721
# This is split to avoid the 10 minute limit
1822
conda install -q sagelib

symengine/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,13 @@
2424
if have_numpy:
2525
from .lib.symengine_wrapper import Lambdify, LambdifyCSE
2626

27-
def lambdify(args, exprs, real=True, backend=None):
27+
def lambdify(args, exprs, real=True, backend=None, as_scipy=False):
2828
try:
2929
len(args)
3030
except TypeError:
3131
args = [args]
32+
if as_scipy:
33+
return Lambdify(args, *exprs, real=real, backend=backend, as_scipy=True)
3234
lmb = Lambdify(args, *exprs, real=real, backend=backend)
3335
def f(*inner_args):
3436
if len(inner_args) != len(args):

symengine/lib/symengine.pxd

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -961,17 +961,17 @@ cdef extern from "<symengine/lambda_double.h>" namespace "SymEngine":
961961
cdef cppclass LambdaRealDoubleVisitor:
962962
LambdaRealDoubleVisitor() nogil
963963
void init(const vec_basic &x, const vec_basic &b) nogil except +
964-
double call(double *r, const double *x) nogil except +
964+
void call(double *r, const double *x) nogil
965965
cdef cppclass LambdaComplexDoubleVisitor:
966966
LambdaComplexDoubleVisitor() nogil
967967
void init(const vec_basic &x, const vec_basic &b) nogil except +
968-
double complex call(double complex *r, const double complex *x) nogil except +
968+
void call(double complex *r, const double complex *x) nogil
969969

970970
cdef extern from "<symengine/llvm_double.h>" namespace "SymEngine":
971971
cdef cppclass LLVMDoubleVisitor:
972972
LLVMDoubleVisitor() nogil
973973
void init(const vec_basic &x, const vec_basic &b) nogil except +
974-
double call(double *r, const double *x) nogil except +
974+
void call(double *r, const double *x) nogil
975975

976976
cdef extern from "<symengine/series.h>" namespace "SymEngine":
977977
cdef cppclass SeriesCoeffInterface:

symengine/lib/symengine_wrapper.pyx

Lines changed: 59 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4560,6 +4560,31 @@ cdef class _Lambdify(object):
45604560
return result
45614561

45624562

4563+
cdef double _scipy_callback_lambda_real(int n, double *x, void *user_data) nogil:
4564+
cdef symengine.LambdaRealDoubleVisitor* lamb = <symengine.LambdaRealDoubleVisitor *>user_data
4565+
cdef double result
4566+
deref(lamb).call(&result, x)
4567+
return result
4568+
4569+
4570+
IF HAVE_SYMENGINE_LLVM:
4571+
cdef double _scipy_callback_llvm_real(int n, double *x, void *user_data) nogil:
4572+
cdef symengine.LLVMDoubleVisitor* lamb = <symengine.LLVMDoubleVisitor *>user_data
4573+
cdef double result
4574+
deref(lamb).call(&result, x)
4575+
return result
4576+
4577+
4578+
def create_low_level_callable(lambdify, *args):
4579+
from scipy import LowLevelCallable
4580+
class LambdifyLowLevelCallable(LowLevelCallable):
4581+
def __init__(self, lambdify, *args):
4582+
self.lambdify = lambdify
4583+
def __new__(cls, value, *args, **kwargs):
4584+
return super(LambdifyLowLevelCallable, cls).__new__(cls, *args)
4585+
return LambdifyLowLevelCallable(lambdify, *args)
4586+
4587+
45634588
cdef class LambdaDouble(_Lambdify):
45644589

45654590
cdef vector[symengine.LambdaRealDoubleVisitor] lambda_double
@@ -4579,6 +4604,17 @@ cdef class LambdaDouble(_Lambdify):
45794604
cpdef unsafe_complex(self, double complex[::1] inp, double complex[::1] out, int inp_offset=0, int out_offset=0):
45804605
self.lambda_double_complex[0].call(&out[out_offset], &inp[inp_offset])
45814606

4607+
cpdef as_scipy_low_level_callable(self):
4608+
from ctypes import c_double, c_void_p, c_int, cast, POINTER, CFUNCTYPE
4609+
if not self.real:
4610+
raise RuntimeError("Lambda function has to be real")
4611+
if self.tot_out_size > 1:
4612+
raise RuntimeError("SciPy LowLevelCallable supports only functions with 1 output")
4613+
addr1 = cast(<size_t>&_scipy_callback_lambda_real,
4614+
CFUNCTYPE(c_double, c_int, POINTER(c_double), c_void_p))
4615+
addr2 = cast(<size_t>&self.lambda_double[0], c_void_p)
4616+
return create_low_level_callable(self, addr1, addr2)
4617+
45824618

45834619
IF HAVE_SYMENGINE_LLVM:
45844620
cdef class LLVMDouble(_Lambdify):
@@ -4592,8 +4628,19 @@ IF HAVE_SYMENGINE_LLVM:
45924628
cpdef unsafe_real(self, double[::1] inp, double[::1] out, int inp_offset=0, int out_offset=0):
45934629
self.lambda_double[0].call(&out[out_offset], &inp[inp_offset])
45944630

4631+
cpdef as_scipy_low_level_callable(self):
4632+
from ctypes import c_double, c_void_p, c_int, cast, POINTER, CFUNCTYPE
4633+
if not self.real:
4634+
raise RuntimeError("Lambda function has to be real")
4635+
if self.tot_out_size > 1:
4636+
raise RuntimeError("SciPy LowLevelCallable supports only functions with 1 output")
4637+
addr1 = cast(<size_t>&_scipy_callback_lambda_real,
4638+
CFUNCTYPE(c_double, c_int, POINTER(c_double), c_void_p))
4639+
addr2 = cast(<size_t>&self.lambda_double[0], c_void_p)
4640+
return create_low_level_callable(self, addr1, addr2)
4641+
45954642

4596-
def Lambdify(args, *exprs, cppbool real=True, backend=None, order='C'):
4643+
def Lambdify(args, *exprs, cppbool real=True, backend=None, order='C', as_scipy=False):
45974644
"""
45984645
Lambdify instances are callbacks that numerically evaluate their symbolic
45994646
expressions from user provided input (real or complex) into (possibly user
@@ -4616,6 +4663,9 @@ def Lambdify(args, *exprs, cppbool real=True, backend=None, order='C'):
46164663
(k, l, 3) (C-contiguous) input will give a (k, l, m, n) shaped output,
46174664
whereas a (3, k, l) (C-contiguous) input will give a (m, n, k, l) shaped
46184665
output. If ``None`` order is taken as ``self.order`` (from initialization).
4666+
as_scipy : bool
4667+
return a SciPy LowLevelCallable which can be used in SciPy's integrate
4668+
methods
46194669
46204670
Returns
46214671
-------
@@ -4637,15 +4687,21 @@ def Lambdify(args, *exprs, cppbool real=True, backend=None, order='C'):
46374687
backend = os.getenv('SYMENGINE_LAMBDIFY_BACKEND', "lambda")
46384688
if backend == "llvm":
46394689
IF HAVE_SYMENGINE_LLVM:
4640-
return LLVMDouble(args, *exprs, real=real, order=order)
4690+
ret = LLVMDouble(args, *exprs, real=real, order=order)
4691+
if as_scipy:
4692+
return ret.as_scipy_low_level_callable()
4693+
return ret
46414694
ELSE:
46424695
raise ValueError("""llvm backend is chosen, but symengine is not compiled
46434696
with llvm support.""")
46444697
elif backend == "lambda":
46454698
pass
46464699
else:
46474700
warnings.warn("Unknown SymEngine backend: %s\nUsing backend='lambda'" % backend)
4648-
return LambdaDouble(args, *exprs, real=real, order=order)
4701+
ret = LambdaDouble(args, *exprs, real=real, order=order)
4702+
if as_scipy:
4703+
return ret.as_scipy_low_level_callable()
4704+
return ret
46494705

46504706

46514707
def LambdifyCSE(args, *exprs, cse=None, order='C', **kwargs):

symengine/tests/test_lambdify.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,13 @@
2525
except ImportError:
2626
have_sympy = False
2727

28+
try:
29+
import scipy
30+
from scipy import LowLevelCallable
31+
have_scipy = True
32+
except ImportError:
33+
have_scipy = False
34+
2835
if have_numpy:
2936
import numpy as np
3037

@@ -785,3 +792,13 @@ def _mtx(_x, _y):
785792
assert np.all(out3b[..., i] == _mtx(*inp3b[2*i:2*(i+1)]))
786793
raises(ValueError, lambda: lmb3(inp3b.reshape((4, 2))))
787794
raises(ValueError, lambda: lmb3(inp3b.reshape((2, 4)).T))
795+
796+
797+
@unittest.skipUnless(have_scipy, "Scipy not installed")
798+
def test_scipy():
799+
from scipy import integrate
800+
import numpy as np
801+
args = t, x = se.symbols('t, x')
802+
lmb = se.Lambdify(args, [se.exp(-x*t)/t**5], as_scipy=True)
803+
res = integrate.nquad(lmb, [[1, np.inf], [0, np.inf]])
804+
assert abs(res[0] - 0.2) < 1e-7

symengine_version.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
fdf132fcb4425589b69b40d60a90234944870b28
1+
f6eac46122d5da4519bc612d4c5203cbb3aa46b0

0 commit comments

Comments
 (0)