@@ -333,49 +333,57 @@ def plot_confusion_matrix(
333333 break
334334
335335 if colors_dict is not None :
336- import matplotlib as mpl
337- from matplotlib .colors import BoundaryNorm , ListedColormap
338- from mpl_toolkits .axes_grid1 import make_axes_locatable
336+ from matplotlib .patches import Rectangle
339337
340338 # Get colors in label order
341339 colors_list = [colors_dict .get (label , "gray" ) for label in labels ]
342- cat_cmap = ListedColormap (colors_list )
343340 n_labels = len (labels )
344- bounds = np .arange (n_labels + 1 )
345- norm = BoundaryNorm (bounds , cat_cmap .N )
346- sm = mpl .cm .ScalarMappable (cmap = cat_cmap , norm = norm )
347341
348- # Remove tick labels from main heatmap (will be on color bar axes)
349- ax .set_xticks ([])
350- ax .set_yticks ([])
351- ax .set_xlabel ("" )
352- ax .set_ylabel ("" )
353- ax .set_title ("" ) # Remove title when using annotation colors
342+ # Color strip thickness as fraction of plot
343+ strip_size = 0.03
354344
355- divider = make_axes_locatable (ax )
356-
357- # Determine x-axis label position
345+ # Move x-axis labels to top if requested
358346 if xlabel_position == "top" :
359- # Add color bar on top (x-axis, predicted labels)
360- col_ax = divider .append_axes ("top" , size = "2%" , pad = 0 )
361- cb_col = ax .figure .colorbar (sm , cax = col_ax , orientation = "horizontal" , ticklocation = "top" )
362- cb_col .set_ticks (np .arange (n_labels ) + 0.5 )
363- cb_col .ax .set_xticklabels (labels , rotation = 90 , ha = "center" , fontsize = "small" )
364- cb_col .ax .tick_params (length = 0 )
365- else :
366- # Add color bar on bottom (x-axis, predicted labels)
367- col_ax = divider .append_axes ("bottom" , size = "2%" , pad = 0 )
368- cb_col = ax .figure .colorbar (sm , cax = col_ax , orientation = "horizontal" , ticklocation = "bottom" )
369- cb_col .set_ticks (np .arange (n_labels ) + 0.5 )
370- cb_col .ax .set_xticklabels (labels , rotation = 90 , ha = "center" , va = "top" , fontsize = "small" )
371- cb_col .ax .tick_params (length = 0 )
372-
373- # Add color bar on left (y-axis, true labels)
374- row_ax = divider .append_axes ("left" , size = "2%" , pad = 0 )
375- cb_row = ax .figure .colorbar (sm , cax = row_ax , orientation = "vertical" , ticklocation = "left" )
376- cb_row .set_ticks (np .arange (n_labels ) + 0.5 )
377- cb_row .ax .set_yticklabels (labels [::- 1 ], fontsize = "small" ) # reversed to match heatmap
378- cb_row .ax .tick_params (length = 0 )
347+ ax .xaxis .tick_top ()
348+ ax .xaxis .set_label_position ("top" )
349+ plt .setp (ax .get_xticklabels (), rotation = 90 , ha = "center" , va = "bottom" )
350+
351+ # Draw color strips using rectangles in axes coordinates
352+ for i , color in enumerate (colors_list ):
353+ # Left strip (y-axis, true labels)
354+ rect_left = Rectangle (
355+ (- strip_size , i / n_labels ),
356+ strip_size ,
357+ 1 / n_labels ,
358+ facecolor = color ,
359+ edgecolor = "none" ,
360+ clip_on = False ,
361+ transform = ax .transAxes ,
362+ )
363+ ax .add_patch (rect_left )
364+
365+ # Top or bottom strip (x-axis, predicted labels)
366+ if xlabel_position == "top" :
367+ rect_x = Rectangle (
368+ (i / n_labels , 1 ),
369+ 1 / n_labels ,
370+ strip_size ,
371+ facecolor = color ,
372+ edgecolor = "none" ,
373+ clip_on = False ,
374+ transform = ax .transAxes ,
375+ )
376+ else :
377+ rect_x = Rectangle (
378+ (i / n_labels , - strip_size ),
379+ 1 / n_labels ,
380+ strip_size ,
381+ facecolor = color ,
382+ edgecolor = "none" ,
383+ clip_on = False ,
384+ transform = ax .transAxes ,
385+ )
386+ ax .add_patch (rect_x )
379387
380388 if save :
381389 ax .figure .savefig (save , bbox_inches = "tight" )
0 commit comments