Skip to content

Commit be27d1e

Browse files
authored
954 modified Atoms.index (#955)
* Improved Atoms.index usability * update changes rst
1 parent c352954 commit be27d1e

File tree

3 files changed

+69
-5
lines changed

3 files changed

+69
-5
lines changed

changes/954.change.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
Improved Atoms.index usability
2+
3+
Now one gets an empty list if no atoms are found.
4+
One can also get the combined indices of several
5+
atom types, e.g., `atoms.index(["C", "H"])`.

src/sisl/_core/atom.py

Lines changed: 54 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -91,19 +91,24 @@ class Atom(
9191
>>> Carbon = Atom("Carbon")
9292
9393
Add a tag to be able to distinguish it from other atoms
94+
9495
>>> tagged_Carbon = Atom("Carbon", tag="siteA")
9596
9697
Create deuterium
98+
9799
>>> D = Atom("H", mass=2.014)
98100
99101
Define an atom with 3 orbitals, each with a range of 2 Angstroem
102+
100103
>>> C3 = Atom("C", orbitals=[2, 2, 2])
101104
102105
Define an atom outside of the periodic table (negative will yield an
103106
AtomGhost object)
107+
104108
>>> ghost_C = Atom(-6)
105109
106110
Define an unknown atom (basically anything can do)
111+
107112
>>> unknown_atom = Atom(1000)
108113
109114
Notes
@@ -758,9 +763,55 @@ def Z(self):
758763
uZ = _a.arrayi([a.Z for a in self.atom])
759764
return uZ[self.species]
760765

761-
def index(self, atom):
762-
"""Return the indices of the atom object"""
763-
return (self._species == self.species_index(atom)).nonzero()[0]
766+
def index(self, atoms):
767+
"""Return indices of atoms matching the specified atom identifier(s).
768+
769+
Parameters
770+
----------
771+
atoms : str, Atom, int, or lists hereof
772+
Atom identifier (element string, Atom instance, or species_index).
773+
It can also be a list of identifiers for matching against any of the entries.
774+
775+
Returns
776+
-------
777+
numpy.ndarray
778+
Unique (sorted) indices of all matching atoms.
779+
Returns an empty array if no match is found.
780+
781+
Examples
782+
--------
783+
>>> atoms = Atoms(["C", "H", "Au"])
784+
>>> idx_C = atoms.index("C")
785+
>>> idx_C = atoms.index(0)
786+
>>> idx_CH = atoms.index(["C", "H"])
787+
>>> idx_Au = atoms.index(Atom(79))
788+
789+
This can be useful to get a subset of a geometry, e.g.,
790+
791+
>>> geom = Geometry(...)
792+
>>> idx_CH = geom.atoms.index(["C", "H"])
793+
>>> geom_CH = geom.sub(idx_CH)
794+
795+
Notes
796+
-----
797+
Unlike Python's ``list.index``, this method does **not** raise a ``ValueError``
798+
when no match is found. Instead, it returns an empty array.
799+
To check for existence, use e.g. ``if idx.size > 0: ...``.
800+
"""
801+
802+
if not isinstance(atoms, (list, tuple, np.ndarray)):
803+
atoms = [atoms]
804+
805+
idx = np.array([], dtype=int)
806+
for atom in atoms:
807+
try:
808+
arr = (self._species == self.species_index(atom)).nonzero()[0]
809+
idx = np.concatenate([idx, arr])
810+
except KeyError:
811+
# no species found
812+
pass
813+
814+
return np.unique(idx)
764815

765816
def species_index(self, atom):
766817
"""Return the species index of the atom object"""

src/sisl/_core/tests/test_atoms.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -297,5 +297,13 @@ def test_charge_diff():
297297

298298
def test_index1():
299299
atom = Atoms(["C", "Au"])
300-
with pytest.raises(KeyError):
301-
atom.index(Atom("B"))
300+
assert len(atom.index("B")) == 0
301+
assert atom.index("C") == [0]
302+
assert atom.index(Atom(79)) == [1]
303+
assert atom.index(1) == [1]
304+
assert atom.index(["C", "C"]) == [0]
305+
assert atom.index(["B", "C", "C"]) == [0]
306+
idx = atom.index(["Au", "C"])
307+
assert idx.ndim == 1
308+
assert np.issubdtype(idx.dtype, np.integer)
309+
assert np.allclose(idx, [0, 1])

0 commit comments

Comments
 (0)