Skip to content

Commit 3b2182f

Browse files
committed
Optimizing _Lambdify.__call__
1 parent 1a015f2 commit 3b2182f

File tree

2 files changed

+13
-14
lines changed

2 files changed

+13
-14
lines changed

symengine/lib/symengine_wrapper.pyx

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3213,11 +3213,12 @@ IF HAVE_NUMPY:
32133213
32143214
"""
32153215
cdef:
3216-
cdef cnp.ndarray[cnp.float64_t, ndim=1, mode='c'] real_inp
3217-
cdef cnp.ndarray[cnp.float64_t, ndim=1, mode='c'] real_out
3218-
cdef cnp.ndarray[cnp.complex128_t, ndim=1, mode='c'] cmplx_inp
3219-
cdef cnp.ndarray[cnp.complex128_t, ndim=1, mode='c'] cmplx_out
3220-
size_t idx, nbroadcast = 1
3216+
cnp.ndarray[cnp.float64_t, ndim=1, mode='c'] real_inp
3217+
cnp.ndarray[cnp.float64_t, ndim=1, mode='c'] real_out
3218+
cnp.ndarray[cnp.complex128_t, ndim=1, mode='c'] cmplx_inp
3219+
cnp.ndarray[cnp.complex128_t, ndim=1, mode='c'] cmplx_out
3220+
bint reshape_outs
3221+
size_t idx, new_tot_out_size, nbroadcast = 1
32213222
long inp_size
32223223
tuple inp_shape
32233224
try:
@@ -3237,10 +3238,9 @@ IF HAVE_NUMPY:
32373238
inp_shape = inp.shape + (1,)
32383239
else:
32393240
inp_shape = inp.shape
3240-
new_out_shapes = [inp_shape[:-1] + out_shape for out_shape in self.out_shapes]
3241-
new_out_sizes = [nbroadcast*self.out_sizes[i] for i in range(self.n_exprs)]
32423241
new_tot_out_size = nbroadcast * self.tot_out_size
32433242
if out is None:
3243+
new_out_shapes = [inp_shape[:-1] + out_shape for out_shape in self.out_shapes]
32443244
reshape_outs = len(new_out_shapes[0]) > 1
32453245
out = np.empty(new_tot_out_size, dtype=self.numpy_dtype)
32463246
if self.real:
@@ -3265,11 +3265,12 @@ IF HAVE_NUMPY:
32653265
if <size_t>cmplx_out.data != out.__array_interface__['data'][0]:
32663266
raise ValueError("out parameter not compatible")
32673267

3268-
for idx in range(nbroadcast):
3269-
if self.real:
3268+
if self.real:
3269+
for idx in range(nbroadcast):
32703270
self.unsafe_real(real_inp, real_out,
32713271
idx*self.args_size, idx*self.tot_out_size)
3272-
else:
3272+
else:
3273+
for idx in range(nbroadcast):
32733274
self.unsafe_complex(cmplx_inp, cmplx_out,
32743275
idx*self.args_size, idx*self.tot_out_size)
32753276

@@ -3281,11 +3282,9 @@ IF HAVE_NUMPY:
32813282
result = [out]
32823283

32833284
if self.n_exprs == 1:
3284-
result = result[0]
3285+
return result[0]
32853286
else:
3286-
result = tuple(result)
3287-
3288-
return result
3287+
return result
32893288

32903289

32913290
cdef class LambdaDouble(_Lambdify):

0 commit comments

Comments
 (0)