@@ -43,16 +43,17 @@ double asci_pt2_constraint(ASCISettings asci_settings,
4343
4444 const size_t ncdets = std::distance (cdets_begin, cdets_end);
4545 logger->info (" [ASCI PT2 Settings]" );
46- logger->info (" * NDETS = {}" , ncdets);
47- logger->info (" * PT2_TOL = {}" , asci_settings.pt2_tol );
48- logger->info (" * PT2_RESERVE_COUNT = {}" , asci_settings.pt2_reserve_count );
49- logger->info (" * PT2_CONSTRAINT_LVL = {}" , asci_settings.pt2_constraint_level );
50- logger->info (" * PT2_PRUNE = {}" , asci_settings.pt2_prune );
51- logger->info (" * PT2_PRECOMP_EPS = {}" , asci_settings.pt2_precompute_eps );
52- logger->info (" * PT2_BIGCON_THRESH = {}" , asci_settings.pt2_bigcon_thresh );
53- logger->info (" * NXTVAL_BCOUNT_THRESH = {}" ,
46+ logger->info (" * NDETS = {}" , ncdets);
47+ logger->info (" * PT2_TOL = {}" , asci_settings.pt2_tol );
48+ logger->info (" * PT2_RESERVE_COUNT = {}" , asci_settings.pt2_reserve_count );
49+ logger->info (" * PT2_CONSTRAINT_LVL_MAX = {}" , asci_settings.pt2_max_constraint_level );
50+ logger->info (" * PT2_CONSTRAINT_LVL_MIN = {}" , asci_settings.pt2_min_constraint_level );
51+ logger->info (" * PT2_PRUNE = {}" , asci_settings.pt2_prune );
52+ logger->info (" * PT2_PRECOMP_EPS = {}" , asci_settings.pt2_precompute_eps );
53+ logger->info (" * PT2_BIGCON_THRESH = {}" , asci_settings.pt2_bigcon_thresh );
54+ logger->info (" * NXTVAL_BCOUNT_THRESH = {}" ,
5455 asci_settings.nxtval_bcount_thresh );
55- logger->info (" * NXTVAL_BCOUNT_INC = {}" ,
56+ logger->info (" * NXTVAL_BCOUNT_INC = {}" ,
5657 asci_settings.nxtval_bcount_inc );
5758 logger->info (" " );
5859
@@ -74,7 +75,7 @@ double asci_pt2_constraint(ASCISettings asci_settings,
7475
7576 beta_coeff_data (double c, size_t norb,
7677 const std::vector<uint32_t >& occ_alpha, wfn_t <N> w,
77- const HamiltonianGenerator<wfn_t <N>>& ham_gen, bool pce) {
78+ const HamiltonianGenerator<wfn_t <N>>& ham_gen, bool pce, bool pci ) {
7879 coeff = c;
7980
8081 beta_string = wfn_traits::beta_string (w);
@@ -84,11 +85,13 @@ double asci_pt2_constraint(ASCISettings asci_settings,
8485
8586 // Compute occ/vir for beta string
8687 std::vector<uint32_t > o_32, v_32;
87- spin_wfn_traits::state_to_occ_vir (norb, beta_string, o_32, v_32);
88- occ_beta.resize (o_32.size ());
89- std::copy (o_32.begin (), o_32.end (), occ_beta.begin ());
90- vir_beta.resize (v_32.size ());
91- std::copy (v_32.begin (), v_32.end (), vir_beta.begin ());
88+ if (pce or pci) {
89+ spin_wfn_traits::state_to_occ_vir (norb, beta_string, o_32, v_32);
90+ occ_beta.resize (o_32.size ());
91+ std::copy (o_32.begin (), o_32.end (), occ_beta.begin ());
92+ vir_beta.resize (v_32.size ());
93+ std::copy (v_32.begin (), v_32.end (), vir_beta.begin ());
94+ }
9295
9396 // Precompute orbital energies
9497 if (pce) {
@@ -117,23 +120,23 @@ double asci_pt2_constraint(ASCISettings asci_settings,
117120 uad[i].reserve (nbeta);
118121 for (auto j = 0 ; j < nbeta; ++j, ++iw) {
119122 const auto & w = *(cdets_begin + iw);
120- uad[i].emplace_back (C[iw], norb, occ_alpha, w, ham_gen,asci_settings.pt2_precompute_eps );
123+ uad[i].emplace_back (C[iw], norb, occ_alpha, w, ham_gen,asci_settings.pt2_precompute_eps , asci_settings. pt2_precompute_idx );
121124 }
122125 }
123126
124127 if (world_rank == 0 ) {
125128 constexpr double gib = 1024 * 1024 * 1024 ;
126- printf (" MEM REQ DETS = % .2e\n " , ncdets * sizeof (wfn_t <N>) / gib);
127- printf (" MEM REQ C = % .2e\n " , ncdets * sizeof (double ) / gib);
129+ logger-> info (" MEM REQ DETS = {: .2e} " , ncdets * sizeof (wfn_t <N>) / gib);
130+ logger-> info (" MEM REQ C = {: .2e} " , ncdets * sizeof (double ) / gib);
128131 size_t mem_alpha = 0 ;
129132 for ( auto i = 0ul ; i < nuniq_alpha; ++i) {
130133 mem_alpha += sizeof (spin_wfn_type);
131134 for (auto j = 0ul ; j < uad[i].size (); ++j) {
132135 mem_alpha += uad[i][j].mem ();
133136 }
134137 }
135- printf (" MEM REQ ALPH = % .2e\n " , mem_alpha / gib);
136- printf (" MEM REQ CONT = % .2e\n " , asci_settings.pt2_reserve_count * sizeof (asci_contrib<wfn_t <N>>)/ gib);
138+ logger-> info (" MEM REQ ALPH = {: .2e} " , mem_alpha / gib);
139+ logger-> info (" MEM REQ CONT = {: .2e} " , asci_settings.pt2_reserve_count * sizeof (asci_contrib<wfn_t <N>>)/ gib);
137140 }
138141 MPI_Barrier (comm);
139142
@@ -153,8 +156,9 @@ double asci_pt2_constraint(ASCISettings asci_settings,
153156 // auto constraints = dist_constraint_general<wfn_t<N>>(
154157 // 5, norb, n_sing_beta, n_doub_beta, uniq_alpha, comm);
155158 auto constraints = gen_constraints_general<wfn_t <N>>(
156- asci_settings.pt2_constraint_level , norb, n_sing_beta,
157- n_doub_beta, uniq_alpha, world_size * omp_get_max_threads (), 0 );
159+ asci_settings.pt2_max_constraint_level , norb, n_sing_beta,
160+ n_doub_beta, uniq_alpha, world_size * omp_get_max_threads (),
161+ asci_settings.pt2_min_constraint_level );
158162 auto gen_c_en = clock_type::now ();
159163 duration_type gen_c_dur = gen_c_en - gen_c_st;
160164 logger->info (" * GEN_DUR = {:.2e} ms" , gen_c_dur.count ());
@@ -198,7 +202,8 @@ double asci_pt2_constraint(ASCISettings asci_settings,
198202 // MPI ranks
199203 ic = nxtval_big.fetch_and_add (1 );
200204 if (ic >= ncon_big) continue ;
201- printf (" [pt2_big rank %4d] %10lu / %10lu\n " , world_rank, ic, ncon_total);
205+ if (asci_settings.pt2_print_progress )
206+ printf (" [pt2_big rank %4d] %10lu / %10lu\n " , world_rank, ic, ncon_total);
202207 const auto & con = constraints[ic].first ;
203208
204209 asci_contrib_container<wfn_t <N>> asci_pairs_con;
@@ -224,11 +229,16 @@ double asci_pt2_constraint(ASCISettings asci_settings,
224229 const auto h_diag = bcd[j_beta].h_diag ;
225230
226231 // TODO: These copies are slow
232+ #if 0
227233 const auto& occ_beta_8 = bcd[j_beta].occ_beta;
228234 const auto& vir_beta_8 = bcd[j_beta].vir_beta;
229235 std::vector<uint32_t> occ_beta(occ_beta_8.size()), vir_beta(vir_beta_8.size());
230236 std::copy(occ_beta_8.begin(), occ_beta_8.end(), occ_beta.begin());
231237 std::copy(vir_beta_8.begin(), vir_beta_8.end(), vir_beta.begin());
238+ #else
239+ std::vector<uint32_t > occ_beta, vir_beta;
240+ spin_wfn_traits::state_to_occ_vir (norb, beta_det, occ_beta, vir_beta);
241+ #endif
232242
233243 std::vector<double > orb_ens_alpha, orb_ens_beta;
234244 if (asci_settings.pt2_precompute_eps ) {
@@ -309,17 +319,21 @@ double asci_pt2_constraint(ASCISettings asci_settings,
309319
310320 double EPT2_local = 0.0 ;
311321 size_t NPT2_local = 0 ;
322+ size_t pair_size = 0 ;
312323 // Local S&A for each quad + update EPT2
313324 {
314325 auto uit = sort_and_accumulate_asci_pairs (asci_pairs_con.begin (),
315326 asci_pairs_con.end ());
327+ pair_size = std::distance (asci_pairs_con.begin (), uit);
316328 for (auto it = asci_pairs_con.begin (); it != uit; ++it) {
317329 if (!std::isinf (it->c_times_matel )) {
318330 EPT2_local += it->pt2 ();
319331 NPT2_local++;
320332 }
321333 }
322334 asci_pairs_con.clear ();
335+ if (asci_settings.pt2_print_progress )
336+ printf (" [pt2_big rank %4d] CAPACITY %lu SZ %lu\n " , world_rank, asci_pairs_con.capacity (), pair_size);
323337 }
324338
325339 EPT2 += EPT2_local;
@@ -345,8 +359,9 @@ double asci_pt2_constraint(ASCISettings asci_settings,
345359 const size_t c_end = std::min (ncon_total, ic + ntake);
346360 for (; ic < c_end; ++ic) {
347361 const auto & con = constraints[ic].first ;
348- printf (" [pt2_small rank %4d tid:%4d] %10lu / %10lu\n " , world_rank,
349- omp_get_thread_num (), ic, ncon_total);
362+ if (asci_settings.pt2_print_progress )
363+ printf (" [pt2_small rank %4d tid:%4d] %10lu / %10lu\n " , world_rank,
364+ omp_get_thread_num (), ic, ncon_total);
350365
351366 for (size_t i_alpha = 0 ; i_alpha < nuniq_alpha; ++i_alpha) {
352367 const size_t old_pair_size = asci_pairs.size ();
@@ -366,11 +381,16 @@ double asci_pt2_constraint(ASCISettings asci_settings,
366381 const auto h_diag = bcd[j_beta].h_diag ;
367382
368383 // TODO: These copies are slow
384+ #if 0
369385 const auto& occ_beta_8 = bcd[j_beta].occ_beta;
370386 const auto& vir_beta_8 = bcd[j_beta].vir_beta;
371387 std::vector<uint32_t> occ_beta(occ_beta_8.size()), vir_beta(vir_beta_8.size());
372388 std::copy(occ_beta_8.begin(), occ_beta_8.end(), occ_beta.begin());
373389 std::copy(vir_beta_8.begin(), vir_beta_8.end(), vir_beta.begin());
390+ #else
391+ std::vector<uint32_t > occ_beta, vir_beta;
392+ spin_wfn_traits::state_to_occ_vir (norb, beta_det, occ_beta, vir_beta);
393+ #endif
374394
375395 std::vector<double > orb_ens_alpha, orb_ens_beta;
376396 if (asci_settings.pt2_precompute_eps ) {
@@ -423,9 +443,14 @@ double asci_pt2_constraint(ASCISettings asci_settings,
423443 asci_pairs.erase (uit, asci_pairs.end ());
424444 // uit = std::stable_partition(asci_pairs.begin(), asci_pairs.end(), [&](const auto& p){ return std::abs(p.pt2()) > h_el_tol; });
425445 // asci_pairs.erase(uit, asci_pairs.end());
426- printf (" [rank %4d tid:%4d] IC = %lu / %lu IA = %lu / %lu SZ = %lu\n " , world_rank,
427- omp_get_thread_num (), ic, ncon_total, i_alpha,
428- nuniq_alpha, asci_pairs.size ());
446+ if (asci_settings.pt2_print_progress )
447+ printf (" [pt2_prune rank %4d tid:%4d] IC = %lu / %lu IA = %lu / %lu SZ = %lu\n " , world_rank,
448+ omp_get_thread_num (), ic, ncon_total, i_alpha,
449+ nuniq_alpha, asci_pairs.size ());
450+
451+ if (asci_pairs.size () > asci_settings.pt2_reserve_count ) {
452+ printf (" * WARNING: PRUNED SIZE LARGER THAN RESERVE COUNT\n " );
453+ }
429454 }
430455
431456 } // Unique Alpha Loop
0 commit comments