Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions changes/954.change.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
Improved Atoms.index usability

Now one gets an empty list if no atoms are found.
One can also get the combined indices of several
atom types, e.g., `atoms.index("C", "H")`.
37 changes: 34 additions & 3 deletions src/sisl/_core/atom.py
Original file line number Diff line number Diff line change
Expand Up @@ -758,9 +758,40 @@ def Z(self):
uZ = _a.arrayi([a.Z for a in self.atom])
return uZ[self.species]

def index(self, atom):
"""Return the indices of the atom object"""
return (self._species == self.species_index(atom)).nonzero()[0]
def index(self, *atoms):
"""Return indices of atoms of the given atom type.

Parameters
----------
atoms : str, Atom, or lists hereof
One or more atom descriptors

Returns
-------
numpy.ndarray
Indices of all matching atoms. Returns an empty array if no match is found.

Notes
-----
Unlike Python's ``list.index``, this method does **not** raise a ``ValueError``
when no match is found. Instead, it returns an empty array.
To check for existence, use e.g. ``if idx.size > 0: ...``.
"""

if len(atoms) == 1 and isinstance(atoms[0], (list, tuple, np.ndarray)):
atoms = tuple(atoms[0])

idx = []
for a in atoms:
if not isinstance(a, Atom):
a = Atom(a)
mask = self.Z == a.Z # compare Z values
if mask.any():
idx.append(np.nonzero(mask)[0])
if len(idx) > 0:
idx = np.unique(idx)

return idx

def species_index(self, atom):
"""Return the species index of the atom object"""
Expand Down
9 changes: 7 additions & 2 deletions src/sisl/_core/tests/test_atoms.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,5 +297,10 @@ def test_charge_diff():

def test_index1():
atom = Atoms(["C", "Au"])
with pytest.raises(KeyError):
atom.index(Atom("B"))
assert atom.index("B") == []
assert atom.index("C") == [0]
assert atom.index("C", Atom(6)) == [0]
assert atom.index(["C", "C"]) == [0]
assert atom.index(["B", "C", "C"]) == [0]
assert atom.index(Atom(79)) == [1]
assert np.allclose(atom.index("Au", "C"), [0, 1])
Loading