Skip to content

Commit 84abb44

Browse files
committed
fix: make xlabel_position work independently with annotation colors
1 parent 4e87f0b commit 84abb44

File tree

1 file changed

+18
-8
lines changed

1 file changed

+18
-8
lines changed

src/cellmapper/model/evaluate.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -313,7 +313,7 @@ def plot_confusion_matrix(
313313
if not show_grid:
314314
ax.grid(False)
315315

316-
# Move x-axis labels to top if requested (only relevant when annotation colors are not shown)
316+
# Move x-axis labels to top if requested (when annotation colors are not shown)
317317
if xlabel_position == "top" and not show_annotation_colors:
318318
ax.xaxis.tick_top()
319319
ax.xaxis.set_label_position("top")
@@ -350,22 +350,32 @@ def plot_confusion_matrix(
350350
ax.set_yticks([])
351351
ax.set_xlabel("")
352352
ax.set_ylabel("")
353+
ax.set_title("") # Remove title when using annotation colors
353354

354355
divider = make_axes_locatable(ax)
355356

356-
# Add color bar on top (x-axis, predicted labels)
357-
col_ax = divider.append_axes("top", size="2%", pad=0)
358-
cb_col = ax.figure.colorbar(sm, cax=col_ax, orientation="horizontal", ticklocation="top")
359-
cb_col.set_ticks(np.arange(n_labels) + 0.5)
360-
cb_col.ax.set_xticklabels(labels, rotation=90, ha="center", fontsize="small")
361-
cb_col.ax.tick_params(length=0) # hide tick marks
357+
# Determine x-axis label position
358+
if xlabel_position == "top":
359+
# Add color bar on top (x-axis, predicted labels)
360+
col_ax = divider.append_axes("top", size="2%", pad=0)
361+
cb_col = ax.figure.colorbar(sm, cax=col_ax, orientation="horizontal", ticklocation="top")
362+
cb_col.set_ticks(np.arange(n_labels) + 0.5)
363+
cb_col.ax.set_xticklabels(labels, rotation=90, ha="center", fontsize="small")
364+
cb_col.ax.tick_params(length=0)
365+
else:
366+
# Add color bar on bottom (x-axis, predicted labels)
367+
col_ax = divider.append_axes("bottom", size="2%", pad=0)
368+
cb_col = ax.figure.colorbar(sm, cax=col_ax, orientation="horizontal", ticklocation="bottom")
369+
cb_col.set_ticks(np.arange(n_labels) + 0.5)
370+
cb_col.ax.set_xticklabels(labels, rotation=90, ha="center", va="top", fontsize="small")
371+
cb_col.ax.tick_params(length=0)
362372

363373
# Add color bar on left (y-axis, true labels)
364374
row_ax = divider.append_axes("left", size="2%", pad=0)
365375
cb_row = ax.figure.colorbar(sm, cax=row_ax, orientation="vertical", ticklocation="left")
366376
cb_row.set_ticks(np.arange(n_labels) + 0.5)
367377
cb_row.ax.set_yticklabels(labels[::-1], fontsize="small") # reversed to match heatmap
368-
cb_row.ax.tick_params(length=0) # hide tick marks
378+
cb_row.ax.tick_params(length=0)
369379

370380
if save:
371381
ax.figure.savefig(save, bbox_inches="tight")

0 commit comments

Comments
 (0)