@@ -213,6 +213,7 @@ def plot_confusion_matrix(
213213 show_annotation_colors : bool = True ,
214214 xlabel_position : Literal ["bottom" , "top" ] = "bottom" ,
215215 show_grid : bool = True ,
216+ min_cells : int | None = None ,
216217 ** kwargs ,
217218 ) -> plt .Axes :
218219 """
@@ -240,6 +241,10 @@ def plot_confusion_matrix(
240241 Position of x-axis tick labels. Either "bottom" (default) or "top".
241242 show_grid
242243 Whether to show gridlines on the heatmap. Default is True.
244+ min_cells
245+ Minimum number of cells required for a category to be included in the confusion matrix.
246+ Categories with fewer cells in both true and predicted labels are filtered out.
247+ If None, all categories are shown.
243248 **kwargs
244249 Additional keyword arguments to pass to ConfusionMatrixDisplay.
245250
@@ -268,6 +273,23 @@ def plot_confusion_matrix(
268273 y_true = y_true [subset ]
269274 y_pred = y_pred [subset ]
270275
276+ # Filter categories by minimum cell count
277+ if min_cells is not None :
278+ true_counts = y_true .value_counts ()
279+ pred_counts = y_pred .value_counts ()
280+ # Keep categories that have at least min_cells in either true or predicted
281+ valid_categories = set (true_counts [true_counts >= min_cells ].index ) | set (
282+ pred_counts [pred_counts >= min_cells ].index
283+ )
284+ mask = y_true .isin (valid_categories ) & y_pred .isin (valid_categories )
285+ y_true = y_true [mask ]
286+ y_pred = y_pred [mask ]
287+ # Update categories if categorical
288+ if hasattr (y_true , "cat" ):
289+ y_true = y_true .cat .remove_unused_categories ()
290+ if hasattr (y_pred , "cat" ):
291+ y_pred = y_pred .cat .remove_unused_categories ()
292+
271293 # Get union of categories if categorical, to handle mismatched category sets
272294 # Also convert to string to avoid sklearn interpreting float categories as continuous
273295 labels = None
0 commit comments