Skip to content

Commit f4f833e

Browse files
committed
Fix FE function for MTREXEE
1 parent 8903a7d commit f4f833e

File tree

2 files changed

+9
-7
lines changed

2 files changed

+9
-7
lines changed

ensemble_md/analysis/analyze_free_energy.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ def _calculate_df_adjacent(estimators):
178178
return df_adjacent, df_err_adjacent
179179

180180

181-
def _combine_df_adjacent(df_adjacent, state_ranges, df_err_adjacent=None, err_type="propagate"):
181+
def _combine_df_adjacent(df_adjacent, state_ranges, n_tot, df_err_adjacent=None, err_type="propagate"):
182182
"""
183183
An internal function used in :func:`calculate_free_energy` to combine the free energy differences between
184184
adjacent states in different state ranges using either simple means or inverse-variance weighted means.
@@ -189,6 +189,8 @@ def _combine_df_adjacent(df_adjacent, state_ranges, df_err_adjacent=None, err_ty
189189
A list of lists free energy differences between adjacent states for all replicas.
190190
state_ranges : list
191191
A list of lists of showing the state indices sampled by each replica.
192+
n_tot : int
193+
Number of lambda states
192194
df_err_adjacent : list, Optional
193195
A list of lists of uncertainties corresponding to the values of :code:`df_adjacent`. Notably, if
194196
:code:`df_err_adjacent` is :code:`None`, simple means will be used. Otherwise, inverse-variance weighted
@@ -213,7 +215,6 @@ def _combine_df_adjacent(df_adjacent, state_ranges, df_err_adjacent=None, err_ty
213215
--------
214216
:func:`calculate_free_energy`
215217
"""
216-
n_tot = state_ranges[-1][-1] + 1
217218
df, df_err, overlap_bool = [], [], []
218219
for i in range(n_tot - 1):
219220
# df_list is a list of free energy difference between sates i and i+1 in different replicas
@@ -307,7 +308,7 @@ def calculate_free_energy(data, state_ranges, df_method="MBAR", err_method="prop
307308
n_tot = state_ranges[-1] + 1
308309
estimators = _apply_estimators(data, df_method)
309310
df_adjacent, df_err_adjacent = _calculate_df_adjacent(estimators)
310-
df, df_err, overlap_bool = _combine_df_adjacent(df_adjacent, state_ranges, df_err_adjacent, err_type='propagate')
311+
df, df_err, overlap_bool = _combine_df_adjacent(df_adjacent, state_ranges, n_tot, df_err_adjacent, err_type='propagate')
311312

312313
if err_method == 'bootstrap':
313314
if seed is not None:
@@ -320,7 +321,7 @@ def calculate_free_energy(data, state_ranges, df_method="MBAR", err_method="prop
320321
sampled_data = [sampled_data_all[i].iloc[b * len(data[i]):(b + 1) * len(data[i])] for i in range(n_sim)]
321322
bootstrap_estimators = _apply_estimators(sampled_data, df_method)
322323
df_adjacent, df_err_adjacent = _calculate_df_adjacent(bootstrap_estimators)
323-
df_sampled, _, overlap_bool = _combine_df_adjacent(df_adjacent, state_ranges, df_err_adjacent, err_type='propagate') # doesn't matter what value err_type here is # noqa: E501
324+
df_sampled, _, overlap_bool = _combine_df_adjacent(df_adjacent, state_ranges, n_tot, df_err_adjacent, err_type='propagate') # doesn't matter what value err_type here is # noqa: E501
324325
df_bootstrap.append(df_sampled)
325326
error_bootstrap = np.std(df_bootstrap, axis=0, ddof=1)
326327

ensemble_md/tests/test_analyze_free_energy.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -186,25 +186,26 @@ def test_combine_df_adjacent():
186186
df_adjacent = [[1, 3], [4, 6]]
187187
df_err_adjacent = [[0.1, 0.1], [0.1, 0.1]]
188188
state_ranges = [[0, 1, 2], [1, 2, 3]]
189+
n_tot = state_ranges[-1][-1] + 1
189190

190191
# Test 1: df_err_adjacent is None (in which case err_type is ignored)
191192
# Note that this test would lead to two harmless RuntimWarnings due to calculations like np.std([1], ddof=1), which return NaN # noqa: E501
192-
results = analyze_free_energy._combine_df_adjacent(df_adjacent, state_ranges, None, "propagate")
193+
results = analyze_free_energy._combine_df_adjacent(df_adjacent, state_ranges, n_tot, None, "propagate")
193194
assert results[0] == [1, 3.5, 6]
194195
assert math.isnan(results[1][0])
195196
assert results[1][1] == np.std([3, 4], ddof=1)
196197
assert math.isnan(results[1][2])
197198
assert results[2] == [False, True, False]
198199

199200
# Test 2: df_err_adjacent is not None and err_type is "std"
200-
results = analyze_free_energy._combine_df_adjacent(df_adjacent, state_ranges, df_err_adjacent, "std")
201+
results = analyze_free_energy._combine_df_adjacent(df_adjacent, state_ranges, n_tot, df_err_adjacent, "std")
201202
assert results[0] == [1, 3.5, 6]
202203
np.testing.assert_array_almost_equal(results[1], [0.1, np.std([3, 4], ddof=1), 0.1])
203204
assert results[2] == [False, True, False]
204205

205206
# Test 3: df_err_adjacent is not None and err_type is "propagate"
206207
df_err_adjacent = [[0.1, 0.1], [0.2, 0.1]] # make the errs different so that the weighted mean will not be equal to simple mean # noqa: E501
207-
results = analyze_free_energy._combine_df_adjacent(df_adjacent, state_ranges, df_err_adjacent, "propagate")
208+
results = analyze_free_energy._combine_df_adjacent(df_adjacent, state_ranges, n_tot, df_err_adjacent, "propagate")
208209
assert results[0] == [1, utils.weighted_mean([3, 4], [0.1, 0.2])[0], 6]
209210
assert results[1] == [0.1, utils.weighted_mean([3, 4], [0.1, 0.2])[1], 0.1]
210211
assert results[2] == [False, True, False]

0 commit comments

Comments
 (0)