11"""k-NN based mapping of labels, embeddings, and expression values."""
22
3- import gc
43from typing import Any , Literal
54
65import numpy as np
@@ -681,7 +680,8 @@ def map_obs(
681680 method : Literal ["iterative" , "spectral" ] = "iterative" ,
682681 prediction_postfix : str = "pred" ,
683682 confidence_postfix : str = "conf" ,
684- ) -> None :
683+ return_probabilities : bool = False ,
684+ ) -> np .ndarray | csr_matrix | None :
685685 """
686686 Map observation data from reference dataset to query dataset.
687687
@@ -706,10 +706,16 @@ def map_obs(
706706 confidence_postfix
707707 Postfix added to create new keys in ``query.obs`` for confidence scores
708708 (only applicable for categorical data)
709+ return_probabilities
710+ If True, return the probability matrix for categorical data.
711+ Only applicable for categorical data. The matrix is never densified.
709712
710713 Returns
711714 -------
712- None
715+ np.ndarray, csr_matrix or None
716+ For categorical data with ``return_probabilities=True``: dense or sparse matrix
717+ of shape (n_query_cells, n_categories) containing probabilities.
718+ For numerical data or when ``return_probabilities=False``: None.
713719
714720 Notes
715721 -----
@@ -744,9 +750,14 @@ def map_obs(
744750 logger .info ("Mapping %s data for key '%s' with t=%d steps using %s method." , data_type , key , t , method )
745751
746752 if is_categorical :
747- self ._map_obs_categorical (key , prediction_postfix , confidence_postfix , t , method )
753+ return self ._map_obs_categorical (
754+ key , prediction_postfix , confidence_postfix , t , method , return_probabilities
755+ )
748756 else :
757+ if return_probabilities :
758+ logger .warning ("return_probabilities=True is only applicable for categorical data, ignoring." )
749759 self ._map_obs_numerical (key , prediction_postfix , t , method )
760+ return None
750761
751762 def _map_obs_categorical (
752763 self ,
@@ -755,7 +766,8 @@ def _map_obs_categorical(
755766 confidence_postfix : str ,
756767 t : int | None ,
757768 method : Literal ["iterative" , "spectral" ],
758- ) -> None :
769+ return_probabilities : bool = False ,
770+ ) -> np .ndarray | csr_matrix | None :
759771 """Map categorical observation data using one-hot encoding."""
760772 onehot = OneHotEncoder (dtype = np .float32 )
761773 xtab = onehot .fit_transform (
@@ -766,7 +778,7 @@ def _map_obs_categorical(
766778 ) # shape = (n_query_cells x n_categories), sparse csr matrix, float32
767779
768780 pred = pd .Series (
769- data = np .array (onehot .categories_ [0 ])[ytab .argmax (axis = 1 ).A1 ],
781+ data = np .array (onehot .categories_ [0 ])[ytab .argmax (axis = 1 ).A1 if issparse ( ytab ) else ytab . argmax ( axis = 1 ) ],
770782 index = self .query .obs_names ,
771783 dtype = self .reference .obs [key ].dtype ,
772784 )
@@ -789,9 +801,11 @@ def _map_obs_categorical(
789801
790802 logger .info ("Categorical data mapped and stored in query.obs['%s']." , f"{ key } _{ prediction_postfix } " )
791803
792- # Free memory explicitly
793- del onehot , xtab , ytab , pred , conf
794- gc .collect ()
804+ # Return probabilities if requested (never densify)
805+ if return_probabilities :
806+ return ytab
807+ else :
808+ return None
795809
796810 def _map_obs_numerical (
797811 self , key : str , prediction_postfix : str , t : int | None , method : Literal ["iterative" , "spectral" ]
0 commit comments