Skip to content

Commit 55598e6

Browse files
committed
Merge branch 'forced-swap' of github.com:wehs7661/ensemble_md into forced-swap
2 parents 4fc0d30 + 5ea8fde commit 55598e6

18 files changed

+777
-750
lines changed

ensemble_md/analysis/analyze_free_energy.py

Lines changed: 71 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,8 @@ def _combine_df_adjacent(df_adjacent, state_ranges, df_err_adjacent=None, err_ty
189189
A list of lists free energy differences between adjacent states for all replicas.
190190
state_ranges : list
191191
A list of lists of showing the state indices sampled by each replica.
192+
n_tot : int
193+
Number of lambda states
192194
df_err_adjacent : list, Optional
193195
A list of lists of uncertainties corresponding to the values of :code:`df_adjacent`. Notably, if
194196
:code:`df_err_adjacent` is :code:`None`, simple means will be used. Otherwise, inverse-variance weighted
@@ -247,7 +249,48 @@ def _combine_df_adjacent(df_adjacent, state_ranges, df_err_adjacent=None, err_ty
247249
return df, df_err, overlap_bool
248250

249251

250-
def calculate_free_energy(data, state_ranges, df_method="MBAR", err_method="propagate", n_bootstrap=None, seed=None):
252+
def _calculate_df(estimators):
253+
"""
254+
An internal function used in :func:`calculate_free_energy` to calculate a list of free energies between adjacent
255+
states for all replicas.
256+
257+
Parameters
258+
----------
259+
estimators : list
260+
A list of estimators fitting the input data for all replicas. With this, the user
261+
can access all the free energies and their associated uncertainties for all states and replicas.
262+
In our code, these estimators come from the function :func:`_apply_estimators`.
263+
264+
Returns
265+
-------
266+
df : float
267+
Free energy differences between for specified replica.
268+
df_err : float
269+
Uncertainties corresponding to the values in :code:`df`.
270+
271+
See also
272+
--------
273+
:func:`calculate_free_energy`
274+
"""
275+
# Compute FE estimate
276+
df = estimators[0].delta_f_
277+
l = np.linspace(0, 1, num=len(df.index))
278+
df.index = l
279+
df.columns = l
280+
est = df.loc[0, 1]
281+
print(df)
282+
283+
# Compute FE extimate error
284+
df_err = estimators[0].d_delta_f_
285+
l = np.linspace(0, 1, num=len(df_err.index))
286+
df_err.index = l
287+
df_err.columns = l
288+
err = df_err.loc[0, 1]
289+
290+
return est, err
291+
292+
293+
def calculate_free_energy(data, state_ranges, df_method="MBAR", err_method="propagate", n_bootstrap=None, seed=None, MTREXEE=False): # noqa: E501
251294
"""
252295
Caculates the averaged free energy profile with the chosen method given :math:`u_{nk}` or :math:`dH/dλ` data
253296
obtained from all replicas of the REXEE simulation. Available methods include TI, BAR, and MBAR. TI
@@ -275,6 +318,8 @@ def calculate_free_energy(data, state_ranges, df_method="MBAR", err_method="prop
275318
seed : int, Optional
276319
The random seed for bootstrapping. Only relevant when :code:`err_method` is :code:`"bootstrap"`.
277320
The default is :code:`None`.
321+
MTREXEE : bool
322+
Whether this is a MT-REXEE simulation or not
278323
279324
Returns
280325
-------
@@ -299,10 +344,17 @@ def calculate_free_energy(data, state_ranges, df_method="MBAR", err_method="prop
299344
>>> f, _, _ = analyze_free_energy.calculate_free_energy(data_list, state_ranges, "MBAR", "propagate")
300345
"""
301346
n_sim = len(data)
302-
n_tot = state_ranges[-1][-1] + 1
347+
if MTREXEE is False:
348+
n_tot = state_ranges[-1][-1] + 1
349+
else:
350+
n_tot = state_ranges[-1] + 1
303351
estimators = _apply_estimators(data, df_method)
304-
df_adjacent, df_err_adjacent = _calculate_df_adjacent(estimators)
305-
df, df_err, overlap_bool = _combine_df_adjacent(df_adjacent, state_ranges, df_err_adjacent, err_type='propagate')
352+
print(estimators)
353+
if MTREXEE is False:
354+
df_adjacent, df_err_adjacent = _calculate_df_adjacent(estimators)
355+
df, df_err, overlap_bool = _combine_df_adjacent(df_adjacent, state_ranges, df_err_adjacent, err_type='propagate') # noqa: E501
356+
else:
357+
df, df_err = _calculate_df(estimators)
306358

307359
if err_method == 'bootstrap':
308360
if seed is not None:
@@ -314,26 +366,33 @@ def calculate_free_energy(data, state_ranges, df_method="MBAR", err_method="prop
314366
for b in range(n_bootstrap):
315367
sampled_data = [sampled_data_all[i].iloc[b * len(data[i]):(b + 1) * len(data[i])] for i in range(n_sim)]
316368
bootstrap_estimators = _apply_estimators(sampled_data, df_method)
317-
df_adjacent, df_err_adjacent = _calculate_df_adjacent(bootstrap_estimators)
318-
df_sampled, _, overlap_bool = _combine_df_adjacent(df_adjacent, state_ranges, df_err_adjacent, err_type='propagate') # doesn't matter what value err_type here is # noqa: E501
369+
if MTREXEE is False:
370+
df_adjacent, df_err_adjacent = _calculate_df_adjacent(bootstrap_estimators)
371+
df_sampled, _, overlap_bool = _combine_df_adjacent(df_adjacent, state_ranges, df_err_adjacent, err_type='propagate') # doesn't matter what value err_type here is # noqa: E501
372+
else:
373+
df_sampled, _ = _calculate_df(bootstrap_estimators)
319374
df_bootstrap.append(df_sampled)
320375
error_bootstrap = np.std(df_bootstrap, axis=0, ddof=1)
321376

322377
# Replace the value in df_err with value in error_bootstrap if df_err corresponds to
323378
# the df between overlapping states
324379
for i in range(n_tot - 1):
325-
if overlap_bool[i] is True:
380+
if MTREXEE is True or overlap_bool[i] is True:
326381
print(f'Replaced the propagated error with the bootstrapped error for states {i} and {i + 1}: {df_err[i]:.5f} -> {error_bootstrap[i]:.5f}.') # noqa: E501
327382
df_err[i] = error_bootstrap[i]
328383
elif err_method == 'propagate':
329384
pass
330385
else:
331386
raise ParameterError('Specified err_method not available.')
332-
333-
df.insert(0, 0)
334-
df_err.insert(0, 0)
335-
f = [sum(df[:(i + 1)]) for i in range(len(df))]
336-
f_err = [np.sqrt(sum([x**2 for x in df_err[:(i+1)]])) for i in range(len(df_err))]
387+
388+
if MTREXEE is False:
389+
df.insert(0, 0)
390+
df_err.insert(0, 0)
391+
f = [sum(df[:(i + 1)]) for i in range(len(df))]
392+
f_err = [np.sqrt(sum([x**2 for x in df_err[:(i+1)]])) for i in range(len(df_err))]
393+
else:
394+
f = df
395+
f_err = df_err
337396

338397
return f, f_err, estimators
339398

ensemble_md/analysis/analyze_traj.py

Lines changed: 62 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ def stitch_time_series(files, rep_trajs, shifts=None, dhdl=True, col_idx=-1, sav
106106
# files_sorted[i] contains the dhdl/plumed output files for starting configuration i sorted
107107
# based on iteration indices
108108
files_sorted = [[] for i in range(n_configs)]
109+
print(n_iter)
109110
for i in range(n_configs):
110111
for j in range(n_iter):
111112
files_sorted[i].append(files[rep_trajs[i][j]][j])
@@ -185,6 +186,8 @@ def stitch_time_series_for_sim(files, shifts=None, dhdl=True, col_idx=-1, save_n
185186
:func:`.stitch_time_series`
186187
:func:`.stitch_xtc_trajs`
187188
"""
189+
#if os.path.exists('track_swap_frame.npy'):
190+
188191
n_sim = len(files) # number of replicas
189192
n_iter = len(files[0]) # number of iterations per replica
190193
trajs = [[] for i in range(n_sim)]
@@ -543,6 +546,7 @@ def plot_state_hist(trajs, state_ranges, fig_name, stack=True, figsize=None, pre
543546
dir_list = []
544547
for i in fig_name.split('/')[:-1]:
545548
dir_list.append(i)
549+
dir_list.append('/')
546550
dir_path = ''.join(dir_list)
547551
np.save(f'{dir_path}/hist_data.npy', hist_data)
548552
else:
@@ -833,7 +837,7 @@ def plot_transit_time(trajs, N, fig_prefix=None, dt=None, folder='.'):
833837
plt.savefig(f'{folder}/hist_{fig_names[t]}', dpi=600)
834838
else:
835839
plt.savefig(f'{folder}/{fig_prefix}_hist_{fig_names[t]}', dpi=600)
836-
#Save to csv
840+
# Save to csv
837841
sim_list, rt_list = [], []
838842
for n in range(len(t_roundtrip_list)):
839843
for rt in t_roundtrip_list[n]:
@@ -1350,6 +1354,30 @@ def get_delta_w_updates(log_file, plot=False):
13501354

13511355

13521356
def end_states_only_traj(working_dir, n_sim, n_iter, l0_states, l1_states, swap_rep_pattern, ps_per_frame):
1357+
"""
1358+
Create a trajectory which is a concatenation off all frames for each unique end state.
1359+
1360+
Parameters
1361+
----------
1362+
working_dir : str
1363+
path for the current working directory
1364+
n_sim : int
1365+
the number of simulations run
1366+
n_iter : int
1367+
the number of iterations run
1368+
l0_states : list of int
1369+
the lambda states which correspond to lambda=0
1370+
l1_states : list of int
1371+
the lambda states which correspond to lambda=1
1372+
swap_rep_pattern : list of int
1373+
the replica swapping pattern which will indicate which end states are common
1374+
ps_per_frame : float
1375+
the timestep to convert the time in the GROMACS dh/dl file to frames in the trajecotry
1376+
1377+
Returns
1378+
-------
1379+
None
1380+
"""
13531381
import pandas as pd
13541382
import os
13551383
import mdtraj as md
@@ -1434,18 +1462,39 @@ def end_states_only_traj(working_dir, n_sim, n_iter, l0_states, l1_states, swap_
14341462
traj = md.join(traj, traj_add)
14351463
traj.save_xtc(f'{working_dir}/analysis/{state}_{rep}.xtc')
14361464

1437-
def concat_sim_traj(working_dir, n_sim, n_iter):
1465+
1466+
def concat_sim_traj(working_dir, n_sim, n_iter, gro):
1467+
"""
1468+
Create a trajectory which is a concatenation off each iterations trajectory
1469+
1470+
Parameters
1471+
----------
1472+
working_dir : str
1473+
path for the current working directory
1474+
n_sim : int
1475+
the number of simulations run
1476+
n_iter : int
1477+
the number of iterations run
1478+
1479+
Returns
1480+
-------
1481+
None
1482+
"""
14381483
import mdtraj as md
14391484
import os
1440-
1485+
from tqdm import tqdm
1486+
14411487
for rep in range(n_sim):
1442-
if os.path.exists(f'{working_dir}/sim_{rep}/iteration_0/confout_backup.gro'):
1443-
name = 'confout_backup'
1444-
else:
1445-
name = 'confout'
1446-
1447-
traj = md.load(f'{working_dir}/sim_{rep}/iteration_0/traj.trr', top=f'{working_dir}/sim_{rep}/iteration_0/{name}.gro')
1448-
for iteration in range(1, n_iter):
1449-
traj_add = md.load(f'{working_dir}/sim_{rep}/iteration_{iteration}/traj.trr', top=f'{working_dir}/sim_{rep}/iteration_0/{name}.gro')
1450-
traj = md.join([traj, traj_add])
1451-
traj.save_xtc(f'{working_dir}/analysis/sim{rep}_concat.xtc')
1488+
if not os.path.exists(f'{working_dir}/analysis/sim{rep}_concat.xtc'):
1489+
if os.path.exists(f'{working_dir}/sim_{rep}/iteration_0/confout_backup.gro'):
1490+
name = 'confout_backup'
1491+
else:
1492+
name = 'confout'
1493+
gro_ref = md.load(f'{working_dir}/{gro[rep]}')
1494+
traj = md.load(f'{working_dir}/sim_{rep}/iteration_0/traj.trr', top=f'{working_dir}/sim_{rep}/iteration_0/{name}.gro') # noqa: E501
1495+
traj.superpose(gro_ref, frame=0)
1496+
for iteration in tqdm(range(1, n_iter)):
1497+
traj_add = md.load(f'{working_dir}/sim_{rep}/iteration_{iteration}/traj.trr', top=f'{working_dir}/sim_{rep}/iteration_0/{name}.gro') # noqa: E501
1498+
traj_add.superpose(gro_ref, frame=0)
1499+
traj = md.join([traj, traj_add[1:]])
1500+
traj.save_xtc(f'{working_dir}/analysis/sim{rep}_concat.xtc')

0 commit comments

Comments
 (0)