Skip to content

Commit f6a2270

Browse files
Merge branch 'feature/pt2' of github.com:wavefunction91/MACIS into feature/pt2
2 parents 105b973 + 82285a2 commit f6a2270

File tree

2 files changed

+128
-127
lines changed

2 files changed

+128
-127
lines changed

include/macis/asci/determinant_search.hpp

Lines changed: 116 additions & 118 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,11 @@
77
*/
88

99
#pragma once
10+
#include <omp.h>
1011
#include <spdlog/sinks/null_sink.h>
1112
#include <spdlog/sinks/stdout_color_sinks.h>
1213
#include <spdlog/spdlog.h>
1314

14-
#include <omp.h>
1515
#include <chrono>
1616
#include <fstream>
1717
#include <macis/asci/determinant_contributions.hpp>
@@ -220,7 +220,6 @@ asci_contrib_container<wfn_t<N>> asci_contributions_constraint(
220220
}
221221
}
222222

223-
224223
// const auto n_occ_alpha = wfn_traits::count(uniq_alpha_wfn[0]);
225224
const auto n_occ_alpha = spin_wfn_traits::count(uniq_alpha[0].first);
226225
const auto n_vir_alpha = norb - n_occ_alpha;
@@ -260,14 +259,14 @@ asci_contrib_container<wfn_t<N>> asci_contributions_constraint(
260259

261260
auto gen_c_st = clock_type::now();
262261
auto constraints = gen_constraints_general<wfn_t<N>>(
263-
asci_settings.constraint_level, norb, n_sing_beta,
264-
n_doub_beta, uniq_alpha, world_size);
262+
asci_settings.constraint_level, norb, n_sing_beta, n_doub_beta,
263+
uniq_alpha, world_size);
265264
auto gen_c_en = clock_type::now();
266265
duration_type gen_c_dur = gen_c_en - gen_c_st;
267266
logger->info(" * GEN_DUR = {:.2e} ms", gen_c_dur.count());
268267

269268
size_t max_size =
270-
std::min(std::min(ntdets,asci_settings.pair_size_max),
269+
std::min(std::min(ntdets, asci_settings.pair_size_max),
271270
ncdets * (n_sing_alpha + n_sing_beta + // AA + BB
272271
n_doub_alpha + n_doub_beta + // AAAA + BBBB
273272
n_sing_alpha * n_sing_beta // AABB
@@ -281,122 +280,121 @@ asci_contrib_container<wfn_t<N>> asci_contributions_constraint(
281280
asci_contrib_container<wfn_t<N>> asci_pairs_total;
282281
#pragma omp parallel
283282
{
284-
// Process ASCI pair contributions for each constraint
285-
asci_contrib_container<wfn_t<N>> asci_pairs;
286-
asci_pairs.reserve(max_size);
287-
288-
size_t ic = 0;
289-
while(ic < ncon_total) {
290-
auto size_before = asci_pairs.size();
291-
const double h_el_tol = asci_settings.h_el_tol;
292-
293-
// Atomically get the next task ID and increment for other
294-
// MPI ranks and threads
295-
size_t ntake = ic < 1000 ? 1 : 10;
296-
ic = nxtval.fetch_and_add(ntake);
297-
298-
// Loop over assigned tasks
299-
const size_t c_end = std::min(ncon_total, ic + ntake);
300-
for(; ic < c_end; ++ic) {
301-
const auto& con = constraints[ic].first;
302-
printf("[rank %4d tid:%4d] %10lu / %10lu\n", world_rank,
303-
omp_get_thread_num(), ic, ncon_total);
304-
305-
for(size_t i_alpha = 0, iw = 0; i_alpha < nuniq_alpha; ++i_alpha) {
306-
const auto& alpha_det = uniq_alpha[i_alpha].first;
307-
const auto occ_alpha = bits_to_indices(alpha_det);
308-
const bool alpha_satisfies_con = satisfies_constraint(alpha_det, con);
309-
310-
const auto& bcd = uad[i_alpha];
311-
const size_t nbeta = bcd.size();
312-
for(size_t j_beta = 0; j_beta < nbeta; ++j_beta, ++iw) {
313-
const auto w = *(cdets_begin + iw);
314-
const auto c = C[iw];
315-
const auto& beta_det = bcd[j_beta].beta_string;
316-
const auto h_diag = bcd[j_beta].h_diag;
317-
const auto& occ_beta = bcd[j_beta].occ_beta;
318-
const auto& vir_beta = bcd[j_beta].vir_beta;
319-
const auto& orb_ens_alpha = bcd[j_beta].orb_ens_alpha;
320-
const auto& orb_ens_beta = bcd[j_beta].orb_ens_beta;
321-
322-
// AA excitations
323-
generate_constraint_singles_contributions_ss(
324-
c, w, con, occ_alpha, occ_beta, orb_ens_alpha.data(), T_pq, norb,
325-
G_red, norb, V_red, norb, h_el_tol, h_diag, E_ASCI, ham_gen,
326-
asci_pairs);
327-
328-
// AAAA excitations
329-
generate_constraint_doubles_contributions_ss(
330-
c, w, con, occ_alpha, occ_beta, orb_ens_alpha.data(), G_pqrs, norb,
331-
h_el_tol, h_diag, E_ASCI, ham_gen, asci_pairs);
332-
333-
// AABB excitations
334-
generate_constraint_doubles_contributions_os(
335-
c, w, con, occ_alpha, occ_beta, vir_beta, orb_ens_alpha.data(),
336-
orb_ens_beta.data(), V_pqrs, norb, h_el_tol, h_diag, E_ASCI,
337-
ham_gen, asci_pairs);
338-
339-
if(alpha_satisfies_con) {
340-
// BB excitations
341-
append_singles_asci_contributions<Spin::Beta>(
342-
c, w, beta_det, occ_beta, vir_beta, occ_alpha,
343-
orb_ens_beta.data(), T_pq, norb, G_red, norb, V_red, norb,
344-
h_el_tol, h_diag, E_ASCI, ham_gen, asci_pairs);
345-
346-
// BBBB excitations
347-
append_ss_doubles_asci_contributions<Spin::Beta>(
348-
c, w, beta_det, alpha_det, occ_beta, vir_beta, occ_alpha,
349-
orb_ens_beta.data(), G_pqrs, norb, h_el_tol, h_diag, E_ASCI,
350-
ham_gen, asci_pairs);
351-
352-
// No excitation (push inf to remove from list)
353-
asci_pairs.push_back(
354-
{w, std::numeric_limits<double>::infinity(), 1.0});
283+
// Process ASCI pair contributions for each constraint
284+
asci_contrib_container<wfn_t<N>> asci_pairs;
285+
asci_pairs.reserve(max_size);
286+
287+
size_t ic = 0;
288+
while(ic < ncon_total) {
289+
auto size_before = asci_pairs.size();
290+
const double h_el_tol = asci_settings.h_el_tol;
291+
292+
// Atomically get the next task ID and increment for other
293+
// MPI ranks and threads
294+
size_t ntake = ic < 1000 ? 1 : 10;
295+
ic = nxtval.fetch_and_add(ntake);
296+
297+
// Loop over assigned tasks
298+
const size_t c_end = std::min(ncon_total, ic + ntake);
299+
for(; ic < c_end; ++ic) {
300+
const auto& con = constraints[ic].first;
301+
printf("[rank %4d tid:%4d] %10lu / %10lu\n", world_rank,
302+
omp_get_thread_num(), ic, ncon_total);
303+
304+
for(size_t i_alpha = 0, iw = 0; i_alpha < nuniq_alpha; ++i_alpha) {
305+
const auto& alpha_det = uniq_alpha[i_alpha].first;
306+
const auto occ_alpha = bits_to_indices(alpha_det);
307+
const bool alpha_satisfies_con = satisfies_constraint(alpha_det, con);
308+
309+
const auto& bcd = uad[i_alpha];
310+
const size_t nbeta = bcd.size();
311+
for(size_t j_beta = 0; j_beta < nbeta; ++j_beta, ++iw) {
312+
const auto w = *(cdets_begin + iw);
313+
const auto c = C[iw];
314+
const auto& beta_det = bcd[j_beta].beta_string;
315+
const auto h_diag = bcd[j_beta].h_diag;
316+
const auto& occ_beta = bcd[j_beta].occ_beta;
317+
const auto& vir_beta = bcd[j_beta].vir_beta;
318+
const auto& orb_ens_alpha = bcd[j_beta].orb_ens_alpha;
319+
const auto& orb_ens_beta = bcd[j_beta].orb_ens_beta;
320+
321+
// AA excitations
322+
generate_constraint_singles_contributions_ss(
323+
c, w, con, occ_alpha, occ_beta, orb_ens_alpha.data(), T_pq,
324+
norb, G_red, norb, V_red, norb, h_el_tol, h_diag, E_ASCI,
325+
ham_gen, asci_pairs);
326+
327+
// AAAA excitations
328+
generate_constraint_doubles_contributions_ss(
329+
c, w, con, occ_alpha, occ_beta, orb_ens_alpha.data(), G_pqrs,
330+
norb, h_el_tol, h_diag, E_ASCI, ham_gen, asci_pairs);
331+
332+
// AABB excitations
333+
generate_constraint_doubles_contributions_os(
334+
c, w, con, occ_alpha, occ_beta, vir_beta, orb_ens_alpha.data(),
335+
orb_ens_beta.data(), V_pqrs, norb, h_el_tol, h_diag, E_ASCI,
336+
ham_gen, asci_pairs);
337+
338+
if(alpha_satisfies_con) {
339+
// BB excitations
340+
append_singles_asci_contributions<Spin::Beta>(
341+
c, w, beta_det, occ_beta, vir_beta, occ_alpha,
342+
orb_ens_beta.data(), T_pq, norb, G_red, norb, V_red, norb,
343+
h_el_tol, h_diag, E_ASCI, ham_gen, asci_pairs);
344+
345+
// BBBB excitations
346+
append_ss_doubles_asci_contributions<Spin::Beta>(
347+
c, w, beta_det, alpha_det, occ_beta, vir_beta, occ_alpha,
348+
orb_ens_beta.data(), G_pqrs, norb, h_el_tol, h_diag, E_ASCI,
349+
ham_gen, asci_pairs);
350+
351+
// No excitation (push inf to remove from list)
352+
asci_pairs.push_back(
353+
{w, std::numeric_limits<double>::infinity(), 1.0});
354+
}
355+
}
356+
357+
// Prune Down Contributions
358+
if(asci_pairs.size() > asci_settings.pair_size_max) {
359+
throw std::runtime_error("DIE DIE DIE");
360+
}
361+
362+
} // Unique Alpha Loop
363+
364+
// Local S&A for each quad
365+
{
366+
if(size_before > asci_pairs.size())
367+
throw std::runtime_error("DIE DIE DIE");
368+
auto uit = sort_and_accumulate_asci_pairs(
369+
asci_pairs.begin() + size_before, asci_pairs.end());
370+
asci_pairs.erase(uit, asci_pairs.end());
371+
372+
// Remove small contributions
373+
uit = std::partition(asci_pairs.begin() + size_before,
374+
asci_pairs.end(), [=](const auto& x) {
375+
return std::abs(x.rv()) >
376+
asci_settings.rv_prune_tol;
377+
});
378+
asci_pairs.erase(uit, asci_pairs.end());
355379
}
356-
}
357-
358-
// Prune Down Contributions
359-
if(asci_pairs.size() > asci_settings.pair_size_max) {
360-
throw std::runtime_error("DIE DIE DIE");
361-
}
380+
} // Loc constraint loop
381+
} // Constraint Loop
362382

363-
} // Unique Alpha Loop
364-
365-
366-
// Local S&A for each quad
383+
// Insert into list
384+
#pragma omp critical
367385
{
368-
if(size_before > asci_pairs.size())
369-
throw std::runtime_error("DIE DIE DIE");
370-
auto uit = sort_and_accumulate_asci_pairs(
371-
asci_pairs.begin() + size_before, asci_pairs.end());
372-
asci_pairs.erase(uit, asci_pairs.end());
373-
374-
// Remove small contributions
375-
uit =
376-
std::partition(asci_pairs.begin() + size_before, asci_pairs.end(),
377-
[=](const auto& x) {
378-
return std::abs(x.rv()) > asci_settings.rv_prune_tol;
379-
});
380-
asci_pairs.erase(uit, asci_pairs.end());
386+
if(asci_pairs_total.size()) {
387+
// Preallocate space for insertion
388+
asci_pairs_total.reserve(asci_pairs.size() + asci_pairs_total.size());
389+
asci_pairs_total.insert(asci_pairs_total.end(), asci_pairs.begin(),
390+
asci_pairs.end());
391+
} else {
392+
asci_pairs_total = std::move(asci_pairs);
393+
}
394+
asci_contrib_container<wfn_t<N>>().swap(asci_pairs);
381395
}
382-
} // Loc constraint loop
383-
} // Constraint Loop
384-
385-
// Insert into list
386-
#pragma omp critical
387-
{
388-
if(asci_pairs_total.size()) {
389-
// Preallocate space for insertion
390-
asci_pairs_total.reserve(asci_pairs.size() + asci_pairs_total.size());
391-
asci_pairs_total.insert(asci_pairs_total.end(), asci_pairs.begin(),
392-
asci_pairs.end());
393-
} else {
394-
asci_pairs_total = std::move(asci_pairs);
395-
}
396-
asci_contrib_container<wfn_t<N>>().swap(asci_pairs);
397-
}
398396

399-
} // OpenMP
397+
} // OpenMP
400398

401399
return asci_pairs_total;
402400
}
@@ -462,8 +460,8 @@ std::vector<wfn_t<N>> asci_search(
462460
// #ifdef MACIS_ENABLE_MPI
463461
// else
464462
asci_pairs = asci_contributions_constraint(
465-
asci_settings, ndets_max, cdets_begin, cdets_end, E_ASCI, C, norb, T_pq, G_red,
466-
V_red, G_pqrs, V_pqrs, ham_gen MACIS_MPI_CODE(, comm));
463+
asci_settings, ndets_max, cdets_begin, cdets_end, E_ASCI, C, norb, T_pq,
464+
G_red, V_red, G_pqrs, V_pqrs, ham_gen MACIS_MPI_CODE(, comm));
467465
// #endif
468466
auto pairs_en = clock_type::now();
469467

include/macis/asci/pt2.hpp

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,8 @@ double asci_pt2_constraint(wavefunction_iterator_t<N> cdets_begin,
107107
// auto constraints = dist_constraint_general<wfn_t<N>>(
108108
// 5, norb, n_sing_beta, n_doub_beta, uniq_alpha, comm);
109109
auto constraints = gen_constraints_general<wfn_t<N>>(
110-
10, norb, n_sing_beta, n_doub_beta, uniq_alpha, world_size * omp_get_max_threads());
110+
10, norb, n_sing_beta, n_doub_beta, uniq_alpha,
111+
world_size * omp_get_max_threads());
111112
auto gen_c_en = clock_type::now();
112113
duration_type gen_c_dur = gen_c_en - gen_c_st;
113114
logger->info(" * GEN_DUR = {:.2e} ms", gen_c_dur.count());
@@ -194,15 +195,17 @@ double asci_pt2_constraint(wavefunction_iterator_t<N> cdets_begin,
194195
}
195196
}
196197

197-
//if(not (i_alpha%10)) {
198+
// if(not (i_alpha%10)) {
198199
//// Cleanup
199-
//auto uit = sort_and_accumulate_asci_pairs(asci_pairs.begin(),
200-
// asci_pairs.end());
201-
//asci_pairs.erase(uit, asci_pairs.end());
202-
// printf("[rank %4d tid:%4d] IC = %lu / %lu IA = %lu / %lu SZ = %lu\n", world_rank,
203-
// omp_get_thread_num(), ic, ncon_total, i_alpha, nuniq_alpha, asci_pairs.size());
204-
//}
205-
200+
// auto uit = sort_and_accumulate_asci_pairs(asci_pairs.begin(),
201+
// asci_pairs.end());
202+
// asci_pairs.erase(uit, asci_pairs.end());
203+
// printf("[rank %4d tid:%4d] IC = %lu / %lu IA = %lu / %lu SZ =
204+
// %lu\n", world_rank,
205+
// omp_get_thread_num(), ic, ncon_total, i_alpha,
206+
// nuniq_alpha, asci_pairs.size());
207+
// }
208+
206209
} // Unique Alpha Loop
207210

208211
double EPT2_local = 0.0;

0 commit comments

Comments
 (0)