Skip to content

Commit e949b3f

Browse files
committed
Simplify Lambdify's __init__
1 parent 473679e commit e949b3f

File tree

1 file changed

+18
-25
lines changed

1 file changed

+18
-25
lines changed

symengine/lib/symengine_wrapper.pyx

Lines changed: 18 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -3264,8 +3264,15 @@ cdef class DenseMatrixBase(MatrixBase):
32643264
def __get__(self):
32653265
return self.nrows()*self.ncols()
32663266

3267-
def ravel(self):
3268-
return [self._get(i, j) for i in range(self.nrows()) for j in range(self.ncols())]
3267+
def ravel(self, order='C'):
3268+
if order == 'C':
3269+
return [self._get(i, j) for i in range(self.nrows())
3270+
for j in range(self.ncols())]
3271+
elif order == 'F':
3272+
return [self._get(i, j) for j in range(self.ncols())
3273+
for i in range(self.nrows())]
3274+
else:
3275+
raise NotImplementedError("Unknown order '%s'" % order)
32693276

32703277
def reshape(self, rows, cols):
32713278
if len(self) != rows*cols:
@@ -4389,28 +4396,14 @@ cdef class _Lambdify(object):
43894396
RCP[const symengine.Basic] b_
43904397
symengine.vec_basic args_, outs_
43914398

4392-
if isinstance(args, DenseMatrixBase):
4393-
nr = args.nrows()
4394-
nc = args.ncols()
4395-
mtx = (<DenseMatrixBase>args).thisptr
4396-
for ri in range(nr):
4397-
for ci in range(nc):
4398-
args_.push_back(deref(mtx).get(ri, ci))
4399-
else:
4400-
for arg in np.ravel(args, order=order):
4401-
e_ = _sympify(arg)
4402-
args_.push_back(e_.thisptr)
4403-
4404-
4405-
for curr_expr in exprs:
4406-
if isinstance(curr_expr, DenseMatrixBase):
4407-
nr = curr_expr.nrows()
4408-
nc = curr_expr.ncols()
4409-
mtx = (<DenseMatrixBase>curr_expr).thisptr
4410-
for ri in range(nr):
4411-
for ci in range(nc):
4412-
b_ = deref(mtx).get(ri, ci)
4413-
outs_.push_back(b_)
4399+
for arg in np.ravel(args, order=order):
4400+
e_ = _sympify(arg)
4401+
args_.push_back(e_.thisptr)
4402+
4403+
for curr_expr in map(np.array, exprs):
4404+
if curr_expr.ndim == 0:
4405+
e_ = _sympify(curr_expr.item())
4406+
outs_.push_back(e_.thisptr)
44144407
else:
44154408
for e in np.ravel(curr_expr, order=order):
44164409
e_ = _sympify(e)
@@ -4509,7 +4502,7 @@ cdef class _Lambdify(object):
45094502
extra_dim = inp.shape[1:]
45104503
else:
45114504
extra_dim = inp.shape
4512-
else:
4505+
else:
45134506
if nbroadcast > 1 and inp.ndim == 1:
45144507
extra_dim = (nbroadcast,) # special case
45154508
else:

0 commit comments

Comments
 (0)