Skip to content

Commit 3bbd61b

Browse files
Work on generic constraints, added templates over constraint type
1 parent e80b9d2 commit 3bbd61b

File tree

4 files changed

+77
-79
lines changed

4 files changed

+77
-79
lines changed

include/macis/asci/determinant_search.hpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -281,7 +281,8 @@ asci_contrib_container<wfn_t<N>> asci_contributions_constraint(
281281

282282
const double h_el_tol = asci_settings.h_el_tol;
283283
const auto& [C, B, C_min] = con;
284-
wfn_t<N> O = full_mask<N>(norb);
284+
wfn_constraint<N/2> alpha_con{ wfn_traits::alpha_string(C), wfn_traits::alpha_string(B), C_min};
285+
//wfn_t<N> O = full_mask<N>(norb);
285286

286287
// Loop over unique alpha strings
287288
for(size_t i_alpha = 0; i_alpha < nuniq_alpha; ++i_alpha) {
@@ -296,7 +297,7 @@ asci_contrib_container<wfn_t<N>> asci_contributions_constraint(
296297
const auto& occ_beta = bcd.occ_beta;
297298
const auto& orb_ens_alpha = bcd.orb_ens_alpha;
298299
generate_constraint_singles_contributions_ss(
299-
coeff, det|beta, C, O, B, occ_alpha, occ_beta,
300+
coeff, det|beta, alpha_con, occ_alpha, occ_beta,
300301
orb_ens_alpha.data(), T_pq, norb, G_red, norb, V_red, norb,
301302
h_el_tol, h_diag, E_ASCI, ham_gen, asci_pairs);
302303
}
@@ -309,7 +310,7 @@ asci_contrib_container<wfn_t<N>> asci_contributions_constraint(
309310
const auto& occ_beta = bcd.occ_beta;
310311
const auto& orb_ens_alpha = bcd.orb_ens_alpha;
311312
generate_constraint_doubles_contributions_ss(
312-
coeff, det|beta, C, O, B, occ_alpha, occ_beta,
313+
coeff, det|beta, alpha_con, occ_alpha, occ_beta,
313314
orb_ens_alpha.data(), G_pqrs, norb, h_el_tol, h_diag, E_ASCI,
314315
ham_gen, asci_pairs);
315316
}
@@ -324,14 +325,14 @@ asci_contrib_container<wfn_t<N>> asci_contributions_constraint(
324325
const auto& orb_ens_alpha = bcd.orb_ens_alpha;
325326
const auto& orb_ens_beta = bcd.orb_ens_beta;
326327
generate_constraint_doubles_contributions_os(
327-
coeff, det|beta, C, O, B, occ_alpha, occ_beta, vir_beta,
328+
coeff, det|beta, alpha_con, occ_alpha, occ_beta, vir_beta,
328329
orb_ens_alpha.data(), orb_ens_beta.data(), V_pqrs, norb, h_el_tol,
329330
h_diag, E_ASCI, ham_gen, asci_pairs);
330331
}
331332

332333
// If the alpha determinant satisfies the constraint,
333334
// append BB and BBBB excitations
334-
if(satisfies_constraint(det, C, C_min)) {
335+
if(satisfies_constraint(det, con)) {
335336
for(const auto& bcd : uad[i_alpha].bcd) {
336337
const auto& beta = bcd.beta_string;
337338
const auto& coeff = bcd.coeff;

include/macis/asci/mask_constraints.hpp

Lines changed: 48 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -22,47 +22,20 @@ struct wfn_constraint {
2222
};
2323

2424
template <size_t N>
25-
auto make_triplet(unsigned i, unsigned j, unsigned k) {
26-
wfn_constraint<N> con;
27-
28-
con.C = 0;
29-
con.C.flip(i).flip(j).flip(k);
30-
con.B = 1;
31-
con.B <<= k;
32-
con.B = con.B.to_ullong() - 1;
33-
con.C_min = k;
34-
35-
return con;
36-
}
37-
38-
template <size_t N>
39-
auto make_quad(unsigned i, unsigned j, unsigned k, unsigned l) {
40-
wfn_constraint<N> con;
41-
42-
con.C = 0;
43-
con.C.flip(i).flip(j).flip(k).flip(l);
44-
con.B = 1;
45-
con.B <<= l;
46-
con.B = con.B.to_ullong() - 1;
47-
con.C_min = l;
48-
49-
return con;
50-
}
51-
52-
template <size_t N>
53-
bool satisfies_constraint(wfn_t<N> det, wfn_t<N> C, unsigned C_min) {
25+
bool satisfies_constraint(wfn_t<N> det, wfn_constraint<N> constraint) {
26+
auto C = constraint.C; auto C_min = constraint.C_min;
5427
return (det & C).count() == C.count() and ((det ^ C) >> C_min).count() == 0;
5528
}
5629

5730
template <size_t N>
58-
auto generate_constraint_single_excitations(wfn_t<N> det, wfn_t<N> C,
59-
wfn_t<N> O_mask, wfn_t<N> B) {
31+
auto generate_constraint_single_excitations(wfn_t<N> det, wfn_constraint<N> constraint) {
32+
const auto C = constraint.C; const auto B = constraint.B;
6033
if((det & C).count() <
6134
(C.count() - 1)) // need to have at most one different from the constraint
6235
return std::make_pair(wfn_t<N>(0), wfn_t<N>(0));
6336

6437
auto o = det ^ C;
65-
auto v = (~det) & O_mask & B;
38+
auto v = (~det) & B;
6639

6740
if((o & C).count() ==
6841
1) { // don't have to change this necessarily, but more clear without >=
@@ -78,15 +51,15 @@ auto generate_constraint_single_excitations(wfn_t<N> det, wfn_t<N> C,
7851
}
7952

8053
template <size_t N>
81-
auto generate_constraint_double_excitations(wfn_t<N> det, wfn_t<N> C,
82-
wfn_t<N> O_mask, wfn_t<N> B) {
54+
auto generate_constraint_double_excitations(wfn_t<N> det, wfn_constraint<N> constraint) {
55+
const auto C = constraint.C; const auto B = constraint.B;
8356
// Occ/Vir pairs to generate excitations
8457
std::vector<wfn_t<N>> O, V;
8558

8659
if((det & C) == 0) return std::make_tuple(O, V);
8760

8861
auto o = det ^ C;
89-
auto v = (~det) & O_mask & B;
62+
auto v = (~det) & B;
9063

9164
if((o & C).count() >= 3) return std::make_tuple(O, V);
9265

@@ -130,9 +103,9 @@ auto generate_constraint_double_excitations(wfn_t<N> det, wfn_t<N> C,
130103
}
131104

132105
template <size_t N>
133-
void generate_constraint_singles(wfn_t<N> det, wfn_t<N> T, wfn_t<N> O_mask,
134-
wfn_t<N> B, std::vector<wfn_t<N>>& t_singles) {
135-
auto [o, v] = generate_constraint_single_excitations(det, T, O_mask, B);
106+
void generate_constraint_singles(wfn_t<N> det, wfn_constraint<N> constraint,
107+
std::vector<wfn_t<N>>& t_singles) {
108+
auto [o, v] = generate_constraint_single_excitations(det, constraint);
136109
const auto oc = o.count();
137110
const auto vc = v.count();
138111
if(!oc or !vc) return;
@@ -156,9 +129,9 @@ unsigned count_constraint_singles(Args&&... args) {
156129
}
157130

158131
template <size_t N>
159-
void generate_constraint_doubles(wfn_t<N> det, wfn_t<N> T, wfn_t<N> O_mask,
160-
wfn_t<N> B, std::vector<wfn_t<N>>& t_doubles) {
161-
auto [O, V] = generate_constraint_double_excitations(det, T, O_mask, B);
132+
void generate_constraint_doubles(wfn_t<N> det, wfn_constraint<N> constraint,
133+
std::vector<wfn_t<N>>& t_doubles) {
134+
auto [O, V] = generate_constraint_double_excitations(det, constraint);
162135

163136
t_doubles.clear();
164137
for(auto ij : O) {
@@ -172,16 +145,15 @@ void generate_constraint_doubles(wfn_t<N> det, wfn_t<N> T, wfn_t<N> O_mask,
172145
/**
173146
* @param[in] det Input root determinant
174147
* @param[in] T Triplet constraint mask
175-
* @param[in] O Overfill mask (full mask 0 -> norb)
176148
* @param[in] B B mask (?)
177149
*/
178150
template <size_t N>
179-
unsigned count_constraint_doubles(wfn_t<N> det, wfn_t<N> C, wfn_t<N> O,
180-
wfn_t<N> B) {
151+
unsigned count_constraint_doubles(wfn_t<N> det, wfn_constraint<N> constraint) {
152+
const auto C = constraint.C; const auto B = constraint.B;
181153
if((det & C) == 0) return 0;
182154

183155
auto o = det ^ C;
184-
auto v = (~det) & O & B;
156+
auto v = (~det) & B;
185157

186158
if((o & C).count() >= 3) return 0;
187159

@@ -223,34 +195,32 @@ unsigned count_constraint_doubles(wfn_t<N> det, wfn_t<N> C, wfn_t<N> O,
223195

224196
template <size_t N, typename... Args>
225197
size_t constraint_histogram(wfn_t<N> det, size_t n_os_singles,
226-
size_t n_os_doubles, wfn_t<N> T, wfn_t<N> O_mask,
227-
wfn_t<N> B) {
228-
auto ns = count_constraint_singles(det, T, O_mask, B);
229-
auto nd = count_constraint_doubles(det, T, O_mask, B);
198+
size_t n_os_doubles, wfn_constraint<N> constraint){
199+
auto ns = count_constraint_singles(det, constraint);
200+
auto nd = count_constraint_doubles(det, constraint);
230201

231202
size_t ndet = 0;
232203
ndet += ns; // AA
233204
ndet += nd; // AAAA
234205
ndet += ns * n_os_singles; // AABB
235-
auto T_min = ffs(T) - 1;
236-
if(satisfies_constraint(det, T, T_min)) {
206+
if(satisfies_constraint(det, constraint)) {
237207
ndet += n_os_singles + n_os_doubles + 1; // BB + BBBB + No Excitations
238208
}
239209

240210
return ndet;
241211
}
242212

243-
template <typename WfnType>
213+
template <typename WfnType, typename ConType>
244214
void generate_constraint_singles_contributions_ss(
245-
double coeff, WfnType det, WfnType T, WfnType O, WfnType B,
215+
double coeff, WfnType det, ConType constraint,
246216
const std::vector<uint32_t>& occ_same,
247217
const std::vector<uint32_t>& occ_othr, const double* eps,
248218
const double* T_pq, const size_t LDT, const double* G_kpq, const size_t LDG,
249219
const double* V_kpq, const size_t LDV, double h_el_tol, double root_diag,
250220
double E0, HamiltonianGeneratorBase<double>& ham_gen,
251221
asci_contrib_container<WfnType>& asci_contributions) {
252222
using wfn_traits = wavefunction_traits<WfnType>;
253-
auto [o, v] = generate_constraint_single_excitations(wfn_traits::alpha_string(det), wfn_traits::alpha_string(T), wfn_traits::alpha_string(O), wfn_traits::alpha_string(B));
223+
auto [o, v] = generate_constraint_single_excitations(wfn_traits::alpha_string(det), constraint);
254224
const auto no = o.count();
255225
const auto nv = v.count();
256226
if(!no or !nv) return;
@@ -291,17 +261,17 @@ void generate_constraint_singles_contributions_ss(
291261
}
292262
}
293263

294-
template <typename WfnType>
264+
template <typename WfnType, typename ConType>
295265
void generate_constraint_doubles_contributions_ss(
296-
double coeff, WfnType det, WfnType T, WfnType O_mask, WfnType B,
266+
double coeff, WfnType det, ConType constraint,
297267
const std::vector<uint32_t>& occ_same,
298268
const std::vector<uint32_t>& occ_othr, const double* eps, const double* G,
299269
const size_t LDG, double h_el_tol, double root_diag, double E0,
300270
HamiltonianGeneratorBase<double>& ham_gen,
301271
asci_contrib_container<WfnType>& asci_contributions) {
302272
using wfn_traits = wavefunction_traits<WfnType>;
303273
using spin_wfn_traits = wavefunction_traits<spin_wfn_t<WfnType>>;
304-
auto [O, V] = generate_constraint_double_excitations(wfn_traits::alpha_string(det), wfn_traits::alpha_string(T), wfn_traits::alpha_string(O_mask), wfn_traits::alpha_string(B));
274+
auto [O, V] = generate_constraint_double_excitations(wfn_traits::alpha_string(det), constraint);
305275
const auto no_pairs = O.size();
306276
const auto nv_pairs = V.size();
307277
if(!no_pairs or !nv_pairs) return;
@@ -346,9 +316,9 @@ void generate_constraint_doubles_contributions_ss(
346316
}
347317
}
348318

349-
template <typename WfnType>
319+
template <typename WfnType, typename ConType>
350320
void generate_constraint_doubles_contributions_os(
351-
double coeff, WfnType det, WfnType T, WfnType O, WfnType B,
321+
double coeff, WfnType det, ConType constraint,
352322
const std::vector<uint32_t>& occ_same,
353323
const std::vector<uint32_t>& occ_othr,
354324
const std::vector<uint32_t>& vir_othr, const double* eps_same,
@@ -357,7 +327,7 @@ void generate_constraint_doubles_contributions_os(
357327
asci_contrib_container<WfnType>& asci_contributions) {
358328
using wfn_traits = wavefunction_traits<WfnType>;
359329
// Generate Single Excitations that Satisfy the Constraint
360-
auto [o, v] = generate_constraint_single_excitations(wfn_traits::alpha_string(det), wfn_traits::alpha_string(T), wfn_traits::alpha_string(O), wfn_traits::alpha_string(B));
330+
auto [o, v] = generate_constraint_single_excitations(wfn_traits::alpha_string(det), constraint);
361331
const auto no = o.count();
362332
const auto nv = v.count();
363333
if(!no or !nv) return;
@@ -488,6 +458,21 @@ auto dist_triplets_histogram(size_t norb, size_t ns_othr, size_t nd_othr,
488458
}
489459
#endif
490460

461+
462+
template <size_t N>
463+
auto make_triplet(unsigned i, unsigned j, unsigned k) {
464+
wfn_constraint<N> con;
465+
466+
con.C = 0;
467+
con.C.flip(i).flip(j).flip(k);
468+
con.B = 1;
469+
con.B <<= k;
470+
con.B = con.B.to_ullong() - 1;
471+
con.C_min = k;
472+
473+
return con;
474+
}
475+
491476
#ifdef MACIS_ENABLE_MPI
492477
template <size_t N>
493478
auto dist_constraint_general(size_t nlevels, size_t norb, size_t ns_othr,
@@ -497,7 +482,7 @@ auto dist_constraint_general(size_t nlevels, size_t norb, size_t ns_othr,
497482
auto world_rank = comm_rank(comm);
498483
auto world_size = comm_size(comm);
499484

500-
wfn_t<N> O = full_mask<N>(norb);
485+
//wfn_t<N> O = full_mask<N>(norb);
501486

502487
// Global workloads
503488
std::vector<size_t> workloads(world_size, 0);
@@ -514,7 +499,7 @@ auto dist_constraint_general(size_t nlevels, size_t norb, size_t ns_othr,
514499

515500
size_t nw = 0;
516501
for(const auto& alpha : unique_alpha) {
517-
nw += constraint_histogram(alpha, ns_othr, nd_othr, T, O, B);
502+
nw += constraint_histogram(alpha, ns_othr, nd_othr, constraint);
518503
}
519504
if(nw) constraint_sizes.emplace_back(constraint, nw);
520505
total_work += nw;
@@ -553,8 +538,7 @@ auto dist_constraint_general(size_t nlevels, size_t norb, size_t ns_othr,
553538
size_t nw = 0;
554539

555540
for(const auto& alpha : unique_alpha) {
556-
nw += constraint_histogram(alpha, ns_othr, nd_othr, c_next.C, O,
557-
c_next.B);
541+
nw += constraint_histogram(alpha, ns_othr, nd_othr, c_next);
558542
}
559543
if(nw) constraint_sizes.emplace_back(c_next, nw);
560544
total_work += nw;

include/macis/asci/pt2.hpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ double asci_pt2_constraint(
131131
//std::cout << std::distance(&constraints[0], &con) << "/" << constraints.size() << std::endl;
132132
const double h_el_tol = 1e-16;
133133
const auto& [C, B, C_min] = con;
134-
wfn_t<N> O = full_mask<N>(norb);
134+
wfn_constraint<N/2> alpha_con{ wfn_traits::alpha_string(C), wfn_traits::alpha_string(B), C_min};
135135

136136
// Loop over unique alpha strings
137137
for(size_t i_alpha = 0; i_alpha < nuniq_alpha; ++i_alpha) {
@@ -146,7 +146,7 @@ double asci_pt2_constraint(
146146
const auto& occ_beta = bcd.occ_beta;
147147
const auto& orb_ens_alpha = bcd.orb_ens_alpha;
148148
generate_constraint_singles_contributions_ss(
149-
coeff, det|beta, C, O, B, occ_alpha, occ_beta,
149+
coeff, det|beta, alpha_con, occ_alpha, occ_beta,
150150
orb_ens_alpha.data(), T_pq, norb, G_red, norb, V_red, norb,
151151
h_el_tol, h_diag, E_ASCI, ham_gen, asci_pairs);
152152
}
@@ -159,7 +159,7 @@ double asci_pt2_constraint(
159159
const auto& occ_beta = bcd.occ_beta;
160160
const auto& orb_ens_alpha = bcd.orb_ens_alpha;
161161
generate_constraint_doubles_contributions_ss(
162-
coeff, det|beta, C, O, B, occ_alpha, occ_beta,
162+
coeff, det|beta, alpha_con, occ_alpha, occ_beta,
163163
orb_ens_alpha.data(), G_pqrs, norb, h_el_tol, h_diag, E_ASCI,
164164
ham_gen, asci_pairs);
165165
}
@@ -174,14 +174,14 @@ double asci_pt2_constraint(
174174
const auto& orb_ens_alpha = bcd.orb_ens_alpha;
175175
const auto& orb_ens_beta = bcd.orb_ens_beta;
176176
generate_constraint_doubles_contributions_os(
177-
coeff, det|beta, C, O, B, occ_alpha, occ_beta, vir_beta,
177+
coeff, det|beta, alpha_con, occ_alpha, occ_beta, vir_beta,
178178
orb_ens_alpha.data(), orb_ens_beta.data(), V_pqrs, norb, h_el_tol,
179179
h_diag, E_ASCI, ham_gen, asci_pairs);
180180
}
181181

182182
// If the alpha determinant satisfies the constraint,
183183
// append BB and BBBB excitations
184-
if(satisfies_constraint(det, C, C_min)) {
184+
if(satisfies_constraint(det, con)) {
185185
for(const auto& bcd : uad[i_alpha].bcd) {
186186
const auto& beta = bcd.beta_string;
187187
const auto& coeff = bcd.coeff;

0 commit comments

Comments
 (0)