Skip to content

Commit 9ee86f9

Browse files
committed
Developed get_g_evolution and modified get_dg_evolution
1 parent 43ac209 commit 9ee86f9

File tree

1 file changed

+70
-37
lines changed

1 file changed

+70
-37
lines changed

ensemble_md/analysis/analyze_traj.py

Lines changed: 70 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -965,15 +965,76 @@ def plot_swaps(swaps, swap_type='', stack=True, figsize=None):
965965
plt.savefig(f'{swap_type}_swaps.png', dpi=600)
966966

967967

968-
def get_dg_evolution(log_file, start_state, end_state):
968+
def get_g_evolution(log_files, N_states, avg_frac=0):
969+
"""
970+
For weight-updating simulations, gets the time series of the alchemical
971+
weights of all states.
972+
973+
Parameters
974+
----------
975+
log_files : list
976+
The list of log file names.
977+
N_states : int
978+
The total number of states in the whole alchemical range.
979+
avg_frac : float
980+
The fraction of the last part of the simulation to be averaged. The
981+
default is 0, which means no averaging.
982+
983+
984+
Returns
985+
-------
986+
g_vecs_all : list
987+
The alchemical weights of all states as a function of time.
988+
It should be a list of lists.
989+
g_vecs_avg : list
990+
The alchemical weights of all states averaged over the last part of
991+
the simulation. If :code:`avg_frac` is 0, :code:`None` will be returned.
992+
"""
993+
g_vecs_all = []
994+
for log_file in log_files:
995+
f = open(log_file, "r")
996+
lines = f.readlines()
997+
f.close()
998+
999+
n = -1
1000+
find_equil = False
1001+
for line in lines:
1002+
n += 1
1003+
if "Count G(in kT)" in line: # this line is lines[n]
1004+
w = [] # the list of weights at this time frame
1005+
for i in range(1, N_states + 1):
1006+
if "<<" in lines[n + i]:
1007+
w.append(float(lines[n + i].split()[-3]))
1008+
else:
1009+
w.append(float(lines[n + i].split()[-2]))
1010+
1011+
if find_equil is False:
1012+
g_vecs_all.append(w)
1013+
1014+
if "Weights have equilibrated" in line:
1015+
find_equil = True
1016+
w = [float(i) for i in lines[n - 2].split(':')[-1].split()]
1017+
g_vecs_all.append(w)
1018+
break
1019+
1020+
if avg_frac != 0:
1021+
n_avg = int(avg_frac * len(g_vecs_all))
1022+
g_vecs_avg = np.mean(g_vecs_all[-n_avg:], axis=0)
1023+
else:
1024+
g_vecs_avg = None
1025+
1026+
return g_vecs_all, g_vecs_avg
1027+
1028+
1029+
def get_dg_evolution(log_files, start_state, end_state):
9691030
"""
9701031
For weight-updating simulations, gets the time series of the weight
9711032
difference (:math:`Δg = g_2-g_1`) between the specified states.
9721033
9731034
Parameters
9741035
----------
975-
log_file : str
976-
The log file name.
1036+
log_files : list
1037+
The list of log file names.
9771038
start_state : int
9781039
The index of the state (starting from 0) whose weight is :math:`g_1`.
9791040
end_state : int
@@ -984,45 +1045,22 @@ def get_dg_evolution(log_file, start_state, end_state):
9841045
dg : list
9851046
A list of :math:`Δg` values.
9861047
"""
987-
f = open(log_file, "r")
988-
lines = f.readlines()
989-
f.close()
990-
991-
n = -1
992-
find_equil = False
993-
dg = []
9941048
N_states = end_state - start_state + 1 # number of states for the range of insterest
995-
for line in lines:
996-
n += 1
997-
if "Count G(in kT)" in line: # this line is lines[n]
998-
w = [] # the list of weights at this time frame
999-
for i in range(1, N_states + 1):
1000-
if "<<" in lines[n + i]:
1001-
w.append(float(lines[n + i].split()[-3]))
1002-
else:
1003-
w.append(float(lines[n + i].split()[-2]))
1004-
1005-
if find_equil is False:
1006-
dg.append(w[end_state] - w[start_state])
1007-
1008-
if "Weights have equilibrated" in line:
1009-
find_equil = True
1010-
w = [float(i) for i in lines[n - 2].split(':')[-1].split()]
1011-
dg.append(w[end_state] - w[start_state])
1012-
break
1049+
g_vecs = get_g_evolution(log_files, N_states)
1050+
dg = [g_vecs[i][end_state] - g_vecs[i][start_state] for i in range(len(g_vecs))]
10131051

10141052
return dg
10151053

10161054

1017-
def plot_dg_evolution(log_file, start_state, end_state, start_idx=0, end_idx=-1, dt_log=2):
1055+
def plot_dg_evolution(log_files, start_state, end_state, start_idx=0, end_idx=-1, dt_log=2):
10181056
"""
10191057
For weight-updating simulations, plots the time series of the weight
10201058
difference (:math:`Δg = g_2-g_1`) between the specified states.
10211059
10221060
Parameters
10231061
----------
1024-
log_file : str or list
1025-
The log file name or a list of log file names.
1062+
log_files : list
1063+
The list of log file names.
10261064
start_state : int
10271065
The index of the state (starting from 0) whose weight is :math:`g_1`.
10281066
end_state : int
@@ -1035,12 +1073,7 @@ def plot_dg_evolution(log_file, start_state, end_state, start_idx=0, end_idx=-1,
10351073
The time interval between two consecutive frames in the log file. The
10361074
default is 2 ps.
10371075
"""
1038-
if isinstance(log_file, str):
1039-
dg = get_dg_evolution(log_file, start_state, end_state)
1040-
else:
1041-
dg = []
1042-
for f in log_file:
1043-
dg += get_dg_evolution(f, start_state, end_state)
1076+
dg = get_dg_evolution(log_files, start_state, end_state)
10441077

10451078
# Now we plot
10461079
dg = dg[start_idx:end_idx]

0 commit comments

Comments
 (0)