Skip to content

Commit e8325a9

Browse files
committed
feat: add min_cells parameter to filter sparse categories in confusion matrix
1 parent b9db406 commit e8325a9

File tree

1 file changed

+22
-0
lines changed

1 file changed

+22
-0
lines changed

src/cellmapper/model/evaluate.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,7 @@ def plot_confusion_matrix(
213213
show_annotation_colors: bool = True,
214214
xlabel_position: Literal["bottom", "top"] = "bottom",
215215
show_grid: bool = True,
216+
min_cells: int | None = None,
216217
**kwargs,
217218
) -> plt.Axes:
218219
"""
@@ -240,6 +241,10 @@ def plot_confusion_matrix(
240241
Position of x-axis tick labels. Either "bottom" (default) or "top".
241242
show_grid
242243
Whether to show gridlines on the heatmap. Default is True.
244+
min_cells
245+
Minimum number of cells required for a category to be included in the confusion matrix.
246+
Categories with fewer cells in both true and predicted labels are filtered out.
247+
If None, all categories are shown.
243248
**kwargs
244249
Additional keyword arguments to pass to ConfusionMatrixDisplay.
245250
@@ -268,6 +273,23 @@ def plot_confusion_matrix(
268273
y_true = y_true[subset]
269274
y_pred = y_pred[subset]
270275

276+
# Filter categories by minimum cell count
277+
if min_cells is not None:
278+
true_counts = y_true.value_counts()
279+
pred_counts = y_pred.value_counts()
280+
# Keep categories that have at least min_cells in either true or predicted
281+
valid_categories = set(true_counts[true_counts >= min_cells].index) | set(
282+
pred_counts[pred_counts >= min_cells].index
283+
)
284+
mask = y_true.isin(valid_categories) & y_pred.isin(valid_categories)
285+
y_true = y_true[mask]
286+
y_pred = y_pred[mask]
287+
# Update categories if categorical
288+
if hasattr(y_true, "cat"):
289+
y_true = y_true.cat.remove_unused_categories()
290+
if hasattr(y_pred, "cat"):
291+
y_pred = y_pred.cat.remove_unused_categories()
292+
271293
# Get union of categories if categorical, to handle mismatched category sets
272294
# Also convert to string to avoid sklearn interpreting float categories as continuous
273295
labels = None

0 commit comments

Comments
 (0)