diff --git a/src/squidpy/_utils.py b/src/squidpy/_utils.py index 4e9165942..b9d220d53 100644 --- a/src/squidpy/_utils.py +++ b/src/squidpy/_utils.py @@ -15,6 +15,8 @@ import joblib as jl import numpy as np +from anndata import AnnData +from spatialdata import SpatialData __all__ = ["singledispatchmethod", "Signal", "SigQueue", "NDArray", "NDArrayA"] @@ -337,3 +339,16 @@ def new_func2(*args: Any, **kwargs: Any) -> Any: else: raise TypeError(repr(type(reason))) + + +def _get_adata_from_input(data: AnnData | SpatialData, table: str | None = None) -> None: + if isinstance(data, AnnData): + return data + elif isinstance(data, SpatialData): + if table is None: + raise ValueError("If using a SpatialData object, a table name must be provided with `table`.") + if table not in data.tables.keys(): + raise ValueError(f"Table `{table}` not found in `SpatialData` object.") + return data.tables[table] + else: + raise TypeError(f"Expected `data` to be of type `AnnData` or `SpatialData`, found `{type(data).__name__}`.") diff --git a/src/squidpy/pl/_var_by_distance.py b/src/squidpy/pl/_var_by_distance.py index 51a4715ef..d4572e3e4 100644 --- a/src/squidpy/pl/_var_by_distance.py +++ b/src/squidpy/pl/_var_by_distance.py @@ -17,8 +17,10 @@ from scanpy.plotting._tools.scatterplots import _panel_grid from scanpy.plotting._utils import _set_default_colors_for_categorical_obs from scipy.sparse import issparse +from spatialdata import SpatialData from squidpy._docs import d +from squidpy._utils import _get_adata_from_input from squidpy.pl._utils import save_fig __all__ = ["var_by_distance"] @@ -26,9 +28,10 @@ @d.dedent def var_by_distance( - adata: AnnData, + data: AnnData | SpatialData, var: str | list[str], anchor_key: str | list[str], + table: str | None = None, design_matrix_key: str = "design_matrix", stack_vars: bool = False, covariate: str | None = None, @@ -97,6 +100,9 @@ def var_by_distance( regplot_kwargs = dict(regplot_kwargs) scatterplot_kwargs = dict(scatterplot_kwargs) + # potentially extract table from SpatialData object + adata = _get_adata_from_input(data, table) + # if several variables are plotted, make a panel grid if isinstance(var, list) and not stack_vars: fig, grid = _panel_grid( @@ -111,16 +117,16 @@ def var_by_distance( var = [var] axs = [] - df = adata.obsm[design_matrix_key] # get design matrix + design_matrix = adata.obsm[design_matrix_key] # add var column to design matrix for name in var: if name in adata.var_names: - df[name] = ( + design_matrix[name] = ( np.array(adata[:, name].X.toarray()) if issparse(adata[:, name].X) else np.array(adata[:, name].X) ) elif name in adata.obs: - df[name] = adata.obs[name].values + design_matrix[name] = adata.obs[name].values else: raise ValueError(f"Variable {name} not found in `adata.var` or `adata.obs`.") @@ -130,7 +136,7 @@ def var_by_distance( line_palette = sns.color_palette("bright", len(var)) for i, v in enumerate(var): sns.regplot( - data=df, + data=design_matrix, x=anchor_key, y=v, label=v, @@ -140,8 +146,60 @@ def var_by_distance( ax=ax, line_kws=regplot_kwargs, ) +<<<<<<< HEAD + else: + # make a categorical color palette if none was specified and there are several regplots to be plotted + if isinstance(line_palette, str) or line_palette is None: + _set_default_colors_for_categorical_obs(adata, covariate) + line_palette = adata.uns[covariate + "_colors"] + covariate_instances = design_matrix[covariate].unique() + + # iterate over all covariate values and make 'sns.regplot' for each + for i, co in enumerate(covariate_instances): + sns.regplot( + data=design_matrix.loc[design_matrix[covariate] == co], + x=anchor_key, + y=v, + order=order, + color=line_palette[i], + scatter=show_scatter, + ax=ax, + label=co, + line_kws=regplot_kwargs, + ) + label_colors, _ = ax.get_legend_handles_labels() + ax.legend(label_colors, covariate_instances) + # add scatter plot if specified + if show_scatter: + if color is None: + plt.scatter(data=design_matrix, x=anchor_key, y=v, color="grey", **scatterplot_kwargs) + # if variable to plot on color palette is categorical, make categorical color palette + elif design_matrix[color].dtype.name == "category": + unique_colors = design_matrix[color].unique() + cNorm = colors.Normalize(vmin=0, vmax=len(unique_colors)) + scalarMap = cm.ScalarMappable(norm=cNorm, cmap=scatter_palette) + for i in range(len(unique_colors)): + plt.scatter( + data=design_matrix.loc[design_matrix[color] == unique_colors[i]], + x=anchor_key, + y=v, + color=scalarMap.to_rgba(i), + **scatterplot_kwargs, + ) + # if variable to plot on color palette is not categorical + else: + plt.scatter( + data=design_matrix, + x=anchor_key, + y=v, + c=color, + cmap=scatter_palette, + **scatterplot_kwargs, + ) +======= ax.legend(title=None) ax.set(ylabel="var") +>>>>>>> e76e147e021fb06922cc63c189f4db6bab4c4103 if title is not None: ax.set(title=title) if axis_label is None: @@ -240,3 +298,211 @@ def var_by_distance( save_fig(fig, path=save, transparent=False, dpi=dpi) if return_ax: return axs + + + +# @d.dedent +# def var_by_distance( +# data: AnnData | SpatialData, +# var: str | list[str], +# table: str | None = None, +# color: str | None = None, +# covariate: str | None = None, +# order: int = 5, +# show_scatter: bool = True, +# line_palette: str | Sequence[str] | Cycler | None = None, +# scatter_palette: str | Sequence[str] | Cycler | None = "viridis", +# dpi: int | None = None, +# figsize: tuple[int, int] | None = None, +# save: str | Path | None = None, +# title: str | list[str] | None = None, +# axis_label: str | None = None, +# return_ax: bool | None = None, +# regplot_kwargs: Mapping[str, Any] = MappingProxyType({}), +# scatterplot_kwargs: Mapping[str, Any] = MappingProxyType({}), +# ) -> Axes | None: +# """ +# Plot variables using a smooth regression line with increasing distance to an anchor point. + +# Parameters +# ---------- +# data +# AnnData or SpatialData object returned by the `var_embeddings` function. +# var +# Variables (genes) to plot on y-axis. +# table +# Name of the table in `SpatialData` object. +# color +# Variable to color the scatter plot. +# covariate +# A covariate for which separate regression lines are plotted for each category. +# order +# Order of the polynomial fit for :func:`seaborn.regplot`. +# show_scatter +# Whether to show a scatter plot underlying the regression line. +# line_palette +# Categorical color palette used in case a covariate is specified. +# scatter_palette +# Color palette for the scatter plot underlying the :func:`seaborn.regplot`. +# dpi +# Dots per inch. +# figsize +# Size of the figure in inches. +# save +# Path to save the plot. +# title +# Panel titles. +# axis_label +# Panel axis labels. +# return_ax +# Whether to return :class:`matplotlib.axes.Axes` object(s). +# regplot_kwargs +# Additional keyword arguments for :func:`seaborn.regplot`. +# scatterplot_kwargs +# Additional keyword arguments for :func:`matplotlib.pyplot.scatter`. + +# Returns +# ------- +# Axes or None +# """ +# dpi = rcParams["figure.dpi"] if dpi is None else dpi +# regplot_kwargs = dict(regplot_kwargs) +# scatterplot_kwargs = dict(scatterplot_kwargs) + +# # Validate data type and extract AnnData object +# if isinstance(data, AnnData): +# adata = data +# elif isinstance(data, SpatialData): +# if table is None: +# raise ValueError("If using a SpatialData object, a table name must be provided with `table`.") +# if table not in data.tables: +# raise KeyError(f"Table '{table}' not found in SpatialData object. Available tables: {list(data.tables.keys())}") +# adata = data.tables[table] +# else: +# raise TypeError(f"Expected `data` to be of type `AnnData` or `SpatialData`, found '{type(data).__name__}'.") + +# if isinstance(var, str): +# var = [var] + +# # If multiple variables, set up a grid of plots +# if len(var) > 1: +# fig, grid = _panel_grid( +# hspace=0.25, +# wspace=0.75 / rcParams["figure.figsize"][0] + 0.02, +# ncols=4, +# num_panels=len(var), +# ) +# axs = [] +# else: +# fig, ax = plt.subplots(1, 1, figsize=figsize) +# axs = [ax] + +# # Create dataframe from adata +# df = adata.to_df() + +# # Ensure all values of `var` are in adata.obs and have float values +# for v in var: +# if v not in adata.obs_names: +# raise KeyError(f"Variable '{v}' not found in adata.obs_names") +# df.loc[v] = df.loc[v].astype(float) + +# # If color is specified and is in adata.obs or adata.var_names, add to df +# if color is not None: +# if color in adata.obs: +# df[color] = adata.obs[color].values +# elif color in adata.var_names: +# df[color] = adata[:, color].X.flatten() +# else: +# raise ValueError(f"Color variable '{color}' not found in adata.obs or adata.var_names.") + +# # If covariate is specified and is in adata.obs, add to df +# if covariate is not None: +# if covariate in adata.obs: +# df[covariate] = adata.obs[covariate].values +# else: +# raise ValueError(f"Covariate '{covariate}' not found in adata.obs.") + +# # Iterate over the variables to plot +# for i, v in enumerate(var): +# if len(var) > 1: +# ax = plt.subplot(grid[i]) +# axs.append(ax) +# else: +# ax = axs[0] + +# # if no covariate is specified, use seaborn regplot directly +# if covariate is None: +# sns.regplot( +# data=df, +# x='distance', +# y=v, +# order=order, +# color=line_palette, +# scatter=show_scatter, +# ax=ax, +# line_kws=regplot_kwargs, +# scatter_kws=scatterplot_kwargs, +# ) +# else: +# # Generate color palette if not provided +# if line_palette is None: +# _set_default_colors_for_categorical_obs(adata, covariate) +# line_palette = adata.uns[covariate + "_colors"] +# covariate_instances = df[covariate].unique() + +# # Iterate over each category in covariate +# for idx, category in enumerate(covariate_instances): +# sns.regplot( +# data=df[df[covariate] == category], +# x='distance', +# y=v, +# order=order, +# color=line_palette[idx % len(line_palette)], +# scatter=show_scatter, +# ax=ax, +# label=str(category), +# line_kws=regplot_kwargs, +# scatter_kws=scatterplot_kwargs, +# ) +# ax.legend(title=covariate) + +# # Add scatter plot if specified +# if show_scatter and color is not None: +# if df[color].dtype.name == "category": +# unique_colors = df[color].unique() +# palette = sns.color_palette(scatter_palette, len(unique_colors)) +# for idx, cat in enumerate(unique_colors): +# ax.scatter( +# x=df.loc[df[color] == cat, 'distance'], +# y=df.loc[df[color] == cat, v], +# color=palette[idx], +# label=str(cat), +# **scatterplot_kwargs, +# ) +# else: +# sc = ax.scatter( +# x=df['distance'], +# y=df[v], +# c=df[color], +# cmap=scatter_palette, +# **scatterplot_kwargs, +# ) +# fig.colorbar(sc, ax=ax) + +# if isinstance(title, list): +# ax.set_title(title[i]) +# elif title is not None: +# ax.set_title(title) +# if axis_label is None: +# ax.set_xlabel("Distance") +# ax.set_ylabel(v) +# else: +# ax.set_xlabel(axis_label) + +# if save is not None: +# save_fig(fig, path=save, transparent=False, dpi=dpi) + +# if return_ax: +# return axs if len(axs) > 1 else axs[0] +# else: +# plt.show() diff --git a/src/squidpy/tl/__init__.py b/src/squidpy/tl/__init__.py index 6d5abe98c..e4e111a72 100644 --- a/src/squidpy/tl/__init__.py +++ b/src/squidpy/tl/__init__.py @@ -4,3 +4,4 @@ from squidpy.tl._sliding_window import _calculate_window_corners, sliding_window from squidpy.tl._var_by_distance import var_by_distance +from squidpy.tl._var_embeddings import var_embeddings diff --git a/src/squidpy/tl/_var_by_distance.py b/src/squidpy/tl/_var_by_distance.py index e845c0c29..a0b4a6108 100644 --- a/src/squidpy/tl/_var_by_distance.py +++ b/src/squidpy/tl/_var_by_distance.py @@ -12,9 +12,10 @@ from sklearn.metrics import DistanceMetric from sklearn.neighbors import KDTree from sklearn.preprocessing import MinMaxScaler +from spatialdata import SpatialData from squidpy._docs import d -from squidpy._utils import NDArrayA +from squidpy._utils import NDArrayA, _get_adata_from_input from squidpy.gr._utils import _save_data __all__ = ["var_by_distance"] @@ -22,11 +23,15 @@ @d.dedent def var_by_distance( - adata: AnnData, + data: AnnData | SpatialData, groups: str | list[str] | NDArrayA, cluster_key: str | None = None, library_key: str | None = None, +<<<<<<< HEAD + table: str | None = None, +======= library_id: str | list[str] | None = None, +>>>>>>> e76e147e021fb06922cc63c189f4db6bab4c4103 design_matrix_key: str = "design_matrix", covariates: str | list[str] | None = None, metric: str = "euclidean", @@ -60,10 +65,26 @@ def var_by_distance( If ``copy = True``, returns the design_matrix with the distances to an anchor point Otherwise, stores design_matrix in `.obsm`. """ + # potentially extract table from SpatialData object + adata = _get_adata_from_input(data, table) + start = logg.info(f"Creating {design_matrix_key}") +<<<<<<< HEAD + # list of columns which will be categorical later on + # categorical_columns = [cluster_key] + # save initial metadata to adata.uns if copy == False + if not copy: + adata.uns[design_matrix_key] = _add_metadata( + cluster_key, groups, metric=metric, library_key=library_key, covariates=covariates + ) + + if isinstance(groups, str | np.ndarray): + anchor: list[Any] = [groups] +======= if isinstance(groups, str): anchor = [groups] +>>>>>>> e76e147e021fb06922cc63c189f4db6bab4c4103 elif isinstance(groups, list): anchor = groups elif isinstance(groups, np.ndarray): @@ -89,7 +110,12 @@ def var_by_distance( else: batch = adata.obs[library_key].unique().tolist() else: +<<<<<<< HEAD + batch = adata.obs[library_key].unique() + # categorical_columns.append(library_key) +======= raise TypeError(f"Invalid type for library_key: {type(library_key)}.") +>>>>>>> e76e147e021fb06922cc63c189f4db6bab4c4103 batch_design_matrices = {} max_distances = {} @@ -200,6 +226,45 @@ def var_by_distance( _save_data(adata, attr="obsm", key=design_matrix_key, data=df, time=start) +<<<<<<< HEAD +def _add_metadata( + cluster_key: str, + groups: str | list[str] | NDArrayA, + library_key: str | None = None, + covariates: str | list[str] | None = None, + metric: str = "euclidean", +) -> dict[str, Any]: + """Add metadata to adata.uns.""" + metadata = {} + if isinstance(groups, np.ndarray): + metadata["anchor_scaled"] = "custom_anchor" + metadata["anchor_raw"] = "custom_anchor_raw" + elif isinstance(groups, list): + for i, anchor in enumerate(groups): + metadata[f"anchor_scaled_{str(i)}"] = anchor + metadata[f"anchor_raw_{str(i)}"] = anchor + "_raw" + else: + metadata["anchor_scaled"] = groups + metadata["anchor_raw"] = groups + "_raw" + + metadata["annotation"] = cluster_key + + if library_key is not None: + metadata["library_key"] = library_key + + metadata["metric"] = metric + + if covariates is not None: + if isinstance(covariates, str): + covariates = [covariates] + for i, covariate in enumerate(covariates): + metadata[f"covariate_{str(i)}"] = covariate + + return metadata + + +======= +>>>>>>> e76e147e021fb06922cc63c189f4db6bab4c4103 def _init_design_matrix( adata: AnnData, cluster_key: str | None, diff --git a/src/squidpy/tl/_var_embeddings.py b/src/squidpy/tl/_var_embeddings.py new file mode 100644 index 000000000..4e5681457 --- /dev/null +++ b/src/squidpy/tl/_var_embeddings.py @@ -0,0 +1,114 @@ +from __future__ import annotations + +from typing import Any, Optional + +import numpy as np +import pandas as pd +import scanpy as sc +from anndata import AnnData +from scanpy import logging as logg +from spatialdata import SpatialData + +from squidpy._docs import d +from squidpy._utils import _get_adata_from_input + +__all__ = ["var_embeddings"] + + +@d.dedent +def var_embeddings( + data: AnnData | SpatialData, + group: str, + design_matrix_key: str = "design_matrix", + table: str | None = None, + n_bins: int = 100, + include_anchor: bool = False, + return_as_adata: bool = False, +) -> AnnData | None: + """ + Bin variables by previously calculated distance to an anchor point. + + Parameters + ---------- + data + AnnData or SpatialData object. + group + Annotation column in design matrix, given by `design_matrix_key`, that is used as anchor. + design_matrix_key + Name of the design matrix saved to `.obsm`. + table + Name of the table in `SpatialData` object. + n_bins + Number of bins to use for aggregation. + include_anchor + Whether to include the variable counts belonging to the anchor point in the aggregation. + return_as_adata + Only evaluated, if `data` is a SpatialData object. Whether to return the result or store it as a new table. + + Returns + ------- + AnnData or None + If `data` is an `AnnData` object or `return_as_adata` is True, returns the new `AnnData` object. + If `data` is a `SpatialData` object and `return_as_adata` is False, modifies `data` in place and returns None. + """ + + adata = _get_adata_from_input(data, table) + + if design_matrix_key not in adata.obsm: + raise KeyError( + f"Design matrix key '{design_matrix_key}' not found in .obsm. Available keys are: {list(adata.obsm.keys())}" + ) + + design_matrix = adata.obsm[design_matrix_key].copy() + if group not in design_matrix.columns: + raise KeyError( + f"Group column '{group}' not found in design matrix. Available columns: {list(design_matrix.columns)}" + ) + if not pd.api.types.is_numeric_dtype(design_matrix[group]): + raise TypeError(f"The group column '{group}' must be numeric.") + + logg.info("Calculating embeddings for distance aggregations by gene.") + + # bin the data by distance and calculate the median distance for each bin + intervals = pd.cut(design_matrix[group], bins=n_bins) + + # Extract the interval bounds as tuples and midpoints in a single pass + design_matrix["bins"] = [ + (interval.left, interval.right) if pd.notnull(interval) else (0.0, 0.0) for interval in intervals + ] + design_matrix["median_value"] = [interval.mid if pd.notnull(interval) else 0.0 for interval in intervals] + + # turn categorical NaNs into float 0s + design_matrix["median_value"] = ( + pd.to_numeric(design_matrix["median_value"], errors="coerce").fillna(0).astype(float) + ) + + # get count matrix and add binned distance to each .obs + X_df = adata.to_df() + X_df["distance"] = design_matrix["median_value"] + + # aggregate the count matrix by the bins + aggregated_df = X_df.groupby(["distance"]).sum() + + result = aggregated_df.T + + # optionally include or remove variable values for distance 0 (anchor point) + start_bin = 0 + if not include_anchor: + result = result.drop(result.columns[0], axis=1) + start_bin = 1 + + # rename column names for plotting + result.columns = range(start_bin, n_bins + 1) + + adata_new = AnnData(X=result) + adata_new.uns[design_matrix_key] = design_matrix + + if isinstance(data, AnnData): + return adata_new + elif isinstance(data, SpatialData): + if return_as_adata: + return adata_new + else: + data.tables["var_by_dist_bins"] = adata_new + return None