diff --git a/pyproject.toml b/pyproject.toml index f7de064f7..cfddf044e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -101,6 +101,7 @@ docs = [ "myst-nb>=0.17.1", "sphinx_copybutton>=0.5.0", ] +spatialleiden = ["spatialleiden>=0.3.0"] [project.urls] Homepage = "https://github.com/scverse/squidpy" diff --git a/src/squidpy/_constants/_constants.py b/src/squidpy/_constants/_constants.py index 403f072ba..f88e41cb6 100644 --- a/src/squidpy/_constants/_constants.py +++ b/src/squidpy/_constants/_constants.py @@ -130,5 +130,6 @@ class NicheDefinitions(ModeEnum): NEIGHBORHOOD = "neighborhood" UTAG = "utag" CELLCHARTER = "cellcharter" + SPATIALLEIDEN = "spatialleiden" SPOT = "spot" BANKSY = "banksy" diff --git a/src/squidpy/gr/_niche.py b/src/squidpy/gr/_niche.py index 13c7afef9..847853ba8 100644 --- a/src/squidpy/gr/_niche.py +++ b/src/squidpy/gr/_niche.py @@ -29,13 +29,13 @@ @inject_docs(fla=NicheDefinitions) def calculate_niche( data: AnnData | SpatialData, - flavor: Literal["neighborhood", "utag", "cellcharter"], + flavor: Literal["neighborhood", "utag", "cellcharter", "spatialleiden"], library_key: str | None = None, table_key: str | None = None, mask: pd.core.series.Series = None, groups: str | None = None, n_neighbors: int | None = None, - resolutions: float | list[float] | None = None, + resolutions: float | tuple[float, float] | list[float | tuple[float, float]] | None = None, min_niche_size: int | None = None, scale: bool = True, abs_nhood: bool = False, @@ -45,6 +45,10 @@ def calculate_niche( n_components: int | None = None, random_state: int = 42, spatial_connectivities_key: str = "spatial_connectivities", + latent_connectivities_key: str = "connectivities", + layer_ratio: float = 1.0, + n_iterations: int = -1, + use_weights: bool | tuple[bool, bool] = True, inplace: bool = True, ) -> AnnData: """ @@ -59,6 +63,7 @@ def calculate_niche( - `{fla.NEIGHBORHOOD.s!r}` - cluster the neighborhood profile. - `{fla.UTAG.s!r}` - use utag algorithm (matrix multiplication). - `{fla.CELLCHARTER.s!r}` - cluster adjacency matrix with Gaussian Mixture Model (GMM) using CellCharter's approach. + - `{fla.SPATIALLEIDEN.s!r}` - cluster spatially resolved omics data using Multiplex Leiden. %(library_key)s If provided, niches will be calculated separately for each unique value in this column. Each niche will be prefixed with the library identifier. @@ -75,7 +80,9 @@ def calculate_niche( Required if flavor == `{fla.NEIGHBORHOOD.s!r}` or flavor == `{fla.UTAG.s!r}`. resolutions List of resolutions to use for leiden clustering. + In the case of spatialleiden you can pass a tuple. Resolution for the latent space and spatial layer, respectively. A single float applies to both layers. Required if flavor == `{fla.NEIGHBORHOOD.s!r}` or flavor == `{fla.UTAG.s!r}`. + Optional if flavor == `{fla.SPATIALLEIDEN.s!r}`. min_niche_size Minimum required size of a niche. Niches with fewer cells will be labeled as 'not_a_niche'. Optional if flavor == `{fla.NEIGHBORHOOD.s!r}`. @@ -99,10 +106,23 @@ def calculate_niche( Number of components to use for GMM. Required if flavor == `{fla.CELLCHARTER.s!r}`. random_state - Random state to use for GMM. - Optional if flavor == `{fla.CELLCHARTER.s!r}`. + Random state to use for GMM or SpatialLeiden. + Optional if flavor == `{fla.CELLCHARTER.s!r}` or flavor == `{fla.SPATIALLEIDEN.s!r}`. spatial_connectivities_key Key in `adata.obsp` where spatial connectivities are stored. + Required if flavor == `{fla.SPATIALLEIDEN.s!r}`. + latent_connectivities_key + Key in `adata.obsp` where gene expression connectivities are stored. + Required if flavor == `{fla.SPATIALLEIDEN.s!r}`. + layer_ratio + The ratio of the weighting of the layers; latent space vs spatial. A higher ratio will increase relevance of the spatial neighbors and lead to more spatially homogeneous clusters. + Optional if flavor == `{fla.SPATIALLEIDEN.s!r}`. + n_iterations + Number of iterations to run the Leiden algorithm. If the number is negative it runs until convergence. + Optional if flavor == `{fla.SPATIALLEIDEN.s!r}`. + use_weights + Whether to use weights for the edges for latent space and spatial neighbors, respectively. A single bool applies to both layers. + Optional if flavor == `{fla.SPATIALLEIDEN.s!r}`. inplace If 'True', perform the operation in place. If 'False', return a new AnnData object with the niche labels. @@ -127,6 +147,11 @@ def calculate_niche( aggregation, n_components, random_state, + spatial_connectivities_key, + latent_connectivities_key, + layer_ratio, + n_iterations, + use_weights, inplace, ) @@ -149,6 +174,12 @@ def calculate_niche( "If you haven't computed a spatial neighborhood graph yet, use `sq.gr.spatial_neighbors`." ) + if flavor == "spatialleiden" and (latent_connectivities_key not in adata.obsp.keys()): + raise KeyError( + f"Key '{latent_connectivities_key}' not found in `adata.obsp`. " + "If you haven't computed a latent neighborhood graph yet, use `sc.pp.neighbors`." + ) + result_columns = _get_result_columns( flavor=flavor, resolutions=resolutions, @@ -197,6 +228,10 @@ def calculate_niche( n_components=n_components, random_state=random_state, spatial_connectivities_key=spatial_connectivities_key, + latent_connectivities_key=latent_connectivities_key, + layer_ratio=layer_ratio, + n_iterations=n_iterations, + use_weights=use_weights, inplace=False, ) @@ -225,6 +260,10 @@ def calculate_niche( n_components, random_state, spatial_connectivities_key, + latent_connectivities_key, + layer_ratio, + n_iterations, + use_weights, ) if not inplace: @@ -250,7 +289,7 @@ def calculate_niche( def _get_result_columns( flavor: str, - resolutions: float | list[float], + resolutions: float | tuple[float, float] | list[float | tuple[float, float]], library_key: str | None, libraries: list[str] | None, ) -> list[str]: @@ -265,11 +304,17 @@ def _get_result_columns( elif libraries is not None and len(libraries) > 0: return [f"{base_column}_{lib}" for lib in libraries] - # For neighborhood and utag, we need to handle resolutions + # For neighborhood, utag and spatialleiden, we need to handle resolutions if not isinstance(resolutions, list): resolutions = [resolutions] - prefix = f"nhood_niche{library_str}" if flavor == "neighborhood" else f"utag_niche{library_str}" + if flavor == "neighborhood": + prefix = f"nhood_niche{library_str}" + elif flavor == "utag": + prefix = f"utag_niche{library_str}" + elif flavor == "spatialleiden": + prefix = f"spatialleiden{library_str}" + if library_key is None: return [f"{prefix}_res={res}" for res in resolutions] else: @@ -283,7 +328,7 @@ def _calculate_niches( flavor: str, groups: str | None, n_neighbors: int | None, - resolutions: float | list[float], + resolutions: float | tuple[float, float] | list[float | tuple[float, float]], min_niche_size: int | None, scale: bool, abs_nhood: bool, @@ -293,9 +338,14 @@ def _calculate_niches( n_components: int | None, random_state: int, spatial_connectivities_key: str, + latent_connectivities_key: str, + layer_ratio: float, + n_iterations: int, + use_weights: bool | tuple[bool, bool], ) -> None: """Calculate niches using the specified flavor and parameters.""" if flavor == "neighborhood": + assert isinstance(resolutions, float | list) _get_nhood_profile_niches( adata, mask, @@ -310,6 +360,7 @@ def _calculate_niches( spatial_connectivities_key, ) elif flavor == "utag": + assert isinstance(resolutions, float | list) _get_utag_niches(adata, n_neighbors, resolutions, spatial_connectivities_key) elif flavor == "cellcharter": assert isinstance(aggregation, str) # for mypy @@ -322,6 +373,17 @@ def _calculate_niches( random_state, spatial_connectivities_key, ) + elif flavor == "spatialleiden": + _get_spatialleiden_domains( + adata, + spatial_connectivities_key, + latent_connectivities_key, + resolutions, + layer_ratio, + use_weights, + n_iterations, + random_state, + ) def _get_nhood_profile_niches( @@ -329,7 +391,7 @@ def _get_nhood_profile_niches( mask: pd.core.series.Series | None, groups: str | None, n_neighbors: int | None, - resolutions: float | list[float], + resolutions: float | tuple[float, float] | list[float | tuple[float, float]], min_niche_size: int | None, scale: bool, abs_nhood: bool, @@ -441,7 +503,7 @@ def _get_nhood_profile_niches( def _get_utag_niches( adata: AnnData, n_neighbors: int | None, - resolutions: float | list[float], + resolutions: float | tuple[float, float] | list[float | tuple[float, float]], spatial_connectivities_key: str, ) -> None: """ @@ -628,6 +690,49 @@ def _get_GMM_clusters(A: NDArray[np.float64], n_components: int, random_state: i return labels +def _get_spatialleiden_domains( + adata: AnnData, + spatial_connectivities_key: str, + latent_connectivities_key: str, + resolutions: float | tuple[float, float] | list[float | tuple[float, float]], + layer_ratio: float, + use_weights: bool | tuple[bool, bool], + n_iterations: int, + random_state: int, +) -> None: + """ + Perform SpatialLeiden clustering. + + This is a wrapper around :py:func:`spatialleiden.multiplex_leiden` that uses :py:class:`anndata.AnnData` as input and works with two layers; one latent space and one spatial layer. + + Adapted from https://github.com/HiDiHlabs/SpatialLeiden/. + """ + try: + import spatialleiden as sl + except ImportError as e: + msg = "Please install the spatialleiden algorithm: `conda install bioconda::spatialleiden` or `pip install spatialleiden`." + raise ImportError(msg) from e + + if not isinstance(resolutions, list): + resolutions = [resolutions] + + for res in resolutions: + sl.spatialleiden( + adata, + resolution=res, + use_weights=use_weights, + n_iterations=n_iterations, + layer_ratio=layer_ratio, + latent_neighbors_key=latent_connectivities_key, + spatial_neighbors_key=spatial_connectivities_key, + random_state=random_state, + directed=False, + key_added=f"spatialleiden_res={res}", + ) + + return + + def _fide_score(adata: AnnData, niche_key: str, average: bool) -> Any: """ F1-score of intra-domain edges (FIDE). A high score indicates a great domain continuity. @@ -667,12 +772,12 @@ def _jensen_shannon_divergence(adata: AnnData, niche_key: str, library_key: str) def _validate_niche_args( data: AnnData | SpatialData, - flavor: Literal["neighborhood", "utag", "cellcharter"], + flavor: Literal["neighborhood", "utag", "cellcharter", "spatialleiden"], library_key: str | None, table_key: str | None, groups: str | None, n_neighbors: int | None, - resolutions: float | list[float] | None, + resolutions: float | tuple[float, float] | list[float | tuple[float, float]] | None, min_niche_size: int | None, scale: bool, abs_nhood: bool, @@ -681,6 +786,11 @@ def _validate_niche_args( aggregation: str | None, n_components: int | None, random_state: int, + spatial_connectivities_key: str, + latent_connectivities_key: str, + layer_ratio: float, + n_iterations: int, + use_weights: bool | tuple[bool, bool], inplace: bool, ) -> None: """ @@ -697,8 +807,10 @@ def _validate_niche_args( if not isinstance(data, AnnData | SpatialData): raise TypeError(f"'data' must be an AnnData or SpatialData object, got {type(data).__name__}") - if flavor not in ["neighborhood", "utag", "cellcharter"]: - raise ValueError(f"Invalid flavor '{flavor}'. Please choose one of 'neighborhood', 'utag', 'cellcharter'.") + if flavor not in ["neighborhood", "utag", "cellcharter", "spatialleiden"]: + raise ValueError( + f"Invalid flavor '{flavor}'. Please choose one of 'neighborhood', 'utag', 'cellcharter', 'spatialleiden'." + ) if library_key is not None: if not isinstance(library_key, str): @@ -718,10 +830,20 @@ def _validate_niche_args( raise TypeError(f"'n_neighbors' must be an integer, got {type(n_neighbors).__name__}") if resolutions is not None: - if not isinstance(resolutions, float | list): - raise TypeError(f"'resolutions' must be a float or list of floats, got {type(resolutions).__name__}") - if isinstance(resolutions, list) and not all(isinstance(res, float) for res in resolutions): - raise TypeError("All elements in 'resolutions' list must be floats") + if not isinstance(resolutions, float | tuple | list): + raise TypeError( + f"'resolutions' must be a float, a tuple of floats, a list of floats, or a list containing floats and/or tuples of floats, got {type(resolutions).__name__}" + ) + + if isinstance(resolutions, tuple): + if not all(isinstance(x, float) for x in resolutions): + raise TypeError("All elements in the tuple 'resolutions' must be floats.") + elif isinstance(resolutions, list): + for item in resolutions: + if not ( + isinstance(item, float) or (isinstance(item, tuple) and all(isinstance(i, float) for i in item)) + ): + raise TypeError("Each item in the list 'resolutions' must be a float or a tuple of floats.") if n_hop_weights is not None and not isinstance(n_hop_weights, list): raise TypeError(f"'n_hop_weights' must be a list of floats, got {type(n_hop_weights).__name__}") @@ -773,6 +895,24 @@ def _validate_niche_args( "n_hop_weights", ], }, + "spatialleiden": { + "required": ["latent_connectivities_key", "spatial_connectivities_key"], + "optional": [ + "resolutions", + "layer_ratio", + "n_iterations", + "use_weights", + "random_state", + ], + "unused": [ + "groups", + "min_niche_size", + "scale", + "abs_nhood", + "n_neighbors", + "n_hop_weights", + ], + }, } for param_name in flavor_param_specs[flavor]["required"]: @@ -832,6 +972,35 @@ def _validate_niche_args( if resolutions is None: resolutions = [0.0] + elif flavor == "spatialleiden": + if not isinstance(latent_connectivities_key, str): + raise TypeError( + f"'latent_connectivities_key' must be a string, got {type(latent_connectivities_key).__name__}" + ) + if not isinstance(spatial_connectivities_key, str): + raise TypeError( + f"'spatial_connectivities_key' must be a string, got {type(spatial_connectivities_key).__name__}" + ) + + if not isinstance(layer_ratio, float | int): + raise TypeError(f"'layer_ratio' must be a float, got {type(layer_ratio).__name__}") + if not isinstance(n_iterations, int): + raise TypeError(f"'n_iterations' must be an integer, got {type(n_iterations).__name__}") + if not ( + isinstance(use_weights, bool) + or ( + isinstance(use_weights, tuple) + and len(use_weights) == 2 + and all(isinstance(x, bool) for x in use_weights) + ) + ): + raise TypeError(f"'use_weights' must be a bool or a tuple of two bools, got {use_weights!r}") + if not isinstance(random_state, int): + raise TypeError(f"'random_state' must be an integer, got {type(random_state).__name__}") + + if resolutions is None: + resolutions = [1.0] + if not isinstance(inplace, bool): raise TypeError(f"'inplace' must be a boolean, got {type(inplace).__name__}") @@ -843,7 +1012,7 @@ def _check_unnecessary_args(flavor: str, param_dict: dict[str, Any], param_specs Parameters ---------- flavor - The flavor being used ('neighborhood', 'utag', or 'cellcharter') + The flavor being used ('neighborhood', 'utag', 'cellcharter', or 'spatialleiden') param_dict Dictionary of parameter names to their values param_specs