@@ -108,15 +108,20 @@ def plot_heatmap(ams: pd.DataFrame, output_file: str):
108108 "rna002 e_coli" ,
109109 "rna002 sarscov2" ,
110110 "rna002 ivt" ,
111+ "rna002 ivt_h_sapiens" ,
112+ "rna002 m1Y" ,
111113 "rna004 h_sapiens" ,
112114 "rna004 s_cerevisiae" ,
113115 "rna004 cevd" ,
114116 "rna004 ivt" ,
117+ "rna004 psU" ,
115118 "dna_r10.4.1_5kHz h_sapiens" ,
116119 "dna_r10.4.1_5kHz zymo_hmw" ,
117120 "dna_r10.4.1_5kHz s_aureus" ,
118121 "dna_r10.4.1_5kHz p_anserina" ,
122+ "dna_r10.4.1_5kHz mod_5mc" ,
119123 ]
124+ # print(heatmap_data.columns)
120125 heatmap_data = heatmap_data [column_order ]
121126
122127 # Rename columns
@@ -125,14 +130,18 @@ def plot_heatmap(ams: pd.DataFrame, output_file: str):
125130 "rna002 e_coli" : r"$E.\ coli$" ,
126131 "rna002 sarscov2" : r"SARS-CoV-2" ,
127132 "rna002 ivt" : r"IVT" ,
133+ "rna002 ivt_h_sapiens" : r"IVT $H.\ sapiens$" ,
134+ "rna002 m1Y" : r"M1Y" ,
128135 "rna004 h_sapiens" : r"$H.\ sapiens$" ,
129136 "rna004 s_cerevisiae" : r"$S.\ cerevisiae$" ,
130137 "rna004 cevd" : r"CEVD" ,
131138 "rna004 ivt" : r"IVT" ,
139+ "rna004 psU" : r"psU" ,
132140 "dna_r10.4.1_5kHz h_sapiens" : r"$H.\ sapiens$" ,
133141 "dna_r10.4.1_5kHz zymo_hmw" : r"Zymo HMW" ,
134142 "dna_r10.4.1_5kHz s_aureus" : r"$S.\ Aureus$" ,
135143 "dna_r10.4.1_5kHz p_anserina" : r"$P.\ Anserina$" ,
144+ "dna_r10.4.1_5kHz mod_5mc" : r"5mC" ,
136145 }
137146 heatmap_data = heatmap_data .rename (columns = column_rename_map )
138147
@@ -163,41 +172,52 @@ def plot_heatmap(ams: pd.DataFrame, output_file: str):
163172 cbar_kws = {'label' : 'Score' , 'shrink' : 0.8 }, # Adjust color bar size
164173 linewidths = 0.5 , # Add grey lines between cells
165174 linecolor = "grey" , # Set the line color to grey
166- annot_kws = {"fontsize" : 9 }, # Adjust font size for annotations
175+ annot_kws = {"fontsize" : 7 }, # Adjust font size for annotations
167176 square = True , # Make cells square
168177 )
169178
170179 # Add superlabels above the dataset labels
171- superlabels = [
172- "RNA002" , "RNA002" , "RNA002" , "RNA002" ,
173- "RNA004" , "RNA004" , "RNA004" , "RNA004" ,
174- "DNA R10.4.1 5kHz" , "DNA R10.4.1 5kHz" , "DNA R10.4.1 5kHz" , "DNA R10.4.1 5kHz"
175- ""
176- ]
180+ # Define group labels aligned to the columns: 6x RNA002, 5x RNA004, 5x DNA, 1x empty for the tool average column
181+ superlabels = (
182+ ["RNA002" ] * 6
183+ + ["RNA004" ] * 5
184+ + ["DNA R10.4.1 5kHz" ] * 5
185+ + ["" ]
186+ )
177187 dataset_labels = [
178- r"$H.\ sapiens$" , r"$E.\ coli$" , "SARS-CoV-2" , "IVT" ,
179- r"$H.\ sapiens$" , r"$S.\ cerevisiae$" , "CEVd" , "IVT" ,
180- r"$H.\ sapiens$" , "Zymo HMW" , r"$S.\ Aureus$" , r"$P.\ Anserina$" , "tool average"
188+ r"$H.\ sapiens$" , r"$E.\ coli$" , "SARS-CoV-2" , "IVT" , r"IVT $H.\ sapiens$" , "m1Y" ,
189+ r"$H.\ sapiens$" , r"$S.\ cerevisiae$" , "CEVd" , "IVT" , "psU" ,
190+ r"$H.\ sapiens$" , "Zymo HMW" , r"$S.\ Aureus$" , r"$P.\ Anserina$" , "5mC" , " tool average"
181191 ]
182192
183193 # Set the dataset labels
184194 ax .set_xticks ([i + 0.5 for i in range (len (dataset_labels ))]) # Center labels
185195 ax .set_xticklabels (dataset_labels , rotation = 45 , ha = "right" , fontsize = 10 )
186196
187- # Add superlabels
188- for i , label in enumerate (superlabels ):
189- if i == 0 or superlabels [i ] != superlabels [i - 1 ]: # Only add label once per group
190- start = i
191- end = i + superlabels .count (superlabels [i ]) - 1
192- ax .text (
193- (start + end ) / 2 + 0.5 , 1.25 * len (tool_order ), # Center above group
194- label ,
195- ha = "center" ,
196- va = "bottom" ,
197- fontsize = 10 ,
198- fontweight = "bold" ,
199- transform = ax .transData
200- )
197+ # Add superlabels (draw each group label exactly once)
198+ groups = []
199+ if superlabels :
200+ current = superlabels [0 ]
201+ start_idx = 0
202+ for i , label in enumerate (superlabels [1 :], start = 1 ):
203+ if label != current :
204+ groups .append ((current , start_idx , i - 1 ))
205+ current = label
206+ start_idx = i
207+ groups .append ((current , start_idx , len (superlabels ) - 1 ))
208+
209+ for label , start , end in groups :
210+ if not label : # skip empty label for the tool average column
211+ continue
212+ ax .text (
213+ (start + end ) / 2 + 0.5 , 1.45 * len (tool_order ), # Center above group
214+ label ,
215+ ha = "center" ,
216+ va = "bottom" ,
217+ fontsize = 10 ,
218+ fontweight = "bold" ,
219+ transform = ax .transData
220+ )
201221
202222 # Adjust layout to fit the labels
203223 plt .subplots_adjust (bottom = 0.2 , top = 0.85 )
@@ -206,7 +226,7 @@ def plot_heatmap(ams: pd.DataFrame, output_file: str):
206226 # plt.ylabel("Tool", fontsize=12) # Adjust y-axis label font size
207227 # plt.xlabel("Dataset", fontsize=12, labelpad=25) # Adjust x-axis label font size
208228 # plt.xticks(rotation=45, ha="right", fontsize=10) # Rotate x-axis labels for better readability
209- plt .xticks (rotation = 25 , ha = "center" , fontsize = 9 ) # Rotate x-axis labels for better readability
229+ plt .xticks (rotation = 45 , ha = "center" , fontsize = 9 ) # Rotate x-axis labels for better readability
210230 plt .yticks (rotation = 0 , fontsize = 9 ) # Adjust y-axis label font size
211231 plt .tight_layout () # Ensure everything fits within the figure
212232
0 commit comments