@@ -189,6 +189,8 @@ def _combine_df_adjacent(df_adjacent, state_ranges, df_err_adjacent=None, err_ty
189189 A list of lists free energy differences between adjacent states for all replicas.
190190 state_ranges : list
191191 A list of lists of showing the state indices sampled by each replica.
192+ n_tot : int
193+ Number of lambda states
192194 df_err_adjacent : list, Optional
193195 A list of lists of uncertainties corresponding to the values of :code:`df_adjacent`. Notably, if
194196 :code:`df_err_adjacent` is :code:`None`, simple means will be used. Otherwise, inverse-variance weighted
@@ -247,7 +249,48 @@ def _combine_df_adjacent(df_adjacent, state_ranges, df_err_adjacent=None, err_ty
247249 return df , df_err , overlap_bool
248250
249251
250- def calculate_free_energy (data , state_ranges , df_method = "MBAR" , err_method = "propagate" , n_bootstrap = None , seed = None ):
252+ def _calculate_df (estimators ):
253+ """
254+ An internal function used in :func:`calculate_free_energy` to calculate a list of free energies between adjacent
255+ states for all replicas.
256+
257+ Parameters
258+ ----------
259+ estimators : list
260+ A list of estimators fitting the input data for all replicas. With this, the user
261+ can access all the free energies and their associated uncertainties for all states and replicas.
262+ In our code, these estimators come from the function :func:`_apply_estimators`.
263+
264+ Returns
265+ -------
266+ df : float
267+ Free energy differences between for specified replica.
268+ df_err : float
269+ Uncertainties corresponding to the values in :code:`df`.
270+
271+ See also
272+ --------
273+ :func:`calculate_free_energy`
274+ """
275+ # Compute FE estimate
276+ df = estimators [0 ].delta_f_
277+ l = np .linspace (0 , 1 , num = len (df .index ))
278+ df .index = l
279+ df .columns = l
280+ est = df .loc [0 , 1 ]
281+ print (df )
282+
283+ # Compute FE extimate error
284+ df_err = estimators [0 ].d_delta_f_
285+ l = np .linspace (0 , 1 , num = len (df_err .index ))
286+ df_err .index = l
287+ df_err .columns = l
288+ err = df_err .loc [0 , 1 ]
289+
290+ return est , err
291+
292+
293+ def calculate_free_energy (data , state_ranges , df_method = "MBAR" , err_method = "propagate" , n_bootstrap = None , seed = None , MTREXEE = False ): # noqa: E501
251294 """
252295 Caculates the averaged free energy profile with the chosen method given :math:`u_{nk}` or :math:`dH/dλ` data
253296 obtained from all replicas of the REXEE simulation. Available methods include TI, BAR, and MBAR. TI
@@ -275,6 +318,8 @@ def calculate_free_energy(data, state_ranges, df_method="MBAR", err_method="prop
275318 seed : int, Optional
276319 The random seed for bootstrapping. Only relevant when :code:`err_method` is :code:`"bootstrap"`.
277320 The default is :code:`None`.
321+ MTREXEE : bool
322+ Whether this is a MT-REXEE simulation or not
278323
279324 Returns
280325 -------
@@ -299,10 +344,17 @@ def calculate_free_energy(data, state_ranges, df_method="MBAR", err_method="prop
299344 >>> f, _, _ = analyze_free_energy.calculate_free_energy(data_list, state_ranges, "MBAR", "propagate")
300345 """
301346 n_sim = len (data )
302- n_tot = state_ranges [- 1 ][- 1 ] + 1
347+ if MTREXEE is False :
348+ n_tot = state_ranges [- 1 ][- 1 ] + 1
349+ else :
350+ n_tot = state_ranges [- 1 ] + 1
303351 estimators = _apply_estimators (data , df_method )
304- df_adjacent , df_err_adjacent = _calculate_df_adjacent (estimators )
305- df , df_err , overlap_bool = _combine_df_adjacent (df_adjacent , state_ranges , df_err_adjacent , err_type = 'propagate' )
352+ print (estimators )
353+ if MTREXEE is False :
354+ df_adjacent , df_err_adjacent = _calculate_df_adjacent (estimators )
355+ df , df_err , overlap_bool = _combine_df_adjacent (df_adjacent , state_ranges , df_err_adjacent , err_type = 'propagate' ) # noqa: E501
356+ else :
357+ df , df_err = _calculate_df (estimators )
306358
307359 if err_method == 'bootstrap' :
308360 if seed is not None :
@@ -314,26 +366,33 @@ def calculate_free_energy(data, state_ranges, df_method="MBAR", err_method="prop
314366 for b in range (n_bootstrap ):
315367 sampled_data = [sampled_data_all [i ].iloc [b * len (data [i ]):(b + 1 ) * len (data [i ])] for i in range (n_sim )]
316368 bootstrap_estimators = _apply_estimators (sampled_data , df_method )
317- df_adjacent , df_err_adjacent = _calculate_df_adjacent (bootstrap_estimators )
318- df_sampled , _ , overlap_bool = _combine_df_adjacent (df_adjacent , state_ranges , df_err_adjacent , err_type = 'propagate' ) # doesn't matter what value err_type here is # noqa: E501
369+ if MTREXEE is False :
370+ df_adjacent , df_err_adjacent = _calculate_df_adjacent (bootstrap_estimators )
371+ df_sampled , _ , overlap_bool = _combine_df_adjacent (df_adjacent , state_ranges , df_err_adjacent , err_type = 'propagate' ) # doesn't matter what value err_type here is # noqa: E501
372+ else :
373+ df_sampled , _ = _calculate_df (bootstrap_estimators )
319374 df_bootstrap .append (df_sampled )
320375 error_bootstrap = np .std (df_bootstrap , axis = 0 , ddof = 1 )
321376
322377 # Replace the value in df_err with value in error_bootstrap if df_err corresponds to
323378 # the df between overlapping states
324379 for i in range (n_tot - 1 ):
325- if overlap_bool [i ] is True :
380+ if MTREXEE is True or overlap_bool [i ] is True :
326381 print (f'Replaced the propagated error with the bootstrapped error for states { i } and { i + 1 } : { df_err [i ]:.5f} -> { error_bootstrap [i ]:.5f} .' ) # noqa: E501
327382 df_err [i ] = error_bootstrap [i ]
328383 elif err_method == 'propagate' :
329384 pass
330385 else :
331386 raise ParameterError ('Specified err_method not available.' )
332-
333- df .insert (0 , 0 )
334- df_err .insert (0 , 0 )
335- f = [sum (df [:(i + 1 )]) for i in range (len (df ))]
336- f_err = [np .sqrt (sum ([x ** 2 for x in df_err [:(i + 1 )]])) for i in range (len (df_err ))]
387+
388+ if MTREXEE is False :
389+ df .insert (0 , 0 )
390+ df_err .insert (0 , 0 )
391+ f = [sum (df [:(i + 1 )]) for i in range (len (df ))]
392+ f_err = [np .sqrt (sum ([x ** 2 for x in df_err [:(i + 1 )]])) for i in range (len (df_err ))]
393+ else :
394+ f = df
395+ f_err = df_err
337396
338397 return f , f_err , estimators
339398
0 commit comments