Skip to content

Commit d0e910b

Browse files
authored
Merge pull request #27 from quadbio/feat/dual_pca
Introduce embedding mixin class
2 parents 15d22f0 + 0daa482 commit d0e910b

20 files changed

+1105
-479
lines changed

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,5 +20,6 @@ __pycache__/
2020
/docs/_build/
2121
.ipynb_checkpoints/
2222

23-
# datasets
23+
# datasets and models
2424
*.h5ad
25+
*.pt

docs/notebooks/tutorials/spatial_mapping.ipynb

Lines changed: 49 additions & 47 deletions
Large diffs are not rendered by default.

docs/notebooks/tutorials/spatial_smoothing.ipynb

Lines changed: 176 additions & 126 deletions
Large diffs are not rendered by default.

docs/references.bib

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,3 +232,27 @@ @article{lopez2018deep
232232
publisher={Nature Publishing Group US New York},
233233
url={https://www.nature.com/articles/s41592-018-0229-2},
234234
}
235+
236+
@article{stuart2019comprehensive,
237+
title={Comprehensive integration of single-cell data},
238+
author={Stuart, Tim and Butler, Andrew and Hoffman, Paul and Hafemeister, Christoph and Papalexi, Efthymia and Mauck III, William M and Hao, Yuhan and Stoeckius, Marlon and Smibert, Peter and Satija, Rahul},
239+
journal={Cell},
240+
volume={177},
241+
number={7},
242+
pages={1888--1902},
243+
year={2019},
244+
publisher={Elsevier},
245+
url={https://www.sciencedirect.com/science/article/pii/S0092867419305598},
246+
}
247+
248+
@article{xia2023spatial,
249+
title={Spatial-linked alignment tool (SLAT) for aligning heterogenous slices},
250+
author={Xia, Chen-Rui and Cao, Zhi-Jie and Tu, Xin-Ming and Gao, Ge},
251+
journal={Nature Communications},
252+
volume={14},
253+
number={1},
254+
pages={7236},
255+
year={2023},
256+
publisher={Nature Publishing Group UK London},
257+
url={https://www.nature.com/articles/s41467-023-43105-5},
258+
}

pyproject.toml

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,12 @@ classifiers = [
2323
]
2424
dynamic = [ "version" ]
2525
dependencies = [
26-
"anndata",
26+
"anndata>=0.11",
2727
"numpy",
2828
"packaging",
2929
"pandas",
30-
"pynndescent",
3130
"rich",
32-
"scanpy",
31+
"scanpy>=1.11",
3332
"scikit-learn",
3433
"scipy",
3534
# for debug logging (referenced from the issue template)
@@ -58,17 +57,17 @@ optional-dependencies.doc = [
5857
optional-dependencies.test = [
5958
"coverage",
6059
"pytest",
61-
"squidpy",
60+
"squidpy>=1.6",
6261
]
6362
optional-dependencies.tutorials = [
64-
"cellmapper",
6563
"harmony-pytorch",
64+
"igraph",
6665
"netgraph",
6766
"python-louvain",
68-
"scvi-tools",
67+
"scvi-tools>=1.3",
6968
"seaborn",
70-
"sopa",
71-
"squidpy",
69+
"sopa>=2",
70+
"squidpy>=1.6",
7271
]
7372

7473
# https://docs.pypi.org/project_metadata/#project-urls

src/cellmapper/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
from importlib.metadata import version
22

3-
from .cellmapper import CellMapper
4-
from .knn import Neighbors
53
from .logging import logger
4+
from .model.cellmapper import CellMapper
5+
from .model.knn import Neighbors
66

77
__all__ = ["logger", "CellMapper", "Neighbors"]
88

src/cellmapper/check.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,12 +63,14 @@ def check(self) -> None:
6363
"https://docs.rapids.ai/install/.",
6464
faiss="To speed up k-NN search on GPU, you may install faiss following the guide from "
6565
"https://github.com/facebookresearch/faiss/blob/main/INSTALL.md",
66+
pynndescent="To use fast approximate k-NN search, install pynndescent: pip install pynndescent",
6667
)
6768

6869
CHECKERS = {
6970
"cuml": Checker("cuml", vmin=None, install_hint=INSTALL_HINTS.cuml),
7071
"cupy": Checker("cupy", vmin=None, install_hint=INSTALL_HINTS.cupy),
7172
"faiss": Checker("faiss", package_name="faiss", vmin="1.7.0", install_hint=INSTALL_HINTS.faiss),
73+
"pynndescent": Checker("pynndescent", vmin=None, install_hint=INSTALL_HINTS.pynndescent),
7274
}
7375

7476

src/cellmapper/constants.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
class PackageConstants:
2+
"""Constants used througout the package."""
3+
4+
n_comps: int = 50
Lines changed: 72 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -3,22 +3,21 @@
33
import gc
44
from typing import Any, Literal
55

6-
import anndata as ad
76
import numpy as np
87
import pandas as pd
98
import scanpy as sc
109
from anndata import AnnData
1110
from scipy.sparse import coo_matrix, csc_matrix, csr_matrix
1211
from sklearn.preprocessing import OneHotEncoder
1312

14-
from cellmapper.evaluate import CellMapperEvaluationMixin
1513
from cellmapper.logging import logger
16-
from cellmapper.utils import create_imputed_anndata
14+
from cellmapper.model.embedding import EmbeddingMixin
15+
from cellmapper.model.evaluate import EvaluationMixin
16+
from cellmapper.model.knn import Neighbors
17+
from cellmapper.utils import create_imputed_anndata, get_n_comps
1718

18-
from .knn import Neighbors
1919

20-
21-
class CellMapper(CellMapperEvaluationMixin):
20+
class CellMapper(EvaluationMixin, EmbeddingMixin):
2221
"""Mapping of labels, embeddings, and expression values between reference and query datasets."""
2322

2423
def __init__(self, query: AnnData, reference: AnnData | None = None) -> None:
@@ -137,79 +136,59 @@ def _validate_and_normalize_mapping_matrix(
137136

138137
return mapping_matrix
139138

140-
def compute_joint_pca(self, n_components: int = 50, key_added: str = "pca_joint", **kwargs) -> None:
141-
"""
142-
Compute a joint PCA on the normalized .X matrices of query and reference, using only overlapping genes.
143-
144-
Parameters
145-
----------
146-
n_components
147-
Number of principal components to compute.
148-
key_added
149-
Key under which to store the joint PCA embeddings in `.obsm` of both query and reference AnnData objects.
150-
**kwargs
151-
Additional keyword arguments to pass to scanpy's `pp.pca` function.
152-
153-
Notes
154-
-----
155-
This method performs an inner join on genes (variables) between the query and reference AnnData objects,
156-
concatenates the normalized expression matrices, and computes a joint PCA using Scanpy. The resulting
157-
PCA embeddings are stored in `.obsm[key_added]` for both objects. This is a fallback and not recommended
158-
for most use cases. Please provide a biologically meaningful representation if possible.
159-
"""
160-
logger.warning(
161-
"No representation provided (use_rep=None). "
162-
"Falling back to joint PCA on normalized .X of both datasets using only overlapping genes. "
163-
"This is NOT recommended for most use cases! Please provide a biologically meaningful representation."
164-
)
165-
# Concatenate with inner join on genes
166-
joint = ad.concat([self.reference, self.query], join="inner", label="batch", keys=["reference", "query"])
167-
168-
# Compute PCA using scanpy
169-
sc.pp.pca(joint, n_comps=n_components, **kwargs)
170-
171-
# Assign PCA embeddings back to each object using the batch key
172-
self.reference.obsm[key_added] = joint.obsm["X_pca"][joint.obs["batch"] == "reference"]
173-
self.query.obsm[key_added] = joint.obsm["X_pca"][joint.obs["batch"] == "query"]
174-
logger.info(
175-
"Joint PCA computed and stored as '%s' in both reference.obsm and query.obsm. "
176-
"Proceeding to use this as the representation for neighbor search.",
177-
key_added,
178-
)
179-
180139
def compute_neighbors(
181140
self,
182141
n_neighbors: int = 30,
183142
use_rep: str | None = None,
143+
n_comps: int | None = None,
184144
method: Literal["sklearn", "pynndescent", "rapids", "faiss"] = "sklearn",
185145
metric: str = "euclidean",
186146
only_yx: bool = False,
187-
joint_pca_key: str = "pca_joint",
188-
n_pca_components: int = 50,
189-
pca_kwargs: dict[str, Any] | None = None,
147+
fallback_representation: Literal["fast_cca", "joint_pca"] = "fast_cca",
148+
fallback_kwargs: dict[str, Any] | None = None,
190149
) -> None:
191150
"""
192151
Compute nearest neighbors between reference and query datasets.
193152
153+
The method computes k-nearest neighbor graphs to enable mapping between
154+
datasets. If no representation is provided (`use_rep=None`), a fallback
155+
representation will be computed automatically using either fast CCA
156+
,inspired by Seurat v3 :cite:`stuart2019comprehensive`), or joint PCA. In self-mapping mode,
157+
a simple PCA will be computed on the query dataset.
158+
194159
Parameters
195160
----------
196161
n_neighbors
197162
Number of nearest neighbors.
198163
use_rep
199-
Data representation based on which to find nearest neighbors. If None, a joint PCA will be computed.
164+
Data representation based on which to find nearest neighbors. If None,
165+
a fallback representation will be computed automatically.
166+
n_comps
167+
Number of components to use. If a pre-computed representation is provided via `use_rep`,
168+
we will use the number of components from that representation. Otherwiese, if `use_rep=None`,
169+
we will compute the given number of components using the fallback representation method.
200170
method
201-
Method to use for computing neighbors. "sklearn" and "pynndescent" run on CPU, "rapids" and "faiss" run on GPU. Note that all but "pynndescent" perform exact neighbor search. With GPU acceleration, "faiss" is usually fastest and more memory efficient than "rapids".
202-
All methods return exactly `n_neighbors` neighbors, including the reference cell itself (in self-mapping mode). For faiss and sklearn, distances to self are very small positive numbers, for rapids and sklearn, they are exactly 0.
171+
Method to use for computing neighbors. "sklearn" and "pynndescent" run on CPU,
172+
"rapids" and "faiss" run on GPU. Note that all but "pynndescent" perform exact
173+
neighbor search. With GPU acceleration, "faiss" is usually fastest and more
174+
memory efficient than "rapids". All methods return exactly `n_neighbors` neighbors,
175+
including the reference cell itself (in self-mapping mode). For faiss and sklearn,
176+
distances to self are very small positive numbers, for rapids and sklearn, they are exactly 0.
203177
metric
204178
Distance metric to use for nearest neighbors.
205179
only_yx
206-
If True, only compute the xy neighbors. This is faster, but not suitable for Jaccard or HNOCA methods.
207-
joint_pca_key
208-
Key under which to store the joint PCA embeddings if use_rep is None.
209-
n_pca_components
210-
Number of principal components to compute for joint PCA if use_rep is None.
211-
pca_kwargs
212-
Additional keyword arguments to pass to scanpy's `pp.pca` function if use_rep is None.
180+
If True, only compute the xy neighbors. This is faster, but not suitable for
181+
Jaccard or HNOCA methods.
182+
fallback_representation
183+
Method to use for computing a cross-dataset representation when `use_rep=None`. Options:
184+
185+
- "fast_cca": Fast canonical correlation analysis, inspired by Seurat v3 :cite:`stuart2019comprehensive` and
186+
SLAT :cite:`xia2023spatial`).
187+
- "joint_pca": Principal component analysis on concatenated datasets.
188+
fallback_kwargs
189+
Additional keyword arguments to pass to the fallback representation method.
190+
For "fast_cca": see :meth:`~cellmapper.EmbeddingMixin.compute_fast_cca`.
191+
For "joint_pca": see :meth:`~cellmapper.EmbeddingMixin.compute_joint_pca`.
213192
214193
Returns
215194
-------
@@ -223,22 +202,50 @@ def compute_neighbors(
223202
- ``n_neighbors``: Number of nearest neighbors.
224203
- ``only_yx``: Whether only yx neighbors were computed.
225204
"""
205+
# Handle backward compatibility parameters
206+
if fallback_kwargs is None:
207+
fallback_kwargs = {}
208+
226209
self.only_yx = only_yx
210+
227211
if use_rep is None:
228-
if pca_kwargs is None:
229-
pca_kwargs = {}
212+
logger.warning(
213+
"No representation provided (`use_rep=None`). Computing a joint representation automatically "
214+
"using '%s'. For optimal results, consider pre-computing a representation and passing it to `use_rep`.",
215+
fallback_representation,
216+
)
217+
230218
if self._is_self_mapping:
231-
sc.pp.pca(self.query, n_comps=n_pca_components, **pca_kwargs)
232-
use_rep = "X_pca"
219+
logger.info("Self-mapping detected. Computing PCA on query dataset for representation.")
220+
key_added = fallback_kwargs.pop("key_added", "X_pca")
221+
sc.tl.pca(self.query, n_comps=n_comps, key_added=key_added, **fallback_kwargs)
233222
else:
234-
self.compute_joint_pca(n_components=n_pca_components, key_added=joint_pca_key, **pca_kwargs)
235-
use_rep = joint_pca_key
223+
if fallback_representation == "fast_cca":
224+
key_added = fallback_kwargs.pop("key_added", "X_cca")
225+
self.compute_fast_cca(n_comps=n_comps, key_added=key_added, **fallback_kwargs)
226+
elif fallback_representation == "joint_pca":
227+
key_added = fallback_kwargs.pop("key_added", "X_pca")
228+
self.compute_joint_pca(n_comps=n_comps, key_added=key_added, **fallback_kwargs)
229+
else:
230+
raise ValueError(
231+
f"Unknown fallback_representation: {fallback_representation}. "
232+
"Supported options are 'fast_cca' and 'joint_pca'."
233+
)
234+
use_rep = key_added
235+
236+
# Extract the representation from the query and reference datasets
236237
if use_rep == "X":
237238
xrep = self.reference.X
238239
yrep = self.query.X
239240
else:
240241
xrep = self.reference.obsm[use_rep]
241242
yrep = self.query.obsm[use_rep]
243+
244+
# handle the number of components
245+
n_comps = get_n_comps(n_comps, n_vars=xrep.shape[1])
246+
xrep = xrep[:, :n_comps]
247+
yrep = yrep[:, :n_comps]
248+
242249
self.knn = Neighbors(np.ascontiguousarray(xrep), np.ascontiguousarray(yrep))
243250
self.knn.compute_neighbors(n_neighbors=n_neighbors, method=method, metric=metric, only_yx=only_yx)
244251

0 commit comments

Comments
 (0)