77
88@cython .boundscheck (False )
99@cython .wraparound (False )
10- def add_cnc_diag_spin (state : complex_or_float [:, :], row_orbs : cython .int [:], col_orbs_uc : cython .int [:],
10+ def add_cnc_diag_spin (state : complex_or_float [:, :], DM_ptr : cython .int [:], DM_col_uc : cython .int [:],
1111 occs : cython .floating [:], DM_kpoint : complex_or_float [:], occtol : float = 1e-9 ):
1212 """Adds the cnc contributions of all orbital pairs to the DM given a array of states.
1313
@@ -18,10 +18,10 @@ def add_cnc_diag_spin(state: complex_or_float[:, :], row_orbs: cython.int[:], co
1818 state:
1919 The coefficients of all eigenstates for this contribution.
2020 Array of shape (n_eigenstates, n_basisorbitals)
21- row_orbs :
22- The orbital row indices of the sparsity pattern .
23- Shape (nnz , ), where nnz is the number of nonzero elements in the sparsity pattern .
24- col_orbs_uc :
21+ DM_ptr :
22+ The pointer to row array of the sparse DM .
23+ Shape (no + 1 , ), where no is the number of orbitals in the unit cell .
24+ DM_col_uc :
2525 The orbital col indices of the sparsity pattern, but converted to the unit cell.
2626 Shape (nnz, ), where nnz is the number of nonzero elements in the sparsity pattern.
2727 occs:
@@ -37,11 +37,13 @@ def add_cnc_diag_spin(state: complex_or_float[:, :], row_orbs: cython.int[:], co
3737 i : cython .int
3838 u : cython .int
3939 v : cython .int
40- ipair : cython .int
40+
41+ # Number of orbitals in the unit cell
42+ no : cython .int = DM_ptr .shape [0 ] - 1
43+ ival : cython .int
4144
4245 # Loop lengths
4346 n_wfs : cython .int = state .shape [0 ]
44- n_opairs : cython .int = row_orbs .shape [0 ]
4547
4648 # Variable to store the occupation of each state
4749 occ : float
@@ -54,17 +56,17 @@ def add_cnc_diag_spin(state: complex_or_float[:, :], row_orbs: cython.int[:], co
5456 if occ < occtol :
5557 continue
5658
57- # The occupation is above the tolerance threshold, loop through all overlaping orbital pairs
58- for ipair in range (n_opairs ):
59- # Get the orbital indices of this pair
60- u = row_orbs [ ipair ]
61- v = col_orbs_uc [ ipair ]
62- # Add the contribution of this eigenstate to the DM_{u,v} element
63- DM_kpoint [ ipair ] = DM_kpoint [ ipair ] + state [ i , u ] * occ * state [ i , v ]. conjugate ()
59+ # Loop over all non zero elements in the sparsity pattern
60+ for u in range (no ):
61+ for ival in range ( DM_ptr [ u ], DM_ptr [ u + 1 ]):
62+ v = DM_col_uc [ ival ]
63+ # Add the contribution of this eigenstate to the DM_{u,v} element
64+ DM_kpoint [ ival ] = DM_kpoint [ ival ] + state [ i , u ] * occ * state [ i , v ]. conjugate ()
65+
6466
6567@cython .boundscheck (False )
6668@cython .wraparound (False )
67- def add_cnc_nc (state : cython .complex [:, :, :], row_orbs : cython .int [:], col_orbs_uc : cython .int [:],
69+ def add_cnc_nc (state : cython .complex [:, :, :], DM_ptr : cython .int [:], DM_col_uc : cython .int [:],
6870 occs : cython .floating [:], DM_kpoint : cython .complex [:, :, :], occtol : float = 1e-9 ):
6971 """Adds the cnc contributions of all orbital pairs to the DM given a array of states.
7072
@@ -76,10 +78,10 @@ def add_cnc_nc(state: cython.complex[:, :, :], row_orbs: cython.int[:], col_orbs
7678 The coefficients of all eigenstates for this contribution.
7779 Array of shape (n_eigenstates, n_basisorbitals, 2), where the last dimension is the spin
7880 "up"/"down" dimension.
79- row_orbs :
80- The orbital row indices of the sparsity pattern .
81- Shape (nnz , ), where nnz is the number of nonzero elements in the sparsity pattern .
82- col_orbs_uc :
81+ DM_ptr :
82+ The pointer to row array of the sparse DM .
83+ Shape (no + 1 , ), where no is the number of orbitals in the unit cell .
84+ DM_col_uc :
8385 The orbital col indices of the sparsity pattern, but converted to the unit cell.
8486 Shape (nnz, ), where nnz is the number of nonzero elements in the sparsity pattern.
8587 occs:
@@ -96,14 +98,17 @@ def add_cnc_nc(state: cython.complex[:, :, :], row_orbs: cython.int[:], col_orbs
9698 i : cython .int
9799 u : cython .int
98100 v : cython .int
99- ipair : cython .int
101+ ival : cython .int
102+
103+ # Number of orbitals in the unit cell
104+ no : cython .int = DM_ptr .shape [0 ] - 1
105+
100106 # The spin box indices.
101107 Di : cython .int
102108 Dj : cython .int
103109
104110 # Loop lengths
105111 n_wfs : cython .int = state .shape [0 ]
106- n_opairs : cython .int = row_orbs .shape [0 ]
107112
108113 # Variable to store the occupation of each state
109114 occ : float
@@ -115,14 +120,13 @@ def add_cnc_nc(state: cython.complex[:, :, :], row_orbs: cython.int[:], col_orbs
115120 # If the occupation is lower than the tolerance, skip the state
116121 if occ < occtol :
117122 continue
118-
119- # The occupation is above the tolerance threshold, loop through all overlaping orbital pairs
120- for ipair in range (n_opairs ):
121- # Get the orbital indices of this pair
122- u = row_orbs [ipair ]
123- v = col_orbs_uc [ipair ]
123+
124+ # Loop over all non zero elements in the sparsity pattern
125+ for u in range (no ):
126+ for ival in range (DM_ptr [u ], DM_ptr [u + 1 ]):
127+ v = DM_col_uc [ival ]
124128
125- # Add to spin box
126- for Di in range (2 ):
127- for Dj in range (2 ):
128- DM_kpoint [ipair , Di , Dj ] = DM_kpoint [ipair , Di , Dj ] + state [i , u , Di ] * occ * state [i , v , Dj ].conjugate ()
129+ # Add to spin box
130+ for Di in range (2 ):
131+ for Dj in range (2 ):
132+ DM_kpoint [ival , Di , Dj ] = DM_kpoint [ival , Di , Dj ] + state [i , u , Di ] * occ * state [i , v , Dj ].conjugate ()
0 commit comments