Skip to content

Commit b95cae7

Browse files
authored
Merge pull request #21 from quadbio/feat/input_imputed
Work with pre-computed objects for imputed values and k-NN graphs
2 parents f0fb6eb + 5fb88f1 commit b95cae7

File tree

11 files changed

+1352
-130
lines changed

11 files changed

+1352
-130
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ optional-dependencies.doc = [
5858
optional-dependencies.test = [
5959
"coverage",
6060
"pytest",
61+
"squidpy",
6162
]
6263
optional-dependencies.tutorials = [
6364
"harmony-pytorch",

src/cellmapper/cellmapper.py

Lines changed: 157 additions & 70 deletions
Large diffs are not rendered by default.

src/cellmapper/check.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,14 +57,17 @@ def check(self) -> None:
5757

5858

5959
INSTALL_HINTS = types.SimpleNamespace(
60-
rapids="To speed up k-NN search on GPU, you may install rapids following the guide from "
61-
"https://docs.rapids.ai/install/. Note that you will only need cuML.",
60+
cuml="To speed up k-NN search on GPU, you may install cuML following the guide from "
61+
"https://docs.rapids.ai/install/.",
62+
cupy="To speed up k-NN search on GPU, you may install cuPy following the guide from "
63+
"https://docs.rapids.ai/install/.",
6264
faiss="To speed up k-NN search on GPU, you may install faiss following the guide from "
6365
"https://github.com/facebookresearch/faiss/blob/main/INSTALL.md",
6466
)
6567

6668
CHECKERS = {
67-
"rapids": Checker("rapids", vmin=None, install_hint=INSTALL_HINTS.rapids),
69+
"cuml": Checker("cuml", vmin=None, install_hint=INSTALL_HINTS.cuml),
70+
"cupy": Checker("cupy", vmin=None, install_hint=INSTALL_HINTS.cupy),
6871
"faiss": Checker("faiss", package_name="faiss", vmin="1.7.0", install_hint=INSTALL_HINTS.faiss),
6972
}
7073

src/cellmapper/evaluate.py

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,6 @@ def _jensen_shannon_divergence(p: np.ndarray, q: np.ndarray) -> float:
3535
q = np.clip(q, 0, None)
3636
if p.sum() == 0 or q.sum() == 0:
3737
return np.nan
38-
p = p / p.sum()
39-
q = q / q.sum()
4038
return jensenshannon(p, q, base=10)
4139

4240

@@ -226,7 +224,11 @@ def evaluate_expression_transfer(
226224

227225
# Helper to compute metrics for a given mask of cells
228226
def compute_metrics(mask):
229-
return np.array([metric_func(original_x[mask, i], imputed_x[mask, i]) for i in range(imputed_x.shape[1])])
227+
# Explicitly return as float32 to match DataFrame's dtype
228+
return np.array(
229+
[metric_func(original_x[mask, i], imputed_x[mask, i]) for i in range(imputed_x.shape[1])],
230+
dtype=np.float32,
231+
)
230232

231233
# Compute metrics for all cells
232234
overall_mask = np.ones(original_x.shape[0], dtype=bool)
@@ -273,7 +275,9 @@ def _get_aligned_expression_arrays(self, layer_key: str) -> tuple[np.ndarray, np
273275
imputed_x, original_x, shared_genes
274276
"""
275277
if self.query_imputed is None:
276-
raise ValueError("Imputed query data not found. Run transfer_expression() first.")
278+
raise ValueError(
279+
"Imputed query data not found. Either run transfer_expression() first or set query_imputed manually."
280+
)
277281
shared_genes = list(self.query_imputed.var_names.intersection(self.query.var_names))
278282
if len(shared_genes) == 0:
279283
raise ValueError("No shared genes between query_imputed and query.")
@@ -376,8 +380,8 @@ def estimate_presence_score(
376380
groupby
377381
Column in self.query.obs to group query cells by (e.g., cell type, batch). If None, computes a single score for all query cells.
378382
key_added
379-
Key to store the presence score: always writes the score across all query cells to self.ref.obs[key_added].
380-
If groupby is not None, also writes per-group scores as a DataFrame to self.ref.obsm[key_added].
383+
Key to store the presence score: always writes the score across all query cells to self.reference.obs[key_added].
384+
If groupby is not None, also writes per-group scores as a DataFrame to self.reference.obsm[key_added].
381385
log
382386
Whether to apply log1p transformation to the scores.
383387
percentile
@@ -387,30 +391,30 @@ def estimate_presence_score(
387391
raise ValueError("Neighbors must be computed before estimating presence scores.")
388392

389393
conn = self.knn.yx.knn_graph_connectivities()
390-
ref_names = self.ref.obs_names
394+
reference_names = self.reference.obs_names
391395

392396
# Always compute and post-process the overall score (all query cells)
393397
scores_all = np.array(conn.sum(axis=0)).flatten()
394-
df_all = pd.DataFrame({"all": scores_all}, index=ref_names)
398+
df_all = pd.DataFrame({"all": scores_all}, index=reference_names)
395399
df_all_processed = process_presence_scores(df_all, log=log, percentile=percentile)
396-
self.ref.obs[key_added] = df_all_processed["all"]
397-
logger.info("Presence score across all query cells computed and stored in `ref.obs['%s']`", key_added)
400+
self.reference.obs[key_added] = df_all_processed["all"]
401+
logger.info("Presence score across all query cells computed and stored in `reference.obs['%s']`", key_added)
398402

399403
# If groupby, also compute and post-process per-group scores
400404
if groupby is not None:
401405
group_labels = self.query.obs[groupby]
402406
groups = group_labels.unique()
403-
score_matrix = np.zeros((len(ref_names), len(groups)), dtype=np.float32)
407+
score_matrix = np.zeros((len(reference_names), len(groups)), dtype=np.float32)
404408
for i, group in enumerate(groups):
405409
mask = group_labels == group
406410
group_conn = conn[mask.values, :]
407411
score_matrix[:, i] = np.array(group_conn.sum(axis=0)).flatten()
408-
df_groups = pd.DataFrame(score_matrix, index=ref_names, columns=groups)
412+
df_groups = pd.DataFrame(score_matrix, index=reference_names, columns=groups)
409413
df_groups_processed = process_presence_scores(df_groups, log=log, percentile=percentile)
410-
self.ref.obsm[key_added] = df_groups_processed
414+
self.reference.obsm[key_added] = df_groups_processed
411415

412416
logger.info(
413-
"Presence scores per group defined in `query.obs['%s']` computed and stored in `ref.obsm['%s']`",
417+
"Presence scores per group defined in `query.obs['%s']` computed and stored in `reference.obsm['%s']`",
414418
groupby,
415419
key_added,
416420
)

0 commit comments

Comments
 (0)