Skip to content

Commit 05cc34e

Browse files
authored
Merge pull request #171 from bjodah/fix-Lambdify-heterogeneous
Fix lambdify heterogeneous
2 parents 6deb3b1 + 8655954 commit 05cc34e

File tree

2 files changed

+61
-25
lines changed

2 files changed

+61
-25
lines changed

symengine/lib/symengine_wrapper.pyx

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3389,25 +3389,24 @@ cdef class _Lambdify(object):
33893389
raise ValueError("Size of out incompatible with number of exprs.")
33903390
self.unsafe_complex(inp, out)
33913391

3392-
def __call__(self, inp, out=None):
3392+
def __call__(self, inp, *, out=None):
33933393
"""
33943394
Parameters
33953395
----------
33963396
inp: array_like
33973397
last dimension must be equal to number of arguments.
33983398
out: array_like or None (default)
3399-
Allows for for low-overhead use (output argument, must be contiguous).
3399+
Allows for low-overhead use (output argument, must be contiguous).
34003400
If ``None``: an output container will be allocated (NumPy ndarray).
34013401
If ``len(exprs) > 0`` output is found in the corresponding
3402-
order. Note that ``out`` is not reshaped.
3402+
order.
34033403
34043404
Returns
34053405
-------
34063406
If ``len(exprs) == 1``: ``numpy.ndarray``, otherwise a tuple of such.
34073407
34083408
"""
34093409
cdef:
3410-
bint reshape_outs
34113410
size_t idx, new_tot_out_size, nbroadcast = 1
34123411
long inp_size
34133412
tuple inp_shape
@@ -3432,12 +3431,10 @@ cdef class _Lambdify(object):
34323431
else:
34333432
inp_shape = inp.shape
34343433
new_tot_out_size = nbroadcast * self.tot_out_size
3434+
new_out_shapes = [inp_shape[:-1] + out_shape for out_shape in self.out_shapes]
34353435
if out is None:
3436-
new_out_shapes = [inp_shape[:-1] + out_shape for out_shape in self.out_shapes]
3437-
reshape_outs = len(new_out_shapes[0]) > 1
34383436
out = np.empty(new_tot_out_size, dtype=self.numpy_dtype)
34393437
else:
3440-
reshape_outs = False
34413438
if out.size < new_tot_out_size:
34423439
raise ValueError("Incompatible size of output argument")
34433440
if not (out.flags['C_CONTIGUOUS'] or out.flags['F_CONTIGUOUS']):
@@ -3460,12 +3457,9 @@ cdef class _Lambdify(object):
34603457
self.unsafe_complex(cmplx_inp, cmplx_out,
34613458
idx*self.args_size, idx*self.tot_out_size)
34623459

3463-
if reshape_outs:
3464-
out = out.reshape((nbroadcast, self.tot_out_size))
3465-
result = [out[:, self.accum_out_sizes[idx]:self.accum_out_sizes[idx+1]].reshape(
3466-
new_out_shapes[idx]) for idx in range(self.n_exprs)]
3467-
else:
3468-
result = [out]
3460+
out = out.reshape((nbroadcast, self.tot_out_size))
3461+
result = [out[:, self.accum_out_sizes[idx]:self.accum_out_sizes[idx+1]].reshape(
3462+
new_out_shapes[idx]) for idx in range(self.n_exprs)]
34693463

34703464
if self.n_exprs == 1:
34713465
return result[0]
@@ -3569,11 +3563,11 @@ def LambdifyCSE(args, *exprs, cse=None, concatenate=None, **kwargs):
35693563
n_taken += size
35703564
lmb = Lambdify(tuple(args) + cse_symbs, *new_exprs, **kwargs)
35713565
cse_lambda = Lambdify(args, [expr.xreplace(explicit_subs) for expr in cse_exprs], **kwargs)
3572-
def cb(inp, out=None, **kw):
3566+
def cb(inp, *, out=None, **kw):
35733567
cse_vals = cse_lambda(inp, **kw)
35743568
print(inp, cse_vals) # DO-NOT-MERGE!
35753569
new_inp = concatenate((inp, cse_vals), axis=-1)
3576-
return lmb(new_inp, out, **kw)
3570+
return lmb(new_inp, out=out, **kw)
35773571

35783572
return cb
35793573
else:

symengine/tests/test_lambdify.py

Lines changed: 52 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -140,24 +140,24 @@ def test_numpy_array_out_exceptions():
140140
lmb = se.Lambdify(args, exprs)
141141

142142
all_right = np.empty(len(exprs))
143-
lmb(inp, all_right)
143+
lmb(inp, out=all_right)
144144

145145
too_short = np.empty(len(exprs) - 1)
146-
raises(ValueError, lambda: (lmb(inp, too_short)))
146+
raises(ValueError, lambda: (lmb(inp, out=too_short)))
147147

148148
wrong_dtype = np.empty(len(exprs), dtype=int)
149-
raises(ValueError, lambda: (lmb(inp, wrong_dtype)))
149+
raises(ValueError, lambda: (lmb(inp, out=wrong_dtype)))
150150

151151
read_only = np.empty(len(exprs))
152152
read_only.flags['WRITEABLE'] = False
153-
raises(ValueError, lambda: (lmb(inp, read_only)))
153+
raises(ValueError, lambda: (lmb(inp, out=read_only)))
154154

155155
all_right_broadcast = np.empty((4, len(exprs)))
156156
inp_bcast = [[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]]
157-
lmb(np.array(inp_bcast), all_right_broadcast)
157+
lmb(np.array(inp_bcast), out=all_right_broadcast)
158158

159159
noncontig_broadcast = np.empty((4, len(exprs), 3)).transpose((1, 2, 0))
160-
raises(ValueError, lambda: (lmb(inp_bcast, noncontig_broadcast)))
160+
raises(ValueError, lambda: (lmb(inp_bcast, out=noncontig_broadcast)))
161161

162162

163163
def test_broadcast():
@@ -337,7 +337,7 @@ def test_jacobian():
337337
lmb = se.Lambdify(args, jac)
338338
out = np.empty((2, 2))
339339
inp = X, Y = 7, 11
340-
lmb(inp, out)
340+
lmb(inp, out=out)
341341
assert np.allclose(out, [[3 * X**2 * Y, X**3],
342342
[Y + 1, X + 1]])
343343

@@ -355,7 +355,7 @@ def test_jacobian__broadcast():
355355
inp1 = 8, 13
356356
inp2 = 5, 9
357357
inp = np.array([inp0, inp1, inp2])
358-
lmb(inp, out)
358+
lmb(inp, out=out)
359359
for idx, (X, Y) in enumerate([inp0, inp1, inp2]):
360360
assert np.allclose(out[idx, ...], [[3 * X**2 * Y, X**3],
361361
[Y + 1, X + 1]])
@@ -380,9 +380,8 @@ def test_excessive_out():
380380
lmb = se.Lambdify([x], [-x])
381381
inp = np.ones(1)
382382
out = np.ones(2)
383-
out = lmb(inp, out)
383+
_ = lmb(inp, out=out[:inp.size])
384384
assert np.allclose(inp, [1, 1])
385-
assert out.shape == (2,)
386385
assert out[0] == -1
387386
assert out[1] == 1
388387

@@ -558,3 +557,46 @@ def test_lambdify__sympy():
558557
import sympy as sp
559558
_sympy_lambdify_heterogeneous_output(se.lambdify, se.DenseMatrix)
560559
_sympy_lambdify_heterogeneous_output(sp.lambdify, sp.Matrix)
560+
561+
562+
def _test_Lambdify_scalar_vector_matrix(Lambdify):
563+
if not have_numpy:
564+
return
565+
args = x, y = se.symbols('x y')
566+
vec = se.DenseMatrix([x+y, x*y])
567+
jac = vec.jacobian(se.DenseMatrix(args))
568+
f = Lambdify(args, x**y, vec, jac)
569+
assert f.n_exprs == 3
570+
s, v, m = f([2, 3])
571+
print(s, v, m)
572+
assert s == 2**3
573+
assert np.allclose(v, [[2+3], [2*3]])
574+
assert np.allclose(m, [
575+
[1, 1],
576+
[3, 2]
577+
])
578+
579+
s2, v2, m2 = f([[2, 3], [5, 7]])
580+
assert np.allclose(s2, [2**3, 5**7])
581+
assert np.allclose(v2, [
582+
[[2+3], [2*3]],
583+
[[5+7], [5*7]]
584+
])
585+
assert np.allclose(m2, [
586+
[
587+
[1, 1],
588+
[3, 2]
589+
],
590+
[
591+
[1, 1],
592+
[7, 5]
593+
]
594+
])
595+
596+
597+
def test_Lambdify_scalar_vector_matrix():
598+
_test_Lambdify_scalar_vector_matrix(lambda *args: se.Lambdify(*args, backend='lambda'))
599+
_test_Lambdify_scalar_vector_matrix(lambda *args: se.LambdifyCSE(*args, backend='lambda'))
600+
if se.have_llvm:
601+
_test_Lambdify_scalar_vector_matrix(lambda *args: se.Lambdify(*args, backend='llvm'))
602+
_test_Lambdify_scalar_vector_matrix(lambda *args: se.LambdifyCSE(*args, backend='llvm'))

0 commit comments

Comments
 (0)