77import pandas as pd
88import seaborn as sns
99import matplotlib .pyplot as plt
10+ import matplotlib
11+ matplotlib .use ('Agg' )
1012from argparse import ArgumentDefaultsHelpFormatter , ArgumentParser , Namespace
1113
1214def parse () -> Namespace :
@@ -33,7 +35,7 @@ def process_csv(input_csv: str) -> pd.DataFrame:
3335 df .loc [df ['Metric' ] == 'present' , 'Metric' ] = 'segmented reads'
3436 df .loc [df ['Metric' ] == 'missing' , 'Metric' ] = 'missing reads'
3537
36- #! Collect meta data for controls and dorado
38+ # Collect meta data for controls and dorado
3739 # total_reads = df.loc[df['Metric'] == 'total reads', 'Value'].values[0]
3840 # min_length = df.loc[df['Metric'] == 'min length', 'Value'].values[0]
3941 # max_length = df.loc[df['Metric'] == 'max length', 'Value'].values[0]
@@ -53,48 +55,161 @@ def process_csv(input_csv: str) -> pd.DataFrame:
5355 # ], ignore_index=True
5456 # )
5557
56- # #! Add trivial values for Dorado
58+ #! Add trivial values for Dorado
59+ # print(df)
5760 # df = pd.concat(
5861 # [
5962 # df, pd.DataFrame({
6063 # "Tool": ["Dorado"] * 10,
6164 # "Metric": ["segmented reads", "missing reads", "truncated reads", "identical reads", "nt changed", "min length", "mean length", "median length", "n50 length", "max length"],
6265 # "Value": [total_reads, 0, 0, total_reads, 0, min_length, mean_length, median_length, n50_length, max_length],
63- # "Metric Score": [1.0, 1.0, 1.0, 1.0, 1.0, df.loc[(df['Tool'] == 'Ctrl R.') & (df['Metric'] == 'min length'), 'Metric Score'].squeeze() , 1.0, 1.0, 1.0, 1.0],
66+ # "Metric Score": [1.0, 1.0, 1.0, 1.0, 1.0, 1.0 , 1.0, 1.0, 1.0, 1.0],
6467 # })
6568 # ], ignore_index=True
6669 # )
70+ # print(df)
71+
72+ #! include specific metrics
73+ df_for_agg = df [df ["Metric" ].isin ([
74+ "median delta" ,
75+ "mad delta" ,
76+ "homogeneity" ,
77+ "segmented reads" ,
78+ # "missing reads",
79+ "truncated reads" ,
80+ # "identical reads",
81+ # "nt changed",
82+ "min length" ,
83+ # "mean length",
84+ # "median length",
85+ "n50 length" ,
86+ "max length" ,
87+ "flye total length" ,
88+ "flye n50" ,
89+ "flye mean coverage" ,
90+ "svim structural variants" ,
91+ ])]
6792
6893 # Calculate metric score sum for all tools
69- ams = df .groupby ("Tool" )["Metric Score" ].sum ().reset_index ()
94+ # ams = df.groupby("Tool")["Metric Score"].sum().reset_index()
95+ ams = df_for_agg .groupby ('Tool' )["Metric Score" ].sum ().reset_index ()
7096 ams ["Dataset" ] = f"{ input_csv .split ('/' )[1 ]} { input_csv .split ('/' )[2 ]} " # Extract dataset name from file path
7197 return ams
7298
7399def plot_heatmap (ams : pd .DataFrame , output_file : str ):
74100 # Pivot the data for heatmap
75101 heatmap_data = ams .pivot (index = "Tool" , columns = "Dataset" , values = "Metric Score" )
76102
103+ # print(heatmap_data.columns)
104+
105+ # set column order
106+ column_order = [
107+ "rna002 h_sapiens" ,
108+ "rna002 e_coli" ,
109+ "rna002 sarscov2" ,
110+ "rna002 ivt" ,
111+ "rna004 h_sapiens" ,
112+ "rna004 s_cerevisiae" ,
113+ "rna004 cevd" ,
114+ "rna004 ivt" ,
115+ "dna_r10.4.1_5kHz h_sapiens" ,
116+ "dna_r10.4.1_5kHz zymo_hmw" ,
117+ "dna_r10.4.1_5kHz s_aureus" ,
118+ "dna_r10.4.1_5kHz p_anserina" ,
119+ ]
120+ heatmap_data = heatmap_data [column_order ]
121+
122+ # Rename columns
123+ column_rename_map = {
124+ "rna002 h_sapiens" : r"$H.\ sapiens$" ,
125+ "rna002 e_coli" : r"$E.\ coli$" ,
126+ "rna002 sarscov2" : r"SARS-CoV-2" ,
127+ "rna002 ivt" : r"IVT" ,
128+ "rna004 h_sapiens" : r"$H.\ sapiens$" ,
129+ "rna004 s_cerevisiae" : r"$S.\ cerevisiae$" ,
130+ "rna004 cevd" : r"CEVD" ,
131+ "rna004 ivt" : r"IVT" ,
132+ "dna_r10.4.1_5kHz h_sapiens" : r"$H.\ sapiens$" ,
133+ "dna_r10.4.1_5kHz zymo_hmw" : r"Zymo HMW" ,
134+ "dna_r10.4.1_5kHz s_aureus" : r"$S.\ Aureus$" ,
135+ "dna_r10.4.1_5kHz p_anserina" : r"$P.\ Anserina$" ,
136+ }
137+ heatmap_data = heatmap_data .rename (columns = column_rename_map )
138+
139+ # Add superlabels (multi-index for columns)
140+ # superlabels = [
141+ # "RNA002", "RNA002", "RNA002", "RNA002",
142+ # "RNA004", "RNA004", "RNA004", "RNA004",
143+ # "DNA R10.4.1 5kHz", "DNA R10.4.1 5kHz"
144+ # ]
145+ # heatmap_data.columns = pd.MultiIndex.from_tuples(
146+ # zip(superlabels, heatmap_data.columns),
147+ # names=["Dataset Type", "Dataset"]
148+ # )
149+
77150 # Sort tools by their mean metric score (descending order)
78- tool_order = heatmap_data .mean (axis = 1 ).sort_values (ascending = False ).index
151+ print (heatmap_data )
152+ tool_order = heatmap_data .mean (axis = 1 ).sort_values (ascending = False ).index .tolist ()
153+ # ensure that dorado is the top row
154+ if "Dorado" in tool_order :
155+ tool_order .remove ("Dorado" )
156+ tool_order = ["Dorado" ] + tool_order
79157 heatmap_data = heatmap_data .loc [tool_order ]
80158
81159 # Plot the heatmap
82- plt .figure (figsize = (9 , 7 )) # Adjust figure size for better readability
83- sns .heatmap (
160+ plt .figure (figsize = (9 , 6 )) # Adjust figure size for better readability
161+ ax = sns .heatmap (
84162 heatmap_data ,
85163 annot = True , # Display values in cells
86164 fmt = ".2f" , # Format values to 2 decimal places
87165 cmap = "coolwarm" , # Use a visually appealing color palette
88- cbar_kws = {'label' : 'Metric Score' }, # Add a label to the color bar
166+ cbar_kws = {'label' : 'Score' , 'shrink' : 0.8 }, # Adjust color bar size
89167 linewidths = 0.5 , # Add grey lines between cells
90168 linecolor = "grey" , # Set the line color to grey
91169 annot_kws = {"fontsize" : 9 }, # Adjust font size for annotations
170+ square = True , # Make cells square
92171 )
93- plt .title ("Metric Score Heatmap" , fontsize = 16 , fontweight = "bold" ) # Add a bold title
172+
173+ # Add superlabels above the dataset labels
174+ superlabels = [
175+ "RNA002" , "RNA002" , "RNA002" , "RNA002" ,
176+ "RNA004" , "RNA004" , "RNA004" , "RNA004" ,
177+ "DNA R10.4.1 5kHz" , "DNA R10.4.1 5kHz" , "DNA R10.4.1 5kHz" , "DNA R10.4.1 5kHz"
178+ ]
179+ dataset_labels = [
180+ r"$H.\ sapiens$" , r"$E.\ coli$" , "SARS-CoV-2" , "IVT" ,
181+ r"$H.\ sapiens$" , r"$S.\ cerevisiae$" , "CEVD" , "IVT" ,
182+ r"$H.\ sapiens$" , "Zymo HMW" , r"$S.\ Aureus$" , r"$P.\ Anserina$"
183+ ]
184+
185+ # Set the dataset labels
186+ ax .set_xticks ([i + 0.5 for i in range (len (dataset_labels ))]) # Center labels
187+ ax .set_xticklabels (dataset_labels , rotation = 45 , ha = "right" , fontsize = 10 )
188+
189+ # Add superlabels
190+ for i , label in enumerate (superlabels ):
191+ if i == 0 or superlabels [i ] != superlabels [i - 1 ]: # Only add label once per group
192+ start = i
193+ end = i + superlabels .count (superlabels [i ]) - 1
194+ ax .text (
195+ (start + end ) / 2 + 0.5 , 1.25 * len (tool_order ), # Center above group
196+ label ,
197+ ha = "center" ,
198+ va = "bottom" ,
199+ fontsize = 10 ,
200+ fontweight = "bold" ,
201+ transform = ax .transData
202+ )
203+
204+ # Adjust layout to fit the labels
205+ plt .subplots_adjust (bottom = 0.2 , top = 0.85 )
206+
207+ plt .title ("Aggregated Metric Score" , fontsize = 14 ) # Add a bold title
94208 plt .ylabel ("Tool" , fontsize = 12 ) # Adjust y-axis label font size
95- plt .xlabel ("Dataset" , fontsize = 12 ) # Adjust x-axis label font size
96- plt .xticks (rotation = 45 , ha = "right" , fontsize = 10 ) # Rotate x-axis labels for better readability
97- plt .yticks (rotation = 0 , fontsize = 10 ) # Adjust y-axis label font size
209+ plt .xlabel ("Dataset" , fontsize = 12 , labelpad = 25 ) # Adjust x-axis label font size
210+ # plt.xticks(rotation=45, ha="right", fontsize=10) # Rotate x-axis labels for better readability
211+ plt .xticks (rotation = 25 , ha = "center" , fontsize = 9 ) # Rotate x-axis labels for better readability
212+ plt .yticks (rotation = 0 , fontsize = 9 ) # Adjust y-axis label font size
98213 plt .tight_layout () # Ensure everything fits within the figure
99214
100215 # Save the heatmap
0 commit comments