@@ -313,8 +313,8 @@ def plot_confusion_matrix(
313313 if not show_grid :
314314 ax .grid (False )
315315
316- # Move x-axis labels to top if requested
317- if xlabel_position == "top" :
316+ # Move x-axis labels to top if requested (only relevant when annotation colors are not shown)
317+ if xlabel_position == "top" and not show_annotation_colors :
318318 ax .xaxis .tick_top ()
319319 ax .xaxis .set_label_position ("top" )
320320 ax .set_title ("" ) # Remove title to avoid overlap
@@ -333,35 +333,39 @@ def plot_confusion_matrix(
333333 break
334334
335335 if colors_dict is not None :
336- from matplotlib .colors import ListedColormap
336+ import matplotlib as mpl
337+ from matplotlib .colors import BoundaryNorm , ListedColormap
337338 from mpl_toolkits .axes_grid1 import make_axes_locatable
338339
339340 # Get colors in label order
340341 colors_list = [colors_dict .get (label , "gray" ) for label in labels ]
341342 cat_cmap = ListedColormap (colors_list )
342343 n_labels = len (labels )
344+ bounds = np .arange (n_labels + 1 )
345+ norm = BoundaryNorm (bounds , cat_cmap .N )
346+ sm = mpl .cm .ScalarMappable (cmap = cat_cmap , norm = norm )
347+
348+ # Remove tick labels from main heatmap (will be on color bar axes)
349+ ax .set_xticks ([])
350+ ax .set_yticks ([])
351+ ax .set_xlabel ("" )
352+ ax .set_ylabel ("" )
343353
344354 divider = make_axes_locatable (ax )
345355
346356 # Add color bar on top (x-axis, predicted labels)
347- col_ax = divider .append_axes ("top" , size = "2%" , pad = 0.05 )
348- col_ax .imshow (
349- np .arange (n_labels ).reshape (1 , - 1 ),
350- cmap = cat_cmap ,
351- aspect = "auto" ,
352- )
353- col_ax .set_xticks ([])
354- col_ax .set_yticks ([])
357+ col_ax = divider .append_axes ("top" , size = "2%" , pad = 0 )
358+ cb_col = ax .figure .colorbar (sm , cax = col_ax , orientation = "horizontal" , ticklocation = "top" )
359+ cb_col .set_ticks (np .arange (n_labels ) + 0.5 )
360+ cb_col .ax .set_xticklabels (labels , rotation = 90 , ha = "center" , fontsize = "small" )
361+ cb_col .ax .tick_params (length = 0 ) # hide tick marks
355362
356363 # Add color bar on left (y-axis, true labels)
357- row_ax = divider .append_axes ("left" , size = "2%" , pad = 0.05 )
358- row_ax .imshow (
359- np .arange (n_labels ).reshape (- 1 , 1 ),
360- cmap = cat_cmap ,
361- aspect = "auto" ,
362- )
363- row_ax .set_xticks ([])
364- row_ax .set_yticks ([])
364+ row_ax = divider .append_axes ("left" , size = "2%" , pad = 0 )
365+ cb_row = ax .figure .colorbar (sm , cax = row_ax , orientation = "vertical" , ticklocation = "left" )
366+ cb_row .set_ticks (np .arange (n_labels ) + 0.5 )
367+ cb_row .ax .set_yticklabels (labels [::- 1 ], fontsize = "small" ) # reversed to match heatmap
368+ cb_row .ax .tick_params (length = 0 ) # hide tick marks
365369
366370 if save :
367371 ax .figure .savefig (save , bbox_inches = "tight" )
0 commit comments