@@ -2117,6 +2117,9 @@ cdef class DenseMatrixBase(MatrixBase):
2117
2117
def __get__ (self ):
2118
2118
return self .nrows()* self .ncols()
2119
2119
2120
+ def ravel (self ):
2121
+ return [self .get(i, j) for i in range (self .nrows()) for j in range (self .ncols())]
2122
+
2120
2123
def reshape (self , rows , cols ):
2121
2124
if len (self ) != rows* cols:
2122
2125
raise ValueError (" Invalid reshape parameters %d %d " % (rows, cols))
@@ -2132,9 +2135,9 @@ cdef class DenseMatrixBase(MatrixBase):
2132
2135
if j < 0 :
2133
2136
j += nc
2134
2137
if i < 0 or i >= nr:
2135
- raise IndexError
2138
+ raise IndexError ( " Row index out of bounds: %d " % i)
2136
2139
if j < 0 or j >= nc:
2137
- raise IndexError
2140
+ raise IndexError ( " Column index out of bounds: %d " % j)
2138
2141
return i, j
2139
2142
2140
2143
def get (self , i , j ):
@@ -3140,19 +3143,20 @@ cdef class _Lambdify(object):
3140
3143
e_ = _sympify(e)
3141
3144
args_.push_back(e_.thisptr)
3142
3145
3143
- if isinstance (exprs, DenseMatrixBase):
3144
- nr = exprs.nrows()
3145
- nc = exprs.ncols()
3146
- mtx = (< DenseMatrixBase> exprs).thisptr
3147
- for ri in range (nr):
3148
- for ci in range (nc):
3149
- b_ = deref(mtx).get(ri, ci)
3150
- outs_.push_back(b_)
3151
- else :
3152
- for e in ravel(exprs):
3153
- e_ = _sympify(e)
3154
- outs_.push_back(e_.thisptr)
3155
3146
3147
+ for curr_expr in exprs:
3148
+ if isinstance (curr_expr, DenseMatrixBase):
3149
+ nr = curr_expr.nrows()
3150
+ nc = curr_expr.ncols()
3151
+ mtx = (< DenseMatrixBase> curr_expr).thisptr
3152
+ for ri in range (nr):
3153
+ for ci in range (nc):
3154
+ b_ = deref(mtx).get(ri, ci)
3155
+ outs_.push_back(b_)
3156
+ else :
3157
+ for e in ravel(curr_expr):
3158
+ e_ = _sympify(e)
3159
+ outs_.push_back(e_.thisptr)
3156
3160
self ._init(args_, outs_)
3157
3161
3158
3162
cdef _init(self , symengine.vec_basic& args_, symengine.vec_basic& outs_):
@@ -3391,7 +3395,8 @@ def LambdifyCSE(args, *exprs, real=True, cse=None, concatenate=None):
3391
3395
if concatenate is None :
3392
3396
from numpy import concatenate
3393
3397
from sympy import sympify as s_sympify
3394
- subs, new_exprs = cse([s_sympify(expr) for expr in exprs])
3398
+ flat_exprs = list (itertools.chain(* map (ravel, exprs)))
3399
+ subs, flat_new_exprs = cse([s_sympify(expr) for expr in flat_exprs])
3395
3400
if subs:
3396
3401
cse_symbs, cse_exprs = zip (* subs)
3397
3402
new_exprs = []
0 commit comments