@@ -3085,9 +3085,11 @@ cdef class _Lambdify(object):
3085
3085
cdef readonly int n_exprs
3086
3086
cdef int * out_sizes
3087
3087
cdef int * accum_out_sizes
3088
+ cdef object numpy_dtype
3088
3089
3089
3090
def __cinit__ (self , args , *exprs , bool real = True ):
3090
3091
self .real = real
3092
+ self .numpy_dtype = np.float64 if self .real else np.complex128
3091
3093
self .out_shapes = [get_shape(expr) for expr in exprs]
3092
3094
self .n_exprs = len (exprs)
3093
3095
self .args_size = _size(args)
@@ -3189,14 +3191,13 @@ cdef class _Lambdify(object):
3189
3191
cdef cnp.ndarray[cnp.float64_t, ndim= 1 , mode= ' c' ] real_out
3190
3192
cdef cnp.ndarray[cnp.complex128_t, ndim= 1 , mode= ' c' ] cmplx_inp
3191
3193
cdef cnp.ndarray[cnp.complex128_t, ndim= 1 , mode= ' c' ] cmplx_out
3192
- size_t nbroadcast = 1
3194
+ size_t idx, nbroadcast = 1
3193
3195
long inp_size
3194
3196
tuple inp_shape
3195
- numpy_dtype = np.float64 if self .real else np.complex128
3196
3197
try :
3197
- inp = np.ascontiguousarray(inp, dtype = numpy_dtype)
3198
+ inp = np.ascontiguousarray(inp, dtype = self . numpy_dtype)
3198
3199
except TypeError :
3199
- inp = np.fromiter(inp, dtype = numpy_dtype)
3200
+ inp = np.fromiter(inp, dtype = self . numpy_dtype)
3200
3201
inp_shape = inp.shape
3201
3202
if self .real:
3202
3203
real_inp = inp.ravel()
@@ -3215,33 +3216,28 @@ cdef class _Lambdify(object):
3215
3216
new_tot_out_size = nbroadcast * self .tot_out_size
3216
3217
if out is None :
3217
3218
reshape_outs = len (new_out_shapes[0 ]) > 1
3218
- out = np.empty(new_tot_out_size, dtype = numpy_dtype)
3219
+ out = np.empty(new_tot_out_size, dtype = self .numpy_dtype)
3220
+ if self .real:
3221
+ real_out = out
3222
+ else :
3223
+ cmplx_out = out
3219
3224
else :
3220
3225
reshape_outs = False
3221
3226
if out.size < new_tot_out_size:
3222
3227
raise ValueError (" Incompatible size of output argument" )
3223
3228
if not (out.flags[' C_CONTIGUOUS' ] or out.flags[' F_CONTIGUOUS' ]):
3224
3229
raise ValueError (" Output argument needs to be C-contiguous" )
3225
- if self .n_exprs == 1 :
3226
- for idx, avail in enumerate (out.shape[- len (self .out_shapes[0 ]):]):
3227
- req = self .out_shapes[0 ][idx- len (self .out_shapes[0 ])]
3228
- if idx + out.ndim - len (self .out_shapes[0 ]) == 0 :
3229
- ok = avail >= req
3230
- else :
3231
- ok = avail == req
3232
- if not ok:
3233
- raise ValueError (" Incompatible shape of output argument" )
3234
3230
if not out.flags[' WRITEABLE' ]:
3235
3231
raise ValueError (" Output argument needs to be writeable" )
3236
3232
3237
- if self .real:
3238
- real_out = out.ravel()
3239
- if < size_t> real_out.data != out.__array_interface__[' data' ][0 ]:
3240
- raise ValueError (" out parameter not compatible" )
3241
- else :
3242
- cmplx_out = out.ravel()
3243
- if < size_t> cmplx_out.data != out.__array_interface__[' data' ][0 ]:
3244
- raise ValueError (" out parameter not compatible" )
3233
+ if self .real:
3234
+ real_out = out.ravel()
3235
+ if < size_t> real_out.data != out.__array_interface__[' data' ][0 ]:
3236
+ raise ValueError (" out parameter not compatible" )
3237
+ else :
3238
+ cmplx_out = out.ravel()
3239
+ if < size_t> cmplx_out.data != out.__array_interface__[' data' ][0 ]:
3240
+ raise ValueError (" out parameter not compatible" )
3245
3241
3246
3242
for idx in range (nbroadcast):
3247
3243
if self .real:
0 commit comments