Skip to content

Commit 4e87f0b

Browse files
committed
fix: use colorbar for annotation colors with proper label placement (moscot-style)
1 parent 30f905a commit 4e87f0b

File tree

1 file changed

+23
-19
lines changed

1 file changed

+23
-19
lines changed

src/cellmapper/model/evaluate.py

Lines changed: 23 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -313,8 +313,8 @@ def plot_confusion_matrix(
313313
if not show_grid:
314314
ax.grid(False)
315315

316-
# Move x-axis labels to top if requested
317-
if xlabel_position == "top":
316+
# Move x-axis labels to top if requested (only relevant when annotation colors are not shown)
317+
if xlabel_position == "top" and not show_annotation_colors:
318318
ax.xaxis.tick_top()
319319
ax.xaxis.set_label_position("top")
320320
ax.set_title("") # Remove title to avoid overlap
@@ -333,35 +333,39 @@ def plot_confusion_matrix(
333333
break
334334

335335
if colors_dict is not None:
336-
from matplotlib.colors import ListedColormap
336+
import matplotlib as mpl
337+
from matplotlib.colors import BoundaryNorm, ListedColormap
337338
from mpl_toolkits.axes_grid1 import make_axes_locatable
338339

339340
# Get colors in label order
340341
colors_list = [colors_dict.get(label, "gray") for label in labels]
341342
cat_cmap = ListedColormap(colors_list)
342343
n_labels = len(labels)
344+
bounds = np.arange(n_labels + 1)
345+
norm = BoundaryNorm(bounds, cat_cmap.N)
346+
sm = mpl.cm.ScalarMappable(cmap=cat_cmap, norm=norm)
347+
348+
# Remove tick labels from main heatmap (will be on color bar axes)
349+
ax.set_xticks([])
350+
ax.set_yticks([])
351+
ax.set_xlabel("")
352+
ax.set_ylabel("")
343353

344354
divider = make_axes_locatable(ax)
345355

346356
# Add color bar on top (x-axis, predicted labels)
347-
col_ax = divider.append_axes("top", size="2%", pad=0.05)
348-
col_ax.imshow(
349-
np.arange(n_labels).reshape(1, -1),
350-
cmap=cat_cmap,
351-
aspect="auto",
352-
)
353-
col_ax.set_xticks([])
354-
col_ax.set_yticks([])
357+
col_ax = divider.append_axes("top", size="2%", pad=0)
358+
cb_col = ax.figure.colorbar(sm, cax=col_ax, orientation="horizontal", ticklocation="top")
359+
cb_col.set_ticks(np.arange(n_labels) + 0.5)
360+
cb_col.ax.set_xticklabels(labels, rotation=90, ha="center", fontsize="small")
361+
cb_col.ax.tick_params(length=0) # hide tick marks
355362

356363
# Add color bar on left (y-axis, true labels)
357-
row_ax = divider.append_axes("left", size="2%", pad=0.05)
358-
row_ax.imshow(
359-
np.arange(n_labels).reshape(-1, 1),
360-
cmap=cat_cmap,
361-
aspect="auto",
362-
)
363-
row_ax.set_xticks([])
364-
row_ax.set_yticks([])
364+
row_ax = divider.append_axes("left", size="2%", pad=0)
365+
cb_row = ax.figure.colorbar(sm, cax=row_ax, orientation="vertical", ticklocation="left")
366+
cb_row.set_ticks(np.arange(n_labels) + 0.5)
367+
cb_row.ax.set_yticklabels(labels[::-1], fontsize="small") # reversed to match heatmap
368+
cb_row.ax.tick_params(length=0) # hide tick marks
365369

366370
if save:
367371
ax.figure.savefig(save, bbox_inches="tight")

0 commit comments

Comments
 (0)