1212#include < macis/util/mpi.hpp>
1313#include < variant>
1414
15+ #include < macis/asci/alpha_constraint.hpp>
16+
1517namespace macis {
1618
17- template <size_t N>
18- struct wfn_constraint {
19- wfn_t <N> C;
20- wfn_t <N> B;
21- unsigned C_min;
22- };
2319
24- template <size_t N>
25- bool satisfies_constraint (wfn_t <N> det, wfn_constraint<N> constraint) {
26- auto C = constraint.C ; auto C_min = constraint.C_min ;
27- return (det & C).count () == C.count () and ((det ^ C) >> C_min).count () == 0 ;
20+ template <size_t N, typename ConType>
21+ bool satisfies_constraint (wfn_t <N> det, ConType C) {
22+ return C.satisfies_constraint (det);
2823}
2924
30- template <size_t N>
31- auto generate_constraint_single_excitations (wfn_t <N> det, wfn_constraint<N> constraint) {
32- const auto C = constraint.C ; const auto B = constraint.B ;
33- if ((det & C).count () <
34- (C.count () - 1 )) // need to have at most one different from the constraint
25+ template <size_t N, typename ConType>
26+ auto generate_constraint_single_excitations (wfn_t <N> det, ConType constraint) {
27+ using spin_wfn_traits = typename ConType::spin_wfn_traits;
28+ const auto C = constraint.C (); const auto B = constraint.B ();
29+
30+ // need to have at most one different from the constraint
31+ if (constraint.overlap (det) < (constraint.count ()-1 ))
3532 return std::make_pair (wfn_t <N>(0 ), wfn_t <N>(0 ));
3633
3734 auto o = det ^ C;
3835 auto v = (~det) & B;
3936
40- if ((o & C).count () ==
41- 1 ) { // don't have to change this necessarily, but more clear without >=
37+ if ((o & C).count () == 1 ) {
4238 v = o & C;
4339 o ^= v;
4440 }
@@ -50,9 +46,9 @@ auto generate_constraint_single_excitations(wfn_t<N> det, wfn_constraint<N> cons
5046 return std::make_pair (o, v);
5147}
5248
53- template <size_t N>
54- auto generate_constraint_double_excitations (wfn_t <N> det, wfn_constraint<N> constraint) {
55- const auto C = constraint.C ; const auto B = constraint.B ;
49+ template <size_t N, typename ConType >
50+ auto generate_constraint_double_excitations (wfn_t <N> det, ConType constraint) {
51+ const auto C = constraint.C () ; const auto B = constraint.B () ;
5652 // Occ/Vir pairs to generate excitations
5753 std::vector<wfn_t <N>> O, V;
5854
@@ -102,8 +98,8 @@ auto generate_constraint_double_excitations(wfn_t<N> det, wfn_constraint<N> cons
10298 return std::make_tuple (O, V);
10399}
104100
105- template <size_t N>
106- void generate_constraint_singles (wfn_t <N> det, wfn_constraint<N> constraint,
101+ template <size_t N, typename ConType >
102+ void generate_constraint_singles (wfn_t <N> det, ConType constraint,
107103 std::vector<wfn_t <N>>& t_singles) {
108104 auto [o, v] = generate_constraint_single_excitations (det, constraint);
109105 const auto oc = o.count ();
@@ -128,8 +124,8 @@ unsigned count_constraint_singles(Args&&... args) {
128124 return o.count () * v.count ();
129125}
130126
131- template <size_t N>
132- void generate_constraint_doubles (wfn_t <N> det, wfn_constraint<N> constraint,
127+ template <size_t N, typename ConType >
128+ void generate_constraint_doubles (wfn_t <N> det, ConType constraint,
133129 std::vector<wfn_t <N>>& t_doubles) {
134130 auto [O, V] = generate_constraint_double_excitations (det, constraint);
135131
@@ -147,9 +143,9 @@ void generate_constraint_doubles(wfn_t<N> det, wfn_constraint<N> constraint,
147143 * @param[in] T Triplet constraint mask
148144 * @param[in] B B mask (?)
149145 */
150- template <size_t N>
151- unsigned count_constraint_doubles (wfn_t <N> det, wfn_constraint<N> constraint) {
152- const auto C = constraint.C ; const auto B = constraint.B ;
146+ template <size_t N, typename ConType >
147+ unsigned count_constraint_doubles (wfn_t <N> det, ConType constraint) {
148+ const auto C = constraint.C () ; const auto B = constraint.B () ;
153149 if ((det & C) == 0 ) return 0 ;
154150
155151 auto o = det ^ C;
@@ -193,9 +189,9 @@ unsigned count_constraint_doubles(wfn_t<N> det, wfn_constraint<N> constraint) {
193189 return no_pairs * nv_pairs;
194190}
195191
196- template <size_t N, typename ... Args >
192+ template <size_t N, typename ConType >
197193size_t constraint_histogram (wfn_t <N> det, size_t n_os_singles,
198- size_t n_os_doubles, wfn_constraint<N> constraint){
194+ size_t n_os_doubles, ConType constraint){
199195 auto ns = count_constraint_singles (det, constraint);
200196 auto nd = count_constraint_doubles (det, constraint);
201197
@@ -461,16 +457,18 @@ auto dist_triplets_histogram(size_t norb, size_t ns_othr, size_t nd_othr,
461457
462458template <size_t N>
463459auto make_triplet (unsigned i, unsigned j, unsigned k) {
464- wfn_constraint<N> con;
465-
466- con.C = 0 ;
467- con.C .flip (i).flip (j).flip (k);
468- con.B = 1 ;
469- con.B <<= k;
470- con.B = con.B .to_ullong () - 1 ;
471- con.C_min = k;
472-
473- return con;
460+ using wfn_type = wfn_t <N>;
461+ using wfn_traits = wavefunction_traits<wfn_type>;
462+ using constraint_type = alpha_constraint<wfn_traits>;
463+ using string_type = typename constraint_type::constraint_type;
464+
465+ string_type C = 0 ;
466+ C.flip (i).flip (j).flip (k);
467+ string_type B = 1 ;
468+ B <<= k;
469+ B = B.to_ullong () - 1 ;
470+
471+ return constraint_type (C,B,k);
474472}
475473
476474#ifdef MACIS_ENABLE_MPI
@@ -479,6 +477,10 @@ auto dist_constraint_general(size_t nlevels, size_t norb, size_t ns_othr,
479477 size_t nd_othr,
480478 const std::vector<wfn_t <N>>& unique_alpha,
481479 MPI_Comm comm) {
480+ using wfn_type = wfn_t <N>;
481+ using wfn_traits = wavefunction_traits<wfn_type>;
482+ using constraint_type = alpha_constraint<wfn_traits>;
483+ using string_type = typename constraint_type::constraint_type;
482484 auto world_rank = comm_rank (comm);
483485 auto world_size = comm_size (comm);
484486
@@ -488,18 +490,17 @@ auto dist_constraint_general(size_t nlevels, size_t norb, size_t ns_othr,
488490 std::vector<size_t > workloads (world_size, 0 );
489491
490492 // Generate triplets + heuristic
491- std::vector<std::pair<wfn_constraint<N> , size_t >> constraint_sizes;
493+ std::vector<std::pair<constraint_type , size_t >> constraint_sizes;
492494 constraint_sizes.reserve (norb * norb * norb);
493495 size_t total_work = 0 ;
494496 for (int t_i = 0 ; t_i < norb; ++t_i)
495497 for (int t_j = 0 ; t_j < t_i; ++t_j)
496498 for (int t_k = 0 ; t_k < t_j; ++t_k) {
497499 auto constraint = make_triplet<N>(t_i, t_j, t_k);
498- const auto & [T, B, _] = constraint;
499500
500501 size_t nw = 0 ;
501502 for (const auto & alpha : unique_alpha) {
502- nw += constraint_histogram (alpha, ns_othr, nd_othr, constraint);
503+ nw += constraint_histogram (wfn_traits::alpha_string ( alpha) , ns_othr, nd_othr, constraint);
503504 }
504505 if (nw) constraint_sizes.emplace_back (constraint, nw);
505506 total_work += nw;
@@ -509,7 +510,7 @@ auto dist_constraint_general(size_t nlevels, size_t norb, size_t ns_othr,
509510
510511 for (size_t ilevel = 0 ; ilevel < nlevels; ++ilevel) {
511512 // Select constraints larger than average to be broken apart
512- std::vector<std::pair<wfn_constraint<N> , size_t >> tps_to_next;
513+ std::vector<std::pair<constraint_type , size_t >> tps_to_next;
513514 {
514515 auto it = std::partition (
515516 constraint_sizes.begin (), constraint_sizes.end (),
@@ -525,20 +526,20 @@ auto dist_constraint_general(size_t nlevels, size_t norb, size_t ns_othr,
525526
526527 // Break apart constraints
527528 for (auto [c, nw_trip] : tps_to_next) {
528- const auto C_min = c.C_min ;
529+ const auto C_min = c.C_min () ;
529530
530531 // Loop over possible constraints with one more element
531532 for (auto q_l = 0 ; q_l < C_min; ++q_l) {
532533 // Generate masks / counts
533- wfn_constraint<N> c_next = c;
534- c_next. C .flip (q_l);
535- c_next. B >>= (C_min - q_l);
536- c_next. C_min = q_l;
534+ string_type cn_C = c. C () ;
535+ cn_C .flip (q_l);
536+ string_type cn_B = c. B () >> (C_min - q_l);
537+ constraint_type c_next (cn_C, cn_B, q_l) ;
537538
538539 size_t nw = 0 ;
539540
540541 for (const auto & alpha : unique_alpha) {
541- nw += constraint_histogram (alpha, ns_othr, nd_othr, c_next);
542+ nw += constraint_histogram (wfn_traits::alpha_string ( alpha) , ns_othr, nd_othr, c_next);
542543 }
543544 if (nw) constraint_sizes.emplace_back (c_next, nw);
544545 total_work += nw;
@@ -569,7 +570,7 @@ auto dist_constraint_general(size_t nlevels, size_t norb, size_t ns_othr,
569570 [](const auto & a, const auto & b) { return a.second > b.second ; });
570571
571572 // Assign work
572- std::vector<wfn_constraint<N> > constraints;
573+ std::vector<constraint_type > constraints;
573574 constraints.reserve (constraint_sizes.size () / world_size);
574575
575576 for (auto [c, nw] : constraint_sizes) {
0 commit comments