Skip to content

Commit 49cb51e

Browse files
committed
Set use_numpy at init of Lambdify (6x performance gain)
1 parent e5c5e83 commit 49cb51e

File tree

4 files changed

+87
-54
lines changed

4 files changed

+87
-54
lines changed

benchmarks/Lambdify.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,17 +18,20 @@
1818

1919
inp = np.ones(28)
2020

21+
lmb_sp(*inp)
2122
tim_sympy = clock()
2223
for i in range(500):
2324
res_sympy = lmb_sp(*inp)
2425
tim_sympy = clock() - tim_sympy
2526

27+
lbm_se(inp)
2628
tim_se = clock()
2729
res_se = np.empty(len(exprs))
2830
for i in range(500):
2931
res_se = lbm_se(inp)
3032
tim_se = clock() - tim_se
3133

34+
lbm_se_llvm(inp)
3235
tim_se_llvm = clock()
3336
res_se_llvm = np.empty(len(exprs))
3437
for i in range(500):
@@ -62,6 +65,7 @@ def func(*args):
6265
return func
6366

6467
lbm_se_llvm_manual = ManualLLVM(args, np.array(exprs))
68+
lbm_se_llvm_manual(inp)
6569
tim_se_llvm_manual = clock()
6670
res_se_llvm_manual = np.empty(len(exprs))
6771
for i in range(500):
@@ -72,3 +76,16 @@ def func(*args):
7276

7377
if tim_se_llvm_manual < tim_se_llvm:
7478
warnings.warn("Cython code for Lambdify.__call__ is slow.")
79+
80+
import setuptools
81+
import pyximport
82+
pyximport.install()
83+
from Lambdify_reference import _benchmark_reference_for_Lambdify as lmb_ref
84+
85+
lmb_ref(inp)
86+
tim_ref = clock()
87+
for i in range(500):
88+
res_ref = lmb_ref(inp)
89+
tim_ref = clock() - tim_ref
90+
print('Hard-coded Cython code speed-up factor (higher is better) vs sympy: %12.5g' %
91+
(tim_sympy/tim_ref))

benchmarks/Lambdify_reference.pyx

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
# Benchmark reference:
2+
cimport numpy as cnp
3+
import numpy as np
4+
from libc.math cimport exp as exp_
5+
6+
def _benchmark_reference_for_Lambdify(cnp.ndarray[cnp.float64_t] x):
7+
cdef cnp.ndarray[cnp.float64_t] out = np.empty(14)
8+
cdef double * data = <double *>out.data
9+
data[:] = [x[0] + x[1] - x[4] + 36.252574322669, x[0] - x[2] + x[3] + 21.3219379611249, x[3] + x[5] - x[6] + 9.9011158998744, 2*x[3] + x[5] - x[7] + 18.190422234653, 3*x[3] + x[5] - x[8] + 24.8679190043357, 4*x[3] + x[5] - x[9] + 29.9336062089226, -x[10] + 5*x[3] + x[5] + 28.5520551531262, 2*x[0] + x[11] - 2*x[4] - 2*x[5] + 32.4401680272417, 3*x[1] - x[12] + x[5] + 34.9992934135095, 4*x[1] - x[13] + x[5] + 37.0716199972041, x[14+0] - x[14+1] + 2*x[14+10] + 2*x[14+11] - x[14+12] - 2*x[14+13] + x[14+2] + 2*x[14+5] + 2*x[14+6] + 2*x[14+7] + 2*x[14+8] + 2*x[14+9] - exp_(x[0]) + exp_(x[1]) - 2*exp_(x[10]) - 2*exp_(x[11]) + exp_(x[12]) + 2*exp_(x[13]) - exp_(x[2]) - 2*exp_(x[5]) - 2*exp_(x[6]) - 2*exp_(x[7]) - 2*exp_(x[8]) - 2*exp_(x[9]), -x[14+0] - x[14+1] - 15*x[14+10] - 2*x[14+11] - 3*x[14+12] - 4*x[14+13] - 4*x[14+2] - 3*x[14+3] - 2*x[14+4] - 3*x[14+6] - 6*x[14+7] - 9*x[14+8] - 12*x[14+9] + exp_(x[0]) + exp_(x[1]) + 15*exp_(x[10]) + 2*exp_(x[11]) + 3*exp_(x[12]) + 4*exp_(x[13]) + 4*exp_(x[2]) + 3*exp_(x[3]) + 2*exp_(x[4]) + 3*exp_(x[6]) + 6*exp_(x[7]) + 9*exp_(x[8]) + 12*exp_(x[9]), -5*x[14+10] - x[14+2] - x[14+3] - x[14+6] - 2*x[14+7] - 3*x[14+8] - 4*x[14+9] + 5*exp_(x[10]) + exp_(x[2]) + exp_(x[3]) + exp_(x[6]) + 2*exp_(x[7]) + 3*exp_(x[8]) + 4*exp_(x[9]), -x[14+1] - 2*x[14+11] - 3*x[14+12] - 4*x[14+13] - x[14+4] + exp_(x[1]) + 2*exp_(x[11]) + 3*exp_(x[12]) + 4*exp_(x[13]) + exp_(x[4]), -x[14+10] - 2*x[14+11] - x[14+12] - x[14+13] - x[14+5] - x[14+6] - x[14+7] - x[14+8] - x[14+9] + exp_(x[10]) + 2*exp_(x[11]) + exp_(x[12]) + exp_(x[13]) + exp_(x[5]) + exp_(x[6]) + exp_(x[7]) + exp_(x[8]) + exp_(x[9])]
10+
return out

symengine/lib/symengine_wrapper.pyx

Lines changed: 37 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -3092,6 +3092,8 @@ cdef class _Lambdify(object):
30923092
the shape of exprs is preserved
30933093
real : bool
30943094
Whether datatype is ``double`` (``double complex`` otherwise).
3095+
use_numpy: bool (default: None)
3096+
None -> use numpy if available
30953097
30963098
Returns
30973099
-------
@@ -3115,9 +3117,22 @@ cdef class _Lambdify(object):
31153117
cdef int *accum_out_sizes
31163118
cdef readonly bool real
31173119
cdef readonly int n_exprs
3120+
cdef readonly bint use_numpy
3121+
cdef object _np
31183122

3119-
def __cinit__(self, args, *exprs, bool real=True):
3123+
def __cinit__(self, args, *exprs, bool real=True, use_numpy=None):
31203124
self.real = real
3125+
if use_numpy is None:
3126+
try:
3127+
import numpy as np
3128+
except ImportError:
3129+
use_numpy = False # we will use cython.view.array instead
3130+
else:
3131+
use_numpy = True
3132+
if use_numpy is True:
3133+
import numpy as np
3134+
self._np = np
3135+
self.use_numpy = use_numpy
31213136
self.out_shapes = [get_shape(expr) for expr in exprs]
31223137
self.n_exprs = len(exprs)
31233138
self.args_size = _size(args)
@@ -3136,7 +3151,7 @@ cdef class _Lambdify(object):
31363151
free(self.out_sizes)
31373152
free(self.accum_out_sizes)
31383153

3139-
def __init__(self, args, *exprs, bool real=True):
3154+
def __init__(self, args, *exprs, bool real=True, use_numpy=None):
31403155
cdef:
31413156
Basic e_
31423157
size_t ri, ci, nr, nc
@@ -3197,7 +3212,7 @@ cdef class _Lambdify(object):
31973212
raise ValueError("Size of out incompatible with number of exprs.")
31983213
self.unsafe_complex(inp, out)
31993214

3200-
def __call__(self, inp, out=None, use_numpy=None):
3215+
def __call__(self, inp, out=None):
32013216
"""
32023217
Parameters
32033218
----------
@@ -3208,26 +3223,26 @@ cdef class _Lambdify(object):
32083223
If ``None``: an output container will be allocated (NumPy ndarray or
32093224
cython.view.array). If ``len(exprs) > 0`` output is found in the corresponding
32103225
order. Note that ``out`` is not reshaped.
3211-
use_numpy: bool (default: None)
3212-
None -> use numpy if available
32133226
32143227
Returns
32153228
-------
32163229
If ``len(exprs) == 1``: ``numpy.ndarray`` or ``cython.view.array``, otherwise
32173230
a tuple of such.
32183231
32193232
"""
3220-
cdef cython.view.array tmp
3221-
cdef double[::1] real_out_view, real_inp_view
3222-
cdef double complex[::1] cmplx_out_view, cmplx_inp_view
3223-
cdef size_t nbroadcast = 1
3233+
cdef:
3234+
cython.view.array tmp
3235+
double[::1] real_out_view, real_inp_view
3236+
double complex[::1] cmplx_out_view, cmplx_inp_view
3237+
size_t nbroadcast = 1
3238+
long inp_size
32243239

32253240
try:
32263241
inp_shape = getattr(inp, 'shape', (len(inp),))
32273242
except TypeError:
32283243
inp = tuple(inp)
32293244
inp_shape = (len(inp),)
3230-
inp_size = long(reduce(mul, inp_shape))
3245+
inp_size = reduce(mul, inp_shape)
32313246
if inp_size % self.args_size != 0:
32323247
raise ValueError("Broadcasting failed")
32333248
nbroadcast = inp_size // self.args_size
@@ -3236,36 +3251,27 @@ cdef class _Lambdify(object):
32363251
new_out_shapes = [inp_shape[:-1] + out_shape for out_shape in self.out_shapes]
32373252
new_out_sizes = [nbroadcast*self.out_sizes[i] for i in range(self.n_exprs)]
32383253
new_tot_out_size = nbroadcast * self.tot_out_size
3239-
if use_numpy is None:
3240-
try:
3241-
import numpy as np
3242-
except ImportError:
3243-
use_numpy = False # we will use cython.view.array instead
3244-
else:
3245-
use_numpy = True
3246-
elif use_numpy is True:
3247-
import numpy as np
32483254

3249-
if use_numpy:
3250-
numpy_dtype = np.float64 if self.real else np.complex128
3255+
if self.use_numpy:
3256+
numpy_dtype = self._np.float64 if self.real else self._np.complex128
32513257
if isinstance(inp, DenseMatrixBase):
3252-
arr = np.empty(inp_size, dtype=numpy_dtype)
3258+
arr = self._np.empty(inp_size, dtype=numpy_dtype)
32533259
if self.real:
32543260
inp.dump_real(arr)
32553261
else:
32563262
inp.dump_complex(arr)
32573263
inp = arr
32583264
else:
3259-
inp = np.ascontiguousarray(inp, dtype=numpy_dtype)
3265+
inp = self._np.ascontiguousarray(inp, dtype=numpy_dtype)
32603266
if inp.ndim > 1:
32613267
inp = inp.ravel()
32623268
else:
32633269
inp = with_buffer(inp, self.real)
32643270

32653271
if out is None:
32663272
# allocate output container
3267-
if use_numpy:
3268-
out = np.empty(new_tot_out_size, dtype=numpy_dtype)
3273+
if self.use_numpy:
3274+
out = self._np.empty(new_tot_out_size, dtype=numpy_dtype)
32693275
else:
32703276
if self.real:
32713277
out = cython.view.array((new_tot_out_size,),
@@ -3276,11 +3282,11 @@ cdef class _Lambdify(object):
32763282
reshape_outs = len(new_out_shapes[0]) > 1
32773283
else:
32783284
reshape_outs = False
3279-
if use_numpy:
3285+
if self.use_numpy:
32803286
try:
32813287
out_dtype = out.dtype
32823288
except AttributeError:
3283-
out = np.asarray(out)
3289+
out = self._np.asarray(out)
32843290
out_dtype = out.dtype
32853291
if out_dtype != numpy_dtype:
32863292
raise TypeError("Output array is of incorrect type")
@@ -3310,7 +3316,7 @@ cdef class _Lambdify(object):
33103316
self.unsafe_complex(complex_inp_view[idx*self.args_size:(idx+1)*self.args_size],
33113317
complex_out_view[idx*self.tot_out_size:(idx+1)*self.tot_out_size])
33123318

3313-
if use_numpy and reshape_outs:
3319+
if self.use_numpy and reshape_outs:
33143320
out = out.reshape((nbroadcast, self.tot_out_size))
33153321
result = [out[:, self.accum_out_sizes[idx]:self.accum_out_sizes[idx+1]].reshape(new_out_shapes[idx])
33163322
for idx in range(self.n_exprs)]
@@ -3375,15 +3381,15 @@ IF HAVE_SYMENGINE_LLVM:
33753381
self.lambda_double[0].call(&out[0], &inp[0])
33763382

33773383

3378-
def Lambdify(args, *exprs, bool real=True, backend="lambda"):
3384+
def Lambdify(args, *exprs, bool real=True, backend="lambda", use_numpy=None):
33793385
if backend == "llvm":
33803386
IF HAVE_SYMENGINE_LLVM:
3381-
return LLVMDouble(args, *exprs, real=real)
3387+
return LLVMDouble(args, *exprs, real=real, use_numpy=use_numpy)
33823388
ELSE:
33833389
raise ValueError("""llvm backend is chosen, but symengine is not compiled
33843390
with llvm support.""")
33853391

3386-
return LambdaDouble(args, *exprs, real=real)
3392+
return LambdaDouble(args, *exprs, real=real, use_numpy=use_numpy)
33873393

33883394

33893395
def LambdifyCSE(args, *exprs, cse=None, concatenate=None, **kwargs):

symengine/tests/test_lambdify.py

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -183,9 +183,9 @@ def test_array_out_no_numpy():
183183
if sys.version_info[0] < 3:
184184
return # requires Py3
185185
args, exprs, inp, check = _get_array()
186-
lmb = se.Lambdify(args, exprs)
186+
lmb = se.Lambdify(args, exprs, use_numpy=False)
187187
out1 = array.array('d', [0]*len(exprs))
188-
out2 = lmb(inp, out1, use_numpy=False)
188+
out2 = lmb(inp, out1)
189189
# Ensure buffer points to still data point:
190190
assert out1.buffer_info() == out2.buffer_info()
191191
assert out1 is out2
@@ -198,10 +198,10 @@ def test_array_out_no_numpy():
198198

199199
def test_memview_out():
200200
args, exprs, inp, check = _get_array()
201-
lmb = se.Lambdify(args, exprs)
202-
cy_arr1 = lmb(inp, use_numpy=False)
201+
lmb = se.Lambdify(args, exprs, use_numpy=False)
202+
cy_arr1 = lmb(inp)
203203
check(cy_arr1)
204-
cy_arr2 = lmb(inp, cy_arr1, use_numpy=False)
204+
cy_arr2 = lmb(inp, cy_arr1)
205205
check(cy_arr2)
206206
assert cy_arr2[0] != -1
207207
cy_arr1[0] = -1
@@ -249,17 +249,17 @@ def _get_cse_exprs():
249249
def test_cse_list_input():
250250
args, exprs, inp, ref = _get_cse_exprs()
251251
lmb = se.LambdifyCSE(args, exprs, concatenate=lambda tup:
252-
tup[0]+list(tup[1]))
253-
out = lmb(inp, use_numpy=False)
252+
tup[0]+list(tup[1]), use_numpy=False)
253+
out = lmb(inp)
254254
assert allclose(out, ref)
255255

256256

257257
def test_cse_array_input():
258258
args, exprs, inp, ref = _get_cse_exprs()
259259
inp = array.array('d', inp)
260260
lmb = se.LambdifyCSE(args, exprs, concatenate=lambda tup:
261-
tup[0]+array.array('d', tup[1]))
262-
out = lmb(inp, use_numpy=False)
261+
tup[0]+array.array('d', tup[1]), use_numpy=False)
262+
out = lmb(inp)
263263
assert allclose(out, ref)
264264

265265

@@ -299,12 +299,12 @@ def test_broadcast_fortran():
299299
check(A[i, ...], inp[i, :])
300300

301301

302-
def _get_1_to_2by3_matrix():
302+
def _get_1_to_2by3_matrix(use_numpy=None):
303303
x = se.symbols('x')
304304
args = x,
305305
exprs = se.DenseMatrix(2, 3, [x+1, x+2, x+3,
306306
1/x, 1/(x*x), 1/(x**3.0)])
307-
L = se.Lambdify(args, exprs)
307+
L = se.Lambdify(args, exprs, use_numpy=use_numpy)
308308

309309
def check(A, inp):
310310
X, = inp
@@ -318,9 +318,9 @@ def check(A, inp):
318318

319319

320320
def _test_2dim_Matrix(use_numpy):
321-
L, check = _get_1_to_2by3_matrix()
321+
L, check = _get_1_to_2by3_matrix(use_numpy=use_numpy)
322322
inp = [7]
323-
check(L(inp, use_numpy=use_numpy), inp)
323+
check(L(inp), inp)
324324

325325

326326
def test_2dim_Matrix():
@@ -335,9 +335,9 @@ def test_2dim_Matrix_numpy():
335335

336336

337337
def _test_2dim_Matrix_broadcast(use_numpy):
338-
L, check = _get_1_to_2by3_matrix()
338+
L, check = _get_1_to_2by3_matrix(use_numpy=use_numpy)
339339
inp = range(1, 5)
340-
out = L(inp, use_numpy=use_numpy)
340+
out = L(inp)
341341
for i in range(len(inp)):
342342
check(out[i, ...], (inp[i],))
343343

@@ -443,10 +443,10 @@ def ravelled(A):
443443
return L
444444

445445

446-
def _get_2_to_2by2_list(real=True):
446+
def _get_2_to_2by2_list(real=True, use_numpy=None):
447447
args = x, y = se.symbols('x y')
448448
exprs = [[x + y*y, y*y], [x*y*y, se.sqrt(x)+y*y]]
449-
L = se.Lambdify(args, exprs, real=real)
449+
L = se.Lambdify(args, exprs, real=real, use_numpy=use_numpy)
450450

451451
def check(A, inp):
452452
X, Y = inp
@@ -461,19 +461,19 @@ def check(A, inp):
461461

462462

463463
def test_2_to_2by2_list():
464-
L, check = _get_2_to_2by2_list()
464+
L, check = _get_2_to_2by2_list(use_numpy=False)
465465
inp = [13, 17]
466-
A = L(inp, use_numpy=False)
466+
A = L(inp)
467467
check(A, inp)
468468

469469

470470
# @pytest.mark.skipif(not HAVE_NUMPY, reason='requires numpy')
471471
def test_2_to_2by2_numpy():
472472
if not HAVE_NUMPY: # nosetests work-around
473473
return
474-
L, check = _get_2_to_2by2_list()
474+
L, check = _get_2_to_2by2_list(use_numpy=True)
475475
inp = [13, 17]
476-
A = L(inp, use_numpy=True)
476+
A = L(inp)
477477
check(A, inp)
478478

479479

@@ -502,9 +502,9 @@ def test_unsafe_complex():
502502

503503
def test_itertools_chain():
504504
args, exprs, inp, check = _get_array()
505-
L = se.Lambdify(args, exprs)
505+
L = se.Lambdify(args, exprs, use_numpy=False)
506506
inp = itertools.chain([inp[0]], (inp[1],), [inp[2]])
507-
A = L(inp, use_numpy=False)
507+
A = L(inp)
508508
check(A)
509509

510510

0 commit comments

Comments
 (0)