Skip to content

Commit f21128a

Browse files
committed
fix: use Rectangle patches for annotation colors to support constrained_layout with subplots
1 parent 84abb44 commit f21128a

File tree

1 file changed

+44
-36
lines changed

1 file changed

+44
-36
lines changed

src/cellmapper/model/evaluate.py

Lines changed: 44 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -333,49 +333,57 @@ def plot_confusion_matrix(
333333
break
334334

335335
if colors_dict is not None:
336-
import matplotlib as mpl
337-
from matplotlib.colors import BoundaryNorm, ListedColormap
338-
from mpl_toolkits.axes_grid1 import make_axes_locatable
336+
from matplotlib.patches import Rectangle
339337

340338
# Get colors in label order
341339
colors_list = [colors_dict.get(label, "gray") for label in labels]
342-
cat_cmap = ListedColormap(colors_list)
343340
n_labels = len(labels)
344-
bounds = np.arange(n_labels + 1)
345-
norm = BoundaryNorm(bounds, cat_cmap.N)
346-
sm = mpl.cm.ScalarMappable(cmap=cat_cmap, norm=norm)
347341

348-
# Remove tick labels from main heatmap (will be on color bar axes)
349-
ax.set_xticks([])
350-
ax.set_yticks([])
351-
ax.set_xlabel("")
352-
ax.set_ylabel("")
353-
ax.set_title("") # Remove title when using annotation colors
342+
# Color strip thickness as fraction of plot
343+
strip_size = 0.03
354344

355-
divider = make_axes_locatable(ax)
356-
357-
# Determine x-axis label position
345+
# Move x-axis labels to top if requested
358346
if xlabel_position == "top":
359-
# Add color bar on top (x-axis, predicted labels)
360-
col_ax = divider.append_axes("top", size="2%", pad=0)
361-
cb_col = ax.figure.colorbar(sm, cax=col_ax, orientation="horizontal", ticklocation="top")
362-
cb_col.set_ticks(np.arange(n_labels) + 0.5)
363-
cb_col.ax.set_xticklabels(labels, rotation=90, ha="center", fontsize="small")
364-
cb_col.ax.tick_params(length=0)
365-
else:
366-
# Add color bar on bottom (x-axis, predicted labels)
367-
col_ax = divider.append_axes("bottom", size="2%", pad=0)
368-
cb_col = ax.figure.colorbar(sm, cax=col_ax, orientation="horizontal", ticklocation="bottom")
369-
cb_col.set_ticks(np.arange(n_labels) + 0.5)
370-
cb_col.ax.set_xticklabels(labels, rotation=90, ha="center", va="top", fontsize="small")
371-
cb_col.ax.tick_params(length=0)
372-
373-
# Add color bar on left (y-axis, true labels)
374-
row_ax = divider.append_axes("left", size="2%", pad=0)
375-
cb_row = ax.figure.colorbar(sm, cax=row_ax, orientation="vertical", ticklocation="left")
376-
cb_row.set_ticks(np.arange(n_labels) + 0.5)
377-
cb_row.ax.set_yticklabels(labels[::-1], fontsize="small") # reversed to match heatmap
378-
cb_row.ax.tick_params(length=0)
347+
ax.xaxis.tick_top()
348+
ax.xaxis.set_label_position("top")
349+
plt.setp(ax.get_xticklabels(), rotation=90, ha="center", va="bottom")
350+
351+
# Draw color strips using rectangles in axes coordinates
352+
for i, color in enumerate(colors_list):
353+
# Left strip (y-axis, true labels)
354+
rect_left = Rectangle(
355+
(-strip_size, i / n_labels),
356+
strip_size,
357+
1 / n_labels,
358+
facecolor=color,
359+
edgecolor="none",
360+
clip_on=False,
361+
transform=ax.transAxes,
362+
)
363+
ax.add_patch(rect_left)
364+
365+
# Top or bottom strip (x-axis, predicted labels)
366+
if xlabel_position == "top":
367+
rect_x = Rectangle(
368+
(i / n_labels, 1),
369+
1 / n_labels,
370+
strip_size,
371+
facecolor=color,
372+
edgecolor="none",
373+
clip_on=False,
374+
transform=ax.transAxes,
375+
)
376+
else:
377+
rect_x = Rectangle(
378+
(i / n_labels, -strip_size),
379+
1 / n_labels,
380+
strip_size,
381+
facecolor=color,
382+
edgecolor="none",
383+
clip_on=False,
384+
transform=ax.transAxes,
385+
)
386+
ax.add_patch(rect_x)
379387

380388
if save:
381389
ax.figure.savefig(save, bbox_inches="tight")

0 commit comments

Comments
 (0)