Skip to content

Commit eb13330

Browse files
authored
Merge pull request #110 from isuruf/llvm
Use LLVMDoubleVisitor
2 parents f05b74d + 4fac508 commit eb13330

File tree

7 files changed

+124
-51
lines changed

7 files changed

+124
-51
lines changed

.travis.yml

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,9 +66,23 @@ matrix:
6666
- g++-4.8
6767
- cmake
6868
- cmake-data
69-
- env: BUILD_TYPE="Debug" WITH_BFD="yes" PYTHON_VERSION="3.5"
69+
- env: BUILD_TYPE="Debug" WITH_BFD="yes" PYTHON_VERSION="3.5" WITH_LLVM="yes"
7070
compiler: clang
7171
os: linux
72+
addons:
73+
apt:
74+
sources:
75+
- ubuntu-toolchain-r-test
76+
- llvm-toolchain-precise-3.8
77+
- george-edison55-precise-backports
78+
packages:
79+
- clang
80+
- libstdc++-4.8-dev
81+
- libgmp-dev
82+
- binutils-dev
83+
- llvm-3.8-dev
84+
- cmake
85+
- cmake-data
7286
- env: BUILD_TYPE="Release" PYTHON_VERSION="2.7"
7387
compiler: clang
7488
os: linux

CMakeLists.txt

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -53,14 +53,21 @@ else()
5353
set(HAVE_SYMENGINE_FLINT False)
5454
endif()
5555

56+
if(SYMENGINE_LLVM_COMPONENTS)
57+
set(HAVE_SYMENGINE_LLVM True)
58+
else()
59+
set(HAVE_SYMENGINE_LLVM False)
60+
endif()
61+
5662

57-
message("CMAKE_BUILD_TYPE : ${CMAKE_BUILD_TYPE}")
63+
message("CMAKE_BUILD_TYPE : ${CMAKE_BUILD_TYPE}")
5864
message("CMAKE_CXX_FLAGS_RELEASE : ${CMAKE_CXX_FLAGS_RELEASE}")
59-
message("CMAKE_CXX_FLAGS_DEBUG : ${CMAKE_CXX_FLAGS_DEBUG}")
60-
message("HAVE_SYMENGINE_MPFR : ${HAVE_SYMENGINE_MPFR}")
61-
message("HAVE_SYMENGINE_MPC : ${HAVE_SYMENGINE_MPC}")
62-
message("HAVE_SYMENGINE_PIRANHA : ${HAVE_SYMENGINE_PIRANHA}")
63-
message("HAVE_SYMENGINE_FLINT : ${HAVE_SYMENGINE_FLINT}")
65+
message("CMAKE_CXX_FLAGS_DEBUG : ${CMAKE_CXX_FLAGS_DEBUG}")
66+
message("HAVE_SYMENGINE_MPFR : ${HAVE_SYMENGINE_MPFR}")
67+
message("HAVE_SYMENGINE_MPC : ${HAVE_SYMENGINE_MPC}")
68+
message("HAVE_SYMENGINE_PIRANHA : ${HAVE_SYMENGINE_PIRANHA}")
69+
message("HAVE_SYMENGINE_FLINT : ${HAVE_SYMENGINE_FLINT}")
70+
message("HAVE_SYMENGINE_LLVM : ${HAVE_SYMENGINE_LLVM}")
6471

6572
message("Copying source of python wrappers into: ${CMAKE_CURRENT_BINARY_DIR}")
6673
file(COPY symengine/ DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/symengine)

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,7 @@ def run(self):
188188
integration with SymPy and Sage.'''
189189

190190
setup(name = "symengine",
191-
version="0.2.0",
191+
version="0.2.1.dev",
192192
description = "Python library providing wrappers to SymEngine",
193193
setup_requires = ['cython>=0.19.1'],
194194
long_description = long_description,

symengine/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from .lib.symengine_wrapper import (Symbol, Integer, sympify, SympifyError,
22
Add, Mul, Pow, exp, log, gamma, sqrt, function_symbol, I, E, pi,
3-
have_mpfr, have_mpc, RealDouble, ComplexDouble, DenseMatrix, Matrix,
3+
have_mpfr, have_mpc, have_flint, have_piranha, have_llvm,
4+
RealDouble, ComplexDouble, DenseMatrix, Matrix,
45
sin, cos, tan, cot, csc, sec, asin, acos, atan, acot, acsc, asec,
56
sinh, cosh, tanh, coth, asinh, acosh, atanh, acoth, Lambdify,
67
LambdifyCSE, DictBasic, series, symarray, diff, zeros, eye, diag,
@@ -13,7 +14,7 @@
1314
if have_mpc:
1415
from .lib.symengine_wrapper import ComplexMPC
1516

16-
__version__ = "0.1.0.dev"
17+
__version__ = "0.2.1.dev"
1718

1819
def test():
1920
import pytest, os

symengine/lib/config.pxi.in

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,4 @@ DEF HAVE_SYMENGINE_MPFR = ${HAVE_SYMENGINE_MPFR}
22
DEF HAVE_SYMENGINE_MPC = ${HAVE_SYMENGINE_MPC}
33
DEF HAVE_SYMENGINE_PIRANHA = ${HAVE_SYMENGINE_PIRANHA}
44
DEF HAVE_SYMENGINE_FLINT = ${HAVE_SYMENGINE_FLINT}
5+
DEF HAVE_SYMENGINE_LLVM = ${HAVE_SYMENGINE_LLVM}

symengine/lib/symengine_wrapper.pyx

Lines changed: 79 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -2143,6 +2143,7 @@ have_mpfr = False
21432143
have_mpc = False
21442144
have_piranha = False
21452145
have_flint = False
2146+
have_llvm = False
21462147

21472148
IF HAVE_SYMENGINE_MPFR:
21482149
have_mpfr = True
@@ -2166,6 +2167,9 @@ IF HAVE_SYMENGINE_PIRANHA:
21662167
IF HAVE_SYMENGINE_FLINT:
21672168
have_flint = True
21682169

2170+
IF HAVE_SYMENGINE_LLVM:
2171+
have_llvm = True
2172+
21692173
def eval(x, long prec):
21702174
if prec <= 53:
21712175
return eval_complex_double(x)
@@ -2566,7 +2570,7 @@ ctypedef fused ValueType:
25662570
cython.double
25672571

25682572

2569-
cdef class Lambdify(object):
2573+
cdef class _Lambdify(object):
25702574
"""
25712575
Lambdify instances are callbacks that numerically evaluate their symbolic
25722576
expressions from user provided input (real or complex) into (possibly user
@@ -2598,22 +2602,21 @@ cdef class Lambdify(object):
25982602
cdef size_t args_size, out_size
25992603
cdef tuple out_shape
26002604
cdef readonly bool real
2601-
cdef vector[symengine.LambdaRealDoubleVisitor] lambda_double
2602-
cdef vector[symengine.LambdaComplexDoubleVisitor] lambda_double_complex
26032605

26042606
def __cinit__(self, args, exprs, bool real=True):
2607+
self.real = real
2608+
self.out_shape = get_shape(exprs)
2609+
self.args_size = _size(args)
2610+
self.out_size = reduce(mul, self.out_shape)
2611+
2612+
2613+
def __init__(self, args, exprs, bool real=True):
26052614
cdef:
2606-
symengine.vec_basic args_
2607-
symengine.vec_basic outs_
26082615
Basic e_
26092616
size_t ri, ci, nr, nc
26102617
symengine.MatrixBase *mtx
26112618
RCP[const symengine.Basic] b_
2612-
int idx = 0
2613-
self.real = real
2614-
self.out_shape = get_shape(exprs)
2615-
self.args_size = _size(args)
2616-
self.out_size = reduce(mul, self.out_shape)
2619+
symengine.vec_basic args_, outs_
26172620

26182621
if isinstance(args, DenseMatrix):
26192622
nr = args.nrows()
@@ -2640,35 +2643,32 @@ cdef class Lambdify(object):
26402643
e_ = sympify(e)
26412644
outs_.push_back(e_.thisptr)
26422645

2643-
if real:
2644-
self.lambda_double.resize(1)
2645-
self.lambda_double[0].init(args_, outs_)
2646-
else:
2647-
self.lambda_double_complex.resize(1)
2648-
self.lambda_double_complex[0].init(args_, outs_)
2646+
self._init(args_, outs_)
2647+
2648+
cdef _init(self, symengine.vec_basic& args_, symengine.vec_basic& outs_):
2649+
raise ValueError("Not supported")
26492650

2651+
cpdef unsafe_real(self, double[::1] inp, double[::1] out):
2652+
raise ValueError("Not supported")
26502653

2651-
cdef void _eval(self, ValueType[::1] inp, ValueType[::1] out):
2652-
cdef size_t idx, ninp = inp.size, nout = out.size
2654+
cpdef unsafe_complex(self, double complex[::1] inp, double complex[::1] out):
2655+
raise ValueError("Not supported")
26532656

2657+
# the two cpdef:ed methods below may use void return type
2658+
# once Cython 0.23 (from 2015) is acceptable as requirement.
2659+
cpdef eval_real(self, double[::1] inp, double[::1] out):
26542660
if inp.size != self.args_size:
26552661
raise ValueError("Size of inp incompatible with number of args.")
26562662
if out.size != self.out_size:
26572663
raise ValueError("Size of out incompatible with number of exprs.")
2664+
self.unsafe_real(inp, out)
26582665

2659-
# Convert expr_subs to doubles write to out
2660-
if ValueType == cython.double:
2661-
self.lambda_double[0].call(&out[0], &inp[0])
2662-
else:
2663-
self.lambda_double_complex[0].call(&out[0], &inp[0])
2664-
2665-
# the two cpdef:ed methods below may use void return type
2666-
# once Cython 0.23 (from 2015) is acceptable as requirement.
2667-
cpdef unsafe_real(self, double[::1] inp, double[::1] out):
2668-
self._eval(inp, out)
2669-
2670-
cpdef unsafe_complex(self, double complex[::1] inp, double complex[::1] out):
2671-
self._eval(inp, out)
2666+
cpdef eval_complex(self, double complex[::1] inp, double complex[::1] out):
2667+
if inp.size != self.args_size:
2668+
raise ValueError("Size of inp incompatible with number of args.")
2669+
if out.size != self.out_size:
2670+
raise ValueError("Size of out incompatible with number of exprs.")
2671+
self.unsafe_complex(inp, out)
26722672

26732673
def __call__(self, inp, out=None, use_numpy=None):
26742674
"""
@@ -2712,18 +2712,16 @@ cdef class Lambdify(object):
27122712
import numpy as np
27132713

27142714
if use_numpy:
2715+
numpy_dtype = np.float64 if self.real else np.complex128
27152716
if isinstance(inp, DenseMatrix):
2717+
arr = np.empty(inp_size, dtype=numpy_dtype)
27162718
if self.real:
2717-
arr = np.empty(inp_size, dtype=np.float64)
27182719
inp.dump_real(arr)
2719-
inp = arr
27202720
else:
2721-
arr = np.empty(inp_size, dtype=np.complex128)
27222721
inp.dump_complex(arr)
2723-
inp = arr
2722+
inp = arr
27242723
else:
2725-
inp = np.ascontiguousarray(inp, dtype=np.float64 if
2726-
self.real else np.complex128)
2724+
inp = np.ascontiguousarray(inp, dtype=numpy_dtype)
27272725
if inp.ndim > 1:
27282726
inp = inp.ravel()
27292727
else:
@@ -2732,8 +2730,7 @@ cdef class Lambdify(object):
27322730
if out is None:
27332731
# allocate output container
27342732
if use_numpy:
2735-
out = np.empty(new_out_size, dtype=np.float64 if
2736-
self.real else np.complex128)
2733+
out = np.empty(new_out_size, dtype=numpy_dtype)
27372734
else:
27382735
if self.real:
27392736
out = cython.view.array((new_out_size,),
@@ -2749,7 +2746,7 @@ cdef class Lambdify(object):
27492746
except AttributeError:
27502747
out = np.asarray(out)
27512748
out_dtype = out.dtype
2752-
if out_dtype != (np.float64 if self.real else np.complex128):
2749+
if out_dtype != numpy_dtype:
27532750
raise TypeError("Output array is of incorrect type")
27542751
if out.size < new_out_size:
27552752
raise ValueError("Incompatible size of output argument")
@@ -2801,6 +2798,48 @@ cdef class Lambdify(object):
28012798
out = tmp
28022799
return out
28032800

2801+
cdef class LambdaDouble(_Lambdify):
2802+
2803+
cdef vector[symengine.LambdaRealDoubleVisitor] lambda_double
2804+
cdef vector[symengine.LambdaComplexDoubleVisitor] lambda_double_complex
2805+
2806+
cdef _init(self, symengine.vec_basic& args_, symengine.vec_basic& outs_):
2807+
if self.real:
2808+
self.lambda_double.resize(1)
2809+
self.lambda_double[0].init(args_, outs_)
2810+
else:
2811+
self.lambda_double_complex.resize(1)
2812+
self.lambda_double_complex[0].init(args_, outs_)
2813+
2814+
cpdef unsafe_real(self, double[::1] inp, double[::1] out):
2815+
self.lambda_double[0].call(&out[0], &inp[0])
2816+
2817+
cpdef unsafe_complex(self, double complex[::1] inp, double complex[::1] out):
2818+
self.lambda_double_complex[0].call(&out[0], &inp[0])
2819+
2820+
2821+
IF HAVE_SYMENGINE_LLVM:
2822+
cdef class LLVMDouble(_Lambdify):
2823+
2824+
cdef vector[symengine.LLVMDoubleVisitor] lambda_double
2825+
2826+
cdef _init(self, symengine.vec_basic& args_, symengine.vec_basic& outs_):
2827+
self.lambda_double.resize(1)
2828+
self.lambda_double[0].init(args_, outs_)
2829+
2830+
cpdef unsafe_real(self, double[::1] inp, double[::1] out):
2831+
self.lambda_double[0].call(&out[0], &inp[0])
2832+
2833+
2834+
def Lambdify(args, exprs, bool real=True, backend="lambda"):
2835+
if backend == "llvm":
2836+
IF HAVE_SYMENGINE_LLVM:
2837+
return LLVMDouble(args, exprs, real)
2838+
ELSE:
2839+
raise ValueError("""llvm backend is chosen, but symengine is not compiled
2840+
with llvm support.""")
2841+
2842+
return LambdaDouble(args, exprs, real)
28042843

28052844
def LambdifyCSE(args, exprs, real=True, cse=None, concatenate=None):
28062845
"""

symengine/tests/test_lambdify.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,18 @@ def allclose(vec1, vec2, rtol=1e-13, atol=1e-13):
4646
def test_Lambdify():
4747
n = 7
4848
args = x, y, z = se.symbols('x y z')
49-
l = se.Lambdify(args, [x+y+z, x**2, (x-y)/z, x*y*z])
49+
l = se.Lambdify(args, [x+y+z, x**2, (x-y)/z, x*y*z], backend='lambda')
50+
assert allclose(l(range(n, n+len(args))),
51+
[3*n+3, n**2, -1/(n+2), n*(n+1)*(n+2)])
52+
53+
def test_Lambdify_LLVM():
54+
n = 7
55+
args = x, y, z = se.symbols('x y z')
56+
if not se.have_llvm:
57+
raises(ValueError, lambda: se.Lambdify(args, [x+y+z, x**2, (x-y)/z, x*y*z],
58+
backend='llvm'))
59+
return
60+
l = se.Lambdify(args, [x+y+z, x**2, (x-y)/z, x*y*z], backend='llvm')
5061
assert allclose(l(range(n, n+len(args))),
5162
[3*n+3, n**2, -1/(n+2), n*(n+1)*(n+2)])
5263

0 commit comments

Comments
 (0)