Skip to content

Commit 3b1fae7

Browse files
committed
Fix Lambdify for heterogeneous input
1 parent 723b0d2 commit 3b1fae7

File tree

2 files changed

+33
-15
lines changed

2 files changed

+33
-15
lines changed

symengine/lib/symengine_wrapper.pyx

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2117,6 +2117,9 @@ cdef class DenseMatrixBase(MatrixBase):
21172117
def __get__(self):
21182118
return self.nrows()*self.ncols()
21192119

2120+
def ravel(self):
2121+
return [self.get(i, j) for i in range(self.nrows()) for j in range(self.ncols())]
2122+
21202123
def reshape(self, rows, cols):
21212124
if len(self) != rows*cols:
21222125
raise ValueError("Invalid reshape parameters %d %d" % (rows, cols))
@@ -2132,9 +2135,9 @@ cdef class DenseMatrixBase(MatrixBase):
21322135
if j < 0:
21332136
j += nc
21342137
if i < 0 or i >= nr:
2135-
raise IndexError
2138+
raise IndexError("Row index out of bounds: %d" % i)
21362139
if j < 0 or j >= nc:
2137-
raise IndexError
2140+
raise IndexError("Column index out of bounds: %d" % j)
21382141
return i, j
21392142

21402143
def get(self, i, j):
@@ -3140,19 +3143,20 @@ cdef class _Lambdify(object):
31403143
e_ = _sympify(e)
31413144
args_.push_back(e_.thisptr)
31423145

3143-
if isinstance(exprs, DenseMatrixBase):
3144-
nr = exprs.nrows()
3145-
nc = exprs.ncols()
3146-
mtx = (<DenseMatrixBase>exprs).thisptr
3147-
for ri in range(nr):
3148-
for ci in range(nc):
3149-
b_ = deref(mtx).get(ri, ci)
3150-
outs_.push_back(b_)
3151-
else:
3152-
for e in ravel(exprs):
3153-
e_ = _sympify(e)
3154-
outs_.push_back(e_.thisptr)
31553146

3147+
for curr_expr in exprs:
3148+
if isinstance(curr_expr, DenseMatrixBase):
3149+
nr = curr_expr.nrows()
3150+
nc = curr_expr.ncols()
3151+
mtx = (<DenseMatrixBase>curr_expr).thisptr
3152+
for ri in range(nr):
3153+
for ci in range(nc):
3154+
b_ = deref(mtx).get(ri, ci)
3155+
outs_.push_back(b_)
3156+
else:
3157+
for e in ravel(curr_expr):
3158+
e_ = _sympify(e)
3159+
outs_.push_back(e_.thisptr)
31563160
self._init(args_, outs_)
31573161

31583162
cdef _init(self, symengine.vec_basic& args_, symengine.vec_basic& outs_):
@@ -3391,7 +3395,8 @@ def LambdifyCSE(args, *exprs, real=True, cse=None, concatenate=None):
33913395
if concatenate is None:
33923396
from numpy import concatenate
33933397
from sympy import sympify as s_sympify
3394-
subs, new_exprs = cse([s_sympify(expr) for expr in exprs])
3398+
flat_exprs = list(itertools.chain(*map(ravel, exprs)))
3399+
subs, flat_new_exprs = cse([s_sympify(expr) for expr in flat_exprs])
33953400
if subs:
33963401
cse_symbs, cse_exprs = zip(*subs)
33973402
new_exprs = []

symengine/tests/test_lambdify.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,19 @@ def test_get_shape():
5050
assert get_shape([[1], [1], [1]]) == (3, 1)
5151
assert get_shape([[1, 1, 1]]) == (1, 3)
5252

53+
x = se.symbols('x')
54+
exprs = [x+1, x+2, x+3, 1/x, 1/(x*x), 1/(x**3.0)]
55+
A = se.DenseMatrix(2, 3, exprs)
56+
assert get_shape(A) == (2, 3)
57+
58+
59+
def test_ravel():
60+
x = se.symbols('x')
61+
ravel = se.lib.symengine_wrapper.ravel
62+
exprs = [x+1, x+2, x+3, 1/x, 1/(x*x), 1/(x**3.0)]
63+
A = se.DenseMatrix(2, 3, exprs)
64+
assert ravel(A) == exprs
65+
5366

5467
def test_Lambdify():
5568
n = 7

0 commit comments

Comments
 (0)