Skip to content

Commit c0ffda2

Browse files
committed
Use symengine's cse
1 parent a12f736 commit c0ffda2

File tree

3 files changed

+25
-80
lines changed

3 files changed

+25
-80
lines changed

symengine/lib/symengine.pxd

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -960,11 +960,11 @@ cdef extern from "<symengine/eval_double.h>" namespace "SymEngine":
960960
cdef extern from "<symengine/lambda_double.h>" namespace "SymEngine":
961961
cdef cppclass LambdaRealDoubleVisitor:
962962
LambdaRealDoubleVisitor() nogil
963-
void init(const vec_basic &x, const vec_basic &b) nogil except +
963+
void init(const vec_basic &x, const vec_basic &b, bool cse) nogil except +
964964
void call(double *r, const double *x) nogil
965965
cdef cppclass LambdaComplexDoubleVisitor:
966966
LambdaComplexDoubleVisitor() nogil
967-
void init(const vec_basic &x, const vec_basic &b) nogil except +
967+
void init(const vec_basic &x, const vec_basic &b, bool cse) nogil except +
968968
void call(double complex *r, const double complex *x) nogil
969969

970970
cdef extern from "<symengine/llvm_double.h>" namespace "SymEngine":

symengine/lib/symengine_wrapper.pyx

Lines changed: 17 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -4369,7 +4369,7 @@ cdef class _Lambdify(object):
43694369
cdef vector[int] accum_out_sizes
43704370
cdef object numpy_dtype
43714371

4372-
def __init__(self, args, *exprs, cppbool real=True, order='C'):
4372+
def __init__(self, args, *exprs, cppbool real=True, order='C', cppbool cse=False):
43734373
cdef:
43744374
Basic e_
43754375
size_t ri, ci, nr, nc
@@ -4409,9 +4409,9 @@ cdef class _Lambdify(object):
44094409
for e in np.ravel(curr_expr, order=self.order):
44104410
e_ = _sympify(e)
44114411
outs_.push_back(e_.thisptr)
4412-
self._init(args_, outs_)
4412+
self._init(args_, outs_, cse)
44134413

4414-
cdef _init(self, symengine.vec_basic& args_, symengine.vec_basic& outs_):
4414+
cdef _init(self, symengine.vec_basic& args_, symengine.vec_basic& outs_, cppbool cse):
44154415
raise ValueError("Not supported")
44164416

44174417
cpdef unsafe_real(self,
@@ -4590,13 +4590,13 @@ cdef class LambdaDouble(_Lambdify):
45904590
cdef vector[symengine.LambdaRealDoubleVisitor] lambda_double
45914591
cdef vector[symengine.LambdaComplexDoubleVisitor] lambda_double_complex
45924592

4593-
cdef _init(self, symengine.vec_basic& args_, symengine.vec_basic& outs_):
4593+
cdef _init(self, symengine.vec_basic& args_, symengine.vec_basic& outs_, cppbool cse):
45944594
if self.real:
45954595
self.lambda_double.resize(1)
4596-
self.lambda_double[0].init(args_, outs_)
4596+
self.lambda_double[0].init(args_, outs_, cse)
45974597
else:
45984598
self.lambda_double_complex.resize(1)
4599-
self.lambda_double_complex[0].init(args_, outs_)
4599+
self.lambda_double_complex[0].init(args_, outs_, cse)
46004600

46014601
cpdef unsafe_real(self, double[::1] inp, double[::1] out, int inp_offset=0, int out_offset=0):
46024602
self.lambda_double[0].call(&out[out_offset], &inp[inp_offset])
@@ -4621,9 +4621,9 @@ IF HAVE_SYMENGINE_LLVM:
46214621

46224622
cdef vector[symengine.LLVMDoubleVisitor] lambda_double
46234623

4624-
cdef _init(self, symengine.vec_basic& args_, symengine.vec_basic& outs_):
4624+
cdef _init(self, symengine.vec_basic& args_, symengine.vec_basic& outs_, cppbool cse):
46254625
self.lambda_double.resize(1)
4626-
self.lambda_double[0].init(args_, outs_)
4626+
self.lambda_double[0].init(args_, outs_, cse)
46274627

46284628
cpdef unsafe_real(self, double[::1] inp, double[::1] out, int inp_offset=0, int out_offset=0):
46294629
self.lambda_double[0].call(&out[out_offset], &inp[inp_offset])
@@ -4640,7 +4640,7 @@ IF HAVE_SYMENGINE_LLVM:
46404640
return create_low_level_callable(self, addr1, addr2)
46414641

46424642

4643-
def Lambdify(args, *exprs, cppbool real=True, backend=None, order='C', as_scipy=False):
4643+
def Lambdify(args, *exprs, cppbool real=True, backend=None, order='C', as_scipy=False, cse=False):
46444644
"""
46454645
Lambdify instances are callbacks that numerically evaluate their symbolic
46464646
expressions from user provided input (real or complex) into (possibly user
@@ -4666,6 +4666,9 @@ def Lambdify(args, *exprs, cppbool real=True, backend=None, order='C', as_scipy=
46664666
as_scipy : bool
46674667
return a SciPy LowLevelCallable which can be used in SciPy's integrate
46684668
methods
4669+
cse : bool
4670+
Run Common Subexpression Elimination on the output before generating
4671+
the callback.
46694672
46704673
Returns
46714674
-------
@@ -4687,7 +4690,7 @@ def Lambdify(args, *exprs, cppbool real=True, backend=None, order='C', as_scipy=
46874690
backend = os.getenv('SYMENGINE_LAMBDIFY_BACKEND', "lambda")
46884691
if backend == "llvm":
46894692
IF HAVE_SYMENGINE_LLVM:
4690-
ret = LLVMDouble(args, *exprs, real=real, order=order)
4693+
ret = LLVMDouble(args, *exprs, real=real, order=order, cse=cse)
46914694
if as_scipy:
46924695
return ret.as_scipy_low_level_callable()
46934696
return ret
@@ -4698,63 +4701,17 @@ def Lambdify(args, *exprs, cppbool real=True, backend=None, order='C', as_scipy=
46984701
pass
46994702
else:
47004703
warnings.warn("Unknown SymEngine backend: %s\nUsing backend='lambda'" % backend)
4701-
ret = LambdaDouble(args, *exprs, real=real, order=order)
4704+
ret = LambdaDouble(args, *exprs, real=real, order=order, cse=cse)
47024705
if as_scipy:
47034706
return ret.as_scipy_low_level_callable()
47044707
return ret
47054708

47064709

4707-
def LambdifyCSE(args, *exprs, cse=None, order='C', **kwargs):
4710+
def LambdifyCSE(args, *exprs, order='C', **kwargs):
47084711
""" Analogous with Lambdify but performs common subexpression elimination.
4709-
4710-
See docstring of Lambdify.
4711-
4712-
Parameters
4713-
----------
4714-
args: iterable of symbols
4715-
exprs: iterable of expressions (with symbols from args)
4716-
cse: callback (default: None)
4717-
defaults to sympy.cse (see SymPy documentation)
4718-
order : str
4719-
order (passed to numpy.ravel and numpy.reshape).
4720-
\\*\\*kwargs: Keyword arguments passed onto Lambdify
4721-
47224712
"""
4723-
if cse is None:
4724-
from sympy import cse
4725-
_exprs = [np.asanyarray(e) for e in exprs]
4726-
_args = np.ravel(args, order=order)
4727-
from sympy import sympify as s_sympify
4728-
flat_exprs = list(itertools.chain(*[np.ravel(e, order=order) for e in _exprs]))
4729-
subs, flat_new_exprs = cse([s_sympify(expr) for expr in flat_exprs])
4730-
4731-
if subs:
4732-
explicit_subs = {}
4733-
for k, v in subs:
4734-
explicit_subs[k] = v.xreplace(explicit_subs)
4735-
4736-
cse_symbs, cse_exprs = zip(*subs)
4737-
new_exprs = []
4738-
n_taken = 0
4739-
for expr in _exprs:
4740-
new_exprs.append(np.reshape(flat_new_exprs[n_taken:n_taken+expr.size],
4741-
expr.shape, order=order))
4742-
n_taken += expr.size
4743-
new_lmb = Lambdify(tuple(_args) + cse_symbs, *new_exprs, order=order, **kwargs)
4744-
cse_lambda = Lambdify(_args, [ce.xreplace(explicit_subs) for ce in cse_exprs], **kwargs)
4745-
def cb(inp, *, out=None, **kw):
4746-
_inp = np.asanyarray(inp)
4747-
cse_vals = cse_lambda(_inp, **kw)
4748-
if order == 'C':
4749-
new_inp = np.concatenate((_inp[(Ellipsis,) + (np.newaxis,)*(cse_vals.ndim - _inp.ndim)],
4750-
cse_vals), axis=-1)
4751-
else:
4752-
new_inp = np.concatenate((_inp[(np.newaxis,)*(cse_vals.ndim - _inp.ndim) + (Ellipsis,)],
4753-
cse_vals), axis=0)
4754-
return new_lmb(new_inp, out=out, **kw)
4755-
return cb
4756-
else:
4757-
return Lambdify(args, *exprs, **kwargs)
4713+
warnings.warn("LambdifyCSE is deprecated. Use Lambdify(..., cse=True)", DeprecationWarning)
4714+
return Lambdify(args, *exprs, cse=True, order=order, **kwargs)
47584715

47594716

47604717
def ccode(expr):

symengine/tests/test_lambdify.py

Lines changed: 6 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -193,32 +193,28 @@ def test_broadcast_multiple_extra_dimensions():
193193
assert abs(out[-1, -1, 1] - 11**3) < 1e-14
194194

195195

196-
@unittest.skipUnless(have_sympy, "SymPy not installed")
197196
def _get_cse_exprs():
198-
import sympy as sp
199-
args = x, y = sp.symbols('x y')
197+
args = x, y = se.symbols('x y')
200198
exprs = [x*x + y, y/(x*x), y*x*x+x]
201199
inp = [11, 13]
202200
ref = [121+13, 13/121, 13*121 + 11]
203201
return args, exprs, inp, ref
204202

205203

206204
@unittest.skipUnless(have_numpy, "Numpy not installed")
207-
@unittest.skipUnless(have_sympy, "SymPy not installed")
208205
def test_cse():
209206
args, exprs, inp, ref = _get_cse_exprs()
210-
lmb = se.LambdifyCSE(args, exprs)
207+
lmb = se.Lambdify(args, exprs, cse=True)
211208
out = lmb(inp)
212209
assert allclose(out, ref)
213210

214211

215212
@unittest.skipUnless(have_numpy, "Numpy not installed")
216-
@unittest.skipUnless(have_sympy, "SymPy not installed")
217213
def test_cse_gh174():
218214
x = se.symbols('x')
219215
funcs = [se.cos(x)**i for i in range(5)]
220216
f_lmb = se.Lambdify([x], funcs)
221-
f_cse = se.LambdifyCSE([x], funcs)
217+
f_cse = se.Lambdify([x], funcs, cse=True)
222218
a = np.array([1, 2, 3])
223219
assert np.allclose(f_lmb(a), f_cse(a))
224220

@@ -250,10 +246,9 @@ def _get_cse_exprs_big():
250246

251247

252248
@unittest.skipUnless(have_numpy, "Numpy not installed")
253-
@unittest.skipUnless(have_sympy, "SymPy not installed")
254249
def test_cse_big():
255250
args, exprs, inp = _get_cse_exprs_big()
256-
lmb = se.LambdifyCSE(args, exprs)
251+
lmb = se.Lambdify(args, exprs, cse=True)
257252
out = lmb(inp)
258253
ref = [expr.xreplace(dict(zip(args, inp))) for expr in exprs]
259254
assert allclose(out, ref)
@@ -526,12 +521,6 @@ def test_Lambdify_heterogeneous_output():
526521
_Lambdify_heterogeneous_output(se.Lambdify)
527522

528523

529-
@unittest.skipUnless(have_numpy, "Numpy not installed")
530-
@unittest.skipUnless(have_sympy, "SymPy not installed")
531-
def test_LambdifyCSE_heterogeneous_output():
532-
_Lambdify_heterogeneous_output(se.LambdifyCSE)
533-
534-
535524
def _sympy_lambdify_heterogeneous_output(cb, Mtx):
536525
x, y = se.symbols('x, y')
537526
args = Mtx(2, 1, [x, y])
@@ -600,11 +589,10 @@ def test_Lambdify_scalar_vector_matrix():
600589
_test_Lambdify_scalar_vector_matrix(lambda *args: se.Lambdify(*args, backend='llvm'))
601590

602591

603-
@unittest.skipUnless(have_sympy, "SymPy not installed")
604592
def test_Lambdify_scalar_vector_matrix_cse():
605-
_test_Lambdify_scalar_vector_matrix(lambda *args: se.LambdifyCSE(*args, backend='lambda'))
593+
_test_Lambdify_scalar_vector_matrix(lambda *args: se.Lambdify(*args, backend='lambda', cse=True))
606594
if se.have_llvm:
607-
_test_Lambdify_scalar_vector_matrix(lambda *args: se.LambdifyCSE(*args, backend='llvm'))
595+
_test_Lambdify_scalar_vector_matrix(lambda *args: se.Lambdify(*args, backend='llvm', cse=True))
608596

609597

610598
@unittest.skipUnless(have_numpy, "Numpy not installed")

0 commit comments

Comments
 (0)