Skip to content

Commit 2c2abad

Browse files
committed
Fix tests for heterogenous Lambdify
1 parent 68b88c1 commit 2c2abad

File tree

2 files changed

+15
-8
lines changed

2 files changed

+15
-8
lines changed

symengine/lib/symengine_wrapper.pyx

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2706,7 +2706,7 @@ cdef class _Lambdify(object):
27062706
self.out_shapes = [get_shape(expr) for expr in exprs]
27072707
self.n_exprs = len(exprs)
27082708
self.args_size = _size(args)
2709-
self.out_sizes = [reduce(mul, shape) for shape in self.out_shapes]
2709+
self.out_sizes = [reduce(mul, shape or (1,)) for shape in self.out_shapes]
27102710
self.accum_out_sizes = [sum(self.out_sizes[:i]) for i in range(self.n_exprs + 1)]
27112711
self.tot_out_size = sum(self.out_sizes)
27122712

@@ -2980,7 +2980,7 @@ def LambdifyCSE(args, *exprs, real=True, cse=None, concatenate=None):
29802980
new_exprs = []
29812981
n_taken = 0
29822982
for expr in exprs:
2983-
shape = get_shape(exprs)
2983+
shape = get_shape(expr)
29842984
size = long(reduce(mul, shape))
29852985
if len(shape) == 1:
29862986
new_exprs.append(flat_new_exprs[n_taken:n_taken+size])
@@ -2991,7 +2991,6 @@ def LambdifyCSE(args, *exprs, real=True, cse=None, concatenate=None):
29912991
n_taken += size
29922992
lmb = Lambdify(tuple(args) + cse_symbs, *new_exprs, real=real)
29932993
cse_lambda = Lambdify(args, cse_exprs, real=real)
2994-
29952994
def cb(inp, out=None, **kwargs):
29962995
cse_vals = cse_lambda(inp, **kwargs)
29972996
new_inp = concatenate((inp, cse_vals))

symengine/tests/test_lambdify.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,14 @@ def allclose(vec1, vec2, rtol=1e-13, atol=1e-13):
4343
return True
4444

4545

46+
def test_get_shape():
47+
get_shape = se.lib.symengine_wrapper.get_shape
48+
assert get_shape([1]) == (1,)
49+
assert get_shape([1, 1, 1]) == (3,)
50+
assert get_shape([[1], [1], [1]]) == (3, 1)
51+
assert get_shape([[1, 1, 1]]) == (1, 3)
52+
53+
4654
def test_Lambdify():
4755
n = 7
4856
args = x, y, z = se.symbols('x y z')
@@ -536,19 +544,19 @@ def _Lambdify_heterogeneous_output(Lambdify):
536544
args = se.DenseMatrix(2, 1, [x, y])
537545
v = se.DenseMatrix(2, 1, [x**3 * y, (x+1)*(y+1)])
538546
jac = v.jacobian(args)
539-
exprs = [jac, x+y, v, x*y]
540-
lmb = se.Lambdify(args, exprs)
547+
exprs = [jac, x+y, v, (x+1)*(y+1)]
548+
lmb = se.Lambdify(args, *exprs)
541549
inp0 = 7, 11
542550
inp1 = 8, 13
543551
inp2 = 5, 9
544552
inp = np.array([inp0, inp1, inp2])
545-
o_j, o_xpy, o_v, o_xty = lmb(inp, out)
553+
o_j, o_xpy, o_v, o_xty = lmb(inp)
546554
for idx, (X, Y) in enumerate([inp0, inp1, inp2]):
547555
assert np.allclose(o_j[idx, ...], [[3 * X**2 * Y, X**3],
548556
[Y + 1, X + 1]])
549557
assert np.allclose(o_xpy[idx, ...], [X+Y])
550-
assert np.allclose(o_v[idx, ...], [X**3 * Y, (X+1)*(Y+1)])
551-
assert np.allclose(o_xty[idx, ...], [X*Y])
558+
assert np.allclose(o_v[idx, ...], [[X**3 * Y], [(X+1)*(Y+1)]])
559+
assert np.allclose(o_xty[idx, ...], [(X+1)*(Y+1)])
552560

553561

554562
def test_Lambdify_heterogeneous_output():

0 commit comments

Comments
 (0)