Skip to content

Commit 9cb910f

Browse files
authored
Merge pull request #25 from wehs7661/hist_correction
Implement the method for histogram correction
2 parents a3023ff + 9ee86f9 commit 9cb910f

File tree

4 files changed

+179
-84
lines changed

4 files changed

+179
-84
lines changed

ensemble_md/analysis/analyze_traj.py

Lines changed: 70 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -965,15 +965,76 @@ def plot_swaps(swaps, swap_type='', stack=True, figsize=None):
965965
plt.savefig(f'{swap_type}_swaps.png', dpi=600)
966966

967967

968-
def get_dg_evolution(log_file, start_state, end_state):
968+
def get_g_evolution(log_files, N_states, avg_frac=0):
969+
"""
970+
For weight-updating simulations, gets the time series of the alchemical
971+
weights of all states.
972+
973+
Parameters
974+
----------
975+
log_files : list
976+
The list of log file names.
977+
N_states : int
978+
The total number of states in the whole alchemical range.
979+
avg_frac : float
980+
The fraction of the last part of the simulation to be averaged. The
981+
default is 0, which means no averaging.
982+
983+
984+
Returns
985+
-------
986+
g_vecs_all : list
987+
The alchemical weights of all states as a function of time.
988+
It should be a list of lists.
989+
g_vecs_avg : list
990+
The alchemical weights of all states averaged over the last part of
991+
the simulation. If :code:`avg_frac` is 0, :code:`None` will be returned.
992+
"""
993+
g_vecs_all = []
994+
for log_file in log_files:
995+
f = open(log_file, "r")
996+
lines = f.readlines()
997+
f.close()
998+
999+
n = -1
1000+
find_equil = False
1001+
for line in lines:
1002+
n += 1
1003+
if "Count G(in kT)" in line: # this line is lines[n]
1004+
w = [] # the list of weights at this time frame
1005+
for i in range(1, N_states + 1):
1006+
if "<<" in lines[n + i]:
1007+
w.append(float(lines[n + i].split()[-3]))
1008+
else:
1009+
w.append(float(lines[n + i].split()[-2]))
1010+
1011+
if find_equil is False:
1012+
g_vecs_all.append(w)
1013+
1014+
if "Weights have equilibrated" in line:
1015+
find_equil = True
1016+
w = [float(i) for i in lines[n - 2].split(':')[-1].split()]
1017+
g_vecs_all.append(w)
1018+
break
1019+
1020+
if avg_frac != 0:
1021+
n_avg = int(avg_frac * len(g_vecs_all))
1022+
g_vecs_avg = np.mean(g_vecs_all[-n_avg:], axis=0)
1023+
else:
1024+
g_vecs_avg = None
1025+
1026+
return g_vecs_all, g_vecs_avg
1027+
1028+
1029+
def get_dg_evolution(log_files, start_state, end_state):
9691030
"""
9701031
For weight-updating simulations, gets the time series of the weight
9711032
difference (:math:`Δg = g_2-g_1`) between the specified states.
9721033
9731034
Parameters
9741035
----------
975-
log_file : str
976-
The log file name.
1036+
log_files : list
1037+
The list of log file names.
9771038
start_state : int
9781039
The index of the state (starting from 0) whose weight is :math:`g_1`.
9791040
end_state : int
@@ -984,45 +1045,22 @@ def get_dg_evolution(log_file, start_state, end_state):
9841045
dg : list
9851046
A list of :math:`Δg` values.
9861047
"""
987-
f = open(log_file, "r")
988-
lines = f.readlines()
989-
f.close()
990-
991-
n = -1
992-
find_equil = False
993-
dg = []
9941048
N_states = end_state - start_state + 1 # number of states for the range of insterest
995-
for line in lines:
996-
n += 1
997-
if "Count G(in kT)" in line: # this line is lines[n]
998-
w = [] # the list of weights at this time frame
999-
for i in range(1, N_states + 1):
1000-
if "<<" in lines[n + i]:
1001-
w.append(float(lines[n + i].split()[-3]))
1002-
else:
1003-
w.append(float(lines[n + i].split()[-2]))
1004-
1005-
if find_equil is False:
1006-
dg.append(w[end_state] - w[start_state])
1007-
1008-
if "Weights have equilibrated" in line:
1009-
find_equil = True
1010-
w = [float(i) for i in lines[n - 2].split(':')[-1].split()]
1011-
dg.append(w[end_state] - w[start_state])
1012-
break
1049+
g_vecs = get_g_evolution(log_files, N_states)
1050+
dg = [g_vecs[i][end_state] - g_vecs[i][start_state] for i in range(len(g_vecs))]
10131051

10141052
return dg
10151053

10161054

1017-
def plot_dg_evolution(log_file, start_state, end_state, start_idx=0, end_idx=-1, dt_log=2):
1055+
def plot_dg_evolution(log_files, start_state, end_state, start_idx=0, end_idx=-1, dt_log=2):
10181056
"""
10191057
For weight-updating simulations, plots the time series of the weight
10201058
difference (:math:`Δg = g_2-g_1`) between the specified states.
10211059
10221060
Parameters
10231061
----------
1024-
log_file : str or list
1025-
The log file name or a list of log file names.
1062+
log_files : list
1063+
The list of log file names.
10261064
start_state : int
10271065
The index of the state (starting from 0) whose weight is :math:`g_1`.
10281066
end_state : int
@@ -1035,12 +1073,7 @@ def plot_dg_evolution(log_file, start_state, end_state, start_idx=0, end_idx=-1,
10351073
The time interval between two consecutive frames in the log file. The
10361074
default is 2 ps.
10371075
"""
1038-
if isinstance(log_file, str):
1039-
dg = get_dg_evolution(log_file, start_state, end_state)
1040-
else:
1041-
dg = []
1042-
for f in log_file:
1043-
dg += get_dg_evolution(f, start_state, end_state)
1076+
dg = get_dg_evolution(log_files, start_state, end_state)
10441077

10451078
# Now we plot
10461079
dg = dg[start_idx:end_idx]

ensemble_md/cli/run_EEXE.py

Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -146,71 +146,72 @@ def main():
146146
dhdl_files = [f'sim_{j}/iteration_{i - 1}/dhdl.xvg' for j in range(EEXE.n_sim)]
147147
log_files = [f'sim_{j}/iteration_{i - 1}/md.log' for j in range(EEXE.n_sim)]
148148
states_ = EEXE.extract_final_dhdl_info(dhdl_files)
149-
wl_delta, weights_, counts = EEXE.extract_final_log_info(log_files)
149+
wl_delta, weights_, counts_ = EEXE.extract_final_log_info(log_files)
150150
print()
151151

152152
# 3-2. Identify swappable pairs, propose swap(s), calculate P_acc, and accept/reject swap(s)
153153
# Note after `get_swapping_pattern`, `states_` and `weights_` won't be necessarily
154154
# since they are updated by `get_swapping_pattern`. (Even if the function does not explicitly
155155
# returns `states_` and `weights_`, `states_` and `weights_` can still be different after
156156
# the use of the function.) Therefore, here we create copies for `states_` and `weights_`
157-
# before the use of `get_swapping_pattern`, so we can use them in `histogram_correction`,
157+
# before the use of `get_swapping_pattern`, so we can use them in `weight_correction`,
158158
# `combine_weights` and `update_MDP`.
159159
states = copy.deepcopy(states_)
160160
weights = copy.deepcopy(weights_)
161+
counts = copy.deepcopy(counts_)
161162
swap_pattern, swap_list = EEXE.get_swapping_pattern(dhdl_files, states_, weights_) # swap_list will only be used for modify_coords # noqa: E501
162163

163-
# 3-3. Perform histogram correction/weight combination
164+
# 3-3. Perform weight correction/weight combination
164165
if wl_delta != [None for i in range(EEXE.n_sim)]: # weight-updating
165166
print(f'\nCurrent Wang-Landau incrementors: {wl_delta}\n')
166167

167168
# (1) First we prepare the weights to be combined.
168-
# Note that although averaged weights are sometimes used for histogram correction/weight combination,
169+
# Note that although averaged weights are sometimes used for weight correction/weight combination,
169170
# the final weights are always used for calculating the acceptance ratio.
170171
if EEXE.N_cutoff != -1 or EEXE.w_combine is not None:
171-
# Only when histogram correction/weight combination is needed.
172+
# Only when weight correction/weight combination is needed.
172173
weights_avg, weights_err = EEXE.get_averaged_weights(log_files)
173174
weights_input = EEXE.prepare_weights(weights_avg, weights) # weights_input is for weight combination # noqa: E501
174175

175-
# (2) Now we perform histogram correction/weight combination.
176+
# (2) Now we perform weight correction/weight combination.
176177
# The product of this step should always be named as "weights" to be used in update_MDP
177178
if EEXE.N_cutoff != -1 and EEXE.w_combine is not None:
178179
# perform both
179180
if weights_input is None:
180-
# Then only histogram correction will be performed
181+
# Then only weight correction will be performed
181182
print('Note: Weight combination is deactivated because the weights are too noisy.')
182-
weights = EEXE.histogram_correction(weights, counts)
183-
_ = EEXE.combine_weights(weights, print_weights=False)[1] # just to print the combiend weights
183+
weights = EEXE.weight_correction(weights, counts)
184+
_ = EEXE.combine_weights(counts_, weights, print_values=False)[1] # just to print the combiend weights # noqa: E501
184185
else:
185-
weights_preprocessed = EEXE.histogram_correction(weights_input, counts)
186+
weights_preprocessed = EEXE.weight_correction(weights_input, counts)
186187
if EEXE.verbose is True:
187188
print('Performing weight combination ...')
188189
else:
189190
print('Performing weight combination ...', end='')
190-
weights, g_vec = EEXE.combine_weights(weights_preprocessed) # inverse-variance weighting seems worse # noqa: E501
191+
counts, weights, g_vec = EEXE.combine_weights(counts_, weights_preprocessed) # inverse-variance weighting seems worse # noqa: E501
191192
EEXE.g_vecs.append(g_vec)
192193
elif EEXE.N_cutoff == -1 and EEXE.w_combine is not None:
193194
# only perform weight combination
194-
print('Note: No histogram correction will be performed.')
195+
print('Note: No weight correction will be performed.')
195196
if weights_input is None:
196197
print('Note: Weight combination is deactivated because the weights are too noisy.')
197-
_ = EEXE.combine_weights(weights, print_weights=False)[1] # just to print the combined weights
198+
_ = EEXE.combine_weights(counts_, weights, print_values=False)[1] # just to print the combined weights # noqa: E501
198199
else:
199200
if EEXE.verbose is True:
200201
print('Performing weight combination ...')
201202
else:
202203
print('Performing weight combination ...', end='')
203-
weights, g_vec = EEXE.combine_weights(weights_input) # inverse-variance weighting seems worse
204+
counts, weights, g_vec = EEXE.combine_weights(counts_, weights_input) # inverse-variance weighting seems worse # noqa: E501
204205
EEXE.g_vecs.append(g_vec)
205206
elif EEXE.N_cutoff != -1 and EEXE.w_combine is None:
206-
# only perform histogram correction
207+
# only perform weight correction
207208
print('Note: No weight combination will be performed.')
208209
weights = EEXE.histogram_correction(weights_input, counts)
209-
_ = EEXE.combine_weights(weights, print_weights=False)[1] # just to print the combined weights
210+
_ = EEXE.combine_weights(counts_, weights, print_values=False)[1] # just to print the combined weights # noqa: E501
210211
else:
211-
print('Note: No histogram correction will be performed.')
212+
print('Note: No weight correction will be performed.')
212213
print('Note: No weight combination will be performed.')
213-
_ = EEXE.combine_weights(weights, print_weights=False)[1] # just to print the combiend weights
214+
_ = EEXE.combine_weights(counts_, weights, print_values=False)[1] # just to print the combiend weights # noqa: E501
214215

215216
# 3-5. Modify the MDP files and swap out the GRO files (if needed)
216217
# Here we keep the lambda range set in mdp the same across different iterations in the same folder but swap out the gro file # noqa: E501

0 commit comments

Comments
 (0)