Skip to content

Commit 90739ec

Browse files
committed
Lambdify: env chosen backend, std::vector, kwargs in lambdify
1 parent 871017c commit 90739ec

File tree

3 files changed

+18
-17
lines changed

3 files changed

+18
-17
lines changed

CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,10 +59,12 @@ else()
5959
endif()
6060

6161
if(WITH_NUMPY)
62+
set(WITH_NUMPY True CACHE BOOL INTERNAL)
6263
find_package(NumPy REQUIRED)
6364
include_directories(${NUMPY_INCLUDE_PATH})
6465
set(HAVE_NUMPY True)
6566
else()
67+
set(WITH_NUMPY False CACHE BOOL INTERNAL)
6668
set(HAVE_NUMPY False)
6769
endif()
6870

symengine/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,12 @@
2020
if have_numpy:
2121
from .lib.symengine_wrapper import Lambdify, LambdifyCSE
2222

23-
def lambdify(args, exprs):
23+
def lambdify(args, exprs, real=True, backend=None):
2424
try:
2525
len(args)
2626
except TypeError:
2727
args = [args]
28-
lmb = Lambdify(args, *exprs)
28+
lmb = Lambdify(args, *exprs, real=real, backend=backend)
2929
def f(*inner_args):
3030
if len(inner_args) != len(args):
3131
raise TypeError("Incorrect number of arguments")

symengine/lib/symengine_wrapper.pyx

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ from libcpp.string cimport string
66
from libcpp.vector cimport vector
77
from cpython cimport PyObject, Py_XINCREF, Py_XDECREF, \
88
PyObject_CallMethodObjArgs
9-
from libc.stdlib cimport malloc, free
109
from libc.string cimport memcpy
1110
import cython
1211
import itertools
@@ -3014,6 +3013,7 @@ def has_symbol(obj, symbol=None):
30143013

30153014
IF HAVE_NUMPY:
30163015
# Lambdify requires NumPy (since b713a61, see gh-112)
3016+
import os
30173017
cimport numpy as cnp
30183018
import numpy as np
30193019
have_numpy = True
@@ -3101,30 +3101,24 @@ IF HAVE_NUMPY:
31013101
cdef list out_shapes
31023102
cdef readonly bint real
31033103
cdef readonly int n_exprs
3104-
cdef int *out_sizes
3105-
cdef int *accum_out_sizes
3104+
cdef vector[int] accum_out_sizes
31063105
cdef object numpy_dtype
31073106

31083107
def __cinit__(self, args, *exprs, bool real=True):
3108+
cdef vector[int] out_sizes
31093109
self.real = real
31103110
self.numpy_dtype = np.float64 if self.real else np.complex128
31113111
self.out_shapes = [get_shape(expr) for expr in exprs]
31123112
self.n_exprs = len(exprs)
31133113
self.args_size = _size(args)
3114-
self.out_sizes = <int *>malloc(sizeof(int)*self.n_exprs)
3115-
self.accum_out_sizes = <int *>malloc(sizeof(int)*(self.n_exprs+1))
31163114
self.tot_out_size = 0
31173115
for idx, shape in enumerate(self.out_shapes):
3118-
self.out_sizes[idx] = reduce(mul, shape or (1,))
3119-
self.tot_out_size += self.out_sizes[idx]
3116+
out_sizes.push_back(reduce(mul, shape or (1,)))
3117+
self.tot_out_size += out_sizes[idx]
31203118
for i in range(self.n_exprs + 1):
3121-
self.accum_out_sizes[i] = 0
3119+
self.accum_out_sizes.push_back(0)
31223120
for j in range(i):
3123-
self.accum_out_sizes[i] += self.out_sizes[j]
3124-
3125-
def __dealloc__(self):
3126-
free(self.out_sizes)
3127-
free(self.accum_out_sizes)
3121+
self.accum_out_sizes[i] += out_sizes[j]
31283122

31293123
def __init__(self, args, *exprs, bool real=True):
31303124
cdef:
@@ -3329,14 +3323,19 @@ IF HAVE_NUMPY:
33293323
self.lambda_double[0].call(&out[out_offset], &inp[inp_offset])
33303324

33313325

3332-
def Lambdify(args, *exprs, bool real=True, backend="lambda"):
3326+
def Lambdify(args, *exprs, bool real=True, backend=None):
3327+
if backend is None:
3328+
backend = os.getenv('SYMENGINE_LAMBDIFY_BACKEND', "lambda")
33333329
if backend == "llvm":
33343330
IF HAVE_SYMENGINE_LLVM:
33353331
return LLVMDouble(args, *exprs, real=real)
33363332
ELSE:
33373333
raise ValueError("""llvm backend is chosen, but symengine is not compiled
33383334
with llvm support.""")
3339-
3335+
elif backend == "lambda":
3336+
pass
3337+
else:
3338+
warnings.warn("Unknown SymEngine backend: %s\nUsing backend='lambda'" % backend)
33403339
return LambdaDouble(args, *exprs, real=real)
33413340

33423341

0 commit comments

Comments
 (0)