Skip to content

Commit 228d211

Browse files
committed
fix merge
2 parents ee284de + 18eff2f commit 228d211

File tree

3 files changed

+18
-12
lines changed

3 files changed

+18
-12
lines changed

ensemble_md/analysis/analyze_free_energy.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ def _calculate_df_adjacent(estimators):
178178
return df_adjacent, df_err_adjacent
179179

180180

181-
def _combine_df_adjacent(df_adjacent, state_ranges, n_tot, df_err_adjacent=None, err_type="propagate"):
181+
def _combine_df_adjacent(df_adjacent, state_ranges, df_err_adjacent=None, err_type="propagate"):
182182
"""
183183
An internal function used in :func:`calculate_free_energy` to combine the free energy differences between
184184
adjacent states in different state ranges using either simple means or inverse-variance weighted means.
@@ -215,6 +215,7 @@ def _combine_df_adjacent(df_adjacent, state_ranges, n_tot, df_err_adjacent=None,
215215
--------
216216
:func:`calculate_free_energy`
217217
"""
218+
n_tot = state_ranges[-1][-1] + 1
218219
df, df_err, overlap_bool = [], [], []
219220
for i in range(n_tot - 1):
220221
# df_list is a list of free energy difference between sates i and i+1 in different replicas
@@ -248,7 +249,7 @@ def _combine_df_adjacent(df_adjacent, state_ranges, n_tot, df_err_adjacent=None,
248249
return df, df_err, overlap_bool
249250

250251

251-
def calculate_free_energy(data, state_ranges, df_method="MBAR", err_method="propagate", n_bootstrap=None, seed=None, MTREXEE=False):
252+
def calculate_free_energy(data, state_ranges, df_method="MBAR", err_method="propagate", n_bootstrap=None, seed=None, MTREXEE=False): # noqa: E501
252253
"""
253254
Caculates the averaged free energy profile with the chosen method given :math:`u_{nk}` or :math:`dH/dλ` data
254255
obtained from all replicas of the REXEE simulation. Available methods include TI, BAR, and MBAR. TI
@@ -307,8 +308,11 @@ def calculate_free_energy(data, state_ranges, df_method="MBAR", err_method="prop
307308
else:
308309
n_tot = state_ranges[-1] + 1
309310
estimators = _apply_estimators(data, df_method)
310-
df_adjacent, df_err_adjacent = _calculate_df_adjacent(estimators)
311-
df, df_err, overlap_bool = _combine_df_adjacent(df_adjacent, state_ranges, n_tot, df_err_adjacent, err_type='propagate')
311+
if MTREXEE is False:
312+
df_adjacent, df_err_adjacent = _calculate_df_adjacent(estimators)
313+
df, df_err, overlap_bool = _combine_df_adjacent(df_adjacent, state_ranges, df_err_adjacent, err_type='propagate') # noqa: E501
314+
else:
315+
df, df_err = _calculate_df_adjacent(estimators)
312316

313317
if err_method == 'bootstrap':
314318
if seed is not None:
@@ -320,15 +324,18 @@ def calculate_free_energy(data, state_ranges, df_method="MBAR", err_method="prop
320324
for b in range(n_bootstrap):
321325
sampled_data = [sampled_data_all[i].iloc[b * len(data[i]):(b + 1) * len(data[i])] for i in range(n_sim)]
322326
bootstrap_estimators = _apply_estimators(sampled_data, df_method)
323-
df_adjacent, df_err_adjacent = _calculate_df_adjacent(bootstrap_estimators)
324-
df_sampled, _, overlap_bool = _combine_df_adjacent(df_adjacent, state_ranges, n_tot, df_err_adjacent, err_type='propagate') # doesn't matter what value err_type here is # noqa: E501
327+
if MTREXEE is False:
328+
df_adjacent, df_err_adjacent = _calculate_df_adjacent(bootstrap_estimators)
329+
df_sampled, _, overlap_bool = _combine_df_adjacent(df_adjacent, state_ranges, df_err_adjacent, err_type='propagate') # doesn't matter what value err_type here is # noqa: E501
330+
else:
331+
df_sampled, _ = _calculate_df_adjacent(bootstrap_estimators)
325332
df_bootstrap.append(df_sampled)
326333
error_bootstrap = np.std(df_bootstrap, axis=0, ddof=1)
327334

328335
# Replace the value in df_err with value in error_bootstrap if df_err corresponds to
329336
# the df between overlapping states
330337
for i in range(n_tot - 1):
331-
if overlap_bool[i] is True:
338+
if MTREXEE is True or overlap_bool[i] is True:
332339
print(f'Replaced the propagated error with the bootstrapped error for states {i} and {i + 1}: {df_err[i]:.5f} -> {error_bootstrap[i]:.5f}.') # noqa: E501
333340
df_err[i] = error_bootstrap[i]
334341
elif err_method == 'propagate':

ensemble_md/cli/run_REXEE.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -322,7 +322,7 @@ def main():
322322
gro = f'{REXEE.working_dir}/sim_{j}/iteration_{i-1}/confout.gro'
323323
if os.path.exists(gro_backup):
324324
os.rename(gro_backup, gro)
325-
325+
326326
for j in range(len(swap_list)):
327327
print('\nModifying the coordinates of the following output GRO files ...')
328328
# gro_1 and gro_2 are the simlation outputs (that we want to back up) and the inputs to modify_coords # noqa: E501

ensemble_md/tests/test_analyze_free_energy.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -186,26 +186,25 @@ def test_combine_df_adjacent():
186186
df_adjacent = [[1, 3], [4, 6]]
187187
df_err_adjacent = [[0.1, 0.1], [0.1, 0.1]]
188188
state_ranges = [[0, 1, 2], [1, 2, 3]]
189-
n_tot = state_ranges[-1][-1] + 1
190189

191190
# Test 1: df_err_adjacent is None (in which case err_type is ignored)
192191
# Note that this test would lead to two harmless RuntimWarnings due to calculations like np.std([1], ddof=1), which return NaN # noqa: E501
193-
results = analyze_free_energy._combine_df_adjacent(df_adjacent, state_ranges, n_tot, None, "propagate")
192+
results = analyze_free_energy._combine_df_adjacent(df_adjacent, state_ranges, None, "propagate")
194193
assert results[0] == [1, 3.5, 6]
195194
assert math.isnan(results[1][0])
196195
assert results[1][1] == np.std([3, 4], ddof=1)
197196
assert math.isnan(results[1][2])
198197
assert results[2] == [False, True, False]
199198

200199
# Test 2: df_err_adjacent is not None and err_type is "std"
201-
results = analyze_free_energy._combine_df_adjacent(df_adjacent, state_ranges, n_tot, df_err_adjacent, "std")
200+
results = analyze_free_energy._combine_df_adjacent(df_adjacent, state_ranges, df_err_adjacent, "std")
202201
assert results[0] == [1, 3.5, 6]
203202
np.testing.assert_array_almost_equal(results[1], [0.1, np.std([3, 4], ddof=1), 0.1])
204203
assert results[2] == [False, True, False]
205204

206205
# Test 3: df_err_adjacent is not None and err_type is "propagate"
207206
df_err_adjacent = [[0.1, 0.1], [0.2, 0.1]] # make the errs different so that the weighted mean will not be equal to simple mean # noqa: E501
208-
results = analyze_free_energy._combine_df_adjacent(df_adjacent, state_ranges, n_tot, df_err_adjacent, "propagate")
207+
results = analyze_free_energy._combine_df_adjacent(df_adjacent, state_ranges, df_err_adjacent, "propagate")
209208
assert results[0] == [1, utils.weighted_mean([3, 4], [0.1, 0.2])[0], 6]
210209
assert results[1] == [0.1, utils.weighted_mean([3, 4], [0.1, 0.2])[1], 0.1]
211210
assert results[2] == [False, True, False]

0 commit comments

Comments
 (0)