Skip to content

Commit 9f409f7

Browse files
Major fixes
(1) Made contraint generation MPI parallel and fixed initial (n_min) constraint recursion (2) Made index precomputation optional (reduce memory footprint) (3) Made PT2 print more configurable
1 parent d7c0d6a commit 9f409f7

File tree

5 files changed

+91
-40
lines changed

5 files changed

+91
-40
lines changed

include/macis/asci/determinant_search.hpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@ struct ASCISettings {
4444
size_t pt2_reserve_count = 70000000;
4545
bool pt2_prune = false;
4646
bool pt2_precompute_eps = false;
47+
bool pt2_precompute_idx = false;
48+
bool pt2_print_progress = false;
4749
size_t pt2_bigcon_thresh = 250;
4850

4951
size_t nxtval_bcount_thresh = 1000;
@@ -59,7 +61,8 @@ struct ASCISettings {
5961

6062
// bool dist_triplet_random = false;
6163
int constraint_level = 2; // Up To Quints
62-
int pt2_constraint_level = 5;
64+
int pt2_max_constraint_level = 5;
65+
int pt2_min_constraint_level = 0;
6366
};
6467

6568
template <size_t N>
@@ -309,8 +312,8 @@ asci_contrib_container<wfn_t<N>> asci_contributions_constraint(
309312
const size_t c_end = std::min(ncon_total, ic + ntake);
310313
for(; ic < c_end; ++ic) {
311314
const auto& con = constraints[ic].first;
312-
printf("[rank %4d tid:%4d] %10lu / %10lu\n", world_rank,
313-
omp_get_thread_num(), ic, ncon_total);
315+
//printf("[rank %4d tid:%4d] %10lu / %10lu\n", world_rank,
316+
// omp_get_thread_num(), ic, ncon_total);
314317

315318
for(size_t i_alpha = 0, iw = 0; i_alpha < nuniq_alpha; ++i_alpha) {
316319
const auto& alpha_det = uniq_alpha[i_alpha].first;

include/macis/asci/mask_constraints.hpp

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -673,18 +673,20 @@ auto gen_constraints_general(size_t nlevels, size_t norb, size_t ns_othr,
673673
}
674674
// Build up higher-order constraints as base if requested
675675
for(size_t ilevel = 0; ilevel < nlevel_min; ++ilevel) {
676-
std::vector cur_constraints = constraint_sizes;
677-
for(auto [c,nw] : cur_constraints) {
676+
decltype(constraint_sizes) cur_constraints;
677+
cur_constraints.reserve(constraint_sizes.size() * norb);
678+
for(auto [c,nw] : constraint_sizes) {
678679
const auto C_min = c.C_min();
679680
for(auto q_l = 0; q_l < C_min; ++q_l) {
680681
// Generate masks / counts
681682
string_type cn_C = c.C();
682683
cn_C.flip(q_l);
683684
string_type cn_B = c.B() >> (C_min - q_l);
684685
constraint_type c_next(cn_C, cn_B, q_l);
685-
constraint_sizes.emplace_back(c_next, 0ul);
686+
cur_constraints.emplace_back(c_next, 0ul);
686687
}
687688
}
689+
constraint_sizes = std::move(cur_constraints);
688690
}
689691

690692
struct atomic_wrapper {
@@ -701,10 +703,14 @@ auto gen_constraints_general(size_t nlevels, size_t norb, size_t ns_othr,
701703
// Compute histogram
702704
const auto ntrip_full = constraint_sizes.size();
703705
std::vector<atomic_wrapper> constraint_work(ntrip_full, 0ul);
704-
int world_rank = comm_rank(MPI_COMM_WORLD);
705-
#pragma omp parallel for schedule(dynamic)
706-
for(auto i_trip = 0ul; i_trip < ntrip_full; ++i_trip) {
707-
if(!world_rank and !(i_trip%1000)) printf("cgen %lu / %lu\n", i_trip, ntrip_full);
706+
global_atomic<size_t> nxtval(MPI_COMM_WORLD);
707+
#pragma omp parallel
708+
{
709+
size_t i_trip = 0;
710+
while(i_trip < ntrip_full) {
711+
i_trip = nxtval.fetch_and_add(1);
712+
if(i_trip >= ntrip_full) break;
713+
if(!(i_trip%1000)) printf("cgen %lu / %lu\n", i_trip, ntrip_full);
708714
auto& [constraint, __nw] = constraint_sizes[i_trip];
709715
auto& c_nw = constraint_work[i_trip];
710716
size_t nw = 0;
@@ -718,10 +724,17 @@ auto gen_constraints_general(size_t nlevels, size_t norb, size_t ns_othr,
718724
}
719725
if(nw) c_nw.value.fetch_add(nw);
720726
}
727+
}
728+
729+
std::vector<size_t> constraint_work_bare(ntrip_full);
730+
for(auto i_trip = 0; i_trip < ntrip_full; ++i_trip) {
731+
constraint_work_bare[i_trip] = constraint_work[i_trip].value.load();
732+
}
733+
allreduce(constraint_work_bare.data(), ntrip_full, MPI_SUM, MPI_COMM_WORLD);
721734

722735
// Copy over constraint work
723736
for(auto i_trip = 0; i_trip < ntrip_full; ++i_trip) {
724-
constraint_sizes[i_trip].second = constraint_work[i_trip].value.load();
737+
constraint_sizes[i_trip].second = constraint_work_bare[i_trip];
725738
}
726739

727740
// Remove zeros

include/macis/asci/pt2.hpp

Lines changed: 53 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -43,16 +43,17 @@ double asci_pt2_constraint(ASCISettings asci_settings,
4343

4444
const size_t ncdets = std::distance(cdets_begin, cdets_end);
4545
logger->info("[ASCI PT2 Settings]");
46-
logger->info(" * NDETS = {}", ncdets);
47-
logger->info(" * PT2_TOL = {}", asci_settings.pt2_tol);
48-
logger->info(" * PT2_RESERVE_COUNT = {}", asci_settings.pt2_reserve_count);
49-
logger->info(" * PT2_CONSTRAINT_LVL = {}", asci_settings.pt2_constraint_level);
50-
logger->info(" * PT2_PRUNE = {}", asci_settings.pt2_prune);
51-
logger->info(" * PT2_PRECOMP_EPS = {}", asci_settings.pt2_precompute_eps);
52-
logger->info(" * PT2_BIGCON_THRESH = {}", asci_settings.pt2_bigcon_thresh);
53-
logger->info(" * NXTVAL_BCOUNT_THRESH = {}",
46+
logger->info(" * NDETS = {}", ncdets);
47+
logger->info(" * PT2_TOL = {}", asci_settings.pt2_tol);
48+
logger->info(" * PT2_RESERVE_COUNT = {}", asci_settings.pt2_reserve_count);
49+
logger->info(" * PT2_CONSTRAINT_LVL_MAX = {}", asci_settings.pt2_max_constraint_level);
50+
logger->info(" * PT2_CONSTRAINT_LVL_MIN = {}", asci_settings.pt2_min_constraint_level);
51+
logger->info(" * PT2_PRUNE = {}", asci_settings.pt2_prune);
52+
logger->info(" * PT2_PRECOMP_EPS = {}", asci_settings.pt2_precompute_eps);
53+
logger->info(" * PT2_BIGCON_THRESH = {}", asci_settings.pt2_bigcon_thresh);
54+
logger->info(" * NXTVAL_BCOUNT_THRESH = {}",
5455
asci_settings.nxtval_bcount_thresh);
55-
logger->info(" * NXTVAL_BCOUNT_INC = {}",
56+
logger->info(" * NXTVAL_BCOUNT_INC = {}",
5657
asci_settings.nxtval_bcount_inc);
5758
logger->info("");
5859

@@ -74,7 +75,7 @@ double asci_pt2_constraint(ASCISettings asci_settings,
7475

7576
beta_coeff_data(double c, size_t norb,
7677
const std::vector<uint32_t>& occ_alpha, wfn_t<N> w,
77-
const HamiltonianGenerator<wfn_t<N>>& ham_gen, bool pce) {
78+
const HamiltonianGenerator<wfn_t<N>>& ham_gen, bool pce, bool pci) {
7879
coeff = c;
7980

8081
beta_string = wfn_traits::beta_string(w);
@@ -84,11 +85,13 @@ double asci_pt2_constraint(ASCISettings asci_settings,
8485

8586
// Compute occ/vir for beta string
8687
std::vector<uint32_t> o_32, v_32;
87-
spin_wfn_traits::state_to_occ_vir(norb, beta_string, o_32, v_32);
88-
occ_beta.resize(o_32.size());
89-
std::copy(o_32.begin(), o_32.end(), occ_beta.begin());
90-
vir_beta.resize(v_32.size());
91-
std::copy(v_32.begin(), v_32.end(), vir_beta.begin());
88+
if(pce or pci) {
89+
spin_wfn_traits::state_to_occ_vir(norb, beta_string, o_32, v_32);
90+
occ_beta.resize(o_32.size());
91+
std::copy(o_32.begin(), o_32.end(), occ_beta.begin());
92+
vir_beta.resize(v_32.size());
93+
std::copy(v_32.begin(), v_32.end(), vir_beta.begin());
94+
}
9295

9396
// Precompute orbital energies
9497
if(pce) {
@@ -117,23 +120,23 @@ double asci_pt2_constraint(ASCISettings asci_settings,
117120
uad[i].reserve(nbeta);
118121
for(auto j = 0; j < nbeta; ++j, ++iw) {
119122
const auto& w = *(cdets_begin + iw);
120-
uad[i].emplace_back(C[iw], norb, occ_alpha, w, ham_gen,asci_settings.pt2_precompute_eps);
123+
uad[i].emplace_back(C[iw], norb, occ_alpha, w, ham_gen,asci_settings.pt2_precompute_eps, asci_settings.pt2_precompute_idx);
121124
}
122125
}
123126

124127
if(world_rank == 0) {
125128
constexpr double gib = 1024 * 1024 * 1024;
126-
printf("MEM REQ DETS = %.2e\n", ncdets * sizeof(wfn_t<N>) / gib);
127-
printf("MEM REQ C = %.2e\n", ncdets * sizeof(double) / gib);
129+
logger->info("MEM REQ DETS = {:.2e}", ncdets * sizeof(wfn_t<N>) / gib);
130+
logger->info("MEM REQ C = {:.2e}", ncdets * sizeof(double) / gib);
128131
size_t mem_alpha = 0;
129132
for( auto i = 0ul; i < nuniq_alpha; ++i) {
130133
mem_alpha += sizeof(spin_wfn_type);
131134
for(auto j = 0ul; j < uad[i].size(); ++j) {
132135
mem_alpha += uad[i][j].mem();
133136
}
134137
}
135-
printf("MEM REQ ALPH = %.2e\n", mem_alpha / gib);
136-
printf("MEM REQ CONT = %.2e\n", asci_settings.pt2_reserve_count * sizeof(asci_contrib<wfn_t<N>>)/ gib);
138+
logger->info("MEM REQ ALPH = {:.2e}", mem_alpha / gib);
139+
logger->info("MEM REQ CONT = {:.2e}", asci_settings.pt2_reserve_count * sizeof(asci_contrib<wfn_t<N>>)/ gib);
137140
}
138141
MPI_Barrier(comm);
139142

@@ -153,8 +156,9 @@ double asci_pt2_constraint(ASCISettings asci_settings,
153156
// auto constraints = dist_constraint_general<wfn_t<N>>(
154157
// 5, norb, n_sing_beta, n_doub_beta, uniq_alpha, comm);
155158
auto constraints = gen_constraints_general<wfn_t<N>>(
156-
asci_settings.pt2_constraint_level, norb, n_sing_beta,
157-
n_doub_beta, uniq_alpha, world_size * omp_get_max_threads(), 0);
159+
asci_settings.pt2_max_constraint_level, norb, n_sing_beta,
160+
n_doub_beta, uniq_alpha, world_size * omp_get_max_threads(),
161+
asci_settings.pt2_min_constraint_level);
158162
auto gen_c_en = clock_type::now();
159163
duration_type gen_c_dur = gen_c_en - gen_c_st;
160164
logger->info(" * GEN_DUR = {:.2e} ms", gen_c_dur.count());
@@ -198,7 +202,8 @@ double asci_pt2_constraint(ASCISettings asci_settings,
198202
// MPI ranks
199203
ic = nxtval_big.fetch_and_add(1);
200204
if(ic >= ncon_big) continue;
201-
printf("[pt2_big rank %4d] %10lu / %10lu\n", world_rank, ic, ncon_total);
205+
if(asci_settings.pt2_print_progress)
206+
printf("[pt2_big rank %4d] %10lu / %10lu\n", world_rank, ic, ncon_total);
202207
const auto& con = constraints[ic].first;
203208

204209
asci_contrib_container<wfn_t<N>> asci_pairs_con;
@@ -224,11 +229,16 @@ double asci_pt2_constraint(ASCISettings asci_settings,
224229
const auto h_diag = bcd[j_beta].h_diag;
225230

226231
// TODO: These copies are slow
232+
#if 0
227233
const auto& occ_beta_8 = bcd[j_beta].occ_beta;
228234
const auto& vir_beta_8 = bcd[j_beta].vir_beta;
229235
std::vector<uint32_t> occ_beta(occ_beta_8.size()), vir_beta(vir_beta_8.size());
230236
std::copy(occ_beta_8.begin(), occ_beta_8.end(), occ_beta.begin());
231237
std::copy(vir_beta_8.begin(), vir_beta_8.end(), vir_beta.begin());
238+
#else
239+
std::vector<uint32_t> occ_beta, vir_beta;
240+
spin_wfn_traits::state_to_occ_vir(norb, beta_det, occ_beta, vir_beta);
241+
#endif
232242

233243
std::vector<double> orb_ens_alpha, orb_ens_beta;
234244
if(asci_settings.pt2_precompute_eps) {
@@ -309,17 +319,21 @@ double asci_pt2_constraint(ASCISettings asci_settings,
309319

310320
double EPT2_local = 0.0;
311321
size_t NPT2_local = 0;
322+
size_t pair_size = 0;
312323
// Local S&A for each quad + update EPT2
313324
{
314325
auto uit = sort_and_accumulate_asci_pairs(asci_pairs_con.begin(),
315326
asci_pairs_con.end());
327+
pair_size = std::distance(asci_pairs_con.begin(), uit);
316328
for(auto it = asci_pairs_con.begin(); it != uit; ++it) {
317329
if(!std::isinf(it->c_times_matel)) {
318330
EPT2_local += it->pt2();
319331
NPT2_local++;
320332
}
321333
}
322334
asci_pairs_con.clear();
335+
if(asci_settings.pt2_print_progress)
336+
printf("[pt2_big rank %4d] CAPACITY %lu SZ %lu\n", world_rank, asci_pairs_con.capacity(), pair_size);
323337
}
324338

325339
EPT2 += EPT2_local;
@@ -345,8 +359,9 @@ double asci_pt2_constraint(ASCISettings asci_settings,
345359
const size_t c_end = std::min(ncon_total, ic + ntake);
346360
for(; ic < c_end; ++ic) {
347361
const auto& con = constraints[ic].first;
348-
printf("[pt2_small rank %4d tid:%4d] %10lu / %10lu\n", world_rank,
349-
omp_get_thread_num(), ic, ncon_total);
362+
if(asci_settings.pt2_print_progress)
363+
printf("[pt2_small rank %4d tid:%4d] %10lu / %10lu\n", world_rank,
364+
omp_get_thread_num(), ic, ncon_total);
350365

351366
for(size_t i_alpha = 0; i_alpha < nuniq_alpha; ++i_alpha) {
352367
const size_t old_pair_size = asci_pairs.size();
@@ -366,11 +381,16 @@ double asci_pt2_constraint(ASCISettings asci_settings,
366381
const auto h_diag = bcd[j_beta].h_diag;
367382

368383
// TODO: These copies are slow
384+
#if 0
369385
const auto& occ_beta_8 = bcd[j_beta].occ_beta;
370386
const auto& vir_beta_8 = bcd[j_beta].vir_beta;
371387
std::vector<uint32_t> occ_beta(occ_beta_8.size()), vir_beta(vir_beta_8.size());
372388
std::copy(occ_beta_8.begin(), occ_beta_8.end(), occ_beta.begin());
373389
std::copy(vir_beta_8.begin(), vir_beta_8.end(), vir_beta.begin());
390+
#else
391+
std::vector<uint32_t> occ_beta, vir_beta;
392+
spin_wfn_traits::state_to_occ_vir(norb, beta_det, occ_beta, vir_beta);
393+
#endif
374394

375395
std::vector<double> orb_ens_alpha, orb_ens_beta;
376396
if(asci_settings.pt2_precompute_eps) {
@@ -423,9 +443,14 @@ double asci_pt2_constraint(ASCISettings asci_settings,
423443
asci_pairs.erase(uit, asci_pairs.end());
424444
//uit = std::stable_partition(asci_pairs.begin(), asci_pairs.end(), [&](const auto& p){ return std::abs(p.pt2()) > h_el_tol; });
425445
//asci_pairs.erase(uit, asci_pairs.end());
426-
printf("[rank %4d tid:%4d] IC = %lu / %lu IA = %lu / %lu SZ = %lu\n", world_rank,
427-
omp_get_thread_num(), ic, ncon_total, i_alpha,
428-
nuniq_alpha, asci_pairs.size());
446+
if(asci_settings.pt2_print_progress)
447+
printf("[pt2_prune rank %4d tid:%4d] IC = %lu / %lu IA = %lu / %lu SZ = %lu\n", world_rank,
448+
omp_get_thread_num(), ic, ncon_total, i_alpha,
449+
nuniq_alpha, asci_pairs.size());
450+
451+
if(asci_pairs.size() > asci_settings.pt2_reserve_count) {
452+
printf("* WARNING: PRUNED SIZE LARGER THAN RESERVE COUNT\n");
453+
}
429454
}
430455

431456
} // Unique Alpha Loop

include/macis/hamiltonian_generator/sorted_double_loop.hpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ class SortedDoubleLoopHamiltonianGenerator
4040
const size_t nket_dets = std::distance(ket_begin, ket_end);
4141

4242
const bool is_symm = bra_begin == ket_begin and bra_end == ket_end;
43+
auto world_rank = comm_rank(MPI_COMM_WORLD);
4344

4445
// Get unique alpha strings
4546
auto setup_st = std::chrono::high_resolution_clock::now();
@@ -102,6 +103,7 @@ class SortedDoubleLoopHamiltonianGenerator
102103
#pragma omp for schedule(dynamic)
103104
for(size_t ia_bra = 0; ia_bra < nuniq_bra; ++ia_bra) {
104105
if(unique_alpha_bra[ia_bra].first.any()) {
106+
if(!(ia_bra%100))printf("[ham_gen rank %d] IA_BRA = %lu / %lu\n", world_rank, ia_bra, nuniq_bra);
105107
// Extract alpha bra
106108
const auto bra_alpha = unique_alpha_bra[ia_bra].first;
107109
const size_t beta_st_bra = unique_alpha_bra_idx[ia_bra];
@@ -227,10 +229,12 @@ class SortedDoubleLoopHamiltonianGenerator
227229

228230
// Sort for CSR Conversion
229231
auto sort_st = std::chrono::high_resolution_clock::now();
232+
printf("[ham_gen rank %d] BEFORE SORT\n", world_rank);
230233
coo_mat.sort_by_row_index();
231234
auto sort_en = std::chrono::high_resolution_clock::now();
232235

233236
auto conv_st = std::chrono::high_resolution_clock::now();
237+
printf("[ham_gen rank %d] BEFORE CONV\n", world_rank);
234238
sparse_matrix_type<index_t> csr_mat(coo_mat); // Convert to CSR Matrix
235239
auto conv_en = std::chrono::high_resolution_clock::now();
236240

tests/standalone_driver.cxx

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -222,9 +222,12 @@ int main(int argc, char** argv) {
222222
OPT_KEYWORD("ASCI.PT2", pt2, bool);
223223
OPT_KEYWORD("ASCI.PT2_TOL", asci_settings.pt2_tol, double);
224224
OPT_KEYWORD("ASCI.PT2_RESERVE_COUNT", asci_settings.pt2_reserve_count, size_t);
225-
OPT_KEYWORD("ASCI.PT2_CONSTRAINT_LVL", asci_settings.pt2_constraint_level, int);
225+
OPT_KEYWORD("ASCI.PT2_CONSTRAINT_LVL_MAX", asci_settings.pt2_max_constraint_level, int);
226+
OPT_KEYWORD("ASCI.PT2_CONSTRAINT_LVL_MIN", asci_settings.pt2_min_constraint_level, int);
226227
OPT_KEYWORD("ASCI.PT2_PRUNE", asci_settings.pt2_prune, bool);
227228
OPT_KEYWORD("ASCI.PT2_PRECOMPUTE_EPS", asci_settings.pt2_precompute_eps, bool);
229+
OPT_KEYWORD("ASCI.PT2_PRECOMPUTE_IDX", asci_settings.pt2_precompute_idx, bool);
230+
OPT_KEYWORD("ASCI.PT2_PRINT_PROGRESS", asci_settings.pt2_print_progress, bool);
228231
OPT_KEYWORD("ASCI.PT2_BIGCON_THRESH", asci_settings.pt2_bigcon_thresh, size_t);
229232
OPT_KEYWORD("ASCI.NXTVAL_BCOUNT_THRESH", asci_settings.nxtval_bcount_thresh, size_t);
230233
OPT_KEYWORD("ASCI.NXTVAL_BCOUNT_INC", asci_settings.nxtval_bcount_inc, size_t);
@@ -233,6 +236,9 @@ int main(int argc, char** argv) {
233236
OPT_KEYWORD("MCSCF.MP2_GUESS", mp2_guess, bool);
234237

235238
if(!world_rank) {
239+
console->info("[Standalone MACIS Driver]:");
240+
console->info(" * NMPI = {}", world_size);
241+
console->info(" * NTHREADS = {}", omp_get_max_threads());
236242
console->info("[Wavefunction Data]:");
237243
console->info(" * JOB = {}", job_str);
238244
console->info(" * CIEXP = {}", ciexp_str);

0 commit comments

Comments
 (0)