Skip to content

Commit a9a19cf

Browse files
committed
Fix Plots
- add missing xlabel - change "Mean Difference" to "Mean Distance" - reduce plotting data - plot max 1000000 random positions to reduce runtime and prevent crashing
1 parent 30962ad commit a9a19cf

File tree

1 file changed

+16
-10
lines changed

1 file changed

+16
-10
lines changed

magnipore/magnipore.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -244,10 +244,10 @@ def magnipore(mapping : dict, unaligned : dict, seq_dict : dict, aln_dict: dict,
244244
num_indels, sign_pos, nans = 0, 0, 0
245245

246246
# TODO add some quality value
247-
plotting_data = pd.DataFrame(columns=['Mean Difference', 'Avg Stdev', 'Strand', 'Mutational Context', 'Significant', 'TD Score', 'KL Divergence'])
247+
plotting_data = pd.DataFrame(columns=['Mean Distance', 'Avg Stdev', 'Strand', 'Mutational Context', 'Significant', 'TD Score', 'KL Divergence'])
248248
plotting_data = plotting_data.astype(
249249
{
250-
'Mean Difference': 'float32',
250+
'Mean Distance': 'float32',
251251
'Avg Stdev': 'float32',
252252
'Strand': 'bool',
253253
'Mutational Context': 'bool',
@@ -309,7 +309,7 @@ def magnipore(mapping : dict, unaligned : dict, seq_dict : dict, aln_dict: dict,
309309
td = td_score(mDiff, sAvg)
310310
significant = td>=1
311311
new_entry = pd.DataFrame({
312-
'Mean Difference' : [mDiff],
312+
'Mean Distance' : [mDiff],
313313
'Avg Stdev' : [sAvg],
314314
'Strand' : [strand],
315315
'Mutational Context' : [mut_context],
@@ -385,10 +385,15 @@ def plotStatistics(plotting_data : pd.DataFrame, working_dir : str, first_sample
385385
plot_dir = os.path.join(working_dir, 'magnipore', f'{first_sample_label}_{sec_sample_label}', 'plots')
386386
if not os.path.exists(plot_dir):
387387
os.mkdir(plot_dir)
388-
### Mean Dist vs Std Avg plot
388+
# reduce plotting_data, if it got too large too reduce runtime and prevent the kernel from killing the process
389+
plotting_threshold = 1000000 # arbitrary threshold
390+
if len(plotting_data.index) > plotting_threshold:
391+
LOGGER.printLog(f'The number of positions exceeds the threshold of {plotting_data} ({len(plotting_data.index)}). To prevent the kernel from killing the process, Magnipore will only plot a subset of {plotting_data} positions. Plots will not include the full data.')
392+
plotting_data = plotting_data.sample(plotting_threshold, replace=False)
393+
# Mean Dist vs Std Avg plot
389394
LOGGER.printLog('Plotting Mean vs Stdev')
390395
plotMeanDiffStdAvg(plotting_data, plot_dir, first_sample_label, sec_sample_label)
391-
### plot scores
396+
# plot scores
392397
LOGGER.printLog(f'Plotting TD score and KL divergence')
393398
plotScores(plotting_data, plot_dir, first_sample_label, sec_sample_label)
394399

@@ -434,17 +439,17 @@ def plotMeanDiffStdAvg(dataframe : pd.DataFrame, working_dir : str, first_sample
434439
label1 = first_sample_label.replace("_", " ")
435440
label2 = sec_sample_label.replace("_", " ")
436441

437-
g = sns.JointGrid(x='Mean Difference', y='Avg Stdev', data=dataframe, hue='Mutational Context', marginal_ticks=True, palette=['blue', '#d95f02'], hue_order=[True, False], height = 10)
442+
g = sns.JointGrid(x='Mean Distance', y='Avg Stdev', data=dataframe, hue='Mutational Context', marginal_ticks=True, palette=['blue', '#d95f02'], hue_order=[True, False], height = 10)
438443
g.plot_joint(func=sns.scatterplot, s = 8)
439444
g.ax_joint.cla()
440445
for _, row in dataframe.iterrows():
441-
g.ax_joint.plot(row['Mean Difference'], row['Avg Stdev'], color = color(row['Mutational Context']), marker = marker(row['Mutational Context']), markersize=3, alpha = 0.6)
446+
g.ax_joint.plot(row['Mean Distance'], row['Avg Stdev'], color = color(row['Mutational Context']), marker = marker(row['Mutational Context']), markersize=3, alpha = 0.6)
442447

443-
g.fig.suptitle(f'{len(dataframe.index)} compared bases mean difference against\naverage standard deviation\n{label1} and {label2}', y=0.98)
448+
g.fig.suptitle(f'{len(dataframe.index)} compared bases mean distance against\naverage standard deviation\n{label1} and {label2}', y=0.98)
444449
g.ax_joint.grid(True, 'both', 'both', alpha = 0.4, linestyle = '--', linewidth = 0.5)
445450

446451
lims = np.array([
447-
[-.02, max(dataframe['Mean Difference']) + 0.1],
452+
[-.02, max(dataframe['Mean Distance']) + 0.1],
448453
[-.02, max(dataframe['Avg Stdev']) + 0.1]
449454
])
450455

@@ -457,7 +462,8 @@ def plotMeanDiffStdAvg(dataframe : pd.DataFrame, working_dir : str, first_sample
457462
g.ax_joint.set_xlim(tuple(lims[0]))
458463
g.ax_joint.set_ylim(tuple(lims[1]))
459464

460-
g.ax_joint.set_ylabel('Average Standard Deviation')
465+
g.ax_joint.set_ylabel('Average standard deviation')
466+
g.ax_joint.set_xlabel('Mean distance')
461467
legend_mut = mlines.Line2D([], [], color='blue', marker='D', linestyle='None', markersize=10, label='mutation')
462468
legend_mod = mlines.Line2D([], [], color='#d95f02', marker='o', linestyle='None', markersize=10, label='matching reference')
463469
sign = mlines.Line2D([], [], color='#1b9e77', marker='s', linestyle='None', markersize=10, label='significant, TD>=1')

0 commit comments

Comments
 (0)