|
1 | 1 | import numpy as np |
2 | 2 | from numpy import einsum |
| 3 | +from numpy import ndarray, bool_ |
3 | 4 |
|
4 | 5 | import sisl._array as _a |
5 | 6 | from sisl.messages import warn |
@@ -31,6 +32,14 @@ def __init__(self, parent, **info): |
31 | 32 | self.parent = parent |
32 | 33 | self.info = info |
33 | 34 |
|
| 35 | + def _sanitize_index(self, idx): |
| 36 | + r""" Ensure indices are transferred to acceptable integers """ |
| 37 | + if isinstance(idx, ndarray) and idx.dtype == bool_: |
| 38 | + return np.flatnonzero(idx) |
| 39 | + elif isinstance(idx, (list, tuple)) and isinstance(idx[0], bool): |
| 40 | + return np.flatnonzero(idx) |
| 41 | + return _a.asarrayi(idx).ravel() |
| 42 | + |
34 | 43 |
|
35 | 44 | class Coefficient(ParentContainer): |
36 | 45 | """ An object holding coefficients for a parent with info |
@@ -136,7 +145,7 @@ def sub(self, idx): |
136 | 145 | Coefficient |
137 | 146 | a new coefficient only containing the requested elements |
138 | 147 | """ |
139 | | - idx = _a.asarrayi(idx).ravel() # this ensures that the first dimension is preserved |
| 148 | + idx = self._sanitize_index(idx) |
140 | 149 | sub = self.__class__(self.c[idx].copy(), self.parent) |
141 | 150 | sub.info = self.info |
142 | 151 | return sub |
@@ -262,7 +271,7 @@ def sub(self, idx): |
262 | 271 | State |
263 | 272 | a new state only containing the requested elements |
264 | 273 | """ |
265 | | - idx = _a.asarrayi(idx).ravel() # this ensures that the first dimension is preserved |
| 274 | + idx = self._sanitize_index(idx) |
266 | 275 | sub = self.__class__(self.state[idx].copy(), self.parent) |
267 | 276 | sub.info = self.info |
268 | 277 | return sub |
@@ -705,7 +714,7 @@ def outer(self, idx=None): |
705 | 714 | """ |
706 | 715 | if idx is None: |
707 | 716 | return einsum('k,ki,kj->ij', self.c, self.state, _conj(self.state)) |
708 | | - idx = np.asarray(idx).ravel() |
| 717 | + idx = self._sanitize_index(idx) |
709 | 718 | return einsum('k,ki,kj->ij', self.c[idx], self.state[idx], _conj(self.state[idx])) |
710 | 719 |
|
711 | 720 | def sort(self, ascending=True): |
@@ -763,7 +772,7 @@ def sub(self, idx): |
763 | 772 | ------- |
764 | 773 | StateC : a new object with a subset of the states |
765 | 774 | """ |
766 | | - idx = _a.asarrayi(idx).ravel() |
| 775 | + idx = self._sanitize_index(idx) |
767 | 776 | sub = self.__class__(self.state[idx, ...], self.c[idx], self.parent) |
768 | 777 | sub.info = self.info |
769 | 778 | return sub |
|
0 commit comments