Skip to content

Commit 78800c4

Browse files
authored
Merge pull request #125 from ShikharJ/Issue121
Minor Improvement in atoms() and Tests
2 parents 83912b7 + c589f9d commit 78800c4

File tree

2 files changed

+27
-12
lines changed

2 files changed

+27
-12
lines changed

symengine/lib/symengine_wrapper.pyx

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -585,12 +585,15 @@ cdef class Basic(object):
585585
return ring(self._sage_())
586586

587587
def atoms(self, *types):
588-
s = set()
589-
if (isinstance(self, types)):
590-
s.add(self)
591-
for arg in self.args:
592-
s.update(arg.atoms(*types))
593-
return s
588+
if types:
589+
s = set()
590+
if (isinstance(self, types)):
591+
s.add(self)
592+
for arg in self.args:
593+
s.update(arg.atoms(*types))
594+
return s
595+
else:
596+
return self.free_symbols
594597

595598
def simplify(self, *args, **kwargs):
596599
return sympify(self._sympy_().simplify(*args, **kwargs))
@@ -1969,12 +1972,15 @@ cdef class DenseMatrix(MatrixBase):
19691972
return self[:]
19701973

19711974
def atoms(self, *types):
1972-
s = set()
1973-
if (isinstance(self, types)):
1974-
s.add(self)
1975-
for arg in self.tolist():
1976-
s.update(arg.atoms(*types))
1977-
return s
1975+
if types:
1976+
s = set()
1977+
if (isinstance(self, types)):
1978+
s.add(self)
1979+
for arg in self.tolist():
1980+
s.update(arg.atoms(*types))
1981+
return s
1982+
else:
1983+
return self.free_symbols
19781984

19791985
def simplify(self, *args, **kwargs):
19801986
return self._applyfunc(lambda x : x.simplify(*args, **kwargs))

symengine/tests/test_arit.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,15 @@ def test_args():
119119
assert (2*x**2).args == (2, x**2)
120120
assert set((2*x**2*y).args) == set((Integer(2), x**2, y))
121121

122+
def test_atoms():
123+
x = Symbol("x")
124+
y = Symbol("y")
125+
z = Symbol("z")
126+
assert (x**2).atoms() == set([x])
127+
assert (x**2).atoms(Symbol) == set([x])
128+
assert (x ** y + z).atoms() == set([x, y, z])
129+
assert (x**y + z).atoms(Symbol) == set([x, y, z])
130+
122131
def test_free_symbols():
123132
x = Symbol("x")
124133
y = Symbol("y")

0 commit comments

Comments
 (0)