Skip to content

Commit c7ef6ef

Browse files
authored
Merge pull request #39 from quadbio/feat/confidence
Return the mapping probabilities
2 parents a6605cd + 5650d6f commit c7ef6ef

File tree

1 file changed

+25
-10
lines changed

1 file changed

+25
-10
lines changed

src/cellmapper/model/cellmapper.py

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
"""k-NN based mapping of labels, embeddings, and expression values."""
22

3-
import gc
43
from typing import Any, Literal
54

65
import numpy as np
@@ -544,7 +543,8 @@ def map(
544543
"hnoca",
545544
"equal",
546545
"umap",
547-
] = "gauss",
546+
]
547+
| None = None,
548548
symmetrize: bool | None = None,
549549
self_edges: bool | None = None,
550550
prediction_postfix: str = "pred",
@@ -680,7 +680,8 @@ def map_obs(
680680
method: Literal["iterative", "spectral"] = "iterative",
681681
prediction_postfix: str = "pred",
682682
confidence_postfix: str = "conf",
683-
) -> None:
683+
return_probabilities: bool = False,
684+
) -> np.ndarray | csr_matrix | None:
684685
"""
685686
Map observation data from reference dataset to query dataset.
686687
@@ -705,10 +706,16 @@ def map_obs(
705706
confidence_postfix
706707
Postfix added to create new keys in ``query.obs`` for confidence scores
707708
(only applicable for categorical data)
709+
return_probabilities
710+
If True, return the probability matrix for categorical data.
711+
Only applicable for categorical data. The matrix is never densified.
708712
709713
Returns
710714
-------
711-
None
715+
np.ndarray, csr_matrix or None
716+
For categorical data with ``return_probabilities=True``: dense or sparse matrix
717+
of shape (n_query_cells, n_categories) containing probabilities.
718+
For numerical data or when ``return_probabilities=False``: None.
712719
713720
Notes
714721
-----
@@ -743,9 +750,14 @@ def map_obs(
743750
logger.info("Mapping %s data for key '%s' with t=%d steps using %s method.", data_type, key, t, method)
744751

745752
if is_categorical:
746-
self._map_obs_categorical(key, prediction_postfix, confidence_postfix, t, method)
753+
return self._map_obs_categorical(
754+
key, prediction_postfix, confidence_postfix, t, method, return_probabilities
755+
)
747756
else:
757+
if return_probabilities:
758+
logger.warning("return_probabilities=True is only applicable for categorical data, ignoring.")
748759
self._map_obs_numerical(key, prediction_postfix, t, method)
760+
return None
749761

750762
def _map_obs_categorical(
751763
self,
@@ -754,7 +766,8 @@ def _map_obs_categorical(
754766
confidence_postfix: str,
755767
t: int | None,
756768
method: Literal["iterative", "spectral"],
757-
) -> None:
769+
return_probabilities: bool = False,
770+
) -> np.ndarray | csr_matrix | None:
758771
"""Map categorical observation data using one-hot encoding."""
759772
onehot = OneHotEncoder(dtype=np.float32)
760773
xtab = onehot.fit_transform(
@@ -765,7 +778,7 @@ def _map_obs_categorical(
765778
) # shape = (n_query_cells x n_categories), sparse csr matrix, float32
766779

767780
pred = pd.Series(
768-
data=np.array(onehot.categories_[0])[ytab.argmax(axis=1).A1],
781+
data=np.array(onehot.categories_[0])[ytab.argmax(axis=1).A1 if issparse(ytab) else ytab.argmax(axis=1)],
769782
index=self.query.obs_names,
770783
dtype=self.reference.obs[key].dtype,
771784
)
@@ -788,9 +801,11 @@ def _map_obs_categorical(
788801

789802
logger.info("Categorical data mapped and stored in query.obs['%s'].", f"{key}_{prediction_postfix}")
790803

791-
# Free memory explicitly
792-
del onehot, xtab, ytab, pred, conf
793-
gc.collect()
804+
# Return probabilities if requested (never densify)
805+
if return_probabilities:
806+
return ytab
807+
else:
808+
return None
794809

795810
def _map_obs_numerical(
796811
self, key: str, prediction_postfix: str, t: int | None, method: Literal["iterative", "spectral"]

0 commit comments

Comments
 (0)