2323
2424
2525def _get_category_colors (
26- adata_list : list [ "AnnData" ] ,
26+ adata : "AnnData | None" ,
2727 label_key : str ,
2828 categories : list [str ],
2929) -> list [str ]:
3030 """Get colors for categories from adata.uns, falling back to gray.
3131
3232 Parameters
3333 ----------
34- adata_list
35- List of AnnData objects to search for colors (in order of priority) .
34+ adata
35+ AnnData object to get colors from .
3636 label_key
3737 Key in .obs storing the categorical annotation.
3838 categories
@@ -45,17 +45,12 @@ def _get_category_colors(
4545 colors_key = f"{ label_key } _colors"
4646 colors_dict : dict [str , str ] = {}
4747
48- # Collect colors from all adatas (first found takes priority)
49- for adata in adata_list :
50- if adata is not None and colors_key in adata .uns :
51- full_categories = adata .obs [label_key ].cat .categories
52- full_colors = adata .uns [colors_key ]
53- for i , cat in enumerate (full_categories ):
54- if i < len (full_colors ):
55- cat_str = str (cat )
56- # Only add if not already found (first adata takes priority)
57- if cat_str not in colors_dict :
58- colors_dict [cat_str ] = full_colors [i ]
48+ if adata is not None and colors_key in adata .uns :
49+ full_categories = adata .obs [label_key ].cat .categories
50+ full_colors = adata .uns [colors_key ]
51+ for i , cat in enumerate (full_categories ):
52+ if i < len (full_colors ):
53+ colors_dict [str (cat )] = full_colors [i ]
5954
6055 return [colors_dict .get (str (cat ), "gray" ) for cat in categories ]
6156
@@ -589,8 +584,8 @@ def plot_confusion_matrix(
589584 # Annotation color strips
590585 if show_annotation_colors :
591586 # Row colors (true labels) from query, column colors (predicted) from reference
592- row_colors = _get_category_colors ([ self .query , self . reference ] , label_key , list (cm_display .index ))
593- col_colors = _get_category_colors ([ self .reference , self . query ] , label_key , list (cm_display .columns ))
587+ row_colors = _get_category_colors (self .query , label_key , list (cm_display .index ))
588+ col_colors = _get_category_colors (self .reference , label_key , list (cm_display .columns ))
594589 _draw_annotation_strips (ax , row_colors , col_colors , xlabel_position )
595590
596591 if save :
0 commit comments