Skip to content

Commit 9aaf126

Browse files
amalia-k510flying-sheeppre-commit-ci[bot]
authored
feat: add modularity to scanpy.metrics (#3613)
Co-authored-by: Philipp A. <flying-sheep@web.de> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent f75575c commit 9aaf126

File tree

11 files changed

+348
-28
lines changed

11 files changed

+348
-28
lines changed

docs/api/metrics.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ Collections of useful measurements for evaluating results.
1515
:nosignatures:
1616
:toctree: ../generated/
1717
18+
metrics.modularity
1819
metrics.confusion_matrix
1920
metrics.gearys_c
2021
metrics.morans_i

docs/conf.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -260,7 +260,10 @@ def setup(app: Sphinx) -> None:
260260
"scanpy.plotting._dotplot.DotPlot": "scanpy.pl.DotPlot",
261261
"scanpy.plotting._stacked_violin.StackedViolin": "scanpy.pl.StackedViolin",
262262
"pandas.core.series.Series": "pandas.Series",
263+
# https://github.com/pandas-dev/pandas/issues/63810
264+
"pandas.api.typing.aliases.AnyArrayLike": ("doc", "pandas:reference/aliases"),
263265
"numpy.bool_": "numpy.bool", # Since numpy 2, numpy.bool is the canonical dtype
266+
"numpy.typing.ArrayLike": ("py:data", "numpy.typing.ArrayLike"),
264267
}
265268

266269
nitpick_ignore = [

docs/release-notes/3613.feat.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Add modularity scoring via {func}`scanpy.metrics.modularity` with support for directed/undirected graphs {smaller}`A. Karesh`

hatch.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ dependency-groups = [ "dev" ]
44

55
[envs.docs]
66
dependency-groups = [ "doc" ]
7+
extra-dependencies = [ "pandas>=3" ]
78
scripts.build = "sphinx-build -M html docs docs/_build -W {args}"
89
scripts.open = "python3 -m webbrowser -t docs/_build/html/index.html"
910
scripts.clean = "git clean -fdX -- {args:docs}"

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ test = [
113113
doc = [
114114
"sphinx>=8.2.3",
115115
"sphinx-book-theme>=1.1.0",
116-
"scanpydoc>=0.16",
116+
"scanpydoc>=0.16.1",
117117
"sphinx-autodoc-typehints>=1.25.2",
118118
"sphinx-issues>=5.0.1",
119119
"myst-parser>=2",

src/scanpy/_utils/__init__.py

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -883,20 +883,25 @@ class NeighborsView:
883883
This defines where to look for neighbors dictionary,
884884
connectivities, distances.
885885
886-
neigh = NeighborsView(adata, key)
887-
neigh['distances']
888-
neigh['connectivities']
889-
neigh['params']
890-
'connectivities' in neigh
891-
'params' in neigh
892-
893-
is the same as
894-
895-
adata.obsp[adata.uns[key]['distances_key']]
896-
adata.obsp[adata.uns[key]['connectivities_key']]
897-
adata.uns[key]['params']
898-
adata.uns[key]['connectivities_key'] in adata.obsp
899-
'params' in adata.uns[key]
886+
Examples
887+
--------
888+
>>> import scanpy as sc
889+
>>> adata = sc.datasets.pbmc68k_reduced()
890+
>>> key = "neighbors"
891+
892+
>>> neigh = NeighborsView(adata, key)
893+
>>> d = neigh["distances"]
894+
>>> c = neigh["connectivities"]
895+
>>> p = neigh["params"]
896+
897+
is the same as doing this manually
898+
899+
>>> d_key = adata.uns[key].get("distances_key", "distances")
900+
>>> c_key = adata.uns[key].get("connectivities_key", "connectivities")
901+
>>> assert d is adata.obsp[d_key]
902+
>>> assert c is adata.obsp[c_key]
903+
>>> assert p is adata.uns[key]["params"]
904+
>>> assert c_key in adata.obsp
900905
901906
"""
902907

src/scanpy/metrics/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from __future__ import annotations
44

55
from ._gearys_c import gearys_c
6-
from ._metrics import confusion_matrix
6+
from ._metrics import confusion_matrix, modularity
77
from ._morans_i import morans_i
88

9-
__all__ = ["confusion_matrix", "gearys_c", "morans_i"]
9+
__all__ = ["confusion_matrix", "gearys_c", "modularity", "morans_i"]

src/scanpy/metrics/_metrics.py

Lines changed: 130 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,28 @@
22

33
from __future__ import annotations
44

5-
from typing import TYPE_CHECKING
5+
from typing import TYPE_CHECKING, overload
66

77
import numpy as np
88
import pandas as pd
9+
from anndata import AnnData
910
from natsort import natsorted
1011
from pandas.api.types import CategoricalDtype
1112

13+
from .._utils import NeighborsView
14+
1215
if TYPE_CHECKING:
1316
from collections.abc import Sequence
17+
from typing import Literal
18+
19+
if TYPE_CHECKING:
20+
from pandas.api.typing.aliases import AnyArrayLike
21+
else: # sphinx-autodoc-typehints will execute the outer block, but end up here:
22+
AnyArrayLike = type(
23+
"AnyArrayLike", (), dict(__module__="pandas.api.typing.aliases")
24+
)
25+
26+
from .._compat import SpBase
1427

1528

1629
def confusion_matrix(
@@ -89,3 +102,119 @@ def confusion_matrix(
89102
df = df.loc[np.array(orig_idx), np.array(new_idx)]
90103

91104
return df
105+
106+
107+
@overload
108+
def modularity(
109+
connectivities: AnyArrayLike | SpBase, /, labels: AnyArrayLike, *, is_directed: bool
110+
) -> float: ...
111+
112+
113+
@overload
114+
def modularity(
115+
adata: AnnData,
116+
/,
117+
labels: str | AnyArrayLike = "leiden",
118+
*,
119+
neighbors_key: str | None = None,
120+
mode: Literal["calculate", "update", "retrieve"] = "calculate",
121+
) -> float: ...
122+
123+
124+
def modularity(
125+
adata_or_connectivities: AnnData | AnyArrayLike | SpBase,
126+
/,
127+
labels: str | AnyArrayLike = "leiden",
128+
*,
129+
neighbors_key: str | None = None,
130+
is_directed: bool | None = None,
131+
mode: Literal["calculate", "update", "retrieve"] = "calculate",
132+
) -> float:
133+
"""Compute the modularity of a graph given its connectivities and labels.
134+
135+
Parameters
136+
----------
137+
adata_or_connectivities
138+
The AnnData object containing the data or a weighted adjacency matrix representing the graph.
139+
labels
140+
Cluster labels for each node in the graph.
141+
When `AnnData` is provided, this can be the key in `adata.obs` that contains the clustering labels and defaults to `"leiden"`.
142+
neighbors_key
143+
When `AnnData` is provided, the key in `adata.obsp` that contains the connectivities.
144+
is_directed
145+
Whether the connectivities are directed or undirected.
146+
Always `False` if `AnnData` is provided, as connectivities are derived from (symmetric) neighbors.
147+
mode
148+
When `AnnData` is provided,
149+
this controls if the stored modularity is retrieved,
150+
or if we should calculate it (and optionally update it in `adata.uns[labels]`).
151+
152+
Returns
153+
-------
154+
The modularity of the graph based on the provided clustering.
155+
"""
156+
if isinstance(adata_or_connectivities, AnnData):
157+
if is_directed:
158+
msg = f"Connectivities stored in `AnnData` are undirected, can’t specify `{is_directed=!r}`"
159+
raise ValueError(msg)
160+
return modularity_adata(
161+
adata_or_connectivities,
162+
labels=labels,
163+
neighbors_key=neighbors_key,
164+
mode=mode,
165+
)
166+
if isinstance(labels, str):
167+
msg = "`labels` must be provided as array when passing a connectivities array"
168+
raise TypeError(msg)
169+
if is_directed is None:
170+
msg = "`is_directed` must be provided when passing a connectivities array"
171+
raise TypeError(msg)
172+
return modularity_array(
173+
adata_or_connectivities, labels=labels, is_directed=is_directed
174+
)
175+
176+
177+
def modularity_adata(
178+
adata: AnnData,
179+
/,
180+
*,
181+
labels: str | AnyArrayLike,
182+
neighbors_key: str | None,
183+
mode: Literal["calculate", "update", "retrieve"],
184+
) -> float:
185+
if mode in {"retrieve", "update"} and not isinstance(labels, str):
186+
msg = "`labels` must be a string when `mode` is `'retrieve'` or `'update'`"
187+
raise ValueError(msg)
188+
if mode == "retrieve":
189+
return adata.uns[labels]["modularity"]
190+
191+
labels_vec = adata.obs[labels] if isinstance(labels, str) else labels
192+
connectivities = NeighborsView(adata, neighbors_key)["connectivities"]
193+
194+
# distances are treated as symmetric, so connectivities as well
195+
m = modularity(connectivities, labels_vec, is_directed=False)
196+
if mode == "update":
197+
adata.uns[labels]["modularity"] = m
198+
return m
199+
200+
201+
def modularity_array(
202+
connectivities: AnyArrayLike | SpBase, /, *, labels: AnyArrayLike, is_directed: bool
203+
) -> float:
204+
try:
205+
import igraph as ig
206+
except ImportError as e: # pragma: no cover
207+
msg = "igraph is require for computing modularity"
208+
raise ImportError(msg) from e
209+
igraph_mode: str = ig.ADJ_DIRECTED if is_directed else ig.ADJ_UNDIRECTED
210+
graph: ig.Graph = ig.Graph.Weighted_Adjacency(connectivities, mode=igraph_mode)
211+
return graph.modularity(_codes(labels))
212+
213+
214+
def _codes(labels: AnyArrayLike) -> AnyArrayLike:
215+
"""Convert cluster labels to integer codes as required by igraph."""
216+
if isinstance(labels, pd.Series):
217+
labels = labels.astype("category").array
218+
if not isinstance(labels, pd.Categorical):
219+
labels = pd.Categorical(labels)
220+
return labels.codes

src/scanpy/tools/_leiden.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from __future__ import annotations
22

3-
from typing import TYPE_CHECKING
3+
from typing import TYPE_CHECKING, cast
44

55
import numpy as np
66
import pandas as pd
@@ -47,7 +47,7 @@ def leiden( # noqa: PLR0912, PLR0913, PLR0915
4747
flavor: Literal["leidenalg", "igraph"] | None = None,
4848
**clustering_args,
4949
) -> AnnData | None:
50-
"""Cluster cells into subgroups :cite:p:`Traag2019`.
50+
r"""Cluster cells into subgroups :cite:p:`Traag2019`.
5151
5252
Cluster cells using the Leiden algorithm :cite:p:`Traag2019`,
5353
an improved version of the Louvain algorithm :cite:p:`Blondel2008`.
@@ -120,6 +120,12 @@ def leiden( # noqa: PLR0912, PLR0913, PLR0915
120120
A dict with the values for the parameters `resolution`, `random_state`,
121121
and `n_iterations`.
122122
123+
`adata.uns['leiden' | key_added]['modularity']` : :class:`float`
124+
The modularity score of the final clustering,
125+
as calculated by the `flavor`.
126+
Use :func:`scanpy.metrics.modularity`\ `(adata, mode='calculate' | 'update')`
127+
to calculate a score independent of `flavor`.
128+
123129
"""
124130
if flavor is None:
125131
flavor = "leidenalg"
@@ -178,7 +184,10 @@ def leiden( # noqa: PLR0912, PLR0913, PLR0915
178184
if use_weights:
179185
clustering_args["weights"] = np.array(g.es["weight"]).astype(np.float64)
180186
clustering_args["seed"] = random_state
181-
part = leidenalg.find_partition(g, partition_type, **clustering_args)
187+
part = cast(
188+
"MutableVertexPartition",
189+
leidenalg.find_partition(g, partition_type, **clustering_args),
190+
)
182191
else:
183192
g = _utils.get_igraph_from_adjacency(adjacency, directed=False)
184193
if use_weights:
@@ -212,6 +221,7 @@ def leiden( # noqa: PLR0912, PLR0913, PLR0915
212221
random_state=random_state,
213222
n_iterations=n_iterations,
214223
)
224+
adata.uns[key_added]["modularity"] = part.modularity
215225
logg.info(
216226
" finished",
217227
time=start,

tests/test_clustering.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
from functools import partial
4+
from typing import TYPE_CHECKING
45

56
import pandas as pd
67
import pytest
@@ -10,21 +11,27 @@
1011
from testing.scanpy._helpers.data import pbmc68k_reduced
1112
from testing.scanpy._pytest.marks import needs
1213

14+
if TYPE_CHECKING:
15+
from typing import Literal
16+
1317

1418
@pytest.fixture
1519
def adata_neighbors():
1620
return pbmc68k_reduced()
1721

1822

19-
FLAVORS = [
20-
pytest.param("igraph", marks=needs.igraph),
21-
pytest.param("leidenalg", marks=needs.leidenalg),
22-
]
23+
@pytest.fixture(
24+
params=[
25+
pytest.param("igraph", marks=needs.igraph),
26+
pytest.param("leidenalg", marks=needs.leidenalg),
27+
]
28+
)
29+
def flavor(request: pytest.FixtureRequest) -> Literal["igraph", "leidenalg"]:
30+
return request.param
2331

2432

2533
@needs.leidenalg
2634
@needs.igraph
27-
@pytest.mark.parametrize("flavor", FLAVORS)
2835
@pytest.mark.parametrize("resolution", [1, 2])
2936
@pytest.mark.parametrize("n_iterations", [-1, 3])
3037
def test_leiden_basic(adata_neighbors, flavor, resolution, n_iterations):
@@ -44,7 +51,6 @@ def test_leiden_basic(adata_neighbors, flavor, resolution, n_iterations):
4451

4552
@needs.leidenalg
4653
@needs.igraph
47-
@pytest.mark.parametrize("flavor", FLAVORS)
4854
def test_leiden_random_state(adata_neighbors, flavor):
4955
is_leiden_alg = flavor == "leidenalg"
5056
n_iterations = 2 if is_leiden_alg else -1
@@ -72,8 +78,18 @@ def test_leiden_random_state(adata_neighbors, flavor):
7278
directed=is_leiden_alg,
7379
n_iterations=n_iterations,
7480
)
81+
# reproducible
7582
pd.testing.assert_series_equal(adata_1.obs["leiden"], adata_1_again.obs["leiden"])
83+
assert (
84+
pytest.approx(adata_1.uns["leiden"]["modularity"])
85+
== adata_1_again.uns["leiden"]["modularity"]
86+
)
87+
# different clustering
7688
assert not adata_2.obs["leiden"].equals(adata_1_again.obs["leiden"])
89+
assert (
90+
pytest.approx(adata_2.uns["leiden"]["modularity"])
91+
!= adata_1_again.uns["leiden"]["modularity"]
92+
)
7793

7894

7995
@needs.igraph

0 commit comments

Comments
 (0)