Skip to content

Commit fd7b4dd

Browse files
committed
simplify: remove fallback in _get_category_colors, use single adata
1 parent 6bd740d commit fd7b4dd

File tree

1 file changed

+11
-16
lines changed

1 file changed

+11
-16
lines changed

src/cellmapper/model/evaluate.py

Lines changed: 11 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -23,16 +23,16 @@
2323

2424

2525
def _get_category_colors(
26-
adata_list: list["AnnData"],
26+
adata: "AnnData | None",
2727
label_key: str,
2828
categories: list[str],
2929
) -> list[str]:
3030
"""Get colors for categories from adata.uns, falling back to gray.
3131
3232
Parameters
3333
----------
34-
adata_list
35-
List of AnnData objects to search for colors (in order of priority).
34+
adata
35+
AnnData object to get colors from.
3636
label_key
3737
Key in .obs storing the categorical annotation.
3838
categories
@@ -45,17 +45,12 @@ def _get_category_colors(
4545
colors_key = f"{label_key}_colors"
4646
colors_dict: dict[str, str] = {}
4747

48-
# Collect colors from all adatas (first found takes priority)
49-
for adata in adata_list:
50-
if adata is not None and colors_key in adata.uns:
51-
full_categories = adata.obs[label_key].cat.categories
52-
full_colors = adata.uns[colors_key]
53-
for i, cat in enumerate(full_categories):
54-
if i < len(full_colors):
55-
cat_str = str(cat)
56-
# Only add if not already found (first adata takes priority)
57-
if cat_str not in colors_dict:
58-
colors_dict[cat_str] = full_colors[i]
48+
if adata is not None and colors_key in adata.uns:
49+
full_categories = adata.obs[label_key].cat.categories
50+
full_colors = adata.uns[colors_key]
51+
for i, cat in enumerate(full_categories):
52+
if i < len(full_colors):
53+
colors_dict[str(cat)] = full_colors[i]
5954

6055
return [colors_dict.get(str(cat), "gray") for cat in categories]
6156

@@ -589,8 +584,8 @@ def plot_confusion_matrix(
589584
# Annotation color strips
590585
if show_annotation_colors:
591586
# Row colors (true labels) from query, column colors (predicted) from reference
592-
row_colors = _get_category_colors([self.query, self.reference], label_key, list(cm_display.index))
593-
col_colors = _get_category_colors([self.reference, self.query], label_key, list(cm_display.columns))
587+
row_colors = _get_category_colors(self.query, label_key, list(cm_display.index))
588+
col_colors = _get_category_colors(self.reference, label_key, list(cm_display.columns))
594589
_draw_annotation_strips(ax, row_colors, col_colors, xlabel_position)
595590

596591
if save:

0 commit comments

Comments
 (0)