diff --git a/src/squidpy/_docs.py b/src/squidpy/_docs.py index 7898574fc..a715a2ed6 100644 --- a/src/squidpy/_docs.py +++ b/src/squidpy/_docs.py @@ -48,6 +48,12 @@ def decorator2(obj: Any) -> Any: _n_perms = """\ n_perms Number of permutations for the permutation test.""" +_normalization = """\ +normalization + Normalization of neighbor counts either `none`, `total` (divide by total number of index cell types) or `conditional` (divide by number of index cell types with at least one neighbor of neighbor cell type).""" +_min_cell_count = """\ +min_cell_count + Minimum number of cells that have to be in a cluster to be included in analysis. If count > min_cell_count, peir will be set to NA.""" _img_layer = """\ layer Image layer in ``img`` that should be processed. If `None` and only 1 layer is present, it will be selected.""" @@ -367,6 +373,8 @@ def decorator2(obj: Any) -> Any: numba_parallel=_numba_parallel, seed=_seed, n_perms=_n_perms, + normalization=_normalization, + min_cell_count=_min_cell_count, img_layer=_img_layer, feature_name=_feature_name, yx=_yx, diff --git a/src/squidpy/gr/_nhood.py b/src/squidpy/gr/_nhood.py index 5ff39e6d2..96db48853 100644 --- a/src/squidpy/gr/_nhood.py +++ b/src/squidpy/gr/_nhood.py @@ -2,6 +2,7 @@ from __future__ import annotations +import warnings from collections.abc import Callable, Iterable, Sequence from functools import partial from typing import Any, NamedTuple @@ -41,10 +42,13 @@ class NhoodEnrichmentResult(NamedTuple): Z-score values of enrichment statistic. count : NDArray[np.number] Enrichment count. + conditional_ratio : NDArray[np.number] | None + Conditional ratio (only present if normalization='conditional'). """ zscore: NDArray[np.number] counts: NDArray[np.number] # NamedTuple inherits from tuple so cannot use 'count' as attribute name + conditional_ratio: NDArray[np.number] | None = None # data type aliases (both for numpy and numba should match) @@ -135,6 +139,43 @@ def _create_function(n_cls: int, parallel: bool = False) -> Callable[[NDArrayA, return globals()[fn_key] # type: ignore[no-any-return] +def filter_clusters_by_min_cell_count( + adata: AnnData, + int_clust: NDArrayA, + connectivity_key: str, + min_cell_count: int, +) -> tuple[NDArrayA, NDArrayA]: + """ + Filter clusters by minimum cell count. + + Parameters + ---------- + %(adata)s + int_clust + Array of cluster labels per cell + connectivity_key + Key in adata.obsp with adjacency matrix + min_cell_count + Minimum number of cells required to keep a cluster + + Returns + ------- + int_clust_filtered + Filtered cluster labels + adj + Adjacency matrix corresponding to filtered cells + """ + clust_sizes = pd.Series(int_clust).value_counts() + valid_clusters = clust_sizes[clust_sizes >= min_cell_count].index.to_numpy() + + valid_mask = np.isin(int_clust, valid_clusters) + valid_cells_idx = np.where(valid_mask)[0] + int_clust = int_clust[valid_mask] + + adj = adata.obsp[connectivity_key][np.ix_(valid_cells_idx, valid_cells_idx)] + return int_clust, adj + + @d.get_sections(base="nhood_ench", sections=["Parameters"]) @d.dedent def nhood_enrichment( @@ -148,6 +189,9 @@ def nhood_enrichment( copy: bool = False, n_jobs: int | None = None, backend: str = "loky", + normalization: str = "none", + min_cell_count: int = 0, + handle_nan: str = "keep", show_progress_bar: bool = True, ) -> NhoodEnrichmentResult | None: """ @@ -164,15 +208,26 @@ def nhood_enrichment( %(seed)s %(copy)s %(parallelize)s + normalization + Normalization mode to use: + - ``'none'``: No normalization of neighbor counts + - ``'total'``: Normalize neighbor counts by total number of cells per cluster (SEA) + - ``'conditional'``: Normalize neighbor counts by number of cells with at least one neighbor of given type (COZI) + handle_nan + How to handle NaN values in z-scores: + - ``'zero'``: Replace NaN values with 0 + - ``'keep'``: Keep NaN values (undefined enrichment) Returns ------- If ``copy = True``, returns a :class:`~squidpy.gr.NhoodEnrichmentResult` with the z-score and the enrichment count. + If normalization = "conditional", also contains the conditional ratio, otherwise it is None. Otherwise, modifies the ``adata`` with the following keys: - :attr:`anndata.AnnData.uns` ``['{cluster_key}_nhood_enrichment']['zscore']`` - the enrichment z-score. - :attr:`anndata.AnnData.uns` ``['{cluster_key}_nhood_enrichment']['count']`` - the enrichment count. + - :attr:`anndata.AnnData.uns` ``['{cluster_key}_nhood_enrichment']['conditional_ratio']`` - the ratio of cells of type A that neighbor type B. """ if isinstance(adata, SpatialData): adata = adata.table @@ -183,9 +238,16 @@ def nhood_enrichment( adj = adata.obsp[connectivity_key] original_clust = adata.obs[cluster_key] - clust_map = {v: i for i, v in enumerate(original_clust.cat.categories.values)} # map categories + clust_map = {v: i for i, v in enumerate(original_clust.cat.categories.values)} int_clust = np.array([clust_map[c] for c in original_clust], dtype=ndt) + n_total_cells = len(int_clust) + int_clust, adj = filter_clusters_by_min_cell_count( + adata=adata, + int_clust=int_clust, + connectivity_key=connectivity_key, + min_cell_count=min_cell_count, + ) if library_key is not None: _assert_categorical_obs(adata, key=library_key) libraries: pd.Series | None = adata.obs[library_key] @@ -197,6 +259,55 @@ def nhood_enrichment( _test = _create_function(n_cls, parallel=numba_parallel) count = _test(indices, indptr, int_clust) + conditional_ratio = np.full((n_cls, n_cls), np.nan, dtype=np.float64) + + if normalization == "total": + row_sums = count.sum(axis=1, keepdims=True) + row_sums[row_sums == 0] = 1 + count_normalized = count / row_sums + elif normalization == "conditional": + res = np.zeros((len(int_clust), n_cls), dtype=np.uint32) + for i in range(len(int_clust)): + xs, xe = indptr[i], indptr[i + 1] + neighbors = indices[xs:xe] + for n in neighbors: + res[i, int_clust[n]] += 1 + + per_cell_neighbor_matrix = res > 0 + + cond_counts = np.zeros((n_cls, n_cls), dtype=np.float64) + conditional_ratio = np.full((n_cls, n_cls), np.nan, dtype=np.float64) + + for a in range(n_cls): + a_cells = int_clust == a + if not np.any(a_cells): + continue + n_type_a_cells = a_cells.sum() + for b in range(n_cls): + has_b_neighbor = per_cell_neighbor_matrix[a_cells, b] + cond_counts[a, b] = has_b_neighbor.sum() + conditional_ratio[a, :] = cond_counts[a, :] / n_type_a_cells + + safe_cond_counts = cond_counts.copy() + safe_cond_counts[safe_cond_counts == 0] = 1.0 + + count_normalized = count / safe_cond_counts + + n_retained_cells = len(int_clust) + n_filtered = n_total_cells - n_retained_cells + frac_filtered = n_filtered / n_total_cells * 100 + + if n_filtered > 0: + warnings.warn( + f"{frac_filtered:.3f}% of cells were excluded because their clusters had fewer than {min_cell_count} cells.", + UserWarning, + stacklevel=2, + ) + + elif normalization == "none": + count_normalized = count.copy() + else: + raise ValueError(f"Invalid normalization mode `{normalization}`. Choose from 'none', 'total', 'conditional'.") n_jobs = _get_n_cores(n_jobs) start = logg.info(f"Calculating neighborhood enrichment using `{n_jobs}` core(s)") @@ -216,17 +327,36 @@ def nhood_enrichment( libraries=libraries, n_cls=n_cls, seed=seed, + normalization=normalization, ) - zscore = (count - perms.mean(axis=0)) / perms.std(axis=0) + + std = perms.std(axis=0) + std[std == 0] = np.nan + zscore = (count_normalized - perms.mean(axis=0)) / std + + if handle_nan == "zero": + zscore = np.nan_to_num(zscore, nan=0.0) + elif handle_nan == "keep": + pass + else: + raise ValueError("handle_nan must be 'keep' or 'zero'") + + result_kwargs = {"zscore": zscore, "count": count} + if normalization == "conditional": + result_kwargs["conditional_ratio"] = conditional_ratio if copy: - return NhoodEnrichmentResult(zscore=zscore, counts=count) + return NhoodEnrichmentResult( + zscore=result_kwargs["zscore"], + counts=result_kwargs["count"], + conditional_ratio=result_kwargs.get("conditional_ratio"), + ) _save_data( adata, attr="uns", key=Key.uns.nhood_enrichment(cluster_key), - data={"zscore": zscore, "count": count}, + data=result_kwargs, time=start, ) @@ -442,9 +572,10 @@ def _nhood_enrichment_helper( n_cls: int, seed: int | None = None, queue: SigQueue | None = None, + normalization: str = "none", ) -> NDArrayA: perms = np.empty((len(ixs), n_cls, n_cls), dtype=np.float64) - int_clust = int_clust.copy() # threading + int_clust = int_clust.copy() rs = np.random.RandomState(seed=None if seed is None else seed + ixs[0]) for i in range(len(ixs)): @@ -452,7 +583,41 @@ def _nhood_enrichment_helper( int_clust = _shuffle_group(int_clust, libraries, rs) else: rs.shuffle(int_clust) - perms[i, ...] = callback(indices, indptr, int_clust) + + count_perms = callback(indices, indptr, int_clust) + + if normalization == "total": + row_sums = count_perms.sum(axis=1, keepdims=True) + row_sums[row_sums == 0] = 1 + count_perms = count_perms / row_sums + elif normalization == "conditional": + res = np.zeros((len(int_clust), n_cls), dtype=np.uint32) + for i_cell in range(len(int_clust)): + xs, xe = indptr[i_cell], indptr[i_cell + 1] + neighbors = indices[xs:xe] + for n in neighbors: + res[i_cell, int_clust[n]] += 1 + + per_cell_neighbor_matrix = res > 0 + + cond_counts = np.zeros((n_cls, n_cls), dtype=np.float64) + cluster_sizes = np.zeros(n_cls, dtype=np.float64) + for a in range(n_cls): + a_cells = int_clust == a + cluster_sizes[a] = a_cells.sum() + if not np.any(a_cells): + continue + for b in range(n_cls): + has_b_neighbor = per_cell_neighbor_matrix[a_cells, b] + cond_counts[a, b] = has_b_neighbor.sum() + + cond_counts[cond_counts == 0] = 1.0 + count_perms = count_perms / cond_counts + + elif normalization == "none": + pass + + perms[i, ...] = count_perms if queue is not None: queue.put(Signal.UPDATE) diff --git a/src/squidpy/im/_io.py b/src/squidpy/im/_io.py index c3ce5bce1..64d42bc5d 100644 --- a/src/squidpy/im/_io.py +++ b/src/squidpy/im/_io.py @@ -2,6 +2,7 @@ from collections.abc import Mapping from pathlib import Path +from typing import Any import dask.array as da import numpy as np diff --git a/src/squidpy/pl/__init__.py b/src/squidpy/pl/__init__.py index 0bcc00233..8879dbf00 100644 --- a/src/squidpy/pl/__init__.py +++ b/src/squidpy/pl/__init__.py @@ -7,6 +7,7 @@ co_occurrence, interaction_matrix, nhood_enrichment, + nhood_enrichment_dotplot, ripley, ) diff --git a/src/squidpy/pl/_graph.py b/src/squidpy/pl/_graph.py index 852906225..10ff335a0 100644 --- a/src/squidpy/pl/_graph.py +++ b/src/squidpy/pl/_graph.py @@ -2,6 +2,7 @@ from __future__ import annotations +import warnings from collections.abc import Mapping, Sequence from pathlib import Path from types import MappingProxyType @@ -13,6 +14,7 @@ import seaborn as sns from anndata import AnnData from matplotlib.axes import Axes +from matplotlib.lines import Line2D from squidpy._constants._constants import RipleyStat from squidpy._constants._pkg_constants import Key @@ -240,6 +242,160 @@ def nhood_enrichment( save_fig(fig, path=save) +@d.dedent +def nhood_enrichment_dotplot( + adata: AnnData, + cluster_key: str, + zscore_key: str = "ct_nhood_enrichment", + annotate: bool = False, + title: str | None = None, + cmap: str = "coolwarm", + palette: Palette_t = None, + cbar_kwargs: Mapping[str, Any] = MappingProxyType({}), + figsize: tuple[float, float] | None = None, + dpi: int | None = None, + size_range: tuple[float, float] = (10, 200), + save: str | Path | None = None, + ax: Axes | None = None, + **kwargs: Any, +) -> None: + """ + Dot plot of neighborhood enrichment. + + This plots the result of :func:`squidpy.gr.nhood_enrichment`, using: + - Color for z-score of enrichment + - Dot size for conditional cell ratio (CCR), scaled continuously + + Parameters + ---------- + adata : AnnData + Annotated data matrix. + cluster_key : str + Key in `adata.obs` where the cluster (cell type) annotation is stored. + zscore_key : str, optional + Key in `adata.uns` where the enrichment results are stored. + annotate : bool, optional + Whether to annotate dots with CCR values. + title : str, optional + Title of the plot. + cmap : str, optional + Colormap used for the z-score values. + palette : Palette_t, optional + Not used, reserved for compatibility. + cbar_kwargs : dict, optional + Keyword arguments for `fig.colorbar`. + figsize : tuple, optional + Figure size. + dpi : int, optional + Dots per inch for the figure. + size_range : tuple of float, optional + Min and max dot sizes for conditional cell ratio scaling. + save : str | Path, optional + Path to save the figure. + ax : matplotlib.axes.Axes, optional + Axes object to draw the plot onto, otherwise a new figure is created. + **kwargs : Any + Additional keyword arguments passed to `plt.scatter`. + + Returns + ------- + None + """ + _assert_categorical_obs(adata, key=cluster_key) + enrichment = _get_data(adata, cluster_key=cluster_key, func_name="nhood_enrichment") + + zscore = enrichment["zscore"] + ccr = enrichment.get("conditional_ratio") + + if ccr is None: + warnings.warn( + "'conditional_ratio' is None in nhood_enrichment results. Please run nhood_erichment with normalization = 'conditional'." + "Dot size will not reflect conditional cell ratios.", + UserWarning, + stacklevel=2, + ) + ccr = np.ones_like(zscore) + + cats = adata.obs[cluster_key].cat.categories + + df = pd.DataFrame( + { + "x": np.tile(np.arange(len(cats)), len(cats)), + "y": np.repeat(np.arange(len(cats)), len(cats)), + "zscore": zscore.flatten(), + "ccr": ccr.flatten(), + } + ) + + size_min, size_max = size_range + ccr_norm = (df["ccr"] - df["ccr"].min()) / (df["ccr"].max() - df["ccr"].min() + 1e-10) + df["size"] = size_min + ccr_norm * (size_max - size_min) + + fig, ax = plt.subplots(figsize=figsize, dpi=dpi) if ax is None else (ax.figure, ax) + cmap = "YlGnBu" + sc = ax.scatter( + df["x"], + df["y"], + c=df["zscore"], + s=df["size"], + cmap=cmap, + edgecolors="black", + linewidths=0.3, + **kwargs, + ) + + ax.set_xticks(np.arange(len(cats))) + ax.set_yticks(np.arange(len(cats))) + ax.set_xticklabels(cats, rotation=90) + ax.set_yticklabels(cats) + ax.set_xlabel("Neighbor cell type") + ax.set_ylabel("Index cell type") + + ax.set_title(title or "Neighborhood enrichment (dot plot)") + + # Colorbar + cbar = fig.colorbar(sc, ax=ax, **cbar_kwargs) + cbar.set_label("Z-score") + + legend_ccr_vals = np.linspace(df["ccr"].min(), df["ccr"].max(), 5) + legend_sizes = size_min + (legend_ccr_vals - df["ccr"].min()) / (df["ccr"].max() - df["ccr"].min() + 1e-10) * ( + size_max - size_min + ) + + legend_elements = [ + Line2D( + [0], + [0], + marker="o", + color="w", + label=f"{v:.2f}", + markerfacecolor="gray", + markersize=np.sqrt(s), # scatter size is area → sqrt for legend + markeredgecolor="black", + ) + for v, s in zip(legend_ccr_vals, legend_sizes, strict=True) + ] + + ax.legend( + handles=legend_elements, + title="CCR", + loc="center left", + bbox_to_anchor=(1.3, 0.5), + borderaxespad=0.0, + frameon=False, + ) + + if annotate: + for _, row in df.iterrows(): + ax.text(row["x"], row["y"], f"{row['ccr']:.2f}", ha="center", va="center") + + ax.invert_yaxis() + ax.set_aspect("equal") + + if save is not None: + save_fig(fig, path=save) + + @d.dedent def ripley( adata: AnnData, diff --git a/tests/_images/Graph_nhood_enrichment_dotplot.png b/tests/_images/Graph_nhood_enrichment_dotplot.png new file mode 100644 index 000000000..325a25b25 Binary files /dev/null and b/tests/_images/Graph_nhood_enrichment_dotplot.png differ diff --git a/tests/graph/test_nhood.py b/tests/graph/test_nhood.py index 74764f236..aa3d460f3 100644 --- a/tests/graph/test_nhood.py +++ b/tests/graph/test_nhood.py @@ -143,3 +143,74 @@ def test_interaction_matrix_nan_values(adata_intmat: AnnData): np.testing.assert_array_equal(expected_weighted, result_weighted) np.testing.assert_array_equal(expected_unweighted, result_unweighted) + + +@pytest.mark.parametrize("normalization", ["none", "total", "conditional"]) +def test_nhood_enrichment_normalization_modes(adata: AnnData, normalization: str): + spatial_neighbors(adata) + result = nhood_enrichment(adata, cluster_key=_CK, normalization=normalization, n_jobs=1, n_perms=20, copy=True) + + z, count, ccr = result + + assert isinstance(z, np.ndarray) + assert isinstance(count, np.ndarray) + if normalization == "conditional": + assert isinstance(ccr, np.ndarray) + assert z.shape == ccr.shape + assert count.shape == ccr.shape + assert z.shape == count.shape + assert z.shape[0] == adata.obs[_CK].cat.categories.shape[0] + + +def test_conditional_normalization_zero_division(adata: AnnData): + adata = adata.copy() + min_cells = 10 + if _CK not in adata.obs: + raise ValueError(f"Cluster key '{_CK}' not in adata.obs") + if not pd.api.types.is_categorical_dtype(adata.obs[_CK]): + adata.obs[_CK] = adata.obs[_CK].astype("category") + adata.obs[_CK] = adata.obs[_CK].cat.add_categories("isolated") + adata.obs.loc[adata.obs.index[0], _CK] = "isolated" + spatial_neighbors(adata) + valid_clusters = [c for c, count in adata.obs[_CK].value_counts().items() if count >= min_cells] + valid_idx = [i for i, cat in enumerate(adata.obs[_CK].cat.categories) if cat in valid_clusters] + + result = nhood_enrichment(adata, cluster_key=_CK, normalization="conditional", copy=True) + assert result is not None + zscore, count_normalized, conditional_ratio = result + assert not np.any(np.isinf(zscore)) + assert not np.any(np.isinf(count_normalized)) + assert not np.any(np.isinf(conditional_ratio)) + assert not np.isnan(zscore[np.ix_(valid_idx, valid_idx)]).any() + assert not np.isnan(count_normalized[np.ix_(valid_idx, valid_idx)]).any() + assert not np.isnan(conditional_ratio[np.ix_(valid_idx, valid_idx)]).any() + + +@pytest.mark.parametrize( + "normalization, expected_dtype", + [ + ("none", np.uint32), + ("total", np.uint32), + ("conditional", np.uint32), + ], +) +def test_output_dtype(adata: AnnData, normalization: str, expected_dtype): + spatial_neighbors(adata) + result = nhood_enrichment( + adata, + cluster_key=_CK, + normalization=normalization, + n_jobs=1, + n_perms=20, + copy=True, + ) + + count = result.counts + + assert count.dtype == expected_dtype + + +def test_invalid_normalization_raises(adata: AnnData): + spatial_neighbors(adata) + with pytest.raises(ValueError, match="Invalid normalization mode"): + nhood_enrichment(adata, cluster_key=_CK, normalization="invalid_mode", copy=True) diff --git a/tests/plotting/test_graph.py b/tests/plotting/test_graph.py index 6e1c20f7d..27e7f4380 100644 --- a/tests/plotting/test_graph.py +++ b/tests/plotting/test_graph.py @@ -66,6 +66,12 @@ def test_plot_nhood_enrichment_ax(self, adata: AnnData): fig, ax = plt.subplots(figsize=(2, 2), constrained_layout=True) pl.nhood_enrichment(adata, cluster_key=C_KEY, ax=ax) + def test_plot_nhood_enrichment_dotplot(self, adata: AnnData): + gr.spatial_neighbors(adata) + gr.nhood_enrichment(adata, cluster_key=C_KEY, normalization="conditional") + + pl.nhood_enrichment_dotplot(adata, cluster_key=C_KEY) + def test_plot_nhood_enrichment_dendro(self, adata: AnnData): gr.spatial_neighbors(adata) gr.nhood_enrichment(adata, cluster_key=C_KEY)