Skip to content

Commit b53159b

Browse files
committed
minor bug fixes
1 parent 7935303 commit b53159b

File tree

3 files changed

+77
-53
lines changed

3 files changed

+77
-53
lines changed

src/python/misc/collectMetrics.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ def main() -> None:
8686
scores = pd.concat([scores, new_entry], ignore_index=True)
8787

8888
for name, time_path in times.items():
89+
# try:
8990
with open(time_path, "r") as time_file:
9091
time = time_file.readline()[14:22]
9192
memory = time_file.readline().strip()[13:].split(" MB")[0]
@@ -95,6 +96,9 @@ def main() -> None:
9596
"Metric": ["Time in hh:mm:ss", "Memory in MB"]
9697
})
9798
scores = pd.concat([scores, new_entry], ignore_index=True)
99+
# except FileNotFoundError:
100+
# print(name)
101+
# exit(1)
98102

99103
for name, downstream_path in downstream_tools.items():
100104
with open(downstream_path, "r") as downstream_file:

src/python/misc/compareTools.py

Lines changed: 28 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -800,34 +800,34 @@ def scoreTools(toolsResult: dict, pod5: str, output: str, pool, window : int) ->
800800

801801
output_file = output + "_score.csv"
802802

803-
if not os.path.exists(output_file):
804-
toolNames = list(toolsResult.keys())
805-
806-
# Find reads segmented by all tools
807-
print("Finding reads segmented by all tools...")
808-
all_reads = list(set.intersection(*[set(toolsResult[tool].keys()) for tool in toolNames]))
809-
810-
# Parallel processing of reads
811-
print("Start multiprocessing...")
812-
read_chunks = [(read, toolsResult, toolNames, pod5, window) for read in all_reads]
813-
814-
# Open output file for incremental writing
815-
with open(output_file, "w") as f:
816-
f.write("Tool,Score,Segment Quality\n") # Write header
817-
818-
# Process reads in parallel and aggregate results
819-
for toolReadScores in tqdm(
820-
pool.imap_unordered(processReadScores, read_chunks, chunksize=10),
821-
total=len(all_reads),
822-
desc="Scoring reads"
823-
):
824-
for tool, scores in toolReadScores.items():
825-
if scores.size > 0: # Ensure non-empty scores
826-
for i, quality in enumerate(["Median Delta", "Mad Delta", "Homogeneity"]):
827-
for score in scores[:, i]:
828-
f.write(f"{tool},{score},{quality}\n")
829-
830-
print(f"Scoring complete. Results saved to {output_file}")
803+
# if not os.path.exists(output_file):
804+
toolNames = list(toolsResult.keys())
805+
806+
# Find reads segmented by all tools
807+
print("Finding reads segmented by all tools...")
808+
all_reads = list(set.intersection(*[set(toolsResult[tool].keys()) for tool in toolNames]))
809+
810+
# Parallel processing of reads
811+
print("Start multiprocessing...")
812+
read_chunks = [(read, toolsResult, toolNames, pod5, window) for read in all_reads]
813+
814+
# Open output file for incremental writing
815+
with open(output_file, "w") as f:
816+
f.write("Tool,Score,Segment Quality\n") # Write header
817+
818+
# Process reads in parallel and aggregate results
819+
for toolReadScores in tqdm(
820+
pool.imap_unordered(processReadScores, read_chunks, chunksize=10),
821+
total=len(all_reads),
822+
desc="Scoring reads"
823+
):
824+
for tool, scores in toolReadScores.items():
825+
if scores.size > 0: # Ensure non-empty scores
826+
for i, quality in enumerate(["Median Delta", "Mad Delta", "Homogeneity"]):
827+
for score in scores[:, i]:
828+
f.write(f"{tool},{score},{quality}\n")
829+
830+
print(f"Scoring complete. Results saved to {output_file}")
831831

832832
return output_file
833833

src/python/misc/csv_to_ms_heatmap.py

Lines changed: 45 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -108,15 +108,20 @@ def plot_heatmap(ams: pd.DataFrame, output_file: str):
108108
"rna002 e_coli",
109109
"rna002 sarscov2",
110110
"rna002 ivt",
111+
"rna002 ivt_h_sapiens",
112+
"rna002 m1Y",
111113
"rna004 h_sapiens",
112114
"rna004 s_cerevisiae",
113115
"rna004 cevd",
114116
"rna004 ivt",
117+
"rna004 psU",
115118
"dna_r10.4.1_5kHz h_sapiens",
116119
"dna_r10.4.1_5kHz zymo_hmw",
117120
"dna_r10.4.1_5kHz s_aureus",
118121
"dna_r10.4.1_5kHz p_anserina",
122+
"dna_r10.4.1_5kHz mod_5mc",
119123
]
124+
# print(heatmap_data.columns)
120125
heatmap_data = heatmap_data[column_order]
121126

122127
# Rename columns
@@ -125,14 +130,18 @@ def plot_heatmap(ams: pd.DataFrame, output_file: str):
125130
"rna002 e_coli": r"$E.\ coli$",
126131
"rna002 sarscov2": r"SARS-CoV-2",
127132
"rna002 ivt": r"IVT",
133+
"rna002 ivt_h_sapiens": r"IVT $H.\ sapiens$",
134+
"rna002 m1Y": r"M1Y",
128135
"rna004 h_sapiens": r"$H.\ sapiens$",
129136
"rna004 s_cerevisiae": r"$S.\ cerevisiae$",
130137
"rna004 cevd": r"CEVD",
131138
"rna004 ivt": r"IVT",
139+
"rna004 psU": r"psU",
132140
"dna_r10.4.1_5kHz h_sapiens": r"$H.\ sapiens$",
133141
"dna_r10.4.1_5kHz zymo_hmw": r"Zymo HMW",
134142
"dna_r10.4.1_5kHz s_aureus": r"$S.\ Aureus$",
135143
"dna_r10.4.1_5kHz p_anserina": r"$P.\ Anserina$",
144+
"dna_r10.4.1_5kHz mod_5mc": r"5mC",
136145
}
137146
heatmap_data = heatmap_data.rename(columns=column_rename_map)
138147

@@ -163,41 +172,52 @@ def plot_heatmap(ams: pd.DataFrame, output_file: str):
163172
cbar_kws={'label': 'Score', 'shrink': 0.8}, # Adjust color bar size
164173
linewidths=0.5, # Add grey lines between cells
165174
linecolor="grey", # Set the line color to grey
166-
annot_kws={"fontsize": 9}, # Adjust font size for annotations
175+
annot_kws={"fontsize": 7}, # Adjust font size for annotations
167176
square=True, # Make cells square
168177
)
169178

170179
# Add superlabels above the dataset labels
171-
superlabels = [
172-
"RNA002", "RNA002", "RNA002", "RNA002",
173-
"RNA004", "RNA004", "RNA004", "RNA004",
174-
"DNA R10.4.1 5kHz", "DNA R10.4.1 5kHz", "DNA R10.4.1 5kHz", "DNA R10.4.1 5kHz"
175-
""
176-
]
180+
# Define group labels aligned to the columns: 6x RNA002, 5x RNA004, 5x DNA, 1x empty for the tool average column
181+
superlabels = (
182+
["RNA002"] * 6
183+
+ ["RNA004"] * 5
184+
+ ["DNA R10.4.1 5kHz"] * 5
185+
+ [""]
186+
)
177187
dataset_labels = [
178-
r"$H.\ sapiens$", r"$E.\ coli$", "SARS-CoV-2", "IVT",
179-
r"$H.\ sapiens$", r"$S.\ cerevisiae$", "CEVd", "IVT",
180-
r"$H.\ sapiens$", "Zymo HMW", r"$S.\ Aureus$", r"$P.\ Anserina$", "tool average"
188+
r"$H.\ sapiens$", r"$E.\ coli$", "SARS-CoV-2", "IVT", r"IVT $H.\ sapiens$", "m1Y",
189+
r"$H.\ sapiens$", r"$S.\ cerevisiae$", "CEVd", "IVT", "psU",
190+
r"$H.\ sapiens$", "Zymo HMW", r"$S.\ Aureus$", r"$P.\ Anserina$", "5mC", "tool average"
181191
]
182192

183193
# Set the dataset labels
184194
ax.set_xticks([i + 0.5 for i in range(len(dataset_labels))]) # Center labels
185195
ax.set_xticklabels(dataset_labels, rotation=45, ha="right", fontsize=10)
186196

187-
# Add superlabels
188-
for i, label in enumerate(superlabels):
189-
if i == 0 or superlabels[i] != superlabels[i - 1]: # Only add label once per group
190-
start = i
191-
end = i + superlabels.count(superlabels[i]) - 1
192-
ax.text(
193-
(start + end) / 2 + 0.5, 1.25 * len(tool_order), # Center above group
194-
label,
195-
ha="center",
196-
va="bottom",
197-
fontsize=10,
198-
fontweight="bold",
199-
transform=ax.transData
200-
)
197+
# Add superlabels (draw each group label exactly once)
198+
groups = []
199+
if superlabels:
200+
current = superlabels[0]
201+
start_idx = 0
202+
for i, label in enumerate(superlabels[1:], start=1):
203+
if label != current:
204+
groups.append((current, start_idx, i - 1))
205+
current = label
206+
start_idx = i
207+
groups.append((current, start_idx, len(superlabels) - 1))
208+
209+
for label, start, end in groups:
210+
if not label: # skip empty label for the tool average column
211+
continue
212+
ax.text(
213+
(start + end) / 2 + 0.5, 1.45 * len(tool_order), # Center above group
214+
label,
215+
ha="center",
216+
va="bottom",
217+
fontsize=10,
218+
fontweight="bold",
219+
transform=ax.transData
220+
)
201221

202222
# Adjust layout to fit the labels
203223
plt.subplots_adjust(bottom=0.2, top=0.85)
@@ -206,7 +226,7 @@ def plot_heatmap(ams: pd.DataFrame, output_file: str):
206226
# plt.ylabel("Tool", fontsize=12) # Adjust y-axis label font size
207227
# plt.xlabel("Dataset", fontsize=12, labelpad=25) # Adjust x-axis label font size
208228
# plt.xticks(rotation=45, ha="right", fontsize=10) # Rotate x-axis labels for better readability
209-
plt.xticks(rotation=25, ha="center", fontsize=9) # Rotate x-axis labels for better readability
229+
plt.xticks(rotation=45, ha="center", fontsize=9) # Rotate x-axis labels for better readability
210230
plt.yticks(rotation=0, fontsize=9) # Adjust y-axis label font size
211231
plt.tight_layout() # Ensure everything fits within the figure
212232

0 commit comments

Comments
 (0)