Skip to content

Commit 871017c

Browse files
committed
Fix lambdify & LambdifyCSE
1 parent 3b2182f commit 871017c

File tree

3 files changed

+62
-38
lines changed

3 files changed

+62
-38
lines changed

symengine/__init__.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,7 @@ def lambdify(args, exprs):
2525
len(args)
2626
except TypeError:
2727
args = [args]
28-
try:
29-
len(exprs)
30-
except TypeError:
31-
exprs = [exprs]
32-
lmb = Lambdify(args, exprs)
28+
lmb = Lambdify(args, *exprs)
3329
def f(*inner_args):
3430
if len(inner_args) != len(args):
3531
raise TypeError("Incorrect number of arguments")

symengine/lib/symengine_wrapper.pyx

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3375,7 +3375,7 @@ IF HAVE_NUMPY:
33753375
new_exprs = []
33763376
n_taken = 0
33773377
for expr in exprs:
3378-
shape = get_shape(expr)
3378+
shape = get_shape(expr) or (1,)
33793379
size = long(reduce(mul, shape))
33803380
if len(shape) == 1:
33813381
new_exprs.append(flat_new_exprs[n_taken:n_taken+size])
@@ -3389,7 +3389,8 @@ IF HAVE_NUMPY:
33893389
cse_lambda = Lambdify(args, [expr.xreplace(explicit_subs) for expr in cse_exprs], **kwargs)
33903390
def cb(inp, out=None, **kw):
33913391
cse_vals = cse_lambda(inp, **kw)
3392-
new_inp = concatenate((inp, cse_vals))
3392+
print(inp, cse_vals) # DO-NOT-MERGE!
3393+
new_inp = concatenate((inp, cse_vals), axis=-1)
33933394
return lmb(new_inp, out, **kw)
33943395

33953396
return cb

symengine/tests/test_lambdify.py

Lines changed: 58 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def allclose(vec1, vec2, rtol=1e-13, atol=1e-13):
3939

4040

4141
def test_get_shape():
42-
if not have_numpy: # nosetests work-around
42+
if not have_numpy:
4343
return
4444
get_shape = se.lib.symengine_wrapper.get_shape
4545
assert get_shape([1]) == (1,)
@@ -54,7 +54,7 @@ def test_get_shape():
5454

5555

5656
def test_ravel():
57-
if not have_numpy: # nosetests work-around
57+
if not have_numpy:
5858
return
5959
x = se.symbols('x')
6060
ravel = se.lib.symengine_wrapper.ravel
@@ -64,7 +64,7 @@ def test_ravel():
6464

6565

6666
def test_Lambdify():
67-
if not have_numpy: # nosetests work-around
67+
if not have_numpy:
6868
return
6969
n = 7
7070
args = x, y, z = se.symbols('x y z')
@@ -74,7 +74,7 @@ def test_Lambdify():
7474

7575

7676
def test_Lambdify_LLVM():
77-
if not have_numpy: # nosetests work-around
77+
if not have_numpy:
7878
return
7979
n = 7
8080
args = x, y, z = se.symbols('x y z')
@@ -104,7 +104,7 @@ def check(A, inp):
104104

105105

106106
def test_Lambdify_2dim():
107-
if not have_numpy: # nosetests work-around
107+
if not have_numpy:
108108
return
109109
lmb, check = _get_2_to_2by2()
110110
for inp in [(5, 7), np.array([5, 7]), [5.0, 7.0]]:
@@ -125,7 +125,7 @@ def check(arr):
125125

126126

127127
def test_array():
128-
if not have_numpy: # nosetests work-around
128+
if not have_numpy:
129129
return
130130
args, exprs, inp, check = _get_array()
131131
lmb = se.Lambdify(args, exprs)
@@ -134,7 +134,7 @@ def test_array():
134134

135135

136136
def test_numpy_array_out_exceptions():
137-
if not have_numpy: # nosetests work-around
137+
if not have_numpy:
138138
return
139139
args, exprs, inp, check = _get_array()
140140
lmb = se.Lambdify(args, exprs)
@@ -161,7 +161,7 @@ def test_numpy_array_out_exceptions():
161161

162162

163163
def test_broadcast():
164-
if not have_numpy: # nosetests work-around
164+
if not have_numpy:
165165
return
166166
a = np.linspace(-np.pi, np.pi)
167167
inp = np.vstack((np.cos(a), np.sin(a))).T # 50 rows 2 cols
@@ -174,7 +174,7 @@ def test_broadcast():
174174

175175

176176
def test_broadcast_multiple_extra_dimensions():
177-
if not have_numpy: # nosetests work-around
177+
if not have_numpy:
178178
return
179179
inp = np.arange(12.).reshape((4, 3, 1))
180180
x = se.symbols('x')
@@ -198,7 +198,7 @@ def _get_cse_exprs():
198198

199199

200200
def test_cse():
201-
if not have_numpy: # nosetests work-around
201+
if not have_numpy:
202202
return
203203
args, exprs, inp, ref = _get_cse_exprs()
204204
lmb = se.LambdifyCSE(args, exprs)
@@ -232,7 +232,7 @@ def _get_cse_exprs_big():
232232

233233

234234
def test_cse_big():
235-
if not have_numpy: # nosetests work-around
235+
if not have_numpy:
236236
return
237237
args, exprs, inp = _get_cse_exprs_big()
238238
lmb = se.LambdifyCSE(args, exprs)
@@ -242,7 +242,7 @@ def test_cse_big():
242242

243243

244244
def test_broadcast_c():
245-
if not have_numpy: # nosetests work-around
245+
if not have_numpy:
246246
return
247247
n = 3
248248
inp = np.arange(2*n).reshape((n, 2))
@@ -254,7 +254,7 @@ def test_broadcast_c():
254254

255255

256256
def test_broadcast_fortran():
257-
if not have_numpy: # nosetests work-around
257+
if not have_numpy:
258258
return
259259
n = 3
260260
inp = np.arange(2*n).reshape((n, 2), order='F')
@@ -284,15 +284,15 @@ def check(A, inp):
284284

285285

286286
def test_2dim_Matrix():
287-
if not have_numpy: # nosetests work-around
287+
if not have_numpy:
288288
return
289289
L, check = _get_1_to_2by3_matrix()
290290
inp = [7]
291291
check(L(inp), inp)
292292

293293

294294
def test_2dim_Matrix__sympy():
295-
if not have_numpy: # nosetests work-around
295+
if not have_numpy:
296296
return
297297
import sympy as sp
298298
L, check = _get_1_to_2by3_matrix(sp.Matrix)
@@ -311,13 +311,13 @@ def _test_2dim_Matrix_broadcast():
311311

312312

313313
def test_2dim_Matrix_broadcast():
314-
if not have_numpy: # nosetests work-around
314+
if not have_numpy:
315315
return
316316
_test_2dim_Matrix_broadcast()
317317

318318

319319
def test_2dim_Matrix_broadcast_multiple_extra_dim():
320-
if not have_numpy: # nosetests work-around
320+
if not have_numpy:
321321
return
322322
L, check = _get_1_to_2by3_matrix()
323323
inp = np.arange(1, 4*5*6+1).reshape((4, 5, 6))
@@ -328,7 +328,7 @@ def test_2dim_Matrix_broadcast_multiple_extra_dim():
328328

329329

330330
def test_jacobian():
331-
if not have_numpy: # nosetests work-around
331+
if not have_numpy:
332332
return
333333
x, y = se.symbols('x, y')
334334
args = se.DenseMatrix(2, 1, [x, y])
@@ -343,7 +343,7 @@ def test_jacobian():
343343

344344

345345
def test_jacobian__broadcast():
346-
if not have_numpy: # nosetests work-around
346+
if not have_numpy:
347347
return
348348
x, y = se.symbols('x, y')
349349
args = se.DenseMatrix(2, 1, [x, y])
@@ -362,7 +362,7 @@ def test_jacobian__broadcast():
362362

363363

364364
def test_excessive_args():
365-
if not have_numpy: # nosetests work-around
365+
if not have_numpy:
366366
return
367367
x = se.symbols('x')
368368
lmb = se.Lambdify([x], [-x])
@@ -374,7 +374,7 @@ def test_excessive_args():
374374

375375

376376
def test_excessive_out():
377-
if not have_numpy: # nosetests work-around
377+
if not have_numpy:
378378
return
379379
x = se.symbols('x')
380380
lmb = se.Lambdify([x], [-x])
@@ -419,7 +419,7 @@ def check(A, inp):
419419

420420

421421
def test_2_to_2by2():
422-
if not have_numpy: # nosetests work-around
422+
if not have_numpy:
423423
return
424424
L, check = _get_2_to_2by2_list()
425425
inp = [13, 17]
@@ -428,7 +428,7 @@ def test_2_to_2by2():
428428

429429

430430
def test_unsafe_real():
431-
if not have_numpy: # nosetests work-around
431+
if not have_numpy:
432432
return
433433
L, check = _get_2_to_2by2_list()
434434
inp = np.array([13., 17.])
@@ -438,7 +438,7 @@ def test_unsafe_real():
438438

439439

440440
def test_unsafe_complex():
441-
if not have_numpy: # nosetests work-around
441+
if not have_numpy:
442442
return
443443
L, check = _get_2_to_2by2_list(real=False)
444444
assert not L.real
@@ -449,7 +449,7 @@ def test_unsafe_complex():
449449

450450

451451
def test_itertools_chain():
452-
if not have_numpy: # nosetests work-around
452+
if not have_numpy:
453453
return
454454
args, exprs, inp, check = _get_array()
455455
L = se.Lambdify(args, exprs)
@@ -460,7 +460,7 @@ def test_itertools_chain():
460460

461461
# @pytest.mark.xfail(not have_numpy, reason='array.array lacks "Zd"')
462462
def test_complex_1():
463-
if not have_numpy: # nosetests work-around
463+
if not have_numpy:
464464
return
465465
x = se.Symbol('x')
466466
lmb = se.Lambdify([x], [1j + x], real=False)
@@ -470,7 +470,7 @@ def test_complex_1():
470470

471471
# @pytest.mark.xfail(not have_numpy, reason='array.array lacks "Zd"')
472472
def test_complex_2():
473-
if not have_numpy: # nosetests work-around
473+
if not have_numpy:
474474
return
475475
x = se.Symbol('x')
476476
lmb = se.Lambdify([x], [3 + x - 1j], real=False)
@@ -482,7 +482,7 @@ def test_more_than_255_args():
482482
# SymPy's lambdify can handle at most 255 arguments
483483
# this is a proof of concept that this limitation does
484484
# not affect SymEngine's Lambdify class
485-
if not have_numpy: # nosetests work-around
485+
if not have_numpy:
486486
return
487487
n = 257
488488
x = se.symarray('x', n)
@@ -507,7 +507,7 @@ def _Lambdify_heterogeneous_output(Lambdify):
507507
v = se.DenseMatrix(2, 1, [x**3 * y, (x+1)*(y+1)])
508508
jac = v.jacobian(args)
509509
exprs = [jac, x+y, v, (x+1)*(y+1)]
510-
lmb = se.Lambdify(args, *exprs)
510+
lmb = Lambdify(args, *exprs)
511511
inp0 = 7, 11
512512
inp1 = 8, 13
513513
inp2 = 5, 9
@@ -522,12 +522,39 @@ def _Lambdify_heterogeneous_output(Lambdify):
522522

523523

524524
def test_Lambdify_heterogeneous_output():
525-
if not have_numpy: # nosetests work-around
525+
if not have_numpy:
526526
return
527527
_Lambdify_heterogeneous_output(se.Lambdify)
528528

529529

530530
def test_LambdifyCSE_heterogeneous_output():
531-
if not have_numpy: # nosetests work-around
531+
if not have_numpy:
532532
return
533533
_Lambdify_heterogeneous_output(se.LambdifyCSE)
534+
535+
536+
def _sympy_lambdify_heterogeneous_output(cb, Mtx):
537+
x, y = se.symbols('x, y')
538+
args = Mtx(2, 1, [x, y])
539+
v = Mtx(2, 1, [x**3 * y, (x+1)*(y+1)])
540+
jac = v.jacobian(args)
541+
exprs = [jac, x+y, v, (x+1)*(y+1)]
542+
lmb = cb(args, exprs)
543+
inp0 = 7, 11
544+
inp1 = 8, 13
545+
inp2 = 5, 9
546+
for idx, (X, Y) in enumerate([inp0, inp1, inp2]):
547+
o_j, o_xpy, o_v, o_xty = lmb(X, Y)
548+
assert np.allclose(o_j, [[3 * X**2 * Y, X**3],
549+
[Y + 1, X + 1]])
550+
assert np.allclose(o_xpy, [X+Y])
551+
assert np.allclose(o_v, [[X**3 * Y], [(X+1)*(Y+1)]])
552+
assert np.allclose(o_xty, [(X+1)*(Y+1)])
553+
554+
555+
def test_lambdify__sympy():
556+
if not have_numpy:
557+
return
558+
import sympy as sp
559+
_sympy_lambdify_heterogeneous_output(se.lambdify, se.DenseMatrix)
560+
_sympy_lambdify_heterogeneous_output(sp.lambdify, sp.Matrix)

0 commit comments

Comments
 (0)