Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ optional-dependencies.doc = [
optional-dependencies.test = [
"coverage",
"pytest",
"squidpy",
]
optional-dependencies.tutorials = [
"harmony-pytorch",
Expand Down
227 changes: 157 additions & 70 deletions src/cellmapper/cellmapper.py

Large diffs are not rendered by default.

9 changes: 6 additions & 3 deletions src/cellmapper/check.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,14 +57,17 @@ def check(self) -> None:


INSTALL_HINTS = types.SimpleNamespace(
rapids="To speed up k-NN search on GPU, you may install rapids following the guide from "
"https://docs.rapids.ai/install/. Note that you will only need cuML.",
cuml="To speed up k-NN search on GPU, you may install cuML following the guide from "
"https://docs.rapids.ai/install/.",
cupy="To speed up k-NN search on GPU, you may install cuPy following the guide from "
"https://docs.rapids.ai/install/.",
faiss="To speed up k-NN search on GPU, you may install faiss following the guide from "
"https://github.com/facebookresearch/faiss/blob/main/INSTALL.md",
)

CHECKERS = {
"rapids": Checker("rapids", vmin=None, install_hint=INSTALL_HINTS.rapids),
"cuml": Checker("cuml", vmin=None, install_hint=INSTALL_HINTS.cuml),
"cupy": Checker("cupy", vmin=None, install_hint=INSTALL_HINTS.cupy),
"faiss": Checker("faiss", package_name="faiss", vmin="1.7.0", install_hint=INSTALL_HINTS.faiss),
}

Expand Down
32 changes: 18 additions & 14 deletions src/cellmapper/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,6 @@ def _jensen_shannon_divergence(p: np.ndarray, q: np.ndarray) -> float:
q = np.clip(q, 0, None)
if p.sum() == 0 or q.sum() == 0:
return np.nan
p = p / p.sum()
q = q / q.sum()
return jensenshannon(p, q, base=10)


Expand Down Expand Up @@ -226,7 +224,11 @@ def evaluate_expression_transfer(

# Helper to compute metrics for a given mask of cells
def compute_metrics(mask):
return np.array([metric_func(original_x[mask, i], imputed_x[mask, i]) for i in range(imputed_x.shape[1])])
# Explicitly return as float32 to match DataFrame's dtype
return np.array(
[metric_func(original_x[mask, i], imputed_x[mask, i]) for i in range(imputed_x.shape[1])],
dtype=np.float32,
)

# Compute metrics for all cells
overall_mask = np.ones(original_x.shape[0], dtype=bool)
Expand Down Expand Up @@ -273,7 +275,9 @@ def _get_aligned_expression_arrays(self, layer_key: str) -> tuple[np.ndarray, np
imputed_x, original_x, shared_genes
"""
if self.query_imputed is None:
raise ValueError("Imputed query data not found. Run transfer_expression() first.")
raise ValueError(
"Imputed query data not found. Either run transfer_expression() first or set query_imputed manually."
)
shared_genes = list(self.query_imputed.var_names.intersection(self.query.var_names))
if len(shared_genes) == 0:
raise ValueError("No shared genes between query_imputed and query.")
Expand Down Expand Up @@ -376,8 +380,8 @@ def estimate_presence_score(
groupby
Column in self.query.obs to group query cells by (e.g., cell type, batch). If None, computes a single score for all query cells.
key_added
Key to store the presence score: always writes the score across all query cells to self.ref.obs[key_added].
If groupby is not None, also writes per-group scores as a DataFrame to self.ref.obsm[key_added].
Key to store the presence score: always writes the score across all query cells to self.reference.obs[key_added].
If groupby is not None, also writes per-group scores as a DataFrame to self.reference.obsm[key_added].
log
Whether to apply log1p transformation to the scores.
percentile
Expand All @@ -387,30 +391,30 @@ def estimate_presence_score(
raise ValueError("Neighbors must be computed before estimating presence scores.")

conn = self.knn.yx.knn_graph_connectivities()
ref_names = self.ref.obs_names
reference_names = self.reference.obs_names

# Always compute and post-process the overall score (all query cells)
scores_all = np.array(conn.sum(axis=0)).flatten()
df_all = pd.DataFrame({"all": scores_all}, index=ref_names)
df_all = pd.DataFrame({"all": scores_all}, index=reference_names)
df_all_processed = process_presence_scores(df_all, log=log, percentile=percentile)
self.ref.obs[key_added] = df_all_processed["all"]
logger.info("Presence score across all query cells computed and stored in `ref.obs['%s']`", key_added)
self.reference.obs[key_added] = df_all_processed["all"]
logger.info("Presence score across all query cells computed and stored in `reference.obs['%s']`", key_added)

# If groupby, also compute and post-process per-group scores
if groupby is not None:
group_labels = self.query.obs[groupby]
groups = group_labels.unique()
score_matrix = np.zeros((len(ref_names), len(groups)), dtype=np.float32)
score_matrix = np.zeros((len(reference_names), len(groups)), dtype=np.float32)
for i, group in enumerate(groups):
mask = group_labels == group
group_conn = conn[mask.values, :]
score_matrix[:, i] = np.array(group_conn.sum(axis=0)).flatten()
df_groups = pd.DataFrame(score_matrix, index=ref_names, columns=groups)
df_groups = pd.DataFrame(score_matrix, index=reference_names, columns=groups)
df_groups_processed = process_presence_scores(df_groups, log=log, percentile=percentile)
self.ref.obsm[key_added] = df_groups_processed
self.reference.obsm[key_added] = df_groups_processed

logger.info(
"Presence scores per group defined in `query.obs['%s']` computed and stored in `ref.obsm['%s']`",
"Presence scores per group defined in `query.obs['%s']` computed and stored in `reference.obsm['%s']`",
groupby,
key_added,
)
Expand Down
Loading
Loading