Skip to content

Commit 5650d6f

Browse files
committed
Add option to return probabilities
1 parent 10d4750 commit 5650d6f

File tree

1 file changed

+23
-9
lines changed

1 file changed

+23
-9
lines changed

src/cellmapper/model/cellmapper.py

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
"""k-NN based mapping of labels, embeddings, and expression values."""
22

3-
import gc
43
from typing import Any, Literal
54

65
import numpy as np
@@ -681,7 +680,8 @@ def map_obs(
681680
method: Literal["iterative", "spectral"] = "iterative",
682681
prediction_postfix: str = "pred",
683682
confidence_postfix: str = "conf",
684-
) -> None:
683+
return_probabilities: bool = False,
684+
) -> np.ndarray | csr_matrix | None:
685685
"""
686686
Map observation data from reference dataset to query dataset.
687687
@@ -706,10 +706,16 @@ def map_obs(
706706
confidence_postfix
707707
Postfix added to create new keys in ``query.obs`` for confidence scores
708708
(only applicable for categorical data)
709+
return_probabilities
710+
If True, return the probability matrix for categorical data.
711+
Only applicable for categorical data. The matrix is never densified.
709712
710713
Returns
711714
-------
712-
None
715+
np.ndarray, csr_matrix or None
716+
For categorical data with ``return_probabilities=True``: dense or sparse matrix
717+
of shape (n_query_cells, n_categories) containing probabilities.
718+
For numerical data or when ``return_probabilities=False``: None.
713719
714720
Notes
715721
-----
@@ -744,9 +750,14 @@ def map_obs(
744750
logger.info("Mapping %s data for key '%s' with t=%d steps using %s method.", data_type, key, t, method)
745751

746752
if is_categorical:
747-
self._map_obs_categorical(key, prediction_postfix, confidence_postfix, t, method)
753+
return self._map_obs_categorical(
754+
key, prediction_postfix, confidence_postfix, t, method, return_probabilities
755+
)
748756
else:
757+
if return_probabilities:
758+
logger.warning("return_probabilities=True is only applicable for categorical data, ignoring.")
749759
self._map_obs_numerical(key, prediction_postfix, t, method)
760+
return None
750761

751762
def _map_obs_categorical(
752763
self,
@@ -755,7 +766,8 @@ def _map_obs_categorical(
755766
confidence_postfix: str,
756767
t: int | None,
757768
method: Literal["iterative", "spectral"],
758-
) -> None:
769+
return_probabilities: bool = False,
770+
) -> np.ndarray | csr_matrix | None:
759771
"""Map categorical observation data using one-hot encoding."""
760772
onehot = OneHotEncoder(dtype=np.float32)
761773
xtab = onehot.fit_transform(
@@ -766,7 +778,7 @@ def _map_obs_categorical(
766778
) # shape = (n_query_cells x n_categories), sparse csr matrix, float32
767779

768780
pred = pd.Series(
769-
data=np.array(onehot.categories_[0])[ytab.argmax(axis=1).A1],
781+
data=np.array(onehot.categories_[0])[ytab.argmax(axis=1).A1 if issparse(ytab) else ytab.argmax(axis=1)],
770782
index=self.query.obs_names,
771783
dtype=self.reference.obs[key].dtype,
772784
)
@@ -789,9 +801,11 @@ def _map_obs_categorical(
789801

790802
logger.info("Categorical data mapped and stored in query.obs['%s'].", f"{key}_{prediction_postfix}")
791803

792-
# Free memory explicitly
793-
del onehot, xtab, ytab, pred, conf
794-
gc.collect()
804+
# Return probabilities if requested (never densify)
805+
if return_probabilities:
806+
return ytab
807+
else:
808+
return None
795809

796810
def _map_obs_numerical(
797811
self, key: str, prediction_postfix: str, t: int | None, method: Literal["iterative", "spectral"]

0 commit comments

Comments
 (0)