Skip to content

Commit 1a015f2

Browse files
committed
Handle SymPy matrix in Lambdify. Update benchmark.
1 parent 18fc670 commit 1a015f2

File tree

4 files changed

+32
-54
lines changed

4 files changed

+32
-54
lines changed

benchmarks/Lambdify_6_links_reference.pyx

Lines changed: 0 additions & 10 deletions
This file was deleted.

benchmarks/Lambdify_6_links_rhs.py renamed to benchmarks/heterogenous_output_Lambdify.py

Lines changed: 14 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -8,55 +8,45 @@
88
import symengine as se
99
import warnings
1010

11-
exprs_s = open(os.path.join(os.path.dirname(__file__), '6_links_rhs.txt'), 'tr').read()
12-
exprs = parse_expr(exprs_s, transformations=standard_transformations)
13-
args = sp.Matrix(1, 14, exprs).free_symbols
11+
src = os.path.join(os.path.dirname(__file__), '6_links_rhs.txt')
12+
serial = open(src, 'tr').read()
13+
parsed = parse_expr(serial, transformations=standard_transformations)
14+
vec = sp.Matrix(1, 14, parsed)
15+
args = tuple(sorted(vec.free_symbols, key=lambda arg: arg.name))
16+
exprs = vec, vec.jacobian(args[:-14])
1417
inp = np.ones(len(args))
1518
assert inp.size == 26
16-
print([expr.subs(dict(zip(args, [1]*len(args)))) for expr in exprs])
1719

18-
# Real-life example (ion speciation problem in water chemistry)
1920

20-
21-
lmb_sp = sp.lambdify(args, exprs, modules='math')
22-
lmb_se = se.Lambdify(args, exprs)
23-
# lmb_se_cse = se.LambdifyCSE(args, exprs)
24-
lmb_se_llvm = se.Lambdify(args, exprs, backend='llvm')
21+
lmb_sp = sp.lambdify(args, exprs, modules=['math', 'sympy'])
22+
lmb_se = se.Lambdify(args, *exprs)
23+
lmb_se_llvm = se.Lambdify(args, *exprs, backend='llvm')
2524

2625

2726
lmb_sp(*inp)
2827
tim_sympy = clock()
2928
for i in range(500):
30-
res_sympy = lmb_sp(*inp)
29+
v, m = lmb_sp(*inp)
3130
tim_sympy = clock() - tim_sympy
3231

3332
lmb_se(inp)
3433
tim_se = clock()
35-
res_se = np.empty(len(exprs))
3634
for i in range(500):
37-
res_se = lmb_se(inp)
35+
v, m = lmb_se(inp)
3836
tim_se = clock() - tim_se
3937

40-
# lmb_se_cse(inp)
41-
# tim_se_cse = clock()
42-
# res_se_cse = np.empty(len(exprs))
43-
# for i in range(500):
44-
# res_se_cse = lmb_se_cse(inp)
45-
# tim_se_cse = clock() - tim_se_cse
4638

4739
lmb_se_llvm(inp)
4840
tim_se_llvm = clock()
4941
res_se_llvm = np.empty(len(exprs))
5042
for i in range(500):
51-
res_se_llvm = lmb_se_llvm(inp)
43+
v, m = lmb_se_llvm(inp)
5244
tim_se_llvm = clock() - tim_se_llvm
5345

5446

5547
print('SymEngine (lambda double) speed-up factor (higher is better) vs sympy: %12.5g' %
5648
(tim_sympy/tim_se))
5749

58-
# print('symengine (lambda double + CSE) speed-up factor (higher is better) vs sympy: %12.5g' %
59-
# (tim_sympy/tim_se_cse))
6050

6151
print('symengine (LLVM) speed-up factor (higher is better) vs sympy: %12.5g' %
6252
(tim_sympy/tim_se_llvm))
@@ -80,28 +70,14 @@ def func(*args):
8070
return result
8171
return func
8272

83-
lmb_se_llvm_manual = ManualLLVM(args, np.array(exprs))
73+
lmb_se_llvm_manual = ManualLLVM(args, *exprs)
8474
lmb_se_llvm_manual(inp)
8575
tim_se_llvm_manual = clock()
86-
res_se_llvm_manual = np.empty(len(exprs))
8776
for i in range(500):
88-
res_se_llvm_manual = lmb_se_llvm_manual(inp)
77+
v, m = lmb_se_llvm_manual(inp)
8978
tim_se_llvm_manual = clock() - tim_se_llvm_manual
9079
print('symengine (ManualLLVM) speed-up factor (higher is better) vs sympy: %12.5g' %
9180
(tim_sympy/tim_se_llvm_manual))
9281

9382
if tim_se_llvm_manual < tim_se_llvm:
9483
warnings.warn("Cython code for Lambdify.__call__ is slow.")
95-
96-
import setuptools
97-
import pyximport
98-
pyximport.install()
99-
from Lambdify_6_links_reference import _benchmark_reference_for_Lambdify as lmb_ref
100-
101-
lmb_ref(inp)
102-
tim_ref = clock()
103-
for i in range(500):
104-
res_ref = lmb_ref(inp)
105-
tim_ref = clock() - tim_ref
106-
print('Hard-coded Cython code speed-up factor (higher is better) vs sympy: %12.5g' %
107-
(tim_sympy/tim_ref))

symengine/lib/symengine_wrapper.pyx

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3060,7 +3060,10 @@ IF HAVE_NUMPY:
30603060
try:
30613061
return ndarr.ravel()
30623062
except AttributeError:
3063-
return _ravel_nested(ndarr)
3063+
try:
3064+
return _ravel_nested(ndarr.tolist())
3065+
except AttributeError:
3066+
return _ravel_nested(ndarr)
30643067

30653068

30663069
cdef class _Lambdify(object):
@@ -3139,8 +3142,8 @@ IF HAVE_NUMPY:
31393142
for ci in range(nc):
31403143
args_.push_back(deref(mtx).get(ri, ci))
31413144
else:
3142-
for e in args:
3143-
e_ = _sympify(e)
3145+
for arg in args:
3146+
e_ = _sympify(arg)
31443147
args_.push_back(e_.thisptr)
31453148

31463149

symengine/tests/test_lambdify.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -265,11 +265,11 @@ def test_broadcast_fortran():
265265
check(A[i, ...], inp[i, :])
266266

267267

268-
def _get_1_to_2by3_matrix():
268+
def _get_1_to_2by3_matrix(Mtx=se.DenseMatrix):
269269
x = se.symbols('x')
270270
args = x,
271-
exprs = se.DenseMatrix(2, 3, [x+1, x+2, x+3,
272-
1/x, 1/(x*x), 1/(x**3.0)])
271+
exprs = Mtx(2, 3, [x+1, x+2, x+3,
272+
1/x, 1/(x*x), 1/(x**3.0)])
273273
L = se.Lambdify(args, exprs)
274274

275275
def check(A, inp):
@@ -291,6 +291,15 @@ def test_2dim_Matrix():
291291
check(L(inp), inp)
292292

293293

294+
def test_2dim_Matrix__sympy():
295+
if not have_numpy: # nosetests work-around
296+
return
297+
import sympy as sp
298+
L, check = _get_1_to_2by3_matrix(sp.Matrix)
299+
inp = [7]
300+
check(L(inp), inp)
301+
302+
294303

295304
def _test_2dim_Matrix_broadcast():
296305
L, check = _get_1_to_2by3_matrix()

0 commit comments

Comments
 (0)