77 */
88
99#pragma once
10+ #include < omp.h>
1011#include < spdlog/sinks/null_sink.h>
1112#include < spdlog/sinks/stdout_color_sinks.h>
1213#include < spdlog/spdlog.h>
1314
14- #include < omp.h>
1515#include < chrono>
1616#include < fstream>
1717#include < macis/asci/determinant_contributions.hpp>
@@ -220,7 +220,6 @@ asci_contrib_container<wfn_t<N>> asci_contributions_constraint(
220220 }
221221 }
222222
223-
224223 // const auto n_occ_alpha = wfn_traits::count(uniq_alpha_wfn[0]);
225224 const auto n_occ_alpha = spin_wfn_traits::count (uniq_alpha[0 ].first );
226225 const auto n_vir_alpha = norb - n_occ_alpha;
@@ -260,14 +259,14 @@ asci_contrib_container<wfn_t<N>> asci_contributions_constraint(
260259
261260 auto gen_c_st = clock_type::now ();
262261 auto constraints = gen_constraints_general<wfn_t <N>>(
263- asci_settings.constraint_level , norb, n_sing_beta,
264- n_doub_beta, uniq_alpha, world_size);
262+ asci_settings.constraint_level , norb, n_sing_beta, n_doub_beta,
263+ uniq_alpha, world_size);
265264 auto gen_c_en = clock_type::now ();
266265 duration_type gen_c_dur = gen_c_en - gen_c_st;
267266 logger->info (" * GEN_DUR = {:.2e} ms" , gen_c_dur.count ());
268267
269268 size_t max_size =
270- std::min (std::min (ntdets,asci_settings.pair_size_max ),
269+ std::min (std::min (ntdets, asci_settings.pair_size_max ),
271270 ncdets * (n_sing_alpha + n_sing_beta + // AA + BB
272271 n_doub_alpha + n_doub_beta + // AAAA + BBBB
273272 n_sing_alpha * n_sing_beta // AABB
@@ -281,122 +280,121 @@ asci_contrib_container<wfn_t<N>> asci_contributions_constraint(
281280 asci_contrib_container<wfn_t <N>> asci_pairs_total;
282281#pragma omp parallel
283282 {
284- // Process ASCI pair contributions for each constraint
285- asci_contrib_container<wfn_t <N>> asci_pairs;
286- asci_pairs.reserve (max_size);
287-
288- size_t ic = 0 ;
289- while (ic < ncon_total) {
290- auto size_before = asci_pairs.size ();
291- const double h_el_tol = asci_settings.h_el_tol ;
292-
293- // Atomically get the next task ID and increment for other
294- // MPI ranks and threads
295- size_t ntake = ic < 1000 ? 1 : 10 ;
296- ic = nxtval.fetch_and_add (ntake);
297-
298- // Loop over assigned tasks
299- const size_t c_end = std::min (ncon_total, ic + ntake);
300- for (; ic < c_end; ++ic) {
301- const auto & con = constraints[ic].first ;
302- printf (" [rank %4d tid:%4d] %10lu / %10lu\n " , world_rank,
303- omp_get_thread_num (), ic, ncon_total);
304-
305- for (size_t i_alpha = 0 , iw = 0 ; i_alpha < nuniq_alpha; ++i_alpha) {
306- const auto & alpha_det = uniq_alpha[i_alpha].first ;
307- const auto occ_alpha = bits_to_indices (alpha_det);
308- const bool alpha_satisfies_con = satisfies_constraint (alpha_det, con);
309-
310- const auto & bcd = uad[i_alpha];
311- const size_t nbeta = bcd.size ();
312- for (size_t j_beta = 0 ; j_beta < nbeta; ++j_beta, ++iw) {
313- const auto w = *(cdets_begin + iw);
314- const auto c = C[iw];
315- const auto & beta_det = bcd[j_beta].beta_string ;
316- const auto h_diag = bcd[j_beta].h_diag ;
317- const auto & occ_beta = bcd[j_beta].occ_beta ;
318- const auto & vir_beta = bcd[j_beta].vir_beta ;
319- const auto & orb_ens_alpha = bcd[j_beta].orb_ens_alpha ;
320- const auto & orb_ens_beta = bcd[j_beta].orb_ens_beta ;
321-
322- // AA excitations
323- generate_constraint_singles_contributions_ss (
324- c, w, con, occ_alpha, occ_beta, orb_ens_alpha.data (), T_pq, norb,
325- G_red, norb, V_red, norb, h_el_tol, h_diag, E_ASCI, ham_gen,
326- asci_pairs);
327-
328- // AAAA excitations
329- generate_constraint_doubles_contributions_ss (
330- c, w, con, occ_alpha, occ_beta, orb_ens_alpha.data (), G_pqrs, norb,
331- h_el_tol, h_diag, E_ASCI, ham_gen, asci_pairs);
332-
333- // AABB excitations
334- generate_constraint_doubles_contributions_os (
335- c, w, con, occ_alpha, occ_beta, vir_beta, orb_ens_alpha.data (),
336- orb_ens_beta.data (), V_pqrs, norb, h_el_tol, h_diag, E_ASCI,
337- ham_gen, asci_pairs);
338-
339- if (alpha_satisfies_con) {
340- // BB excitations
341- append_singles_asci_contributions<Spin::Beta>(
342- c, w, beta_det, occ_beta, vir_beta, occ_alpha,
343- orb_ens_beta.data (), T_pq, norb, G_red, norb, V_red, norb,
344- h_el_tol, h_diag, E_ASCI, ham_gen, asci_pairs);
345-
346- // BBBB excitations
347- append_ss_doubles_asci_contributions<Spin::Beta>(
348- c, w, beta_det, alpha_det, occ_beta, vir_beta, occ_alpha,
349- orb_ens_beta.data (), G_pqrs, norb, h_el_tol, h_diag, E_ASCI,
350- ham_gen, asci_pairs);
351-
352- // No excitation (push inf to remove from list)
353- asci_pairs.push_back (
354- {w, std::numeric_limits<double >::infinity (), 1.0 });
283+ // Process ASCI pair contributions for each constraint
284+ asci_contrib_container<wfn_t <N>> asci_pairs;
285+ asci_pairs.reserve (max_size);
286+
287+ size_t ic = 0 ;
288+ while (ic < ncon_total) {
289+ auto size_before = asci_pairs.size ();
290+ const double h_el_tol = asci_settings.h_el_tol ;
291+
292+ // Atomically get the next task ID and increment for other
293+ // MPI ranks and threads
294+ size_t ntake = ic < 1000 ? 1 : 10 ;
295+ ic = nxtval.fetch_and_add (ntake);
296+
297+ // Loop over assigned tasks
298+ const size_t c_end = std::min (ncon_total, ic + ntake);
299+ for (; ic < c_end; ++ic) {
300+ const auto & con = constraints[ic].first ;
301+ printf (" [rank %4d tid:%4d] %10lu / %10lu\n " , world_rank,
302+ omp_get_thread_num (), ic, ncon_total);
303+
304+ for (size_t i_alpha = 0 , iw = 0 ; i_alpha < nuniq_alpha; ++i_alpha) {
305+ const auto & alpha_det = uniq_alpha[i_alpha].first ;
306+ const auto occ_alpha = bits_to_indices (alpha_det);
307+ const bool alpha_satisfies_con = satisfies_constraint (alpha_det, con);
308+
309+ const auto & bcd = uad[i_alpha];
310+ const size_t nbeta = bcd.size ();
311+ for (size_t j_beta = 0 ; j_beta < nbeta; ++j_beta, ++iw) {
312+ const auto w = *(cdets_begin + iw);
313+ const auto c = C[iw];
314+ const auto & beta_det = bcd[j_beta].beta_string ;
315+ const auto h_diag = bcd[j_beta].h_diag ;
316+ const auto & occ_beta = bcd[j_beta].occ_beta ;
317+ const auto & vir_beta = bcd[j_beta].vir_beta ;
318+ const auto & orb_ens_alpha = bcd[j_beta].orb_ens_alpha ;
319+ const auto & orb_ens_beta = bcd[j_beta].orb_ens_beta ;
320+
321+ // AA excitations
322+ generate_constraint_singles_contributions_ss (
323+ c, w, con, occ_alpha, occ_beta, orb_ens_alpha.data (), T_pq,
324+ norb, G_red, norb, V_red, norb, h_el_tol, h_diag, E_ASCI,
325+ ham_gen, asci_pairs);
326+
327+ // AAAA excitations
328+ generate_constraint_doubles_contributions_ss (
329+ c, w, con, occ_alpha, occ_beta, orb_ens_alpha.data (), G_pqrs,
330+ norb, h_el_tol, h_diag, E_ASCI, ham_gen, asci_pairs);
331+
332+ // AABB excitations
333+ generate_constraint_doubles_contributions_os (
334+ c, w, con, occ_alpha, occ_beta, vir_beta, orb_ens_alpha.data (),
335+ orb_ens_beta.data (), V_pqrs, norb, h_el_tol, h_diag, E_ASCI,
336+ ham_gen, asci_pairs);
337+
338+ if (alpha_satisfies_con) {
339+ // BB excitations
340+ append_singles_asci_contributions<Spin::Beta>(
341+ c, w, beta_det, occ_beta, vir_beta, occ_alpha,
342+ orb_ens_beta.data (), T_pq, norb, G_red, norb, V_red, norb,
343+ h_el_tol, h_diag, E_ASCI, ham_gen, asci_pairs);
344+
345+ // BBBB excitations
346+ append_ss_doubles_asci_contributions<Spin::Beta>(
347+ c, w, beta_det, alpha_det, occ_beta, vir_beta, occ_alpha,
348+ orb_ens_beta.data (), G_pqrs, norb, h_el_tol, h_diag, E_ASCI,
349+ ham_gen, asci_pairs);
350+
351+ // No excitation (push inf to remove from list)
352+ asci_pairs.push_back (
353+ {w, std::numeric_limits<double >::infinity (), 1.0 });
354+ }
355+ }
356+
357+ // Prune Down Contributions
358+ if (asci_pairs.size () > asci_settings.pair_size_max ) {
359+ throw std::runtime_error (" DIE DIE DIE" );
360+ }
361+
362+ } // Unique Alpha Loop
363+
364+ // Local S&A for each quad
365+ {
366+ if (size_before > asci_pairs.size ())
367+ throw std::runtime_error (" DIE DIE DIE" );
368+ auto uit = sort_and_accumulate_asci_pairs (
369+ asci_pairs.begin () + size_before, asci_pairs.end ());
370+ asci_pairs.erase (uit, asci_pairs.end ());
371+
372+ // Remove small contributions
373+ uit = std::partition (asci_pairs.begin () + size_before,
374+ asci_pairs.end (), [=](const auto & x) {
375+ return std::abs (x.rv ()) >
376+ asci_settings.rv_prune_tol ;
377+ });
378+ asci_pairs.erase (uit, asci_pairs.end ());
355379 }
356- }
357-
358- // Prune Down Contributions
359- if (asci_pairs.size () > asci_settings.pair_size_max ) {
360- throw std::runtime_error (" DIE DIE DIE" );
361- }
380+ } // Loc constraint loop
381+ } // Constraint Loop
362382
363- } // Unique Alpha Loop
364-
365-
366- // Local S&A for each quad
383+ // Insert into list
384+ #pragma omp critical
367385 {
368- if (size_before > asci_pairs.size ())
369- throw std::runtime_error (" DIE DIE DIE" );
370- auto uit = sort_and_accumulate_asci_pairs (
371- asci_pairs.begin () + size_before, asci_pairs.end ());
372- asci_pairs.erase (uit, asci_pairs.end ());
373-
374- // Remove small contributions
375- uit =
376- std::partition (asci_pairs.begin () + size_before, asci_pairs.end (),
377- [=](const auto & x) {
378- return std::abs (x.rv ()) > asci_settings.rv_prune_tol ;
379- });
380- asci_pairs.erase (uit, asci_pairs.end ());
386+ if (asci_pairs_total.size ()) {
387+ // Preallocate space for insertion
388+ asci_pairs_total.reserve (asci_pairs.size () + asci_pairs_total.size ());
389+ asci_pairs_total.insert (asci_pairs_total.end (), asci_pairs.begin (),
390+ asci_pairs.end ());
391+ } else {
392+ asci_pairs_total = std::move (asci_pairs);
393+ }
394+ asci_contrib_container<wfn_t <N>>().swap (asci_pairs);
381395 }
382- } // Loc constraint loop
383- } // Constraint Loop
384-
385- // Insert into list
386- #pragma omp critical
387- {
388- if (asci_pairs_total.size ()) {
389- // Preallocate space for insertion
390- asci_pairs_total.reserve (asci_pairs.size () + asci_pairs_total.size ());
391- asci_pairs_total.insert (asci_pairs_total.end (), asci_pairs.begin (),
392- asci_pairs.end ());
393- } else {
394- asci_pairs_total = std::move (asci_pairs);
395- }
396- asci_contrib_container<wfn_t <N>>().swap (asci_pairs);
397- }
398396
399- } // OpenMP
397+ } // OpenMP
400398
401399 return asci_pairs_total;
402400}
@@ -462,8 +460,8 @@ std::vector<wfn_t<N>> asci_search(
462460 // #ifdef MACIS_ENABLE_MPI
463461 // else
464462 asci_pairs = asci_contributions_constraint (
465- asci_settings, ndets_max, cdets_begin, cdets_end, E_ASCI, C, norb, T_pq, G_red,
466- V_red, G_pqrs, V_pqrs, ham_gen MACIS_MPI_CODE (, comm));
463+ asci_settings, ndets_max, cdets_begin, cdets_end, E_ASCI, C, norb, T_pq,
464+ G_red, V_red, G_pqrs, V_pqrs, ham_gen MACIS_MPI_CODE (, comm));
467465 // #endif
468466 auto pairs_en = clock_type::now ();
469467
0 commit comments