Skip to content

Commit 23ce74d

Browse files
committed
improve subs and msubs
1 parent 1f6857c commit 23ce74d

File tree

1 file changed

+37
-39
lines changed

1 file changed

+37
-39
lines changed

symengine/lib/symengine_wrapper.pyx

Lines changed: 37 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -365,6 +365,21 @@ class DictBasic(_DictBasic, collections.MutableMapping):
365365
def __repr__(self):
366366
return self.__str__()
367367

368+
def get_dict(*args):
369+
if len(args) == 2:
370+
arg = {args[0]: args[1]}
371+
elif len(args) == 1:
372+
arg = args[0]
373+
else:
374+
raise TypeError("subs/msubs takes one or two arguments (%d given)" % \
375+
len(args))
376+
if isinstance(arg, DictBasic):
377+
return arg
378+
cdef _DictBasic D = DictBasic()
379+
for k, v in arg.items():
380+
D.add(k, v)
381+
return D
382+
368383

369384
cdef class Basic(object):
370385

@@ -467,49 +482,24 @@ cdef class Basic(object):
467482
cdef Basic s = sympify(x)
468483
return c2py(symengine.diff(self.thisptr, s.thisptr))
469484

485+
#TODO: deprecate this
470486
def subs_dict(Basic self not None, subs_dict):
471-
cdef _DictBasic D
472-
if isinstance(subs_dict, DictBasic):
473-
D = subs_dict
474-
return c2py(symengine.ssubs(self.thisptr, D.c))
475-
cdef symengine.map_basic_basic d
476-
cdef Basic K, V
477-
for k in subs_dict:
478-
K = sympify(k)
479-
V = sympify(subs_dict[k])
480-
d[K.thisptr] = V.thisptr
481-
return c2py(symengine.ssubs(self.thisptr, d))
487+
cdef _DictBasic D = get_dict(*args)
488+
return c2py(symengine.msubs(self.thisptr, D.c))
482489

490+
#TODO: deprecate this
483491
def subs_oldnew(Basic self not None, old, new):
484492
return self.subs_dict({old: new})
485493

486494
def subs(Basic self not None, *args):
487-
if len(args) == 1:
488-
return self.subs_dict(args[0])
489-
elif len(args) == 2:
490-
return self.subs_oldnew(args[0], args[1])
491-
else:
492-
raise TypeError("subs() takes one or two arguments (%d given)" % \
493-
len(args))
495+
cdef _DictBasic D = get_dict(*args)
496+
return c2py(symengine.msubs(self.thisptr, D.c))
494497

495498
xreplace = subs
496499

497500
def msubs(Basic self not None, *args):
498-
if len(args) == 2:
499-
arg = {args[0]: args[1]}
500-
else:
501-
arg = args[0]
502-
cdef _DictBasic D
503-
if isinstance(arg, DictBasic):
504-
D = arg
505-
return c2py(symengine.msubs(self.thisptr, D.c))
506-
cdef symengine.map_basic_basic d
507-
cdef Basic K, V
508-
for k in arg:
509-
K = sympify(k)
510-
V = sympify(arg[k])
511-
d[K.thisptr] = V.thisptr
512-
return c2py(symengine.msubs(self.thisptr, d))
501+
cdef _DictBasic D = get_dict(*args)
502+
return c2py(symengine.msubs(self.thisptr, D.c))
513503

514504
def n(self, prec = 53, real = False):
515505
if real:
@@ -1721,15 +1711,21 @@ cdef class DenseMatrix(MatrixBase):
17211711
return self.transpose()
17221712

17231713
def _applyfunc(self, f):
1724-
for i in range(self.nrows()):
1725-
for j in range(self.ncols()):
1726-
e_ = self._set(i, j, f(self._get(i, j)))
1714+
cdef int nr = self.nrows()
1715+
cdef int nc = self.ncols()
1716+
for i in range(nr):
1717+
for j in range(nc):
1718+
self._set(i, j, f(self._get(i, j)))
17271719

17281720
def applyfunc(self, f):
1729-
out = DenseMatrix(self)
1721+
cdef DenseMatrix out = DenseMatrix(self)
17301722
out._applyfunc(f)
17311723
return out
17321724

1725+
def msubs(self, *args):
1726+
cdef _DictBasic D = get_dict(*args)
1727+
return self.applyfunc(lambda x: x.msubs(D))
1728+
17331729
def diff(self, x):
17341730
cdef Basic x_ = sympify(x)
17351731
R = DenseMatrix(self.rows, self.cols)
@@ -1738,8 +1734,10 @@ cdef class DenseMatrix(MatrixBase):
17381734
return R
17391735

17401736
#TODO: implement this in C++
1741-
def subs(self, subs_dict):
1742-
return self.applyfunc(lambda y: y.subs(subs_dict))
1737+
def subs(self, *args):
1738+
cdef _DictBasic D = get_dict(*args)
1739+
return self.applyfunc(lambda x: x.subs(D))
1740+
17431741

17441742
@property
17451743
def free_symbols(self):

0 commit comments

Comments
 (0)