@@ -313,7 +313,7 @@ def plot_confusion_matrix(
313313 if not show_grid :
314314 ax .grid (False )
315315
316- # Move x-axis labels to top if requested (only relevant when annotation colors are not shown)
316+ # Move x-axis labels to top if requested (when annotation colors are not shown)
317317 if xlabel_position == "top" and not show_annotation_colors :
318318 ax .xaxis .tick_top ()
319319 ax .xaxis .set_label_position ("top" )
@@ -350,22 +350,32 @@ def plot_confusion_matrix(
350350 ax .set_yticks ([])
351351 ax .set_xlabel ("" )
352352 ax .set_ylabel ("" )
353+ ax .set_title ("" ) # Remove title when using annotation colors
353354
354355 divider = make_axes_locatable (ax )
355356
356- # Add color bar on top (x-axis, predicted labels)
357- col_ax = divider .append_axes ("top" , size = "2%" , pad = 0 )
358- cb_col = ax .figure .colorbar (sm , cax = col_ax , orientation = "horizontal" , ticklocation = "top" )
359- cb_col .set_ticks (np .arange (n_labels ) + 0.5 )
360- cb_col .ax .set_xticklabels (labels , rotation = 90 , ha = "center" , fontsize = "small" )
361- cb_col .ax .tick_params (length = 0 ) # hide tick marks
357+ # Determine x-axis label position
358+ if xlabel_position == "top" :
359+ # Add color bar on top (x-axis, predicted labels)
360+ col_ax = divider .append_axes ("top" , size = "2%" , pad = 0 )
361+ cb_col = ax .figure .colorbar (sm , cax = col_ax , orientation = "horizontal" , ticklocation = "top" )
362+ cb_col .set_ticks (np .arange (n_labels ) + 0.5 )
363+ cb_col .ax .set_xticklabels (labels , rotation = 90 , ha = "center" , fontsize = "small" )
364+ cb_col .ax .tick_params (length = 0 )
365+ else :
366+ # Add color bar on bottom (x-axis, predicted labels)
367+ col_ax = divider .append_axes ("bottom" , size = "2%" , pad = 0 )
368+ cb_col = ax .figure .colorbar (sm , cax = col_ax , orientation = "horizontal" , ticklocation = "bottom" )
369+ cb_col .set_ticks (np .arange (n_labels ) + 0.5 )
370+ cb_col .ax .set_xticklabels (labels , rotation = 90 , ha = "center" , va = "top" , fontsize = "small" )
371+ cb_col .ax .tick_params (length = 0 )
362372
363373 # Add color bar on left (y-axis, true labels)
364374 row_ax = divider .append_axes ("left" , size = "2%" , pad = 0 )
365375 cb_row = ax .figure .colorbar (sm , cax = row_ax , orientation = "vertical" , ticklocation = "left" )
366376 cb_row .set_ticks (np .arange (n_labels ) + 0.5 )
367377 cb_row .ax .set_yticklabels (labels [::- 1 ], fontsize = "small" ) # reversed to match heatmap
368- cb_row .ax .tick_params (length = 0 ) # hide tick marks
378+ cb_row .ax .tick_params (length = 0 )
369379
370380 if save :
371381 ax .figure .savefig (save , bbox_inches = "tight" )
0 commit comments