Skip to content

Commit 0395740

Browse files
committed
Optimize Lambdify.__call__ performance
1 parent b713a61 commit 0395740

File tree

2 files changed

+20
-27
lines changed

2 files changed

+20
-27
lines changed

symengine/lib/symengine_wrapper.pyx

Lines changed: 18 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -3085,9 +3085,11 @@ cdef class _Lambdify(object):
30853085
cdef readonly int n_exprs
30863086
cdef int *out_sizes
30873087
cdef int *accum_out_sizes
3088+
cdef object numpy_dtype
30883089

30893090
def __cinit__(self, args, *exprs, bool real=True):
30903091
self.real = real
3092+
self.numpy_dtype = np.float64 if self.real else np.complex128
30913093
self.out_shapes = [get_shape(expr) for expr in exprs]
30923094
self.n_exprs = len(exprs)
30933095
self.args_size = _size(args)
@@ -3189,14 +3191,13 @@ cdef class _Lambdify(object):
31893191
cdef cnp.ndarray[cnp.float64_t, ndim=1, mode='c'] real_out
31903192
cdef cnp.ndarray[cnp.complex128_t, ndim=1, mode='c'] cmplx_inp
31913193
cdef cnp.ndarray[cnp.complex128_t, ndim=1, mode='c'] cmplx_out
3192-
size_t nbroadcast = 1
3194+
size_t idx, nbroadcast = 1
31933195
long inp_size
31943196
tuple inp_shape
3195-
numpy_dtype = np.float64 if self.real else np.complex128
31963197
try:
3197-
inp = np.ascontiguousarray(inp, dtype=numpy_dtype)
3198+
inp = np.ascontiguousarray(inp, dtype=self.numpy_dtype)
31983199
except TypeError:
3199-
inp = np.fromiter(inp, dtype=numpy_dtype)
3200+
inp = np.fromiter(inp, dtype=self.numpy_dtype)
32003201
inp_shape = inp.shape
32013202
if self.real:
32023203
real_inp = inp.ravel()
@@ -3215,33 +3216,28 @@ cdef class _Lambdify(object):
32153216
new_tot_out_size = nbroadcast * self.tot_out_size
32163217
if out is None:
32173218
reshape_outs = len(new_out_shapes[0]) > 1
3218-
out = np.empty(new_tot_out_size, dtype=numpy_dtype)
3219+
out = np.empty(new_tot_out_size, dtype=self.numpy_dtype)
3220+
if self.real:
3221+
real_out = out
3222+
else:
3223+
cmplx_out = out
32193224
else:
32203225
reshape_outs = False
32213226
if out.size < new_tot_out_size:
32223227
raise ValueError("Incompatible size of output argument")
32233228
if not (out.flags['C_CONTIGUOUS'] or out.flags['F_CONTIGUOUS']):
32243229
raise ValueError("Output argument needs to be C-contiguous")
3225-
if self.n_exprs == 1:
3226-
for idx, avail in enumerate(out.shape[-len(self.out_shapes[0]):]):
3227-
req = self.out_shapes[0][idx-len(self.out_shapes[0])]
3228-
if idx + out.ndim - len(self.out_shapes[0]) == 0:
3229-
ok = avail >= req
3230-
else:
3231-
ok = avail == req
3232-
if not ok:
3233-
raise ValueError("Incompatible shape of output argument")
32343230
if not out.flags['WRITEABLE']:
32353231
raise ValueError("Output argument needs to be writeable")
32363232

3237-
if self.real:
3238-
real_out = out.ravel()
3239-
if <size_t>real_out.data != out.__array_interface__['data'][0]:
3240-
raise ValueError("out parameter not compatible")
3241-
else:
3242-
cmplx_out = out.ravel()
3243-
if <size_t>cmplx_out.data != out.__array_interface__['data'][0]:
3244-
raise ValueError("out parameter not compatible")
3233+
if self.real:
3234+
real_out = out.ravel()
3235+
if <size_t>real_out.data != out.__array_interface__['data'][0]:
3236+
raise ValueError("out parameter not compatible")
3237+
else:
3238+
cmplx_out = out.ravel()
3239+
if <size_t>cmplx_out.data != out.__array_interface__['data'][0]:
3240+
raise ValueError("out parameter not compatible")
32453241

32463242
for idx in range(nbroadcast):
32473243
if self.real:

symengine/tests/test_lambdify.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -153,11 +153,8 @@ def test_numpy_array_out_exceptions():
153153
inp_bcast = [[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]]
154154
lmb(np.array(inp_bcast), all_right_broadcast)
155155

156-
f_contig_broadcast = np.empty((4, len(exprs)), order='F')
157-
raises(ValueError, lambda: (lmb(inp_bcast, f_contig_broadcast)))
158-
159-
improper_bcast = np.empty((4, len(exprs)+1))
160-
raises(ValueError, lambda: (lmb(inp_bcast, improper_bcast)))
156+
noncontig_broadcast = np.empty((4, len(exprs), 3)).transpose((1, 2, 0))
157+
raises(ValueError, lambda: (lmb(inp_bcast, noncontig_broadcast)))
161158

162159

163160
# @pytest.mark.skipif(not HAVE_NUMPY, reason='requires numpy')

0 commit comments

Comments
 (0)