Skip to content

Commit ef54593

Browse files
authored
Merge pull request #215 from isuruf/lambdify
Simplify lambdify
2 parents 359e933 + fd2cf72 commit ef54593

File tree

2 files changed

+10
-18
lines changed

2 files changed

+10
-18
lines changed

symengine/__init__.py

Lines changed: 4 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -22,21 +22,10 @@
2222
from .lib.symengine_wrapper import ComplexMPC
2323

2424
if have_numpy:
25-
from .lib.symengine_wrapper import Lambdify, LambdifyCSE
26-
27-
def lambdify(args, exprs, real=True, backend=None, as_scipy=False):
28-
try:
29-
len(args)
30-
except TypeError:
31-
args = [args]
32-
if as_scipy:
33-
return Lambdify(args, *exprs, real=real, backend=backend, as_scipy=True)
34-
lmb = Lambdify(args, *exprs, real=real, backend=backend)
35-
def f(*inner_args):
36-
if len(inner_args) != len(args):
37-
raise TypeError("Incorrect number of arguments")
38-
return lmb(inner_args)
39-
return f
25+
from .lib.symengine_wrapper import (Lambdify, LambdifyCSE)
26+
27+
def lambdify(args, exprs, **kwargs):
28+
return Lambdify(args, *exprs, **kwargs)
4029

4130

4231
__version__ = "0.3.0"

symengine/lib/symengine_wrapper.pyx

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4447,7 +4447,7 @@ cdef class _Lambdify(object):
44474447
raise ValueError("Size of out incompatible with number of exprs.")
44484448
self.unsafe_complex(inp, out)
44494449

4450-
def __call__(self, inp, *, out=None):
4450+
def __call__(self, *args, out=None):
44514451
"""
44524452
Parameters
44534453
----------
@@ -4473,10 +4473,13 @@ cdef class _Lambdify(object):
44734473
if self.order not in ('C', 'F'):
44744474
raise NotImplementedError("Only C & F order supported for now.")
44754475

4476+
if len(args) == 1:
4477+
args = args[0]
4478+
44764479
try:
4477-
inp = np.asanyarray(inp, dtype=self.numpy_dtype)
4480+
inp = np.asanyarray(args, dtype=self.numpy_dtype)
44784481
except TypeError:
4479-
inp = np.fromiter(inp, dtype=self.numpy_dtype)
4482+
inp = np.fromiter(args, dtype=self.numpy_dtype)
44804483

44814484
if self.real:
44824485
real_inp = np.ascontiguousarray(inp.ravel(order=self.order))

0 commit comments

Comments
 (0)