Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
5a52976
Add method to calculate embeddings for variable by distance aggregation
LLehner Mar 4, 2024
eb84518
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 4, 2024
488da20
Fix pre-commit
LLehner Mar 4, 2024
8fce577
Fix pre-commit
LLehner Mar 4, 2024
0b72494
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 4, 2024
edcca87
Update param name
LLehner Mar 4, 2024
4be2529
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 4, 2024
f91c1af
Merge branch 'var_by_distance_clustering' of https://github.com/scver…
LLehner Mar 4, 2024
cfe496c
Remove duplicate code
LLehner Apr 22, 2024
c4fca29
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 22, 2024
64e38df
Improve performance, Update output
LLehner Apr 22, 2024
3ab8467
Improve performance, Update output
LLehner Apr 22, 2024
9eabd0d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 22, 2024
a40a8cf
Remove import
LLehner Apr 22, 2024
90108ad
Merge branch 'var_by_distance_clustering' of https://github.com/scver…
LLehner Apr 22, 2024
09c72b0
Remove import
LLehner Apr 22, 2024
3396146
Update return
LLehner May 26, 2024
a44f661
Merge branch 'var_by_distance_clustering' of https://github.com/scver…
LLehner May 26, 2024
41a2ae4
Merge branch 'main' into var_by_distance_clustering
LLehner May 26, 2024
67bdd5c
Fix pre-commit
LLehner May 26, 2024
99b41b0
Merge branch 'var_by_distance_clustering' of https://github.com/scver…
LLehner May 26, 2024
876c4ed
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 26, 2024
8ee07ba
Fix pre-commit
LLehner May 26, 2024
d3cefff
Fix pre-commit
LLehner May 27, 2024
80e23fc
Merge branch 'main' into var_by_distance_clustering
timtreis Jun 20, 2024
f2b0e12
Merge branch 'main' into var_by_distance_clustering
timtreis Jul 9, 2024
2a863a4
Merge branch 'main' into var_by_distance_clustering
timtreis Aug 7, 2024
5729676
Fix indices; Update return type
LLehner Aug 8, 2024
7dfa933
Add spatialdata as input
LLehner Aug 26, 2024
bf1dcff
Merge branch 'main' into var_by_distance_clustering
LLehner Aug 27, 2024
d6e5ecd
Update docstring
LLehner Aug 27, 2024
6e724f0
Merge branch 'main' into var_by_distance_clustering
timtreis Oct 1, 2024
6e28662
Merge branch 'main' into var_by_distance_clustering
LLehner Oct 8, 2024
1b1c05a
Merge branch 'main' into var_by_distance_clustering
timtreis Nov 13, 2024
abf1364
Merge branch 'main' into var_by_distance_clustering
timtreis Jan 18, 2025
e76e147
Merge branch 'main' into var_by_distance_clustering
timtreis May 9, 2025
dc4028c
stash
timtreis May 9, 2025
6f1d062
stash
timtreis May 9, 2025
f914920
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 9, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions src/squidpy/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@

import joblib as jl
import numpy as np
from anndata import AnnData
from spatialdata import SpatialData

__all__ = ["singledispatchmethod", "Signal", "SigQueue", "NDArray", "NDArrayA"]

Expand Down Expand Up @@ -337,3 +339,16 @@ def new_func2(*args: Any, **kwargs: Any) -> Any:

else:
raise TypeError(repr(type(reason)))


def _get_adata_from_input(data: AnnData | SpatialData, table: str | None = None) -> None:
if isinstance(data, AnnData):
return data
elif isinstance(data, SpatialData):
if table is None:
raise ValueError("If using a SpatialData object, a table name must be provided with `table`.")
if table not in data.tables.keys():
raise ValueError(f"Table `{table}` not found in `SpatialData` object.")
return data.tables[table]
else:
raise TypeError(f"Expected `data` to be of type `AnnData` or `SpatialData`, found `{type(data).__name__}`.")
276 changes: 271 additions & 5 deletions src/squidpy/pl/_var_by_distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,21 @@
from scanpy.plotting._tools.scatterplots import _panel_grid
from scanpy.plotting._utils import _set_default_colors_for_categorical_obs
from scipy.sparse import issparse
from spatialdata import SpatialData

from squidpy._docs import d
from squidpy._utils import _get_adata_from_input
from squidpy.pl._utils import save_fig

__all__ = ["var_by_distance"]


@d.dedent
def var_by_distance(
adata: AnnData,
data: AnnData | SpatialData,
var: str | list[str],
anchor_key: str | list[str],
table: str | None = None,
design_matrix_key: str = "design_matrix",
stack_vars: bool = False,
covariate: str | None = None,
Expand Down Expand Up @@ -97,6 +100,9 @@ def var_by_distance(
regplot_kwargs = dict(regplot_kwargs)
scatterplot_kwargs = dict(scatterplot_kwargs)

# potentially extract table from SpatialData object
adata = _get_adata_from_input(data, table)

# if several variables are plotted, make a panel grid
if isinstance(var, list) and not stack_vars:
fig, grid = _panel_grid(
Expand All @@ -111,16 +117,16 @@ def var_by_distance(
var = [var]
axs = []

df = adata.obsm[design_matrix_key] # get design matrix
design_matrix = adata.obsm[design_matrix_key]

# add var column to design matrix
for name in var:
if name in adata.var_names:
df[name] = (
design_matrix[name] = (
np.array(adata[:, name].X.toarray()) if issparse(adata[:, name].X) else np.array(adata[:, name].X)
)
elif name in adata.obs:
df[name] = adata.obs[name].values
design_matrix[name] = adata.obs[name].values
else:
raise ValueError(f"Variable {name} not found in `adata.var` or `adata.obs`.")

Expand All @@ -130,7 +136,7 @@ def var_by_distance(
line_palette = sns.color_palette("bright", len(var))
for i, v in enumerate(var):
sns.regplot(
data=df,
data=design_matrix,
x=anchor_key,
y=v,
label=v,
Expand All @@ -140,8 +146,60 @@ def var_by_distance(
ax=ax,
line_kws=regplot_kwargs,
)
<<<<<<< HEAD
else:
# make a categorical color palette if none was specified and there are several regplots to be plotted
if isinstance(line_palette, str) or line_palette is None:
_set_default_colors_for_categorical_obs(adata, covariate)
line_palette = adata.uns[covariate + "_colors"]
covariate_instances = design_matrix[covariate].unique()

# iterate over all covariate values and make 'sns.regplot' for each
for i, co in enumerate(covariate_instances):
sns.regplot(
data=design_matrix.loc[design_matrix[covariate] == co],
x=anchor_key,
y=v,
order=order,
color=line_palette[i],
scatter=show_scatter,
ax=ax,
label=co,
line_kws=regplot_kwargs,
)
label_colors, _ = ax.get_legend_handles_labels()
ax.legend(label_colors, covariate_instances)
# add scatter plot if specified
if show_scatter:
if color is None:
plt.scatter(data=design_matrix, x=anchor_key, y=v, color="grey", **scatterplot_kwargs)
# if variable to plot on color palette is categorical, make categorical color palette
elif design_matrix[color].dtype.name == "category":
unique_colors = design_matrix[color].unique()
cNorm = colors.Normalize(vmin=0, vmax=len(unique_colors))
scalarMap = cm.ScalarMappable(norm=cNorm, cmap=scatter_palette)
for i in range(len(unique_colors)):
plt.scatter(
data=design_matrix.loc[design_matrix[color] == unique_colors[i]],
x=anchor_key,
y=v,
color=scalarMap.to_rgba(i),
**scatterplot_kwargs,
)
# if variable to plot on color palette is not categorical
else:
plt.scatter(
data=design_matrix,
x=anchor_key,
y=v,
c=color,
cmap=scatter_palette,
**scatterplot_kwargs,
)
=======
ax.legend(title=None)
ax.set(ylabel="var")
>>>>>>> e76e147e021fb06922cc63c189f4db6bab4c4103
if title is not None:
ax.set(title=title)
if axis_label is None:
Expand Down Expand Up @@ -240,3 +298,211 @@ def var_by_distance(
save_fig(fig, path=save, transparent=False, dpi=dpi)
if return_ax:
return axs



# @d.dedent
# def var_by_distance(
# data: AnnData | SpatialData,
# var: str | list[str],
# table: str | None = None,
# color: str | None = None,
# covariate: str | None = None,
# order: int = 5,
# show_scatter: bool = True,
# line_palette: str | Sequence[str] | Cycler | None = None,
# scatter_palette: str | Sequence[str] | Cycler | None = "viridis",
# dpi: int | None = None,
# figsize: tuple[int, int] | None = None,
# save: str | Path | None = None,
# title: str | list[str] | None = None,
# axis_label: str | None = None,
# return_ax: bool | None = None,
# regplot_kwargs: Mapping[str, Any] = MappingProxyType({}),
# scatterplot_kwargs: Mapping[str, Any] = MappingProxyType({}),
# ) -> Axes | None:
# """
# Plot variables using a smooth regression line with increasing distance to an anchor point.

# Parameters
# ----------
# data
# AnnData or SpatialData object returned by the `var_embeddings` function.
# var
# Variables (genes) to plot on y-axis.
# table
# Name of the table in `SpatialData` object.
# color
# Variable to color the scatter plot.
# covariate
# A covariate for which separate regression lines are plotted for each category.
# order
# Order of the polynomial fit for :func:`seaborn.regplot`.
# show_scatter
# Whether to show a scatter plot underlying the regression line.
# line_palette
# Categorical color palette used in case a covariate is specified.
# scatter_palette
# Color palette for the scatter plot underlying the :func:`seaborn.regplot`.
# dpi
# Dots per inch.
# figsize
# Size of the figure in inches.
# save
# Path to save the plot.
# title
# Panel titles.
# axis_label
# Panel axis labels.
# return_ax
# Whether to return :class:`matplotlib.axes.Axes` object(s).
# regplot_kwargs
# Additional keyword arguments for :func:`seaborn.regplot`.
# scatterplot_kwargs
# Additional keyword arguments for :func:`matplotlib.pyplot.scatter`.

# Returns
# -------
# Axes or None
# """
# dpi = rcParams["figure.dpi"] if dpi is None else dpi
# regplot_kwargs = dict(regplot_kwargs)
# scatterplot_kwargs = dict(scatterplot_kwargs)

# # Validate data type and extract AnnData object
# if isinstance(data, AnnData):
# adata = data
# elif isinstance(data, SpatialData):
# if table is None:
# raise ValueError("If using a SpatialData object, a table name must be provided with `table`.")
# if table not in data.tables:
# raise KeyError(f"Table '{table}' not found in SpatialData object. Available tables: {list(data.tables.keys())}")
# adata = data.tables[table]
# else:
# raise TypeError(f"Expected `data` to be of type `AnnData` or `SpatialData`, found '{type(data).__name__}'.")

# if isinstance(var, str):
# var = [var]

# # If multiple variables, set up a grid of plots
# if len(var) > 1:
# fig, grid = _panel_grid(
# hspace=0.25,
# wspace=0.75 / rcParams["figure.figsize"][0] + 0.02,
# ncols=4,
# num_panels=len(var),
# )
# axs = []
# else:
# fig, ax = plt.subplots(1, 1, figsize=figsize)
# axs = [ax]

# # Create dataframe from adata
# df = adata.to_df()

# # Ensure all values of `var` are in adata.obs and have float values
# for v in var:
# if v not in adata.obs_names:
# raise KeyError(f"Variable '{v}' not found in adata.obs_names")
# df.loc[v] = df.loc[v].astype(float)

# # If color is specified and is in adata.obs or adata.var_names, add to df
# if color is not None:
# if color in adata.obs:
# df[color] = adata.obs[color].values
# elif color in adata.var_names:
# df[color] = adata[:, color].X.flatten()
# else:
# raise ValueError(f"Color variable '{color}' not found in adata.obs or adata.var_names.")

# # If covariate is specified and is in adata.obs, add to df
# if covariate is not None:
# if covariate in adata.obs:
# df[covariate] = adata.obs[covariate].values
# else:
# raise ValueError(f"Covariate '{covariate}' not found in adata.obs.")

# # Iterate over the variables to plot
# for i, v in enumerate(var):
# if len(var) > 1:
# ax = plt.subplot(grid[i])
# axs.append(ax)
# else:
# ax = axs[0]

# # if no covariate is specified, use seaborn regplot directly
# if covariate is None:
# sns.regplot(
# data=df,
# x='distance',
# y=v,
# order=order,
# color=line_palette,
# scatter=show_scatter,
# ax=ax,
# line_kws=regplot_kwargs,
# scatter_kws=scatterplot_kwargs,
# )
# else:
# # Generate color palette if not provided
# if line_palette is None:
# _set_default_colors_for_categorical_obs(adata, covariate)
# line_palette = adata.uns[covariate + "_colors"]
# covariate_instances = df[covariate].unique()

# # Iterate over each category in covariate
# for idx, category in enumerate(covariate_instances):
# sns.regplot(
# data=df[df[covariate] == category],
# x='distance',
# y=v,
# order=order,
# color=line_palette[idx % len(line_palette)],
# scatter=show_scatter,
# ax=ax,
# label=str(category),
# line_kws=regplot_kwargs,
# scatter_kws=scatterplot_kwargs,
# )
# ax.legend(title=covariate)

# # Add scatter plot if specified
# if show_scatter and color is not None:
# if df[color].dtype.name == "category":
# unique_colors = df[color].unique()
# palette = sns.color_palette(scatter_palette, len(unique_colors))
# for idx, cat in enumerate(unique_colors):
# ax.scatter(
# x=df.loc[df[color] == cat, 'distance'],
# y=df.loc[df[color] == cat, v],
# color=palette[idx],
# label=str(cat),
# **scatterplot_kwargs,
# )
# else:
# sc = ax.scatter(
# x=df['distance'],
# y=df[v],
# c=df[color],
# cmap=scatter_palette,
# **scatterplot_kwargs,
# )
# fig.colorbar(sc, ax=ax)

# if isinstance(title, list):
# ax.set_title(title[i])
# elif title is not None:
# ax.set_title(title)
# if axis_label is None:
# ax.set_xlabel("Distance")
# ax.set_ylabel(v)
# else:
# ax.set_xlabel(axis_label)

# if save is not None:
# save_fig(fig, path=save, transparent=False, dpi=dpi)

# if return_ax:
# return axs if len(axs) > 1 else axs[0]
# else:
# plt.show()
1 change: 1 addition & 0 deletions src/squidpy/tl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@

from squidpy.tl._sliding_window import _calculate_window_corners, sliding_window
from squidpy.tl._var_by_distance import var_by_distance
from squidpy.tl._var_embeddings import var_embeddings
Loading
Loading