Skip to content

Commit 960ecfa

Browse files
committed
Discovering orbital pairs on the fly using ptr
1 parent 63a685e commit 960ecfa

File tree

2 files changed

+40
-38
lines changed

2 files changed

+40
-38
lines changed

sisl/physics/_compute_dm.py

Lines changed: 35 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
@cython.boundscheck(False)
99
@cython.wraparound(False)
10-
def add_cnc_diag_spin(state: complex_or_float[:, :], row_orbs: cython.int[:], col_orbs_uc: cython.int[:],
10+
def add_cnc_diag_spin(state: complex_or_float[:, :], DM_ptr: cython.int[:], DM_col_uc: cython.int[:],
1111
occs: cython.floating[:], DM_kpoint: complex_or_float[:], occtol: float = 1e-9):
1212
"""Adds the cnc contributions of all orbital pairs to the DM given a array of states.
1313
@@ -18,10 +18,10 @@ def add_cnc_diag_spin(state: complex_or_float[:, :], row_orbs: cython.int[:], co
1818
state:
1919
The coefficients of all eigenstates for this contribution.
2020
Array of shape (n_eigenstates, n_basisorbitals)
21-
row_orbs:
22-
The orbital row indices of the sparsity pattern.
23-
Shape (nnz, ), where nnz is the number of nonzero elements in the sparsity pattern.
24-
col_orbs_uc:
21+
DM_ptr:
22+
The pointer to row array of the sparse DM.
23+
Shape (no + 1, ), where no is the number of orbitals in the unit cell.
24+
DM_col_uc:
2525
The orbital col indices of the sparsity pattern, but converted to the unit cell.
2626
Shape (nnz, ), where nnz is the number of nonzero elements in the sparsity pattern.
2727
occs:
@@ -37,11 +37,13 @@ def add_cnc_diag_spin(state: complex_or_float[:, :], row_orbs: cython.int[:], co
3737
i: cython.int
3838
u: cython.int
3939
v: cython.int
40-
ipair: cython.int
40+
41+
# Number of orbitals in the unit cell
42+
no: cython.int = DM_ptr.shape[0] - 1
43+
ival: cython.int
4144

4245
# Loop lengths
4346
n_wfs: cython.int = state.shape[0]
44-
n_opairs: cython.int = row_orbs.shape[0]
4547

4648
# Variable to store the occupation of each state
4749
occ: float
@@ -54,17 +56,17 @@ def add_cnc_diag_spin(state: complex_or_float[:, :], row_orbs: cython.int[:], co
5456
if occ < occtol:
5557
continue
5658

57-
# The occupation is above the tolerance threshold, loop through all overlaping orbital pairs
58-
for ipair in range(n_opairs):
59-
# Get the orbital indices of this pair
60-
u = row_orbs[ipair]
61-
v = col_orbs_uc[ipair]
62-
# Add the contribution of this eigenstate to the DM_{u,v} element
63-
DM_kpoint[ipair] = DM_kpoint[ipair] + state[i, u] * occ * state[i, v].conjugate()
59+
# Loop over all non zero elements in the sparsity pattern
60+
for u in range(no):
61+
for ival in range(DM_ptr[u], DM_ptr[u+1]):
62+
v = DM_col_uc[ival]
63+
# Add the contribution of this eigenstate to the DM_{u,v} element
64+
DM_kpoint[ival] = DM_kpoint[ival] + state[i, u] * occ * state[i, v].conjugate()
65+
6466

6567
@cython.boundscheck(False)
6668
@cython.wraparound(False)
67-
def add_cnc_nc(state: cython.complex[:, :, :], row_orbs: cython.int[:], col_orbs_uc: cython.int[:],
69+
def add_cnc_nc(state: cython.complex[:, :, :], DM_ptr: cython.int[:], DM_col_uc: cython.int[:],
6870
occs: cython.floating[:], DM_kpoint: cython.complex[:, :, :], occtol: float = 1e-9):
6971
"""Adds the cnc contributions of all orbital pairs to the DM given a array of states.
7072
@@ -76,10 +78,10 @@ def add_cnc_nc(state: cython.complex[:, :, :], row_orbs: cython.int[:], col_orbs
7678
The coefficients of all eigenstates for this contribution.
7779
Array of shape (n_eigenstates, n_basisorbitals, 2), where the last dimension is the spin
7880
"up"/"down" dimension.
79-
row_orbs:
80-
The orbital row indices of the sparsity pattern.
81-
Shape (nnz, ), where nnz is the number of nonzero elements in the sparsity pattern.
82-
col_orbs_uc:
81+
DM_ptr:
82+
The pointer to row array of the sparse DM.
83+
Shape (no + 1, ), where no is the number of orbitals in the unit cell.
84+
DM_col_uc:
8385
The orbital col indices of the sparsity pattern, but converted to the unit cell.
8486
Shape (nnz, ), where nnz is the number of nonzero elements in the sparsity pattern.
8587
occs:
@@ -96,14 +98,17 @@ def add_cnc_nc(state: cython.complex[:, :, :], row_orbs: cython.int[:], col_orbs
9698
i: cython.int
9799
u: cython.int
98100
v: cython.int
99-
ipair: cython.int
101+
ival: cython.int
102+
103+
# Number of orbitals in the unit cell
104+
no: cython.int = DM_ptr.shape[0] - 1
105+
100106
# The spin box indices.
101107
Di: cython.int
102108
Dj: cython.int
103109

104110
# Loop lengths
105111
n_wfs: cython.int = state.shape[0]
106-
n_opairs: cython.int = row_orbs.shape[0]
107112

108113
# Variable to store the occupation of each state
109114
occ: float
@@ -115,14 +120,13 @@ def add_cnc_nc(state: cython.complex[:, :, :], row_orbs: cython.int[:], col_orbs
115120
# If the occupation is lower than the tolerance, skip the state
116121
if occ < occtol:
117122
continue
118-
119-
# The occupation is above the tolerance threshold, loop through all overlaping orbital pairs
120-
for ipair in range(n_opairs):
121-
# Get the orbital indices of this pair
122-
u = row_orbs[ipair]
123-
v = col_orbs_uc[ipair]
123+
124+
# Loop over all non zero elements in the sparsity pattern
125+
for u in range(no):
126+
for ival in range(DM_ptr[u], DM_ptr[u+1]):
127+
v = DM_col_uc[ival]
124128

125-
# Add to spin box
126-
for Di in range(2):
127-
for Dj in range(2):
128-
DM_kpoint[ipair, Di, Dj] = DM_kpoint[ipair, Di, Dj] + state[i, u, Di] * occ * state[i, v, Dj].conjugate()
129+
# Add to spin box
130+
for Di in range(2):
131+
for Dj in range(2):
132+
DM_kpoint[ival, Di, Dj] = DM_kpoint[ival, Di, Dj] + state[i, u, Di] * occ * state[i, v, Dj].conjugate()

sisl/physics/compute_dm.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,7 @@ def compute_dm(bz: BrillouinZone, eigenstates: Optional[Sequence[EigenstateElect
4646
geom = H.geometry
4747

4848
# Sparsity pattern information
49-
row_orbs, col_orbs = H.nonzero()
50-
col_orbs_uc = H.osc2uc(col_orbs)
51-
col_isc = col_orbs // H.no
49+
col_isc, col_uc = np.divmod(H._csr.col, H.no)
5250
sc_offsets = H.sc_off.dot(H.cell)
5351

5452
# Initialize the density matrix using the sparsity pattern of the Hamiltonian.
@@ -121,8 +119,8 @@ def compute_dm(bz: BrillouinZone, eigenstates: Optional[Sequence[EigenstateElect
121119

122120
if DM.spin.is_diagonal:
123121
# Calculate the matrix elements contributions for this k point.
124-
DM_kpoint = np.zeros(row_orbs.shape[0], dtype=k_eigs.state.dtype)
125-
add_cnc_diag_spin(state, row_orbs, col_orbs_uc, occs, DM_kpoint, occtol=occtol)
122+
DM_kpoint = np.zeros(DM.nnz, dtype=k_eigs.state.dtype)
123+
add_cnc_diag_spin(state, H._csr.ptr, col_uc, occs, DM_kpoint, occtol=occtol)
126124

127125
# Apply phases
128126
DM_kpoint = DM_kpoint * phases
@@ -139,8 +137,8 @@ def compute_dm(bz: BrillouinZone, eigenstates: Optional[Sequence[EigenstateElect
139137

140138
# Calculate the matrix elements contributions for this k point. For each matrix element
141139
# we allocate a 2x2 spin box.
142-
DM_kpoint = np.zeros((row_orbs.shape[0], 2, 2), dtype=np.complex128)
143-
add_cnc_nc(state, row_orbs, col_orbs_uc, occs, DM_kpoint, occtol=occtol)
140+
DM_kpoint = np.zeros((DM.nnz, 2, 2), dtype=np.complex128)
141+
add_cnc_nc(state, H._csr.ptr, col_uc, occs, DM_kpoint, occtol=occtol)
144142

145143
# Apply phases
146144
DM_kpoint *= phases.reshape(-1, 1, 1)

0 commit comments

Comments
 (0)