Skip to content

Commit b362286

Browse files
authored
Merge pull request #16 from quadbio/feat/presence_scores
Implement presence scores and allow evaluation per group
2 parents 6c461b6 + 9b8c544 commit b362286

File tree

4 files changed

+219
-80
lines changed

4 files changed

+219
-80
lines changed

src/cellmapper/evaluate.py

Lines changed: 164 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -175,9 +175,10 @@ def evaluate_expression_transfer(
175175
self,
176176
layer_key: str = "X",
177177
method: Literal["pearson", "spearman", "js", "rmse"] = "pearson",
178+
groupby: str | None = None,
178179
) -> None:
179180
"""
180-
Evaluate the agreement between imputed and original expression in the query dataset.
181+
Evaluate the agreement between imputed and original expression in the query dataset, optionally per group.
181182
182183
These metrics are inspired by Li et al., Nature Methods 2022 (https://www.nature.com/articles/s41592-022-01480-9).
183184
@@ -186,59 +187,69 @@ def evaluate_expression_transfer(
186187
layer_key
187188
Key in `self.query.layers` to use as the original expression. Use "X" to use `self.query.X`.
188189
method
189-
Method to quantify agreement. Currently supported: "pearson" (average Pearson correlation across shared genes).
190+
Method to quantify agreement. Supported: "pearson", "spearman", "js", "rmse".
191+
groupby
192+
Column in self.query.obs to group query cells by (e.g., cell type, batch). If None, computes a single score for all query cells.
190193
191194
Returns
192195
-------
193196
Nothing, but updates the following attributes:
194197
expression_transfer_metrics
195-
Dictionary containing the average correlation and number of genes used for the evaluation.
198+
Dictionary containing the average metric and number of genes used for the evaluation.
199+
query.var[f"metric_{method}"]
200+
Per-gene metric values (overall, across all cells).
201+
query.varm[f"metric_{method}"]
202+
Per-gene, per-group metric values (if groupby is provided).
196203
"""
197204
imputed_x, original_x, shared_genes = self._get_aligned_expression_arrays(layer_key)
198-
log_msg = "Expression transfer evaluation (%s): average value = %.4f (n_genes=%d, n_valid_genes=%d)"
199-
if method == "pearson" or method == "spearman":
200-
if method == "pearson":
201-
corr_func = pearsonr
202-
elif method == "spearman":
203-
corr_func = spearmanr
204-
205-
# Compute per-gene correlation
206-
corrs = np.array([corr_func(original_x[:, i], imputed_x[:, i])[0] for i in range(imputed_x.shape[1])])
207-
208-
# Store per-gene correlation in query_imputed.var, only for shared genes
209-
self._store_expression_metric(
210-
shared_genes,
211-
corrs,
212-
f"metric_{method}",
213-
"average_correlation",
214-
method,
215-
log_msg,
216-
)
205+
206+
# Select metric function
207+
if method == "pearson":
208+
metric_func = lambda a, b: pearsonr(a, b)[0]
209+
elif method == "spearman":
210+
metric_func = lambda a, b: spearmanr(a, b)[0]
217211
elif method in ("js", "jensen-shannon"):
218-
jsds = np.array(
219-
[_jensen_shannon_divergence(imputed_x[:, i], original_x[:, i]) for i in range(imputed_x.shape[1])]
220-
)
221-
self._store_expression_metric(
222-
shared_genes,
223-
jsds,
224-
"metric_js",
225-
"average_jsd",
226-
"js",
227-
log_msg,
228-
)
212+
metric_func = _jensen_shannon_divergence
229213
elif method == "rmse":
230-
rmses = np.array([_rmse_zscore(imputed_x[:, i], original_x[:, i]) for i in range(imputed_x.shape[1])])
231-
self._store_expression_metric(
232-
shared_genes,
233-
rmses,
234-
"metric_rmse",
235-
"average_rmse",
236-
"rmse",
237-
log_msg,
238-
)
214+
metric_func = _rmse_zscore
239215
else:
240216
raise NotImplementedError(f"Method '{method}' is not implemented.")
241217

218+
# Helper to compute metrics for a given mask of cells
219+
def compute_metrics(mask):
220+
return np.array([metric_func(original_x[mask, i], imputed_x[mask, i]) for i in range(imputed_x.shape[1])])
221+
222+
# Compute metrics for all cells
223+
overall_mask = np.ones(original_x.shape[0], dtype=bool)
224+
overall_metrics = compute_metrics(overall_mask)
225+
self._store_expression_metric(
226+
shared_genes,
227+
overall_metrics,
228+
method,
229+
)
230+
231+
if groupby is not None:
232+
# Prepare DataFrame to store per-group metrics
233+
group_labels = self.query.obs[groupby]
234+
groups = group_labels.unique()
235+
metrics_df = pd.DataFrame(
236+
np.full((self.query.n_vars, len(groups)), np.nan, dtype=np.float32),
237+
index=self.query.var_names,
238+
columns=groups,
239+
)
240+
241+
# Compute and store metrics for each group
242+
for group in groups:
243+
mask = group_labels == group
244+
metrics_df.loc[shared_genes, group] = compute_metrics(mask.values)
245+
self.query.varm[f"metric_{method}"] = metrics_df
246+
247+
logger.info(
248+
"Metrics per group defined in `query.obs['%s']` computed and stored in `query.varm['%s']`",
249+
groupby,
250+
f"metric_{method}",
251+
)
252+
242253
def _get_aligned_expression_arrays(self, layer_key: str) -> tuple[np.ndarray, np.ndarray, list[str]]:
243254
"""
244255
Extract and align imputed and original expression arrays for shared genes between query_imputed and query.
@@ -271,53 +282,136 @@ def _store_expression_metric(
271282
self,
272283
shared_genes: list[str],
273284
values: np.ndarray,
274-
metric_name: str,
275-
summary_key: str,
276-
method_label: str,
277-
log_msg: str,
285+
method: str,
278286
) -> None:
279287
"""
280-
Store per-gene and summary expression transfer metrics in the query_imputed AnnData object and log the results.
288+
Store per-gene and summary expression transfer metrics in the query AnnData object and log the results.
281289
282290
Parameters
283291
----------
284292
shared_genes
285293
List of shared gene names.
286294
values
287-
Array of per-gene metric values (e.g., correlation, JSD).
288-
metric_name
289-
Name of the column to store per-gene values in query_imputed.var.
290-
summary_key
291-
Key for the average metric in self.expression_transfer_metrics.
292-
method_label
295+
Array of per-gene metric values (e.g., correlation, JSD) or 2D array (genes x groups).
296+
method
293297
Name of the method/metric (for logging and summary dict).
294-
log_msg
295-
Logging message format string, should accept (avg_value, n_genes, n_valid_genes).
296-
297-
Returns
298-
-------
299-
Nothing, but updates the following attributes:
300-
query_imputed.var[metric_name]
301-
DataFrame column with per-gene metric values.
302-
expression_transfer_metrics
303-
Dictionary containing the average metric value and number of genes used for the evaluation.
304298
"""
305-
self.query_imputed.var[metric_name] = None
306-
self.query_imputed.var.loc[shared_genes, metric_name] = values
299+
# Store overall metric in .var
300+
self.query.var[f"metric_{method}"] = np.nan
301+
self.query.var.loc[shared_genes, f"metric_{method}"] = values
307302
valid_values = values[~np.isnan(values)]
303+
308304
if valid_values.size == 0:
309-
raise ValueError(f"No valid genes for {method_label} calculation.")
305+
raise ValueError(f"No valid genes for {method} calculation.")
310306
avg_value = float(np.mean(valid_values))
311307
self.expression_transfer_metrics = {
312-
"method": method_label,
313-
summary_key: avg_value,
308+
"method": method,
309+
"average": avg_value,
314310
"n_genes": len(shared_genes),
315311
"n_valid_genes": int(valid_values.size),
316312
}
313+
317314
logger.info(
318-
log_msg,
319-
method_label,
315+
"Expression transfer evaluation (%s): average value = %.4f (n_genes=%d, n_valid_genes=%d)",
316+
method,
320317
avg_value,
321318
len(shared_genes),
322319
int(valid_values.size),
323320
)
321+
322+
def estimate_presence_score(
323+
self,
324+
groupby: str | None = None,
325+
key_added: str = "presence_score",
326+
log: bool = False,
327+
percentile: tuple[float, float] = (1, 99),
328+
):
329+
"""
330+
Estimate raw presence scores for each reference cell based on query-to-reference connectivities.
331+
332+
Adapted from the HNOCA-tools package: https://github.com/devsystemslab/HNOCA-tools
333+
334+
Parameters
335+
----------
336+
groupby
337+
Column in self.query.obs to group query cells by (e.g., cell type, batch). If None, computes a single score for all query cells.
338+
key_added
339+
Key to store the presence score: always writes the score across all query cells to self.ref.obs[key_added].
340+
If groupby is not None, also writes per-group scores as a DataFrame to self.ref.obsm[key_added].
341+
log
342+
Whether to apply log1p transformation to the scores.
343+
percentile
344+
Tuple of (low, high) percentiles for clipping scores before normalization.
345+
"""
346+
if self.knn is None or self.knn.yx is None:
347+
raise ValueError("Neighbors must be computed before estimating presence scores.")
348+
349+
conn = self.knn.yx.knn_graph_connectivities()
350+
ref_names = self.ref.obs_names
351+
352+
# Always compute and post-process the overall score (all query cells)
353+
scores_all = np.array(conn.sum(axis=0)).flatten()
354+
df_all = pd.DataFrame({"all": scores_all}, index=ref_names)
355+
df_all_processed = process_presence_scores(df_all, log=log, percentile=percentile)
356+
self.ref.obs[key_added] = df_all_processed["all"]
357+
logger.info("Presence score across all query cells computed and stored in `ref.obs['%s']`", key_added)
358+
359+
# If groupby, also compute and post-process per-group scores
360+
if groupby is not None:
361+
group_labels = self.query.obs[groupby]
362+
groups = group_labels.unique()
363+
score_matrix = np.zeros((len(ref_names), len(groups)), dtype=np.float32)
364+
for i, group in enumerate(groups):
365+
mask = group_labels == group
366+
group_conn = conn[mask.values, :]
367+
score_matrix[:, i] = np.array(group_conn.sum(axis=0)).flatten()
368+
df_groups = pd.DataFrame(score_matrix, index=ref_names, columns=groups)
369+
df_groups_processed = process_presence_scores(df_groups, log=log, percentile=percentile)
370+
self.ref.obsm[key_added] = df_groups_processed
371+
372+
logger.info(
373+
"Presence scores per group defined in `query.obs['%s']` computed and stored in `ref.obsm['%s']`",
374+
groupby,
375+
key_added,
376+
)
377+
378+
379+
def process_presence_scores(
380+
scores: pd.DataFrame,
381+
log: bool = False,
382+
percentile: tuple[float, float] = (1, 99),
383+
) -> pd.DataFrame:
384+
"""
385+
Post-process presence scores with log1p, percentile clipping, and min-max normalization.
386+
387+
Parameters
388+
----------
389+
scores
390+
DataFrame of raw presence scores (rows: reference cells, columns: groups or 'all').
391+
log
392+
Whether to apply log1p transformation to the scores.
393+
percentile
394+
Tuple of (low, high) percentiles for clipping scores before normalization.
395+
396+
Returns
397+
-------
398+
pd.DataFrame
399+
Post-processed presence scores, same shape as input.
400+
"""
401+
# Log1p transformation (optional)
402+
if log:
403+
scores = np.log1p(scores)
404+
405+
# Percentile clipping (optional)
406+
if percentile != (0, 100):
407+
low, high = percentile
408+
scores = scores.apply(lambda x: np.clip(x, np.percentile(x, low), np.percentile(x, high)), axis=0)
409+
410+
# Min-max normalization (always)
411+
def minmax(x):
412+
min_val, max_val = np.min(x), np.max(x)
413+
return (x - min_val) / (max_val - min_val) if max_val > min_val else np.zeros_like(x)
414+
415+
scores = scores.apply(minmax, axis=0)
416+
417+
return scores

tests/conftest.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,10 @@ def query_ref_adata(adata_pbmc3k):
8888
)
8989
query = query[:, query_genes].copy()
9090

91+
# Introduce two deterministic batch categories in the query AnnData object
92+
query.obs["batch"] = np.repeat(["A", "B"], repeats=n_query_cells // 2).tolist()
93+
query.obs["batch"] = query.obs["batch"].astype("category")
94+
9195
return query, ref
9296

9397

@@ -124,7 +128,7 @@ def expected_label_transfer_metrics():
124128
def expected_expression_transfer_metrics():
125129
return {
126130
"method": "pearson",
127-
"average_correlation": 0.376,
131+
"average": 0.376,
128132
"n_genes": 300,
129133
"n_valid_genes": 300,
130134
}

tests/test_cellmapper.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -40,15 +40,6 @@ def test_expression_transfer_layers(self, cmap, layer_key):
4040
assert cmap.query_imputed is not None
4141
assert cmap.query_imputed.X.shape[0] == cmap.query.n_obs
4242

43-
@pytest.mark.parametrize("eval_layer", ["X", "counts"])
44-
@pytest.mark.parametrize("method", ["pearson", "spearman", "js", "rmse"])
45-
def test_evaluate_expression_transfer_layers_and_methods(self, cmap, eval_layer, method):
46-
cmap.transfer_expression(layer_key="X")
47-
cmap.evaluate_expression_transfer(layer_key=eval_layer, method=method)
48-
metrics = cmap.expression_transfer_metrics
49-
assert metrics["method"] == method
50-
assert metrics["n_valid_genes"] > 0
51-
5243
@pytest.mark.parametrize(
5344
"joint_pca_key,n_pca_components,pca_kwargs",
5445
[

tests/test_evaluate.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
import numpy as np
2+
import pandas as pd
3+
import pytest
4+
5+
6+
class TestEvaluate:
7+
@pytest.mark.parametrize("eval_layer", ["X", "counts"])
8+
@pytest.mark.parametrize("method", ["pearson", "spearman", "js", "rmse"])
9+
@pytest.mark.parametrize("groupby", ["batch", "modality"])
10+
def test_evaluate_expression_transfer_layers_and_methods(self, cmap, eval_layer, method, groupby):
11+
cmap.transfer_expression(layer_key="X")
12+
cmap.evaluate_expression_transfer(layer_key=eval_layer, method=method, groupby=groupby)
13+
metrics = cmap.expression_transfer_metrics
14+
assert metrics["method"] == method
15+
assert metrics["n_valid_genes"] > 0
16+
assert cmap.query_imputed is not None
17+
assert cmap.query.var[f"metric_{method}"] is not None
18+
if groupby == "batch":
19+
assert cmap.query.varm[f"metric_{method}"] is not None
20+
21+
@pytest.mark.parametrize(
22+
"log,percentile",
23+
[
24+
(False, (0, 100)),
25+
(True, (0, 100)),
26+
(False, (5, 95)),
27+
(True, (1, 99)),
28+
],
29+
)
30+
def test_presence_score_overall(self, cmap, log, percentile):
31+
cmap.estimate_presence_score(log=log, percentile=percentile)
32+
assert "presence_score" in cmap.ref.obs
33+
scores = cmap.ref.obs["presence_score"]
34+
assert isinstance(scores, pd.Series | np.ndarray)
35+
assert np.all((scores >= 0) & (scores <= 1))
36+
assert not np.all(scores == 0) # Should not be all zeros
37+
38+
@pytest.mark.parametrize("groupby", ["batch", "modality"])
39+
def test_presence_score_groupby(self, cmap, groupby):
40+
cmap.estimate_presence_score(groupby=groupby)
41+
# Overall score should always be present in .obs
42+
assert "presence_score" in cmap.ref.obs
43+
# Per-group scores should be present in .obsm
44+
assert "presence_score" in cmap.ref.obsm
45+
df = cmap.ref.obsm["presence_score"]
46+
assert isinstance(df, pd.DataFrame)
47+
assert all(np.all((df[col] >= 0) & (df[col] <= 1)) for col in df.columns)
48+
# Columns should match group names
49+
groups = cmap.query.obs[groupby].unique()
50+
assert set(df.columns) == set(groups)

0 commit comments

Comments
 (0)