@@ -3213,11 +3213,12 @@ IF HAVE_NUMPY:
3213
3213
3214
3214
"""
3215
3215
cdef:
3216
- cdef cnp.ndarray[cnp.float64_t, ndim= 1 , mode= ' c' ] real_inp
3217
- cdef cnp.ndarray[cnp.float64_t, ndim= 1 , mode= ' c' ] real_out
3218
- cdef cnp.ndarray[cnp.complex128_t, ndim= 1 , mode= ' c' ] cmplx_inp
3219
- cdef cnp.ndarray[cnp.complex128_t, ndim= 1 , mode= ' c' ] cmplx_out
3220
- size_t idx, nbroadcast = 1
3216
+ cnp.ndarray[cnp.float64_t, ndim= 1 , mode= ' c' ] real_inp
3217
+ cnp.ndarray[cnp.float64_t, ndim= 1 , mode= ' c' ] real_out
3218
+ cnp.ndarray[cnp.complex128_t, ndim= 1 , mode= ' c' ] cmplx_inp
3219
+ cnp.ndarray[cnp.complex128_t, ndim= 1 , mode= ' c' ] cmplx_out
3220
+ bint reshape_outs
3221
+ size_t idx, new_tot_out_size, nbroadcast = 1
3221
3222
long inp_size
3222
3223
tuple inp_shape
3223
3224
try :
@@ -3237,10 +3238,9 @@ IF HAVE_NUMPY:
3237
3238
inp_shape = inp.shape + (1 ,)
3238
3239
else :
3239
3240
inp_shape = inp.shape
3240
- new_out_shapes = [inp_shape[:- 1 ] + out_shape for out_shape in self .out_shapes]
3241
- new_out_sizes = [nbroadcast* self .out_sizes[i] for i in range (self .n_exprs)]
3242
3241
new_tot_out_size = nbroadcast * self .tot_out_size
3243
3242
if out is None :
3243
+ new_out_shapes = [inp_shape[:- 1 ] + out_shape for out_shape in self .out_shapes]
3244
3244
reshape_outs = len (new_out_shapes[0 ]) > 1
3245
3245
out = np.empty(new_tot_out_size, dtype = self .numpy_dtype)
3246
3246
if self .real:
@@ -3265,11 +3265,12 @@ IF HAVE_NUMPY:
3265
3265
if < size_t> cmplx_out.data != out.__array_interface__[' data' ][0 ]:
3266
3266
raise ValueError (" out parameter not compatible" )
3267
3267
3268
- for idx in range (nbroadcast) :
3269
- if self .real :
3268
+ if self .real :
3269
+ for idx in range (nbroadcast) :
3270
3270
self .unsafe_real(real_inp, real_out,
3271
3271
idx* self .args_size, idx* self .tot_out_size)
3272
- else :
3272
+ else :
3273
+ for idx in range (nbroadcast):
3273
3274
self .unsafe_complex(cmplx_inp, cmplx_out,
3274
3275
idx* self .args_size, idx* self .tot_out_size)
3275
3276
@@ -3281,11 +3282,9 @@ IF HAVE_NUMPY:
3281
3282
result = [out]
3282
3283
3283
3284
if self .n_exprs == 1 :
3284
- result = result[0 ]
3285
+ return result[0 ]
3285
3286
else :
3286
- result = tuple (result)
3287
-
3288
- return result
3287
+ return result
3289
3288
3290
3289
3291
3290
cdef class LambdaDouble(_Lambdify):
0 commit comments