Skip to content

Commit b43ac95

Browse files
committed
expand(deep=False) wrapped
1 parent 18b94bc commit b43ac95

File tree

3 files changed

+10
-6
lines changed

3 files changed

+10
-6
lines changed

symengine/lib/symengine.pxd

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -310,7 +310,7 @@ cdef extern from "<symengine/basic.h>" namespace "SymEngine":
310310
bool is_a_Not "SymEngine::is_a<SymEngine::Not>"(const Basic &b) nogil
311311
bool is_a_Or "SymEngine::is_a<SymEngine::Or>"(const Basic &b) nogil
312312
bool is_a_Xor "SymEngine::is_a<SymEngine::Xor>"(const Basic &b) nogil
313-
RCP[const Basic] expand(RCP[const Basic] &o) nogil except +
313+
RCP[const Basic] expand(RCP[const Basic] &o, bool deep) nogil except +
314314
void as_numer_denom(RCP[const Basic] &x, const Ptr[RCP[Basic]] &numer, const Ptr[RCP[Basic]] &denom) nogil
315315

316316
cdef extern from "<symengine/subs.h>" namespace "SymEngine":

symengine/lib/symengine_wrapper.pyx

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -844,8 +844,8 @@ cdef class Basic(object):
844844
elif (op == 5):
845845
return c2py(<RCP[const symengine.Basic]>(symengine.Ge(A.thisptr, B.thisptr)))
846846

847-
def expand(Basic self not None):
848-
return c2py(symengine.expand(self.thisptr))
847+
def expand(Basic self not None, cppbool deep=True):
848+
return c2py(symengine.expand(self.thisptr, deep))
849849

850850
def diff(Basic self not None, x = None):
851851
if x is None:
@@ -3560,7 +3560,7 @@ cdef class DenseMatrixBase(MatrixBase):
35603560
return self._applyfunc(lambda x : x.simplify(*args, **kwargs))
35613561

35623562
def expand(self, *args, **kwargs):
3563-
return self.applyfunc(lambda x : x.expand())
3563+
return self.applyfunc(lambda x : x.expand(*args, **kwargs))
35643564

35653565

35663566
def div_matrices(a, b):
@@ -3726,8 +3726,8 @@ def diff(ex, *x):
37263726
ex = ex.diff(i)
37273727
return ex
37283728

3729-
def expand(x):
3730-
return sympify(x).expand()
3729+
def expand(x, deep=True):
3730+
return sympify(x).expand(deep)
37313731

37323732
expand_mul = expand
37333733

symengine/tests/test_arit.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,9 +113,13 @@ def test_arit9():
113113

114114

115115
def test_expand2():
116+
x = Symbol("x")
116117
y = Symbol("y")
117118
z = Symbol("z")
118119
assert ((1/(y*z) - y*z)*y*z).expand() == 1-(y*z)**2
120+
assert (2*(x + 2*(y + z))).expand(deep=False) == 2*x + 4*(y+z)
121+
ex = x + 2*(y + z)
122+
assert ex.expand(deep=False) == ex
119123

120124

121125
def test_expand3():

0 commit comments

Comments
 (0)