Skip to content

Commit e224338

Browse files
changed wfn_constraint -> alpha_constraint and made appropriate changes
1 parent 3bbd61b commit e224338

File tree

5 files changed

+144
-74
lines changed

5 files changed

+144
-74
lines changed
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
#pragma once
2+
3+
#include <macis/wfn/raw_bitset.hpp>
4+
namespace macis {
5+
6+
template <typename WfnTraits>
7+
class alpha_constraint {
8+
9+
public:
10+
using wfn_traits = WfnTraits;
11+
using wfn_type = typename WfnTraits::wfn_type;
12+
using spin_wfn_type = spin_wfn_t<wfn_type>;
13+
using spin_wfn_traits = wavefunction_traits<spin_wfn_type>;
14+
using constraint_type = spin_wfn_type;
15+
16+
private:
17+
constraint_type C_;
18+
constraint_type B_;
19+
uint32_t C_min_;
20+
uint32_t count_;
21+
22+
public:
23+
24+
alpha_constraint(constraint_type C, constraint_type B, uint32_t C_min) :
25+
C_(C), B_(B), C_min_(C_min), count_(spin_wfn_traits::count(C)) {}
26+
27+
alpha_constraint(const alpha_constraint&) = default;
28+
alpha_constraint& operator=(const alpha_constraint&) = default;
29+
30+
alpha_constraint(alpha_constraint&& other) noexcept = default;
31+
alpha_constraint& operator=(alpha_constraint&&) noexcept = default;
32+
33+
34+
inline auto C() const { return C_; }
35+
inline auto B() const { return B_; }
36+
inline auto C_min() const { return C_min_; }
37+
inline auto count() const { return count_; }
38+
39+
40+
inline spin_wfn_type c_mask_union(spin_wfn_type state) const {
41+
return state & C_;
42+
}
43+
inline spin_wfn_type b_mask_union(spin_wfn_type state) const {
44+
return state & B_;
45+
}
46+
47+
inline spin_wfn_type symmetric_difference(spin_wfn_type state) const {
48+
return state ^ C_;
49+
}
50+
inline spin_wfn_type symmetric_difference(wfn_type state) const {
51+
return symmetric_difference(wfn_traits::alpha_string(state));
52+
}
53+
54+
template <typename WfnType>
55+
inline auto overlap(WfnType state) const {
56+
return spin_wfn_traits::count(c_mask_union(state));
57+
}
58+
59+
template <typename WfnType>
60+
inline bool satisfies_constraint(WfnType state) const {
61+
return overlap(state) == count_ and
62+
spin_wfn_traits::count(symmetric_difference(state) >> C_min_) == 0;
63+
}
64+
65+
66+
67+
};
68+
69+
70+
}

include/macis/asci/determinant_search.hpp

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -280,9 +280,6 @@ asci_contrib_container<wfn_t<N>> asci_contributions_constraint(
280280
auto size_before = asci_pairs.size();
281281

282282
const double h_el_tol = asci_settings.h_el_tol;
283-
const auto& [C, B, C_min] = con;
284-
wfn_constraint<N/2> alpha_con{ wfn_traits::alpha_string(C), wfn_traits::alpha_string(B), C_min};
285-
//wfn_t<N> O = full_mask<N>(norb);
286283

287284
// Loop over unique alpha strings
288285
for(size_t i_alpha = 0; i_alpha < nuniq_alpha; ++i_alpha) {
@@ -297,7 +294,7 @@ asci_contrib_container<wfn_t<N>> asci_contributions_constraint(
297294
const auto& occ_beta = bcd.occ_beta;
298295
const auto& orb_ens_alpha = bcd.orb_ens_alpha;
299296
generate_constraint_singles_contributions_ss(
300-
coeff, det|beta, alpha_con, occ_alpha, occ_beta,
297+
coeff, det|beta, con, occ_alpha, occ_beta,
301298
orb_ens_alpha.data(), T_pq, norb, G_red, norb, V_red, norb,
302299
h_el_tol, h_diag, E_ASCI, ham_gen, asci_pairs);
303300
}
@@ -310,7 +307,7 @@ asci_contrib_container<wfn_t<N>> asci_contributions_constraint(
310307
const auto& occ_beta = bcd.occ_beta;
311308
const auto& orb_ens_alpha = bcd.orb_ens_alpha;
312309
generate_constraint_doubles_contributions_ss(
313-
coeff, det|beta, alpha_con, occ_alpha, occ_beta,
310+
coeff, det|beta, con, occ_alpha, occ_beta,
314311
orb_ens_alpha.data(), G_pqrs, norb, h_el_tol, h_diag, E_ASCI,
315312
ham_gen, asci_pairs);
316313
}
@@ -325,14 +322,14 @@ asci_contrib_container<wfn_t<N>> asci_contributions_constraint(
325322
const auto& orb_ens_alpha = bcd.orb_ens_alpha;
326323
const auto& orb_ens_beta = bcd.orb_ens_beta;
327324
generate_constraint_doubles_contributions_os(
328-
coeff, det|beta, alpha_con, occ_alpha, occ_beta, vir_beta,
325+
coeff, det|beta, con, occ_alpha, occ_beta, vir_beta,
329326
orb_ens_alpha.data(), orb_ens_beta.data(), V_pqrs, norb, h_el_tol,
330327
h_diag, E_ASCI, ham_gen, asci_pairs);
331328
}
332329

333330
// If the alpha determinant satisfies the constraint,
334331
// append BB and BBBB excitations
335-
if(satisfies_constraint(det, con)) {
332+
if(satisfies_constraint(wfn_traits::alpha_string(det), con)) {
336333
for(const auto& bcd : uad[i_alpha].bcd) {
337334
const auto& beta = bcd.beta_string;
338335
const auto& coeff = bcd.coeff;
@@ -368,7 +365,7 @@ asci_contrib_container<wfn_t<N>> asci_contributions_constraint(
368365
});
369366
asci_pairs.erase(it, asci_pairs.end());
370367

371-
auto c_indices = bits_to_indices(C);
368+
auto c_indices = bits_to_indices(con.C());
372369
std::string c_string;
373370
for(int i = 0; i < c_indices.size(); ++i)
374371
c_string += std::to_string(c_indices[i]) + " ";

include/macis/asci/mask_constraints.hpp

Lines changed: 51 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -12,33 +12,29 @@
1212
#include <macis/util/mpi.hpp>
1313
#include <variant>
1414

15+
#include <macis/asci/alpha_constraint.hpp>
16+
1517
namespace macis {
1618

17-
template <size_t N>
18-
struct wfn_constraint {
19-
wfn_t<N> C;
20-
wfn_t<N> B;
21-
unsigned C_min;
22-
};
2319

24-
template <size_t N>
25-
bool satisfies_constraint(wfn_t<N> det, wfn_constraint<N> constraint) {
26-
auto C = constraint.C; auto C_min = constraint.C_min;
27-
return (det & C).count() == C.count() and ((det ^ C) >> C_min).count() == 0;
20+
template <size_t N, typename ConType>
21+
bool satisfies_constraint(wfn_t<N> det, ConType C) {
22+
return C.satisfies_constraint(det);
2823
}
2924

30-
template <size_t N>
31-
auto generate_constraint_single_excitations(wfn_t<N> det, wfn_constraint<N> constraint) {
32-
const auto C = constraint.C; const auto B = constraint.B;
33-
if((det & C).count() <
34-
(C.count() - 1)) // need to have at most one different from the constraint
25+
template <size_t N, typename ConType>
26+
auto generate_constraint_single_excitations(wfn_t<N> det, ConType constraint) {
27+
using spin_wfn_traits = typename ConType::spin_wfn_traits;
28+
const auto C = constraint.C(); const auto B = constraint.B();
29+
30+
// need to have at most one different from the constraint
31+
if(constraint.overlap(det) < (constraint.count()-1))
3532
return std::make_pair(wfn_t<N>(0), wfn_t<N>(0));
3633

3734
auto o = det ^ C;
3835
auto v = (~det) & B;
3936

40-
if((o & C).count() ==
41-
1) { // don't have to change this necessarily, but more clear without >=
37+
if((o & C).count() == 1) {
4238
v = o & C;
4339
o ^= v;
4440
}
@@ -50,9 +46,9 @@ auto generate_constraint_single_excitations(wfn_t<N> det, wfn_constraint<N> cons
5046
return std::make_pair(o, v);
5147
}
5248

53-
template <size_t N>
54-
auto generate_constraint_double_excitations(wfn_t<N> det, wfn_constraint<N> constraint) {
55-
const auto C = constraint.C; const auto B = constraint.B;
49+
template <size_t N, typename ConType>
50+
auto generate_constraint_double_excitations(wfn_t<N> det, ConType constraint) {
51+
const auto C = constraint.C(); const auto B = constraint.B();
5652
// Occ/Vir pairs to generate excitations
5753
std::vector<wfn_t<N>> O, V;
5854

@@ -102,8 +98,8 @@ auto generate_constraint_double_excitations(wfn_t<N> det, wfn_constraint<N> cons
10298
return std::make_tuple(O, V);
10399
}
104100

105-
template <size_t N>
106-
void generate_constraint_singles(wfn_t<N> det, wfn_constraint<N> constraint,
101+
template <size_t N, typename ConType>
102+
void generate_constraint_singles(wfn_t<N> det, ConType constraint,
107103
std::vector<wfn_t<N>>& t_singles) {
108104
auto [o, v] = generate_constraint_single_excitations(det, constraint);
109105
const auto oc = o.count();
@@ -128,8 +124,8 @@ unsigned count_constraint_singles(Args&&... args) {
128124
return o.count() * v.count();
129125
}
130126

131-
template <size_t N>
132-
void generate_constraint_doubles(wfn_t<N> det, wfn_constraint<N> constraint,
127+
template <size_t N, typename ConType >
128+
void generate_constraint_doubles(wfn_t<N> det, ConType constraint,
133129
std::vector<wfn_t<N>>& t_doubles) {
134130
auto [O, V] = generate_constraint_double_excitations(det, constraint);
135131

@@ -147,9 +143,9 @@ void generate_constraint_doubles(wfn_t<N> det, wfn_constraint<N> constraint,
147143
* @param[in] T Triplet constraint mask
148144
* @param[in] B B mask (?)
149145
*/
150-
template <size_t N>
151-
unsigned count_constraint_doubles(wfn_t<N> det, wfn_constraint<N> constraint) {
152-
const auto C = constraint.C; const auto B = constraint.B;
146+
template <size_t N, typename ConType>
147+
unsigned count_constraint_doubles(wfn_t<N> det, ConType constraint) {
148+
const auto C = constraint.C(); const auto B = constraint.B();
153149
if((det & C) == 0) return 0;
154150

155151
auto o = det ^ C;
@@ -193,9 +189,9 @@ unsigned count_constraint_doubles(wfn_t<N> det, wfn_constraint<N> constraint) {
193189
return no_pairs * nv_pairs;
194190
}
195191

196-
template <size_t N, typename... Args>
192+
template <size_t N, typename ConType>
197193
size_t constraint_histogram(wfn_t<N> det, size_t n_os_singles,
198-
size_t n_os_doubles, wfn_constraint<N> constraint){
194+
size_t n_os_doubles, ConType constraint){
199195
auto ns = count_constraint_singles(det, constraint);
200196
auto nd = count_constraint_doubles(det, constraint);
201197

@@ -461,16 +457,18 @@ auto dist_triplets_histogram(size_t norb, size_t ns_othr, size_t nd_othr,
461457

462458
template <size_t N>
463459
auto make_triplet(unsigned i, unsigned j, unsigned k) {
464-
wfn_constraint<N> con;
465-
466-
con.C = 0;
467-
con.C.flip(i).flip(j).flip(k);
468-
con.B = 1;
469-
con.B <<= k;
470-
con.B = con.B.to_ullong() - 1;
471-
con.C_min = k;
472-
473-
return con;
460+
using wfn_type = wfn_t<N>;
461+
using wfn_traits = wavefunction_traits<wfn_type>;
462+
using constraint_type = alpha_constraint<wfn_traits>;
463+
using string_type = typename constraint_type::constraint_type;
464+
465+
string_type C = 0;
466+
C.flip(i).flip(j).flip(k);
467+
string_type B = 1;
468+
B <<= k;
469+
B = B.to_ullong() - 1;
470+
471+
return constraint_type(C,B,k);
474472
}
475473

476474
#ifdef MACIS_ENABLE_MPI
@@ -479,6 +477,10 @@ auto dist_constraint_general(size_t nlevels, size_t norb, size_t ns_othr,
479477
size_t nd_othr,
480478
const std::vector<wfn_t<N>>& unique_alpha,
481479
MPI_Comm comm) {
480+
using wfn_type = wfn_t<N>;
481+
using wfn_traits = wavefunction_traits<wfn_type>;
482+
using constraint_type = alpha_constraint<wfn_traits>;
483+
using string_type = typename constraint_type::constraint_type;
482484
auto world_rank = comm_rank(comm);
483485
auto world_size = comm_size(comm);
484486

@@ -488,18 +490,17 @@ auto dist_constraint_general(size_t nlevels, size_t norb, size_t ns_othr,
488490
std::vector<size_t> workloads(world_size, 0);
489491

490492
// Generate triplets + heuristic
491-
std::vector<std::pair<wfn_constraint<N>, size_t>> constraint_sizes;
493+
std::vector<std::pair<constraint_type, size_t>> constraint_sizes;
492494
constraint_sizes.reserve(norb * norb * norb);
493495
size_t total_work = 0;
494496
for(int t_i = 0; t_i < norb; ++t_i)
495497
for(int t_j = 0; t_j < t_i; ++t_j)
496498
for(int t_k = 0; t_k < t_j; ++t_k) {
497499
auto constraint = make_triplet<N>(t_i, t_j, t_k);
498-
const auto& [T, B, _] = constraint;
499500

500501
size_t nw = 0;
501502
for(const auto& alpha : unique_alpha) {
502-
nw += constraint_histogram(alpha, ns_othr, nd_othr, constraint);
503+
nw += constraint_histogram(wfn_traits::alpha_string(alpha), ns_othr, nd_othr, constraint);
503504
}
504505
if(nw) constraint_sizes.emplace_back(constraint, nw);
505506
total_work += nw;
@@ -509,7 +510,7 @@ auto dist_constraint_general(size_t nlevels, size_t norb, size_t ns_othr,
509510

510511
for(size_t ilevel = 0; ilevel < nlevels; ++ilevel) {
511512
// Select constraints larger than average to be broken apart
512-
std::vector<std::pair<wfn_constraint<N>, size_t>> tps_to_next;
513+
std::vector<std::pair<constraint_type, size_t>> tps_to_next;
513514
{
514515
auto it = std::partition(
515516
constraint_sizes.begin(), constraint_sizes.end(),
@@ -525,20 +526,20 @@ auto dist_constraint_general(size_t nlevels, size_t norb, size_t ns_othr,
525526

526527
// Break apart constraints
527528
for(auto [c, nw_trip] : tps_to_next) {
528-
const auto C_min = c.C_min;
529+
const auto C_min = c.C_min();
529530

530531
// Loop over possible constraints with one more element
531532
for(auto q_l = 0; q_l < C_min; ++q_l) {
532533
// Generate masks / counts
533-
wfn_constraint<N> c_next = c;
534-
c_next.C.flip(q_l);
535-
c_next.B >>= (C_min - q_l);
536-
c_next.C_min = q_l;
534+
string_type cn_C = c.C();
535+
cn_C.flip(q_l);
536+
string_type cn_B = c.B() >> (C_min - q_l);
537+
constraint_type c_next(cn_C, cn_B, q_l);
537538

538539
size_t nw = 0;
539540

540541
for(const auto& alpha : unique_alpha) {
541-
nw += constraint_histogram(alpha, ns_othr, nd_othr, c_next);
542+
nw += constraint_histogram(wfn_traits::alpha_string(alpha), ns_othr, nd_othr, c_next);
542543
}
543544
if(nw) constraint_sizes.emplace_back(c_next, nw);
544545
total_work += nw;
@@ -569,7 +570,7 @@ auto dist_constraint_general(size_t nlevels, size_t norb, size_t ns_othr,
569570
[](const auto& a, const auto& b) { return a.second > b.second; });
570571

571572
// Assign work
572-
std::vector<wfn_constraint<N>> constraints;
573+
std::vector<constraint_type> constraints;
573574
constraints.reserve(constraint_sizes.size() / world_size);
574575

575576
for(auto [c, nw] : constraint_sizes) {

include/macis/asci/pt2.hpp

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -130,8 +130,6 @@ double asci_pt2_constraint(
130130
const auto& con = constraints[ic];
131131
//std::cout << std::distance(&constraints[0], &con) << "/" << constraints.size() << std::endl;
132132
const double h_el_tol = 1e-16;
133-
const auto& [C, B, C_min] = con;
134-
wfn_constraint<N/2> alpha_con{ wfn_traits::alpha_string(C), wfn_traits::alpha_string(B), C_min};
135133

136134
// Loop over unique alpha strings
137135
for(size_t i_alpha = 0; i_alpha < nuniq_alpha; ++i_alpha) {
@@ -146,7 +144,7 @@ double asci_pt2_constraint(
146144
const auto& occ_beta = bcd.occ_beta;
147145
const auto& orb_ens_alpha = bcd.orb_ens_alpha;
148146
generate_constraint_singles_contributions_ss(
149-
coeff, det|beta, alpha_con, occ_alpha, occ_beta,
147+
coeff, det|beta, con, occ_alpha, occ_beta,
150148
orb_ens_alpha.data(), T_pq, norb, G_red, norb, V_red, norb,
151149
h_el_tol, h_diag, E_ASCI, ham_gen, asci_pairs);
152150
}
@@ -159,7 +157,7 @@ double asci_pt2_constraint(
159157
const auto& occ_beta = bcd.occ_beta;
160158
const auto& orb_ens_alpha = bcd.orb_ens_alpha;
161159
generate_constraint_doubles_contributions_ss(
162-
coeff, det|beta, alpha_con, occ_alpha, occ_beta,
160+
coeff, det|beta, con, occ_alpha, occ_beta,
163161
orb_ens_alpha.data(), G_pqrs, norb, h_el_tol, h_diag, E_ASCI,
164162
ham_gen, asci_pairs);
165163
}
@@ -174,14 +172,14 @@ double asci_pt2_constraint(
174172
const auto& orb_ens_alpha = bcd.orb_ens_alpha;
175173
const auto& orb_ens_beta = bcd.orb_ens_beta;
176174
generate_constraint_doubles_contributions_os(
177-
coeff, det|beta, alpha_con, occ_alpha, occ_beta, vir_beta,
175+
coeff, det|beta, con, occ_alpha, occ_beta, vir_beta,
178176
orb_ens_alpha.data(), orb_ens_beta.data(), V_pqrs, norb, h_el_tol,
179177
h_diag, E_ASCI, ham_gen, asci_pairs);
180178
}
181179

182180
// If the alpha determinant satisfies the constraint,
183181
// append BB and BBBB excitations
184-
if(satisfies_constraint(det, con)) {
182+
if(satisfies_constraint(wfn_traits::alpha_string(det), con)) {
185183
for(const auto& bcd : uad[i_alpha].bcd) {
186184
const auto& beta = bcd.beta_string;
187185
const auto& coeff = bcd.coeff;

0 commit comments

Comments
 (0)