Skip to content

Commit d573252

Browse files
committed
Use LLVM double visitor
1 parent f05b74d commit d573252

File tree

4 files changed

+91
-46
lines changed

4 files changed

+91
-46
lines changed

CMakeLists.txt

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -53,14 +53,21 @@ else()
5353
set(HAVE_SYMENGINE_FLINT False)
5454
endif()
5555

56+
if(SYMENGINE_LLVM_COMPONENTS)
57+
set(HAVE_SYMENGINE_LLVM True)
58+
else()
59+
set(HAVE_SYMENGINE_LLVM False)
60+
endif()
61+
5662

57-
message("CMAKE_BUILD_TYPE : ${CMAKE_BUILD_TYPE}")
63+
message("CMAKE_BUILD_TYPE : ${CMAKE_BUILD_TYPE}")
5864
message("CMAKE_CXX_FLAGS_RELEASE : ${CMAKE_CXX_FLAGS_RELEASE}")
59-
message("CMAKE_CXX_FLAGS_DEBUG : ${CMAKE_CXX_FLAGS_DEBUG}")
60-
message("HAVE_SYMENGINE_MPFR : ${HAVE_SYMENGINE_MPFR}")
61-
message("HAVE_SYMENGINE_MPC : ${HAVE_SYMENGINE_MPC}")
62-
message("HAVE_SYMENGINE_PIRANHA : ${HAVE_SYMENGINE_PIRANHA}")
63-
message("HAVE_SYMENGINE_FLINT : ${HAVE_SYMENGINE_FLINT}")
65+
message("CMAKE_CXX_FLAGS_DEBUG : ${CMAKE_CXX_FLAGS_DEBUG}")
66+
message("HAVE_SYMENGINE_MPFR : ${HAVE_SYMENGINE_MPFR}")
67+
message("HAVE_SYMENGINE_MPC : ${HAVE_SYMENGINE_MPC}")
68+
message("HAVE_SYMENGINE_PIRANHA : ${HAVE_SYMENGINE_PIRANHA}")
69+
message("HAVE_SYMENGINE_FLINT : ${HAVE_SYMENGINE_FLINT}")
70+
message("HAVE_SYMENGINE_LLVM : ${HAVE_SYMENGINE_LLVM}")
6471

6572
message("Copying source of python wrappers into: ${CMAKE_CURRENT_BINARY_DIR}")
6673
file(COPY symengine/ DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/symengine)

symengine/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from .lib.symengine_wrapper import (Symbol, Integer, sympify, SympifyError,
22
Add, Mul, Pow, exp, log, gamma, sqrt, function_symbol, I, E, pi,
3-
have_mpfr, have_mpc, RealDouble, ComplexDouble, DenseMatrix, Matrix,
3+
have_mpfr, have_mpc, have_flint, have_piranha, have_llvm,
4+
RealDouble, ComplexDouble, DenseMatrix, Matrix,
45
sin, cos, tan, cot, csc, sec, asin, acos, atan, acot, acsc, asec,
56
sinh, cosh, tanh, coth, asinh, acosh, atanh, acoth, Lambdify,
67
LambdifyCSE, DictBasic, series, symarray, diff, zeros, eye, diag,

symengine/lib/config.pxi.in

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,4 @@ DEF HAVE_SYMENGINE_MPFR = ${HAVE_SYMENGINE_MPFR}
22
DEF HAVE_SYMENGINE_MPC = ${HAVE_SYMENGINE_MPC}
33
DEF HAVE_SYMENGINE_PIRANHA = ${HAVE_SYMENGINE_PIRANHA}
44
DEF HAVE_SYMENGINE_FLINT = ${HAVE_SYMENGINE_FLINT}
5+
DEF HAVE_SYMENGINE_LLVM = ${HAVE_SYMENGINE_LLVM}

symengine/lib/symengine_wrapper.pyx

Lines changed: 75 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -2143,6 +2143,7 @@ have_mpfr = False
21432143
have_mpc = False
21442144
have_piranha = False
21452145
have_flint = False
2146+
have_llvm = False
21462147

21472148
IF HAVE_SYMENGINE_MPFR:
21482149
have_mpfr = True
@@ -2166,6 +2167,9 @@ IF HAVE_SYMENGINE_PIRANHA:
21662167
IF HAVE_SYMENGINE_FLINT:
21672168
have_flint = True
21682169

2170+
IF HAVE_SYMENGINE_LLVM:
2171+
have_llvm = True
2172+
21692173
def eval(x, long prec):
21702174
if prec <= 53:
21712175
return eval_complex_double(x)
@@ -2566,7 +2570,7 @@ ctypedef fused ValueType:
25662570
cython.double
25672571

25682572

2569-
cdef class Lambdify(object):
2573+
cdef class _Lambdify(object):
25702574
"""
25712575
Lambdify instances are callbacks that numerically evaluate their symbolic
25722576
expressions from user provided input (real or complex) into (possibly user
@@ -2598,22 +2602,21 @@ cdef class Lambdify(object):
25982602
cdef size_t args_size, out_size
25992603
cdef tuple out_shape
26002604
cdef readonly bool real
2601-
cdef vector[symengine.LambdaRealDoubleVisitor] lambda_double
2602-
cdef vector[symengine.LambdaComplexDoubleVisitor] lambda_double_complex
26032605

26042606
def __cinit__(self, args, exprs, bool real=True):
2607+
self.real = real
2608+
self.out_shape = get_shape(exprs)
2609+
self.args_size = _size(args)
2610+
self.out_size = reduce(mul, self.out_shape)
2611+
2612+
2613+
def __init__(self, args, exprs, bool real=True):
26052614
cdef:
2606-
symengine.vec_basic args_
2607-
symengine.vec_basic outs_
26082615
Basic e_
26092616
size_t ri, ci, nr, nc
26102617
symengine.MatrixBase *mtx
26112618
RCP[const symengine.Basic] b_
2612-
int idx = 0
2613-
self.real = real
2614-
self.out_shape = get_shape(exprs)
2615-
self.args_size = _size(args)
2616-
self.out_size = reduce(mul, self.out_shape)
2619+
symengine.vec_basic args_, outs_
26172620

26182621
if isinstance(args, DenseMatrix):
26192622
nr = args.nrows()
@@ -2640,35 +2643,32 @@ cdef class Lambdify(object):
26402643
e_ = sympify(e)
26412644
outs_.push_back(e_.thisptr)
26422645

2643-
if real:
2644-
self.lambda_double.resize(1)
2645-
self.lambda_double[0].init(args_, outs_)
2646-
else:
2647-
self.lambda_double_complex.resize(1)
2648-
self.lambda_double_complex[0].init(args_, outs_)
2649-
2646+
self._init(args_, outs_)
26502647

2651-
cdef void _eval(self, ValueType[::1] inp, ValueType[::1] out):
2652-
cdef size_t idx, ninp = inp.size, nout = out.size
2648+
cdef _init(self, symengine.vec_basic& args_, symengine.vec_basic& outs_):
2649+
raise ValueError("Not supported")
26532650

2654-
if inp.size != self.args_size:
2655-
raise ValueError("Size of inp incompatible with number of args.")
2656-
if out.size != self.out_size:
2657-
raise ValueError("Size of out incompatible with number of exprs.")
2651+
cdef void _eval_real(self, double[::1] inp, double[::1] out):
2652+
raise ValueError("Not supported")
26582653

2659-
# Convert expr_subs to doubles write to out
2660-
if ValueType == cython.double:
2661-
self.lambda_double[0].call(&out[0], &inp[0])
2662-
else:
2663-
self.lambda_double_complex[0].call(&out[0], &inp[0])
2654+
cdef void _eval_complex(self, double complex[::1] inp, double complex[::1] out):
2655+
raise ValueError("Not supported")
26642656

26652657
# the two cpdef:ed methods below may use void return type
26662658
# once Cython 0.23 (from 2015) is acceptable as requirement.
26672659
cpdef unsafe_real(self, double[::1] inp, double[::1] out):
2668-
self._eval(inp, out)
2660+
if inp.size != self.args_size:
2661+
raise ValueError("Size of inp incompatible with number of args.")
2662+
if out.size != self.out_size:
2663+
raise ValueError("Size of out incompatible with number of exprs.")
2664+
self._eval_real(inp, out)
26692665

26702666
cpdef unsafe_complex(self, double complex[::1] inp, double complex[::1] out):
2671-
self._eval(inp, out)
2667+
if inp.size != self.args_size:
2668+
raise ValueError("Size of inp incompatible with number of args.")
2669+
if out.size != self.out_size:
2670+
raise ValueError("Size of out incompatible with number of exprs.")
2671+
self._eval_complex(inp, out)
26722672

26732673
def __call__(self, inp, out=None, use_numpy=None):
26742674
"""
@@ -2712,18 +2712,16 @@ cdef class Lambdify(object):
27122712
import numpy as np
27132713

27142714
if use_numpy:
2715+
numpy_dtype = np.float64 if self.real else np.complex128
27152716
if isinstance(inp, DenseMatrix):
2717+
arr = np.empty(inp_size, dtype=numpy_dtype)
27162718
if self.real:
2717-
arr = np.empty(inp_size, dtype=np.float64)
27182719
inp.dump_real(arr)
2719-
inp = arr
27202720
else:
2721-
arr = np.empty(inp_size, dtype=np.complex128)
27222721
inp.dump_complex(arr)
2723-
inp = arr
2722+
inp = arr
27242723
else:
2725-
inp = np.ascontiguousarray(inp, dtype=np.float64 if
2726-
self.real else np.complex128)
2724+
inp = np.ascontiguousarray(inp, dtype=numpy_dtype)
27272725
if inp.ndim > 1:
27282726
inp = inp.ravel()
27292727
else:
@@ -2732,8 +2730,7 @@ cdef class Lambdify(object):
27322730
if out is None:
27332731
# allocate output container
27342732
if use_numpy:
2735-
out = np.empty(new_out_size, dtype=np.float64 if
2736-
self.real else np.complex128)
2733+
out = np.empty(new_out_size, dtype=numpy_dtype)
27372734
else:
27382735
if self.real:
27392736
out = cython.view.array((new_out_size,),
@@ -2749,7 +2746,7 @@ cdef class Lambdify(object):
27492746
except AttributeError:
27502747
out = np.asarray(out)
27512748
out_dtype = out.dtype
2752-
if out_dtype != (np.float64 if self.real else np.complex128):
2749+
if out_dtype != numpy_dtype:
27532750
raise TypeError("Output array is of incorrect type")
27542751
if out.size < new_out_size:
27552752
raise ValueError("Incompatible size of output argument")
@@ -2801,6 +2798,45 @@ cdef class Lambdify(object):
28012798
out = tmp
28022799
return out
28032800

2801+
cdef class LambdaDouble(_Lambdify):
2802+
2803+
cdef vector[symengine.LambdaRealDoubleVisitor] lambda_double
2804+
cdef vector[symengine.LambdaComplexDoubleVisitor] lambda_double_complex
2805+
2806+
cdef _init(self, symengine.vec_basic& args_, symengine.vec_basic& outs_):
2807+
if self.real:
2808+
self.lambda_double.resize(1)
2809+
self.lambda_double[0].init(args_, outs_)
2810+
else:
2811+
self.lambda_double_complex.resize(1)
2812+
self.lambda_double_complex[0].init(args_, outs_)
2813+
2814+
cdef void _eval_real(self, double[::1] inp, double[::1] out):
2815+
self.lambda_double[0].call(&out[0], &inp[0])
2816+
2817+
cdef void _eval_complex(self, double complex[::1] inp, double complex[::1] out):
2818+
self.lambda_double_complex[0].call(&out[0], &inp[0])
2819+
2820+
2821+
IF HAVE_SYMENGINE_LLVM:
2822+
cdef class LLVMDouble(_Lambdify):
2823+
2824+
cdef vector[symengine.LLVMDoubleVisitor] lambda_double
2825+
2826+
cdef _init(self, symengine.vec_basic& args_, symengine.vec_basic& outs_):
2827+
self.lambda_double.resize(1)
2828+
self.lambda_double[0].init(args_, outs_)
2829+
2830+
cdef void _eval_real(self, double[::1] inp, double[::1] out):
2831+
self.lambda_double[0].call(&out[0], &inp[0])
2832+
2833+
2834+
def Lambdify(args, exprs, bool real=True, bool llvm=False):
2835+
IF HAVE_SYMENGINE_LLVM:
2836+
if llvm:
2837+
return LLVMDouble(args, exprs, real)
2838+
2839+
return LambdaDouble(args, exprs, real)
28042840

28052841
def LambdifyCSE(args, exprs, real=True, cse=None, concatenate=None):
28062842
"""

0 commit comments

Comments
 (0)