Skip to content

Commit 1df763b

Browse files
authored
Merge pull request #100 from isuruf/pydy
Improve subs and msubs
2 parents 1f6857c + deeb6a9 commit 1df763b

File tree

2 files changed

+46
-43
lines changed

2 files changed

+46
-43
lines changed

benchmarks/pydy_pendulum.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import sympy
77
import symengine
88
import pydy
9-
from pydy.models import n_link_pendulum_on_cart
9+
from sympy.physics.mechanics.models import n_link_pendulum_on_cart
1010

1111
print(sympy.__file__)
1212
print(symengine.__file__)

symengine/lib/symengine_wrapper.pyx

Lines changed: 45 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ import itertools
1212
from operator import mul
1313
from functools import reduce
1414
import collections
15-
15+
import warnings
1616

1717
include "config.pxi"
1818

@@ -365,6 +365,27 @@ class DictBasic(_DictBasic, collections.MutableMapping):
365365
def __repr__(self):
366366
return self.__str__()
367367

368+
def get_dict(*args):
369+
"""
370+
Returns a DictBasic instance from args. Inputs can be,
371+
1. a DictBasic
372+
2. a Python dictionary
373+
3. two args old, new
374+
"""
375+
if len(args) == 2:
376+
arg = {args[0]: args[1]}
377+
elif len(args) == 1:
378+
arg = args[0]
379+
else:
380+
raise TypeError("subs/msubs takes one or two arguments (%d given)" % \
381+
len(args))
382+
if isinstance(arg, DictBasic):
383+
return arg
384+
cdef _DictBasic D = DictBasic()
385+
for k, v in arg.items():
386+
D.add(k, v)
387+
return D
388+
368389

369390
cdef class Basic(object):
370391

@@ -467,49 +488,23 @@ cdef class Basic(object):
467488
cdef Basic s = sympify(x)
468489
return c2py(symengine.diff(self.thisptr, s.thisptr))
469490

470-
def subs_dict(Basic self not None, subs_dict):
471-
cdef _DictBasic D
472-
if isinstance(subs_dict, DictBasic):
473-
D = subs_dict
474-
return c2py(symengine.ssubs(self.thisptr, D.c))
475-
cdef symengine.map_basic_basic d
476-
cdef Basic K, V
477-
for k in subs_dict:
478-
K = sympify(k)
479-
V = sympify(subs_dict[k])
480-
d[K.thisptr] = V.thisptr
481-
return c2py(symengine.ssubs(self.thisptr, d))
491+
def subs_dict(Basic self not None, *args):
492+
warnings.warn("subs_dict() is deprecated. Use subs() instead", DeprecationWarning)
493+
return self.subs(*args)
482494

483495
def subs_oldnew(Basic self not None, old, new):
484-
return self.subs_dict({old: new})
496+
warnings.warn("subs_oldnew() is deprecated. Use subs() instead", DeprecationWarning)
497+
return self.subs({old: new})
485498

486499
def subs(Basic self not None, *args):
487-
if len(args) == 1:
488-
return self.subs_dict(args[0])
489-
elif len(args) == 2:
490-
return self.subs_oldnew(args[0], args[1])
491-
else:
492-
raise TypeError("subs() takes one or two arguments (%d given)" % \
493-
len(args))
500+
cdef _DictBasic D = get_dict(*args)
501+
return c2py(symengine.ssubs(self.thisptr, D.c))
494502

495503
xreplace = subs
496504

497505
def msubs(Basic self not None, *args):
498-
if len(args) == 2:
499-
arg = {args[0]: args[1]}
500-
else:
501-
arg = args[0]
502-
cdef _DictBasic D
503-
if isinstance(arg, DictBasic):
504-
D = arg
505-
return c2py(symengine.msubs(self.thisptr, D.c))
506-
cdef symengine.map_basic_basic d
507-
cdef Basic K, V
508-
for k in arg:
509-
K = sympify(k)
510-
V = sympify(arg[k])
511-
d[K.thisptr] = V.thisptr
512-
return c2py(symengine.msubs(self.thisptr, d))
506+
cdef _DictBasic D = get_dict(*args)
507+
return c2py(symengine.msubs(self.thisptr, D.c))
513508

514509
def n(self, prec = 53, real = False):
515510
if real:
@@ -1721,15 +1716,21 @@ cdef class DenseMatrix(MatrixBase):
17211716
return self.transpose()
17221717

17231718
def _applyfunc(self, f):
1724-
for i in range(self.nrows()):
1725-
for j in range(self.ncols()):
1726-
e_ = self._set(i, j, f(self._get(i, j)))
1719+
cdef int nr = self.nrows()
1720+
cdef int nc = self.ncols()
1721+
for i in range(nr):
1722+
for j in range(nc):
1723+
self._set(i, j, f(self._get(i, j)))
17271724

17281725
def applyfunc(self, f):
1729-
out = DenseMatrix(self)
1726+
cdef DenseMatrix out = DenseMatrix(self)
17301727
out._applyfunc(f)
17311728
return out
17321729

1730+
def msubs(self, *args):
1731+
cdef _DictBasic D = get_dict(*args)
1732+
return self.applyfunc(lambda x: x.msubs(D))
1733+
17331734
def diff(self, x):
17341735
cdef Basic x_ = sympify(x)
17351736
R = DenseMatrix(self.rows, self.cols)
@@ -1738,8 +1739,10 @@ cdef class DenseMatrix(MatrixBase):
17381739
return R
17391740

17401741
#TODO: implement this in C++
1741-
def subs(self, subs_dict):
1742-
return self.applyfunc(lambda y: y.subs(subs_dict))
1742+
def subs(self, *args):
1743+
cdef _DictBasic D = get_dict(*args)
1744+
return self.applyfunc(lambda x: x.subs(D))
1745+
17431746

17441747
@property
17451748
def free_symbols(self):

0 commit comments

Comments
 (0)