Skip to content

Commit 3739064

Browse files
committed
Fix linting
1 parent 01f1a6d commit 3739064

File tree

3 files changed

+54
-14
lines changed

3 files changed

+54
-14
lines changed

ensemble_md/analysis/analyze_traj.py

Lines changed: 44 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -833,7 +833,7 @@ def plot_transit_time(trajs, N, fig_prefix=None, dt=None, folder='.'):
833833
plt.savefig(f'{folder}/hist_{fig_names[t]}', dpi=600)
834834
else:
835835
plt.savefig(f'{folder}/{fig_prefix}_hist_{fig_names[t]}', dpi=600)
836-
#Save to csv
836+
# Save to csv
837837
sim_list, rt_list = [], []
838838
for n in range(len(t_roundtrip_list)):
839839
for rt in t_roundtrip_list[n]:
@@ -1350,6 +1350,30 @@ def get_delta_w_updates(log_file, plot=False):
13501350

13511351

13521352
def end_states_only_traj(working_dir, n_sim, n_iter, l0_states, l1_states, swap_rep_pattern, ps_per_frame):
1353+
"""
1354+
Create a trajectory which is a concatenation off all frames for each unique end state.
1355+
1356+
Parameters
1357+
----------
1358+
working_dir : str
1359+
path for the current working directory
1360+
n_sim : int
1361+
the number of simulations run
1362+
n_iter : int
1363+
the number of iterations run
1364+
l0_states : list of int
1365+
the lambda states which correspond to lambda=0
1366+
l1_states : list of int
1367+
the lambda states which correspond to lambda=1
1368+
swap_rep_pattern : list of int
1369+
the replica swapping pattern which will indicate which end states are common
1370+
ps_per_frame : float
1371+
the timestep to convert the time in the GROMACS dh/dl file to frames in the trajecotry
1372+
1373+
Returns
1374+
-------
1375+
None
1376+
"""
13531377
import pandas as pd
13541378
import os
13551379
import mdtraj as md
@@ -1434,7 +1458,24 @@ def end_states_only_traj(working_dir, n_sim, n_iter, l0_states, l1_states, swap_
14341458
traj = md.join(traj, traj_add)
14351459
traj.save_xtc(f'{working_dir}/analysis/{state}_{rep}.xtc')
14361460

1461+
14371462
def concat_sim_traj(working_dir, n_sim, n_iter):
1463+
"""
1464+
Create a trajectory which is a concatenation off each iterations trajectory
1465+
1466+
Parameters
1467+
----------
1468+
working_dir : str
1469+
path for the current working directory
1470+
n_sim : int
1471+
the number of simulations run
1472+
n_iter : int
1473+
the number of iterations run
1474+
1475+
Returns
1476+
-------
1477+
None
1478+
"""
14381479
import mdtraj as md
14391480
import os
14401481

@@ -1444,8 +1485,8 @@ def concat_sim_traj(working_dir, n_sim, n_iter):
14441485
else:
14451486
name = 'confout'
14461487

1447-
traj = md.load(f'{working_dir}/sim_{rep}/iteration_0/traj.trr', top=f'{working_dir}/sim_{rep}/iteration_0/{name}.gro')
1488+
traj = md.load(f'{working_dir}/sim_{rep}/iteration_0/traj.trr', top=f'{working_dir}/sim_{rep}/iteration_0/{name}.gro') # noqa: E501
14481489
for iteration in range(1, n_iter):
1449-
traj_add = md.load(f'{working_dir}/sim_{rep}/iteration_{iteration}/traj.trr', top=f'{working_dir}/sim_{rep}/iteration_0/{name}.gro')
1490+
traj_add = md.load(f'{working_dir}/sim_{rep}/iteration_{iteration}/traj.trr', top=f'{working_dir}/sim_{rep}/iteration_0/{name}.gro') # noqa: E501
14501491
traj = md.join([traj, traj_add])
14511492
traj.save_xtc(f'{working_dir}/analysis/sim{rep}_concat.xtc')

ensemble_md/cli/analyze_REXEE.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,6 @@ def main():
121121
print('\nData analysis of the simulation ensemble')
122122
print('========================================')
123123

124-
125124
# Section 1. Analysis based on transitions between state sets
126125
print('[ Section 1. Analysis based on transitions between state sets/replicas ]')
127126
section_idx += 1
@@ -182,9 +181,9 @@ def main():
182181
)
183182
rmse = analyze_traj.calc_hist_rmse(hist_data, REXEE.state_ranges)
184183
print(f'The RMSE of accumulated histogram counts of the state index: {rmse:.0f}')
185-
184+
186185
if REXEE.proposal != 'forced_random' and REXEE.proposal != 'forced_swap': # Need to FIX THIS FOR FORCED-RANDOM
187-
# 2-4. Stitch the time series of state index for different replicas
186+
# 2-4. Stitch the time series of state index for different replicas
188187
if os.path.isfile(args.state_trajs_for_sim) is True:
189188
print('\n2-4. Reading in the stitched time series of state index for different replicas ...')
190189
state_trajs_for_sim = np.load(args.state_trajs_for_sim)
@@ -393,8 +392,8 @@ def main():
393392
if REXEE.free_energy is True:
394393
section_idx += 1
395394
print(f'\n[ Section {section_idx}. Free energy calculations ]')
396-
397-
if REXEE.modify_coords == False:
395+
396+
if REXEE.modify_coords is False:
398397
# 4-1. Subsampling the data
399398
data_list = [] # either a list of u_nk or a list of dhdl
400399
if REXEE.df_data_type == 'u_nk':
@@ -411,7 +410,7 @@ def main():
411410
data_list, t_idx_list, g_list = data_all[0], data_all[1], data_all[2]
412411

413412
if data_list == []:
414-
files_list = [natsort.natsorted(glob.glob(f'sim_{i}/iteration_*/*dhdl*xvg')) for i in range(REXEE.n_sim)]
413+
files_list = [natsort.natsorted(glob.glob(f'sim_{i}/iteration_*/*dhdl*xvg')) for i in range(REXEE.n_sim)] # noqa: E501
415414
data_list, t_idx_list, g_list = analyze_free_energy.preprocess_data(files_list, REXEE.temp, REXEE.df_data_type, REXEE.df_spacing) # noqa: E501
416415

417416
data_all = [data_list, t_idx_list, g_list]
@@ -496,14 +495,14 @@ def main():
496495
print(f'The free energy difference between the coupled and decoupled states: {f[-1]:.3f} +/- {f_err[-1]:.3f} kT') # noqa: E501
497496

498497
if REXEE.df_ref is not None:
499-
rmse_list = analyze_free_energy.calculate_df_rmse(estimators, REXEE.df_ref, REXEE.state_ranges[sim])
498+
rmse_list = analyze_free_energy.calculate_df_rmse(estimators, REXEE.df_ref, REXEE.state_ranges[sim]) # noqa: E501
500499
for i in range(REXEE.n_sim):
501500
print(f'RMSE of the free energy profile for alchemical range {i} (states {REXEE.state_ranges[i][0]} to {REXEE.state_ranges[i][-1]}): {rmse_list[i]:.2f} kT') # noqa: E501
502501

503502
# 4-3. Recalculate the free energy profile if subsampling_avg is True
504503
if REXEE.subsampling_avg is True:
505504
print('\nUsing averaged start index of the equilibrated data and the avearged statistic inefficiency to re-perform free energy calculations ...') # noqa: E501
506-
t_avg = int(np.mean(t_idx_list)) + 1 # Using the ceiling function to be a little more conservative
505+
t_avg = int(np.mean(t_idx_list)) + 1 # Using the ceiling function to be a little more conservative # noqa: E501
507506
g_avg = np.array(g_list).prod() ** (1/len(g_list)) # geometric mean
508507
print(f'Averaged start index: {t_avg}')
509508
print(f'Averaged statistical inefficiency: {g_avg:.2f}')
@@ -515,14 +514,14 @@ def main():
515514

516515
f, f_err, estimators = analyze_free_energy.calculate_free_energy(data_list, REXEE.state_ranges[sim], REXEE.df_method, REXEE.err_method, REXEE.n_bootstrap, REXEE.seed) # noqa: E501
517516
print('Plotting the full-range free energy profile ...')
518-
analyze_free_energy.plot_free_energy(f, f_err, f'{args.dir}/free_energy_profile_avg_subsampling_{sim}.png')
517+
analyze_free_energy.plot_free_energy(f, f_err, f'{args.dir}/free_energy_profile_avg_subsampling_{sim}.png') # noqa: E501
519518

520519
print('The full-range free energy profile averaged over all replicas:')
521520
print(f" {', '.join(f'{f[i]:.3f} +/- {f_err[i]:.3f} kT' for i in range(REXEE.n_tot))}")
522521
print(f'The free energy difference between the coupled and decoupled states: {f[-1]:.3f} +/- {f_err[-1]:.3f} kT') # noqa: E501
523522

524523
if REXEE.df_ref is not None:
525-
rmse_list = analyze_free_energy.calculate_df_rmse(estimators, REXEE.df_ref, REXEE.state_ranges[sim])
524+
rmse_list = analyze_free_energy.calculate_df_rmse(estimators, REXEE.df_ref, REXEE.state_ranges[sim]) # noqa: E501
526525
for i in range(REXEE.n_sim):
527526
print(f'RMSE of the free energy profile for alchemical range {i} (states {REXEE.state_ranges[i][0]} to {REXEE.state_ranges[i][-1]}): {rmse_list[i]:.2f} kT') # noqa: E501
528527

ensemble_md/utils/coordinate_swap.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -361,7 +361,7 @@ def get_miss_coord(mol_align, mol_ref, name_align, name_ref, df_atom_swap, dir,
361361
for a in range(mol_align_select.n_atoms):
362362
if a != conn0_align:
363363
mol_align_select.xyz[0, a, :] = _rotate_point_around_axis(mol_align_select.xyz[0, a, :], mol_ref_select.xyz[0, conn0_ref, :], axis_rot, theta_min) # noqa: E501
364-
364+
365365
# Add coordinates to df
366366
for r in range(len(df_atom_swap.index)):
367367
if df_atom_swap.iloc[r]['Swap'] == dir:

0 commit comments

Comments
 (0)