@@ -965,15 +965,76 @@ def plot_swaps(swaps, swap_type='', stack=True, figsize=None):
965965 plt .savefig (f'{ swap_type } _swaps.png' , dpi = 600 )
966966
967967
968- def get_dg_evolution (log_file , start_state , end_state ):
968+ def get_g_evolution (log_files , N_states , avg_frac = 0 ):
969+ """
970+ For weight-updating simulations, gets the time series of the alchemical
971+ weights of all states.
972+
973+ Parameters
974+ ----------
975+ log_files : list
976+ The list of log file names.
977+ N_states : int
978+ The total number of states in the whole alchemical range.
979+ avg_frac : float
980+ The fraction of the last part of the simulation to be averaged. The
981+ default is 0, which means no averaging.
982+
983+
984+ Returns
985+ -------
986+ g_vecs_all : list
987+ The alchemical weights of all states as a function of time.
988+ It should be a list of lists.
989+ g_vecs_avg : list
990+ The alchemical weights of all states averaged over the last part of
991+ the simulation. If :code:`avg_frac` is 0, :code:`None` will be returned.
992+ """
993+ g_vecs_all = []
994+ for log_file in log_files :
995+ f = open (log_file , "r" )
996+ lines = f .readlines ()
997+ f .close ()
998+
999+ n = - 1
1000+ find_equil = False
1001+ for line in lines :
1002+ n += 1
1003+ if "Count G(in kT)" in line : # this line is lines[n]
1004+ w = [] # the list of weights at this time frame
1005+ for i in range (1 , N_states + 1 ):
1006+ if "<<" in lines [n + i ]:
1007+ w .append (float (lines [n + i ].split ()[- 3 ]))
1008+ else :
1009+ w .append (float (lines [n + i ].split ()[- 2 ]))
1010+
1011+ if find_equil is False :
1012+ g_vecs_all .append (w )
1013+
1014+ if "Weights have equilibrated" in line :
1015+ find_equil = True
1016+ w = [float (i ) for i in lines [n - 2 ].split (':' )[- 1 ].split ()]
1017+ g_vecs_all .append (w )
1018+ break
1019+
1020+ if avg_frac != 0 :
1021+ n_avg = int (avg_frac * len (g_vecs_all ))
1022+ g_vecs_avg = np .mean (g_vecs_all [- n_avg :], axis = 0 )
1023+ else :
1024+ g_vecs_avg = None
1025+
1026+ return g_vecs_all , g_vecs_avg
1027+
1028+
1029+ def get_dg_evolution (log_files , start_state , end_state ):
9691030 """
9701031 For weight-updating simulations, gets the time series of the weight
9711032 difference (:math:`Δg = g_2-g_1`) between the specified states.
9721033
9731034 Parameters
9741035 ----------
975- log_file : str
976- The log file name .
1036+ log_files : list
1037+ The list of log file names .
9771038 start_state : int
9781039 The index of the state (starting from 0) whose weight is :math:`g_1`.
9791040 end_state : int
@@ -984,45 +1045,22 @@ def get_dg_evolution(log_file, start_state, end_state):
9841045 dg : list
9851046 A list of :math:`Δg` values.
9861047 """
987- f = open (log_file , "r" )
988- lines = f .readlines ()
989- f .close ()
990-
991- n = - 1
992- find_equil = False
993- dg = []
9941048 N_states = end_state - start_state + 1 # number of states for the range of insterest
995- for line in lines :
996- n += 1
997- if "Count G(in kT)" in line : # this line is lines[n]
998- w = [] # the list of weights at this time frame
999- for i in range (1 , N_states + 1 ):
1000- if "<<" in lines [n + i ]:
1001- w .append (float (lines [n + i ].split ()[- 3 ]))
1002- else :
1003- w .append (float (lines [n + i ].split ()[- 2 ]))
1004-
1005- if find_equil is False :
1006- dg .append (w [end_state ] - w [start_state ])
1007-
1008- if "Weights have equilibrated" in line :
1009- find_equil = True
1010- w = [float (i ) for i in lines [n - 2 ].split (':' )[- 1 ].split ()]
1011- dg .append (w [end_state ] - w [start_state ])
1012- break
1049+ g_vecs = get_g_evolution (log_files , N_states )
1050+ dg = [g_vecs [i ][end_state ] - g_vecs [i ][start_state ] for i in range (len (g_vecs ))]
10131051
10141052 return dg
10151053
10161054
1017- def plot_dg_evolution (log_file , start_state , end_state , start_idx = 0 , end_idx = - 1 , dt_log = 2 ):
1055+ def plot_dg_evolution (log_files , start_state , end_state , start_idx = 0 , end_idx = - 1 , dt_log = 2 ):
10181056 """
10191057 For weight-updating simulations, plots the time series of the weight
10201058 difference (:math:`Δg = g_2-g_1`) between the specified states.
10211059
10221060 Parameters
10231061 ----------
1024- log_file : str or list
1025- The log file name or a list of log file names.
1062+ log_files : list
1063+ The list of log file names.
10261064 start_state : int
10271065 The index of the state (starting from 0) whose weight is :math:`g_1`.
10281066 end_state : int
@@ -1035,12 +1073,7 @@ def plot_dg_evolution(log_file, start_state, end_state, start_idx=0, end_idx=-1,
10351073 The time interval between two consecutive frames in the log file. The
10361074 default is 2 ps.
10371075 """
1038- if isinstance (log_file , str ):
1039- dg = get_dg_evolution (log_file , start_state , end_state )
1040- else :
1041- dg = []
1042- for f in log_file :
1043- dg += get_dg_evolution (f , start_state , end_state )
1076+ dg = get_dg_evolution (log_files , start_state , end_state )
10441077
10451078 # Now we plot
10461079 dg = dg [start_idx :end_idx ]
0 commit comments