11"""k-NN based mapping of labels, embeddings, and expression values."""
22
3- import gc
43from typing import Any , Literal
54
65import numpy as np
@@ -544,7 +543,8 @@ def map(
544543 "hnoca" ,
545544 "equal" ,
546545 "umap" ,
547- ] = "gauss" ,
546+ ]
547+ | None = None ,
548548 symmetrize : bool | None = None ,
549549 self_edges : bool | None = None ,
550550 prediction_postfix : str = "pred" ,
@@ -680,7 +680,8 @@ def map_obs(
680680 method : Literal ["iterative" , "spectral" ] = "iterative" ,
681681 prediction_postfix : str = "pred" ,
682682 confidence_postfix : str = "conf" ,
683- ) -> None :
683+ return_probabilities : bool = False ,
684+ ) -> np .ndarray | csr_matrix | None :
684685 """
685686 Map observation data from reference dataset to query dataset.
686687
@@ -705,10 +706,16 @@ def map_obs(
705706 confidence_postfix
706707 Postfix added to create new keys in ``query.obs`` for confidence scores
707708 (only applicable for categorical data)
709+ return_probabilities
710+ If True, return the probability matrix for categorical data.
711+ Only applicable for categorical data. The matrix is never densified.
708712
709713 Returns
710714 -------
711- None
715+ np.ndarray, csr_matrix or None
716+ For categorical data with ``return_probabilities=True``: dense or sparse matrix
717+ of shape (n_query_cells, n_categories) containing probabilities.
718+ For numerical data or when ``return_probabilities=False``: None.
712719
713720 Notes
714721 -----
@@ -743,9 +750,14 @@ def map_obs(
743750 logger .info ("Mapping %s data for key '%s' with t=%d steps using %s method." , data_type , key , t , method )
744751
745752 if is_categorical :
746- self ._map_obs_categorical (key , prediction_postfix , confidence_postfix , t , method )
753+ return self ._map_obs_categorical (
754+ key , prediction_postfix , confidence_postfix , t , method , return_probabilities
755+ )
747756 else :
757+ if return_probabilities :
758+ logger .warning ("return_probabilities=True is only applicable for categorical data, ignoring." )
748759 self ._map_obs_numerical (key , prediction_postfix , t , method )
760+ return None
749761
750762 def _map_obs_categorical (
751763 self ,
@@ -754,7 +766,8 @@ def _map_obs_categorical(
754766 confidence_postfix : str ,
755767 t : int | None ,
756768 method : Literal ["iterative" , "spectral" ],
757- ) -> None :
769+ return_probabilities : bool = False ,
770+ ) -> np .ndarray | csr_matrix | None :
758771 """Map categorical observation data using one-hot encoding."""
759772 onehot = OneHotEncoder (dtype = np .float32 )
760773 xtab = onehot .fit_transform (
@@ -765,7 +778,7 @@ def _map_obs_categorical(
765778 ) # shape = (n_query_cells x n_categories), sparse csr matrix, float32
766779
767780 pred = pd .Series (
768- data = np .array (onehot .categories_ [0 ])[ytab .argmax (axis = 1 ).A1 ],
781+ data = np .array (onehot .categories_ [0 ])[ytab .argmax (axis = 1 ).A1 if issparse ( ytab ) else ytab . argmax ( axis = 1 ) ],
769782 index = self .query .obs_names ,
770783 dtype = self .reference .obs [key ].dtype ,
771784 )
@@ -788,9 +801,11 @@ def _map_obs_categorical(
788801
789802 logger .info ("Categorical data mapped and stored in query.obs['%s']." , f"{ key } _{ prediction_postfix } " )
790803
791- # Free memory explicitly
792- del onehot , xtab , ytab , pred , conf
793- gc .collect ()
804+ # Return probabilities if requested (never densify)
805+ if return_probabilities :
806+ return ytab
807+ else :
808+ return None
794809
795810 def _map_obs_numerical (
796811 self , key : str , prediction_postfix : str , t : int | None , method : Literal ["iterative" , "spectral" ]
0 commit comments