diff --git a/changes/954.change.rst b/changes/954.change.rst new file mode 100644 index 000000000..519ec3f48 --- /dev/null +++ b/changes/954.change.rst @@ -0,0 +1,5 @@ +Improved Atoms.index usability + +Now one gets an empty list if no atoms are found. +One can also get the combined indices of several +atom types, e.g., `atoms.index(["C", "H"])`. diff --git a/src/sisl/_core/atom.py b/src/sisl/_core/atom.py index 9476a6761..1a3ee7e6c 100644 --- a/src/sisl/_core/atom.py +++ b/src/sisl/_core/atom.py @@ -91,19 +91,24 @@ class Atom( >>> Carbon = Atom("Carbon") Add a tag to be able to distinguish it from other atoms + >>> tagged_Carbon = Atom("Carbon", tag="siteA") Create deuterium + >>> D = Atom("H", mass=2.014) Define an atom with 3 orbitals, each with a range of 2 Angstroem + >>> C3 = Atom("C", orbitals=[2, 2, 2]) Define an atom outside of the periodic table (negative will yield an AtomGhost object) + >>> ghost_C = Atom(-6) Define an unknown atom (basically anything can do) + >>> unknown_atom = Atom(1000) Notes @@ -758,9 +763,55 @@ def Z(self): uZ = _a.arrayi([a.Z for a in self.atom]) return uZ[self.species] - def index(self, atom): - """Return the indices of the atom object""" - return (self._species == self.species_index(atom)).nonzero()[0] + def index(self, atoms): + """Return indices of atoms matching the specified atom identifier(s). + + Parameters + ---------- + atoms : str, Atom, int, or lists hereof + Atom identifier (element string, Atom instance, or species_index). + It can also be a list of identifiers for matching against any of the entries. + + Returns + ------- + numpy.ndarray + Unique (sorted) indices of all matching atoms. + Returns an empty array if no match is found. + + Examples + -------- + >>> atoms = Atoms(["C", "H", "Au"]) + >>> idx_C = atoms.index("C") + >>> idx_C = atoms.index(0) + >>> idx_CH = atoms.index(["C", "H"]) + >>> idx_Au = atoms.index(Atom(79)) + + This can be useful to get a subset of a geometry, e.g., + + >>> geom = Geometry(...) + >>> idx_CH = geom.atoms.index(["C", "H"]) + >>> geom_CH = geom.sub(idx_CH) + + Notes + ----- + Unlike Python's ``list.index``, this method does **not** raise a ``ValueError`` + when no match is found. Instead, it returns an empty array. + To check for existence, use e.g. ``if idx.size > 0: ...``. + """ + + if not isinstance(atoms, (list, tuple, np.ndarray)): + atoms = [atoms] + + idx = np.array([], dtype=int) + for atom in atoms: + try: + arr = (self._species == self.species_index(atom)).nonzero()[0] + idx = np.concatenate([idx, arr]) + except KeyError: + # no species found + pass + + return np.unique(idx) def species_index(self, atom): """Return the species index of the atom object""" diff --git a/src/sisl/_core/tests/test_atoms.py b/src/sisl/_core/tests/test_atoms.py index d112bc06f..ce718df3f 100644 --- a/src/sisl/_core/tests/test_atoms.py +++ b/src/sisl/_core/tests/test_atoms.py @@ -297,5 +297,13 @@ def test_charge_diff(): def test_index1(): atom = Atoms(["C", "Au"]) - with pytest.raises(KeyError): - atom.index(Atom("B")) + assert len(atom.index("B")) == 0 + assert atom.index("C") == [0] + assert atom.index(Atom(79)) == [1] + assert atom.index(1) == [1] + assert atom.index(["C", "C"]) == [0] + assert atom.index(["B", "C", "C"]) == [0] + idx = atom.index(["Au", "C"]) + assert idx.ndim == 1 + assert np.issubdtype(idx.dtype, np.integer) + assert np.allclose(idx, [0, 1])