Skip to content

Commit 40c23d1

Browse files
committed
enh: allowed bool arrays for state objects
Signed-off-by: Nick Papior <[email protected]>
1 parent 2e6ae92 commit 40c23d1

File tree

2 files changed

+15
-4
lines changed

2 files changed

+15
-4
lines changed

sisl/physics/state.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import numpy as np
22
from numpy import einsum
3+
from numpy import ndarray, bool_
34

45
import sisl._array as _a
56
from sisl.messages import warn
@@ -31,6 +32,14 @@ def __init__(self, parent, **info):
3132
self.parent = parent
3233
self.info = info
3334

35+
def _sanitize_index(self, idx):
36+
r""" Ensure indices are transferred to acceptable integers """
37+
if isinstance(idx, ndarray) and idx.dtype == bool_:
38+
return np.flatnonzero(idx)
39+
elif isinstance(idx, (list, tuple)) and isinstance(idx[0], bool):
40+
return np.flatnonzero(idx)
41+
return _a.asarrayi(idx).ravel()
42+
3443

3544
class Coefficient(ParentContainer):
3645
""" An object holding coefficients for a parent with info
@@ -136,7 +145,7 @@ def sub(self, idx):
136145
Coefficient
137146
a new coefficient only containing the requested elements
138147
"""
139-
idx = _a.asarrayi(idx).ravel() # this ensures that the first dimension is preserved
148+
idx = self._sanitize_index(idx)
140149
sub = self.__class__(self.c[idx].copy(), self.parent)
141150
sub.info = self.info
142151
return sub
@@ -262,7 +271,7 @@ def sub(self, idx):
262271
State
263272
a new state only containing the requested elements
264273
"""
265-
idx = _a.asarrayi(idx).ravel() # this ensures that the first dimension is preserved
274+
idx = self._sanitize_index(idx)
266275
sub = self.__class__(self.state[idx].copy(), self.parent)
267276
sub.info = self.info
268277
return sub
@@ -705,7 +714,7 @@ def outer(self, idx=None):
705714
"""
706715
if idx is None:
707716
return einsum('k,ki,kj->ij', self.c, self.state, _conj(self.state))
708-
idx = np.asarray(idx).ravel()
717+
idx = self._sanitize_index(idx)
709718
return einsum('k,ki,kj->ij', self.c[idx], self.state[idx], _conj(self.state[idx]))
710719

711720
def sort(self, ascending=True):
@@ -763,7 +772,7 @@ def sub(self, idx):
763772
-------
764773
StateC : a new object with a subset of the states
765774
"""
766-
idx = _a.asarrayi(idx).ravel()
775+
idx = self._sanitize_index(idx)
767776
sub = self.__class__(self.state[idx, ...], self.c[idx], self.parent)
768777
sub.info = self.info
769778
return sub

sisl/physics/tests/test_state.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,8 @@ def test_state_sub1():
116116
for i, sub in enumerate(state.iter(True)):
117117
assert (sub ** 2).sum() == norm2[i]
118118

119+
assert np.allclose(state.sub([False, True, False, True]).state, state.sub([1, 3]).state)
120+
119121

120122
def test_state_outer1():
121123
state = ar(10, 10)

0 commit comments

Comments
 (0)