Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
51a6e21
add conditional and total normalization
chiarasch Apr 25, 2025
a9bdafd
remove prints
chiarasch Apr 25, 2025
52cc04a
hande nan
chiarasch Apr 29, 2025
cafd835
add normalization tests
chiarasch Apr 29, 2025
eeef861
add docstring
chiarasch Apr 29, 2025
7b89081
filter out low cond counts
chiarasch May 6, 2025
26e90bf
add warnings low cond counts
chiarasch May 6, 2025
bb0eeb8
add min_cell_count filter
chiarasch Aug 7, 2025
e7a33ab
add CCR dotplot
chiarasch Aug 7, 2025
02f361d
Merge branch 'main' into cozi_nhood_enrichment
chiarasch Aug 7, 2025
fc8c1b8
add cozi tests
chiarasch Aug 11, 2025
c9956fc
clean script
chiarasch Aug 11, 2025
28a77dc
add cozi parameters in docs
chiarasch Aug 11, 2025
74e77fc
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 12, 2025
23cf1a2
Merge branch 'main' into cozi_nhood_enrichment
chiarasch Aug 18, 2025
02fb792
fix precommit hooks
chiarasch Aug 19, 2025
3cd8c44
set min_cell_count default to 0
chiarasch Aug 19, 2025
785a3ee
add dotplot test image and function
chiarasch Aug 19, 2025
f14f31d
remove matplotlib loading from function
chiarasch Aug 19, 2025
005a57b
fix typo
chiarasch Aug 19, 2025
2d34d99
update code with changes in #1043
chiarasch Oct 6, 2025
b90dbb0
add optional cond_ratio to result class
chiarasch Oct 6, 2025
7ac43a4
adapt plotting to result class changes
chiarasch Oct 6, 2025
a82ddaa
fix optional cond ratio output
chiarasch Oct 6, 2025
7a71607
remove dict plotting changes again
chiarasch Oct 6, 2025
4972f96
adapt tests to new output format
chiarasch Oct 6, 2025
c2c40e7
typo
chiarasch Oct 6, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/squidpy/_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,12 @@ def decorator2(obj: Any) -> Any:
_n_perms = """\
n_perms
Number of permutations for the permutation test."""
_normalization = """\
normalization
Normalization of neighbor counts either `none`, `total` (divide by total number of index cell types) or `conditional` (divide by number of index cell types with at least one neighbor of neighbor cell type)."""
_min_cell_count = """\
min_cell_count
Minimum number of cells that have to be in a cluster to be included in analysis. If count > min_cell_count, peir will be set to NA."""
_img_layer = """\
layer
Image layer in ``img`` that should be processed. If `None` and only 1 layer is present, it will be selected."""
Expand Down Expand Up @@ -367,6 +373,8 @@ def decorator2(obj: Any) -> Any:
numba_parallel=_numba_parallel,
seed=_seed,
n_perms=_n_perms,
normalization=_normalization,
min_cell_count=_min_cell_count,
img_layer=_img_layer,
feature_name=_feature_name,
yx=_yx,
Expand Down
177 changes: 171 additions & 6 deletions src/squidpy/gr/_nhood.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from __future__ import annotations

import warnings
from collections.abc import Callable, Iterable, Sequence
from functools import partial
from typing import Any, NamedTuple
Expand Down Expand Up @@ -41,10 +42,13 @@ class NhoodEnrichmentResult(NamedTuple):
Z-score values of enrichment statistic.
count : NDArray[np.number]
Enrichment count.
conditional_ratio : NDArray[np.number] | None
Conditional ratio (only present if normalization='conditional').
"""

zscore: NDArray[np.number]
counts: NDArray[np.number] # NamedTuple inherits from tuple so cannot use 'count' as attribute name
conditional_ratio: NDArray[np.number] | None = None


# data type aliases (both for numpy and numba should match)
Expand Down Expand Up @@ -135,6 +139,43 @@ def _create_function(n_cls: int, parallel: bool = False) -> Callable[[NDArrayA,
return globals()[fn_key] # type: ignore[no-any-return]


def filter_clusters_by_min_cell_count(
adata: AnnData,
int_clust: NDArrayA,
connectivity_key: str,
min_cell_count: int,
) -> tuple[NDArrayA, NDArrayA]:
"""
Filter clusters by minimum cell count.

Parameters
----------
%(adata)s
int_clust
Array of cluster labels per cell
connectivity_key
Key in adata.obsp with adjacency matrix
min_cell_count
Minimum number of cells required to keep a cluster

Returns
-------
int_clust_filtered
Filtered cluster labels
adj
Adjacency matrix corresponding to filtered cells
"""
clust_sizes = pd.Series(int_clust).value_counts()
valid_clusters = clust_sizes[clust_sizes >= min_cell_count].index.to_numpy()

valid_mask = np.isin(int_clust, valid_clusters)
valid_cells_idx = np.where(valid_mask)[0]
int_clust = int_clust[valid_mask]

adj = adata.obsp[connectivity_key][np.ix_(valid_cells_idx, valid_cells_idx)]
return int_clust, adj


@d.get_sections(base="nhood_ench", sections=["Parameters"])
@d.dedent
def nhood_enrichment(
Expand All @@ -148,6 +189,9 @@ def nhood_enrichment(
copy: bool = False,
n_jobs: int | None = None,
backend: str = "loky",
normalization: str = "none",
min_cell_count: int = 0,
handle_nan: str = "keep",
show_progress_bar: bool = True,
) -> NhoodEnrichmentResult | None:
"""
Expand All @@ -164,15 +208,26 @@ def nhood_enrichment(
%(seed)s
%(copy)s
%(parallelize)s
normalization
Normalization mode to use:
- ``'none'``: No normalization of neighbor counts
- ``'total'``: Normalize neighbor counts by total number of cells per cluster (SEA)
- ``'conditional'``: Normalize neighbor counts by number of cells with at least one neighbor of given type (COZI)
handle_nan
How to handle NaN values in z-scores:
- ``'zero'``: Replace NaN values with 0
- ``'keep'``: Keep NaN values (undefined enrichment)

Returns
-------
If ``copy = True``, returns a :class:`~squidpy.gr.NhoodEnrichmentResult` with the z-score and the enrichment count.
If normalization = "conditional", also contains the conditional ratio, otherwise it is None.

Otherwise, modifies the ``adata`` with the following keys:

- :attr:`anndata.AnnData.uns` ``['{cluster_key}_nhood_enrichment']['zscore']`` - the enrichment z-score.
- :attr:`anndata.AnnData.uns` ``['{cluster_key}_nhood_enrichment']['count']`` - the enrichment count.
- :attr:`anndata.AnnData.uns` ``['{cluster_key}_nhood_enrichment']['conditional_ratio']`` - the ratio of cells of type A that neighbor type B.
"""
if isinstance(adata, SpatialData):
adata = adata.table
Expand All @@ -183,9 +238,16 @@ def nhood_enrichment(

adj = adata.obsp[connectivity_key]
original_clust = adata.obs[cluster_key]
clust_map = {v: i for i, v in enumerate(original_clust.cat.categories.values)} # map categories
clust_map = {v: i for i, v in enumerate(original_clust.cat.categories.values)}
int_clust = np.array([clust_map[c] for c in original_clust], dtype=ndt)
n_total_cells = len(int_clust)

int_clust, adj = filter_clusters_by_min_cell_count(
adata=adata,
int_clust=int_clust,
connectivity_key=connectivity_key,
min_cell_count=min_cell_count,
)
if library_key is not None:
_assert_categorical_obs(adata, key=library_key)
libraries: pd.Series | None = adata.obs[library_key]
Expand All @@ -197,6 +259,55 @@ def nhood_enrichment(

_test = _create_function(n_cls, parallel=numba_parallel)
count = _test(indices, indptr, int_clust)
conditional_ratio = np.full((n_cls, n_cls), np.nan, dtype=np.float64)

if normalization == "total":
row_sums = count.sum(axis=1, keepdims=True)
row_sums[row_sums == 0] = 1
count_normalized = count / row_sums
elif normalization == "conditional":
res = np.zeros((len(int_clust), n_cls), dtype=np.uint32)
for i in range(len(int_clust)):
xs, xe = indptr[i], indptr[i + 1]
neighbors = indices[xs:xe]
for n in neighbors:
res[i, int_clust[n]] += 1

per_cell_neighbor_matrix = res > 0

cond_counts = np.zeros((n_cls, n_cls), dtype=np.float64)
conditional_ratio = np.full((n_cls, n_cls), np.nan, dtype=np.float64)

for a in range(n_cls):
a_cells = int_clust == a
if not np.any(a_cells):
continue
n_type_a_cells = a_cells.sum()
for b in range(n_cls):
has_b_neighbor = per_cell_neighbor_matrix[a_cells, b]
cond_counts[a, b] = has_b_neighbor.sum()
conditional_ratio[a, :] = cond_counts[a, :] / n_type_a_cells

safe_cond_counts = cond_counts.copy()
safe_cond_counts[safe_cond_counts == 0] = 1.0

count_normalized = count / safe_cond_counts

n_retained_cells = len(int_clust)
n_filtered = n_total_cells - n_retained_cells
frac_filtered = n_filtered / n_total_cells * 100

if n_filtered > 0:
warnings.warn(
f"{frac_filtered:.3f}% of cells were excluded because their clusters had fewer than {min_cell_count} cells.",
UserWarning,
stacklevel=2,
)

elif normalization == "none":
count_normalized = count.copy()
else:
raise ValueError(f"Invalid normalization mode `{normalization}`. Choose from 'none', 'total', 'conditional'.")

n_jobs = _get_n_cores(n_jobs)
start = logg.info(f"Calculating neighborhood enrichment using `{n_jobs}` core(s)")
Expand All @@ -216,17 +327,36 @@ def nhood_enrichment(
libraries=libraries,
n_cls=n_cls,
seed=seed,
normalization=normalization,
)
zscore = (count - perms.mean(axis=0)) / perms.std(axis=0)

std = perms.std(axis=0)
std[std == 0] = np.nan
zscore = (count_normalized - perms.mean(axis=0)) / std

if handle_nan == "zero":
zscore = np.nan_to_num(zscore, nan=0.0)
elif handle_nan == "keep":
pass
else:
raise ValueError("handle_nan must be 'keep' or 'zero'")

result_kwargs = {"zscore": zscore, "count": count}
if normalization == "conditional":
result_kwargs["conditional_ratio"] = conditional_ratio

if copy:
return NhoodEnrichmentResult(zscore=zscore, counts=count)
return NhoodEnrichmentResult(
zscore=result_kwargs["zscore"],
counts=result_kwargs["count"],
conditional_ratio=result_kwargs.get("conditional_ratio"),
)

_save_data(
adata,
attr="uns",
key=Key.uns.nhood_enrichment(cluster_key),
data={"zscore": zscore, "count": count},
data=result_kwargs,
time=start,
)

Expand Down Expand Up @@ -442,17 +572,52 @@ def _nhood_enrichment_helper(
n_cls: int,
seed: int | None = None,
queue: SigQueue | None = None,
normalization: str = "none",
) -> NDArrayA:
perms = np.empty((len(ixs), n_cls, n_cls), dtype=np.float64)
int_clust = int_clust.copy() # threading
int_clust = int_clust.copy()
rs = np.random.RandomState(seed=None if seed is None else seed + ixs[0])

for i in range(len(ixs)):
if libraries is not None:
int_clust = _shuffle_group(int_clust, libraries, rs)
else:
rs.shuffle(int_clust)
perms[i, ...] = callback(indices, indptr, int_clust)

count_perms = callback(indices, indptr, int_clust)

if normalization == "total":
row_sums = count_perms.sum(axis=1, keepdims=True)
row_sums[row_sums == 0] = 1
count_perms = count_perms / row_sums
elif normalization == "conditional":
res = np.zeros((len(int_clust), n_cls), dtype=np.uint32)
for i_cell in range(len(int_clust)):
xs, xe = indptr[i_cell], indptr[i_cell + 1]
neighbors = indices[xs:xe]
for n in neighbors:
res[i_cell, int_clust[n]] += 1

per_cell_neighbor_matrix = res > 0

cond_counts = np.zeros((n_cls, n_cls), dtype=np.float64)
cluster_sizes = np.zeros(n_cls, dtype=np.float64)
for a in range(n_cls):
a_cells = int_clust == a
cluster_sizes[a] = a_cells.sum()
if not np.any(a_cells):
continue
for b in range(n_cls):
has_b_neighbor = per_cell_neighbor_matrix[a_cells, b]
cond_counts[a, b] = has_b_neighbor.sum()

cond_counts[cond_counts == 0] = 1.0
count_perms = count_perms / cond_counts

elif normalization == "none":
pass

perms[i, ...] = count_perms

if queue is not None:
queue.put(Signal.UPDATE)
Expand Down
1 change: 1 addition & 0 deletions src/squidpy/im/_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from collections.abc import Mapping
from pathlib import Path
from typing import Any

import dask.array as da
import numpy as np
Expand Down
1 change: 1 addition & 0 deletions src/squidpy/pl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
co_occurrence,
interaction_matrix,
nhood_enrichment,
nhood_enrichment_dotplot,
ripley,
)

Expand Down
Loading