Skip to content

Commit 319a21d

Browse files
committed
Add basic tests for the main class
1 parent 6e1ed0d commit 319a21d

File tree

3 files changed

+2836
-0
lines changed

3 files changed

+2836
-0
lines changed

tests/conftest.py

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,11 @@
1+
from importlib.resources import files
2+
13
import numpy as np
4+
import pandas as pd
25
import pytest
6+
import scanpy as sc
7+
8+
from cellmapper.cellmapper import CellMapper
39

410

511
@pytest.fixture
@@ -20,3 +26,105 @@ def small_data():
2026
x = np.array([[0, 0], [1, 0], [0, 1], [1, 1], [2, 2]], dtype=np.float64)
2127
y = np.array([[0, 0], [1, 0], [0, 1], [1, 1], [2, 2]], dtype=np.float64)
2228
return x, y
29+
30+
31+
@pytest.fixture
32+
def precomputed_leiden():
33+
"""Fixture to load precomputed leiden clustering."""
34+
data_path = files("tests.data") / "precomputed_leiden.csv"
35+
leiden_cl = pd.read_csv(str(data_path), index_col=0)["leiden"]
36+
37+
return leiden_cl
38+
39+
40+
@pytest.fixture
41+
def adata_pbmc3k(precomputed_leiden):
42+
adata = sc.datasets.pbmc3k()
43+
44+
# basic cell and gene filtering
45+
sc.pp.filter_cells(adata, min_genes=200)
46+
sc.pp.filter_genes(adata, min_cells=3)
47+
48+
# Saving count data
49+
adata.layers["counts"] = adata.X.copy()
50+
51+
# Normalizing to median total counts
52+
sc.pp.normalize_total(adata)
53+
# Logarithmize the data
54+
sc.pp.log1p(adata)
55+
56+
# compute hvgs
57+
sc.pp.highly_variable_genes(adata, n_top_genes=2000)
58+
59+
# PCA
60+
sc.tl.pca(adata, mask_var="highly_variable")
61+
62+
# Load precomputed leiden clustering
63+
adata.obs["leiden"] = precomputed_leiden.astype("str").astype("category")
64+
65+
return adata
66+
67+
68+
@pytest.fixture
69+
def query_ref_adata(adata_pbmc3k):
70+
# Define the number of query cells and genes
71+
n_query_cells = 500
72+
n_query_genes = 300
73+
n_ref_cells = adata_pbmc3k.n_obs - n_query_cells
74+
75+
# Create modality annotations in the AnnData object
76+
adata_pbmc3k.obs["modality"] = (
77+
np.repeat("query", repeats=n_query_cells).tolist() + np.repeat("ref", repeats=n_ref_cells).tolist()
78+
)
79+
adata_pbmc3k.obs["modality"] = adata_pbmc3k.obs["modality"].astype("category")
80+
81+
# use these annotations to create query and reference AnnData objects
82+
query = adata_pbmc3k[adata_pbmc3k.obs["modality"] == "query"].copy()
83+
ref = adata_pbmc3k[adata_pbmc3k.obs["modality"] == "ref"].copy()
84+
85+
# Subset genes in the query AnnData object
86+
query_genes = (
87+
query.var[query.var["highly_variable"]].sort_values("means", ascending=False).head(n_query_genes).index.tolist()
88+
)
89+
query = query[:, query_genes].copy()
90+
91+
return query, ref
92+
93+
94+
@pytest.fixture
95+
def cmap(query_ref_adata):
96+
query, ref = query_ref_adata
97+
98+
# Create a CellMapper object
99+
cmap = CellMapper(
100+
query=query,
101+
ref=ref,
102+
)
103+
104+
# Compute neighbors and mapping matrix
105+
cmap.compute_neighbors(n_neighbors=30, use_rep="X_pca", method="sklearn")
106+
cmap.compute_mappping_matrix(method="gaussian")
107+
108+
return cmap
109+
110+
111+
@pytest.fixture
112+
def expected_label_transfer_metrics():
113+
return {
114+
"accuracy": 0.954,
115+
"precision": 0.955,
116+
"recall": 0.954,
117+
"f1_weighted": 0.951,
118+
"f1_macro": 0.899,
119+
"excluded_fraction": 0.0,
120+
}
121+
122+
123+
@pytest.fixture
124+
def expected_expression_transfer_metrics():
125+
return {
126+
"method": "pearson",
127+
"average_correlation": 0.376,
128+
"n_genes": 300,
129+
"n_valid_genes": 300,
130+
}

0 commit comments

Comments
 (0)