@@ -4369,7 +4369,7 @@ cdef class _Lambdify(object):
4369
4369
cdef vector[int ] accum_out_sizes
4370
4370
cdef object numpy_dtype
4371
4371
4372
- def __init__ (self , args , *exprs , cppbool real = True , order = ' C' ):
4372
+ def __init__ (self , args , *exprs , cppbool real = True , order = ' C' , cppbool cse = False ):
4373
4373
cdef:
4374
4374
Basic e_
4375
4375
size_t ri, ci, nr, nc
@@ -4409,9 +4409,9 @@ cdef class _Lambdify(object):
4409
4409
for e in np.ravel(curr_expr, order = self .order):
4410
4410
e_ = _sympify(e)
4411
4411
outs_.push_back(e_.thisptr)
4412
- self ._init(args_, outs_)
4412
+ self ._init(args_, outs_, cse )
4413
4413
4414
- cdef _init(self , symengine.vec_basic& args_, symengine.vec_basic& outs_):
4414
+ cdef _init(self , symengine.vec_basic& args_, symengine.vec_basic& outs_, cppbool cse ):
4415
4415
raise ValueError (" Not supported" )
4416
4416
4417
4417
cpdef unsafe_real(self ,
@@ -4590,13 +4590,13 @@ cdef class LambdaDouble(_Lambdify):
4590
4590
cdef vector[symengine.LambdaRealDoubleVisitor] lambda_double
4591
4591
cdef vector[symengine.LambdaComplexDoubleVisitor] lambda_double_complex
4592
4592
4593
- cdef _init(self , symengine.vec_basic& args_, symengine.vec_basic& outs_):
4593
+ cdef _init(self , symengine.vec_basic& args_, symengine.vec_basic& outs_, cppbool cse ):
4594
4594
if self .real:
4595
4595
self .lambda_double.resize(1 )
4596
- self .lambda_double[0 ].init(args_, outs_)
4596
+ self .lambda_double[0 ].init(args_, outs_, cse )
4597
4597
else :
4598
4598
self .lambda_double_complex.resize(1 )
4599
- self .lambda_double_complex[0 ].init(args_, outs_)
4599
+ self .lambda_double_complex[0 ].init(args_, outs_, cse )
4600
4600
4601
4601
cpdef unsafe_real(self , double [::1 ] inp, double [::1 ] out, int inp_offset = 0 , int out_offset = 0 ):
4602
4602
self .lambda_double[0 ].call(& out[out_offset], & inp[inp_offset])
@@ -4621,9 +4621,9 @@ IF HAVE_SYMENGINE_LLVM:
4621
4621
4622
4622
cdef vector[symengine.LLVMDoubleVisitor] lambda_double
4623
4623
4624
- cdef _init(self , symengine.vec_basic& args_, symengine.vec_basic& outs_):
4624
+ cdef _init(self , symengine.vec_basic& args_, symengine.vec_basic& outs_, cppbool cse ):
4625
4625
self .lambda_double.resize(1 )
4626
- self .lambda_double[0 ].init(args_, outs_)
4626
+ self .lambda_double[0 ].init(args_, outs_, cse )
4627
4627
4628
4628
cpdef unsafe_real(self , double [::1 ] inp, double [::1 ] out, int inp_offset = 0 , int out_offset = 0 ):
4629
4629
self .lambda_double[0 ].call(& out[out_offset], & inp[inp_offset])
@@ -4640,7 +4640,7 @@ IF HAVE_SYMENGINE_LLVM:
4640
4640
return create_low_level_callable(self , addr1, addr2)
4641
4641
4642
4642
4643
- def Lambdify (args , *exprs , cppbool real = True , backend = None , order = ' C' , as_scipy = False ):
4643
+ def Lambdify (args , *exprs , cppbool real = True , backend = None , order = ' C' , as_scipy = False , cse = False ):
4644
4644
"""
4645
4645
Lambdify instances are callbacks that numerically evaluate their symbolic
4646
4646
expressions from user provided input (real or complex) into (possibly user
@@ -4666,6 +4666,9 @@ def Lambdify(args, *exprs, cppbool real=True, backend=None, order='C', as_scipy=
4666
4666
as_scipy : bool
4667
4667
return a SciPy LowLevelCallable which can be used in SciPy's integrate
4668
4668
methods
4669
+ cse : bool
4670
+ Run Common Subexpression Elimination on the output before generating
4671
+ the callback.
4669
4672
4670
4673
Returns
4671
4674
-------
@@ -4687,7 +4690,7 @@ def Lambdify(args, *exprs, cppbool real=True, backend=None, order='C', as_scipy=
4687
4690
backend = os.getenv(' SYMENGINE_LAMBDIFY_BACKEND' , " lambda" )
4688
4691
if backend == " llvm" :
4689
4692
IF HAVE_SYMENGINE_LLVM:
4690
- ret = LLVMDouble(args, * exprs, real = real, order = order)
4693
+ ret = LLVMDouble(args, * exprs, real = real, order = order, cse = cse )
4691
4694
if as_scipy:
4692
4695
return ret.as_scipy_low_level_callable()
4693
4696
return ret
@@ -4698,63 +4701,17 @@ def Lambdify(args, *exprs, cppbool real=True, backend=None, order='C', as_scipy=
4698
4701
pass
4699
4702
else :
4700
4703
warnings.warn(" Unknown SymEngine backend: %s \n Using backend='lambda'" % backend)
4701
- ret = LambdaDouble(args, * exprs, real = real, order = order)
4704
+ ret = LambdaDouble(args, * exprs, real = real, order = order, cse = cse )
4702
4705
if as_scipy:
4703
4706
return ret.as_scipy_low_level_callable()
4704
4707
return ret
4705
4708
4706
4709
4707
- def LambdifyCSE (args , *exprs , cse = None , order = ' C' , **kwargs ):
4710
+ def LambdifyCSE (args , *exprs , order = ' C' , **kwargs ):
4708
4711
""" Analogous with Lambdify but performs common subexpression elimination.
4709
-
4710
- See docstring of Lambdify.
4711
-
4712
- Parameters
4713
- ----------
4714
- args: iterable of symbols
4715
- exprs: iterable of expressions (with symbols from args)
4716
- cse: callback (default: None)
4717
- defaults to sympy.cse (see SymPy documentation)
4718
- order : str
4719
- order (passed to numpy.ravel and numpy.reshape).
4720
- \\ *\\ *kwargs: Keyword arguments passed onto Lambdify
4721
-
4722
4712
"""
4723
- if cse is None :
4724
- from sympy import cse
4725
- _exprs = [np.asanyarray(e) for e in exprs]
4726
- _args = np.ravel(args, order = order)
4727
- from sympy import sympify as s_sympify
4728
- flat_exprs = list (itertools.chain(* [np.ravel(e, order = order) for e in _exprs]))
4729
- subs, flat_new_exprs = cse([s_sympify(expr) for expr in flat_exprs])
4730
-
4731
- if subs:
4732
- explicit_subs = {}
4733
- for k, v in subs:
4734
- explicit_subs[k] = v.xreplace(explicit_subs)
4735
-
4736
- cse_symbs, cse_exprs = zip (* subs)
4737
- new_exprs = []
4738
- n_taken = 0
4739
- for expr in _exprs:
4740
- new_exprs.append(np.reshape(flat_new_exprs[n_taken:n_taken+ expr.size],
4741
- expr.shape, order = order))
4742
- n_taken += expr.size
4743
- new_lmb = Lambdify(tuple (_args) + cse_symbs, * new_exprs, order = order, ** kwargs)
4744
- cse_lambda = Lambdify(_args, [ce.xreplace(explicit_subs) for ce in cse_exprs], ** kwargs)
4745
- def cb (inp , *, out = None , **kw ):
4746
- _inp = np.asanyarray(inp)
4747
- cse_vals = cse_lambda(_inp, ** kw)
4748
- if order == ' C' :
4749
- new_inp = np.concatenate((_inp[(Ellipsis ,) + (np.newaxis,)* (cse_vals.ndim - _inp.ndim)],
4750
- cse_vals), axis = - 1 )
4751
- else :
4752
- new_inp = np.concatenate((_inp[(np.newaxis,)* (cse_vals.ndim - _inp.ndim) + (Ellipsis ,)],
4753
- cse_vals), axis = 0 )
4754
- return new_lmb(new_inp, out = out, ** kw)
4755
- return cb
4756
- else :
4757
- return Lambdify(args, * exprs, ** kwargs)
4713
+ warnings.warn(" LambdifyCSE is deprecated. Use Lambdify(..., cse=True)" , DeprecationWarning )
4714
+ return Lambdify(args, * exprs, cse = True , order = order, ** kwargs)
4758
4715
4759
4716
4760
4717
def ccode (expr ):
0 commit comments