Skip to content

Commit 0a68daf

Browse files
committed
Fix MTREXEE FE issue
1 parent 471f884 commit 0a68daf

File tree

2 files changed

+9
-4
lines changed

2 files changed

+9
-4
lines changed

ensemble_md/analysis/analyze_free_energy.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,7 @@ def _combine_df_adjacent(df_adjacent, state_ranges, df_err_adjacent=None, err_ty
247247
return df, df_err, overlap_bool
248248

249249

250-
def calculate_free_energy(data, state_ranges, df_method="MBAR", err_method="propagate", n_bootstrap=None, seed=None):
250+
def calculate_free_energy(data, state_ranges, df_method="MBAR", err_method="propagate", n_bootstrap=None, seed=None, MTREXEE=False):
251251
"""
252252
Caculates the averaged free energy profile with the chosen method given :math:`u_{nk}` or :math:`dH/dλ` data
253253
obtained from all replicas of the REXEE simulation. Available methods include TI, BAR, and MBAR. TI
@@ -275,6 +275,8 @@ def calculate_free_energy(data, state_ranges, df_method="MBAR", err_method="prop
275275
seed : int, Optional
276276
The random seed for bootstrapping. Only relevant when :code:`err_method` is :code:`"bootstrap"`.
277277
The default is :code:`None`.
278+
MTREXEE : bool
279+
Whether this is a MT-REXEE simulation or not
278280
279281
Returns
280282
-------
@@ -299,7 +301,10 @@ def calculate_free_energy(data, state_ranges, df_method="MBAR", err_method="prop
299301
>>> f, _, _ = analyze_free_energy.calculate_free_energy(data_list, state_ranges, "MBAR", "propagate")
300302
"""
301303
n_sim = len(data)
302-
n_tot = state_ranges[-1][-1] + 1
304+
if MTREXEE is False:
305+
n_tot = state_ranges[-1][-1] + 1
306+
else:
307+
n_tot = state_ranges[-1] + 1
303308
estimators = _apply_estimators(data, df_method)
304309
df_adjacent, df_err_adjacent = _calculate_df_adjacent(estimators)
305310
df, df_err, overlap_bool = _combine_df_adjacent(df_adjacent, state_ranges, df_err_adjacent, err_type='propagate')

ensemble_md/cli/analyze_REXEE.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@ def main():
182182
rmse = analyze_traj.calc_hist_rmse(hist_data, REXEE.state_ranges)
183183
print(f'The RMSE of accumulated histogram counts of the state index: {rmse:.0f}')
184184

185-
if REXEE.proposal != 'forced_random' and REXEE.proposal != 'forced_swap': # Need to FIX THIS FOR FORCED-RANDOM
185+
if REXEE.proposal != 'random_range': # Need to FIX THIS FOR RANDOM-Range
186186
# 2-4. Stitch the time series of state index for different replicas
187187
if os.path.isfile(args.state_trajs_for_sim) is True:
188188
print('\n2-4. Reading in the stitched time series of state index for different replicas ...')
@@ -485,7 +485,7 @@ def main():
485485
pickle.dump(data_all, handle, protocol=pickle.HIGHEST_PROTOCOL)
486486

487487
# 4-2. Calculate the free energy profile
488-
f, f_err, estimators = analyze_free_energy.calculate_free_energy(data_list, REXEE.state_ranges[sim], REXEE.df_method, REXEE.err_method, REXEE.n_bootstrap, REXEE.seed) # noqa: E501
488+
f, f_err, estimators = analyze_free_energy.calculate_free_energy(data_list, REXEE.state_ranges[sim], REXEE.df_method, REXEE.err_method, REXEE.n_bootstrap, REXEE.seed, MTREXEE=True) # noqa: E501
489489

490490
print('Plotting the full-range free energy profile ...')
491491
analyze_free_energy.plot_free_energy(f, f_err, f'{args.dir}/free_energy_profile_{sim}.png')

0 commit comments

Comments
 (0)