diff --git a/pyproject.toml b/pyproject.toml index cecbe86b7..43bbedf2c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -94,7 +94,6 @@ test = [ "pytest>=7", "pytest-xdist>=3", "pytest-mock>=3.5.0", - # Just for VS Code "pytest-cov>=4", "coverage[toml]>=7", "pytest-timeout>=2.1.0", diff --git a/src/squidpy/_utils.py b/src/squidpy/_utils.py index 99f1b1348..a12e035c0 100644 --- a/src/squidpy/_utils.py +++ b/src/squidpy/_utils.py @@ -11,12 +11,13 @@ from multiprocessing import Manager, cpu_count from queue import Queue from threading import Thread -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Literal import joblib as jl import numba import numpy as np import spatialdata as sd +import xarray as xr from spatialdata.models import Image2DModel, Labels2DModel __all__ = ["singledispatchmethod", "Signal", "SigQueue", "NDArray", "NDArrayA"] @@ -373,3 +374,17 @@ def _yx_from_shape(shape: tuple[int, ...]) -> tuple[int, int]: return shape[1], shape[2] raise ValueError(f"Unsupported shape {shape}. Expected (y, x) or (c, y, x).") + + +def _ensure_dim_order(img_da: xr.DataArray, order: Literal["cyx", "yxc"] = "yxc") -> xr.DataArray: + """ + Ensure dims are in the requested order and that a 'c' dim exists. + Only supports images with dims subset of {'y','x','c'}. + """ + dims = list(img_da.dims) + if "y" not in dims or "x" not in dims: + raise ValueError(f'Expected dims to include "y" and "x". Found dims={dims}') + if "c" not in dims: + img_da = img_da.expand_dims({"c": [0]}) + # After possible expand, just transpose to target + return img_da.transpose(*tuple(order)) diff --git a/src/squidpy/datasets/_utils.py b/src/squidpy/datasets/_utils.py index 8dcfff33d..d0257fcf4 100644 --- a/src/squidpy/datasets/_utils.py +++ b/src/squidpy/datasets/_utils.py @@ -185,7 +185,11 @@ def _extension(self) -> str: def _get_zipped_dataset(folderpath: Path, dataset_name: str, figshare_id: str) -> sd.SpatialData: """Returns a specific dataset as SpatialData object. If the file is not present on disk, it will be downloaded and extracted.""" - if not folderpath.is_dir(): + # Create directory if it doesn't exist + if not folderpath.exists(): + logg.info(f"Creating directory `{folderpath}`") + folderpath.mkdir(parents=True, exist_ok=True) + elif not folderpath.is_dir(): raise ValueError(f"Expected a directory path for `folderpath`, found: {folderpath}") download_zip = folderpath / f"{dataset_name}.zip" diff --git a/src/squidpy/experimental/__init__.py b/src/squidpy/experimental/__init__.py index df8681a40..435cd0098 100644 --- a/src/squidpy/experimental/__init__.py +++ b/src/squidpy/experimental/__init__.py @@ -6,7 +6,6 @@ from __future__ import annotations -from . import im -from .im._detect_tissue import detect_tissue +from . import im, pl -__all__ = ["detect_tissue", "im"] +__all__ = ["im", "pl"] diff --git a/src/squidpy/experimental/im/__init__.py b/src/squidpy/experimental/im/__init__.py index 5c43a7b78..d5296a611 100644 --- a/src/squidpy/experimental/im/__init__.py +++ b/src/squidpy/experimental/im/__init__.py @@ -5,5 +5,13 @@ FelzenszwalbParams, detect_tissue, ) +from ._qc_sharpness import qc_sharpness +from ._sharpness_metrics import SharpnessMetric -__all__ = ["detect_tissue", "BackgroundDetectionParams", "FelzenszwalbParams"] +__all__ = [ + "qc_sharpness", + "detect_tissue", + "SharpnessMetric", + "BackgroundDetectionParams", + "FelzenszwalbParams", +] diff --git a/src/squidpy/experimental/im/_detect_tissue.py b/src/squidpy/experimental/im/_detect_tissue.py index 3b0fa8a78..64599c344 100644 --- a/src/squidpy/experimental/im/_detect_tissue.py +++ b/src/squidpy/experimental/im/_detect_tissue.py @@ -17,9 +17,9 @@ from spatialdata.models import Labels2DModel from spatialdata.transformations import get_transformation -from squidpy._utils import _get_scale_factors, _yx_from_shape +from squidpy._utils import _ensure_dim_order, _get_scale_factors, _yx_from_shape -from ._utils import _flatten_channels, _get_image_data +from ._utils import _flatten_channels, _get_element_data class DETECT_TISSUE_METHOD(enum.Enum): @@ -170,7 +170,9 @@ def detect_tissue( manual_scale = scale.lower() != "auto" # Load smallest available or explicit scale - img_src = _get_image_data(sdata, image_key, scale=scale if manual_scale else "auto") + img_node = sdata.images[image_key] + img_da = _get_element_data(img_node, scale if manual_scale else "auto", "image", image_key) + img_src = _ensure_dim_order(img_da, "yxc") src_h, src_w = _yx_from_shape(img_src.shape) n_src_px = src_h * src_w diff --git a/src/squidpy/experimental/im/_qc_sharpness.py b/src/squidpy/experimental/im/_qc_sharpness.py new file mode 100644 index 000000000..d1277eb68 --- /dev/null +++ b/src/squidpy/experimental/im/_qc_sharpness.py @@ -0,0 +1,599 @@ +from __future__ import annotations + +from typing import Literal + +import dask.array as da +import geopandas as gpd +import numba +import numpy as np +import pandas as pd +import xarray as xr +from anndata import AnnData +from dask.diagnostics import ProgressBar +from sklearn.preprocessing import StandardScaler +from spatialdata import SpatialData +from spatialdata._logging import logger +from spatialdata.models import ShapesModel, TableModel + +from squidpy._docs import d +from squidpy._utils import _ensure_dim_order + +from ._detect_tissue import detect_tissue +from ._sharpness_metrics import SharpnessMetric, _get_sharpness_metric_function +from ._utils import TileGrid, _get_element_data + +# single-thread numba to avoid clashes with Dask +numba.set_num_threads(1) + + +@d.dedent +def qc_sharpness( + sdata: SpatialData, + image_key: str, + *, + scale: str = "scale0", + metrics: SharpnessMetric | list[SharpnessMetric] | None = None, + tile_size: Literal["auto"] | tuple[int, int] = "auto", + detect_outliers: bool = True, + detect_tissue: bool = True, + outlier_method: Literal["pvalue", "iqr", "zscore", "tenengrad_tissue"] = "pvalue", + outlier_cutoff: float = 0.1, + progress: bool = True, + tissue_mask_key: str | None = None, +) -> None: + """ + Perform quality control analysis on image sharpness. + + Parameters + ---------- + sdata + SpatialData object containing the image. + image_key + Key of the image in ``sdata.images`` to analyze. + scale + Scale level to use for processing. Defaults to ``"scale0"``. + metrics + Sharpness metrics to compute. Can be a single metric or list of metrics. + tile_size + Size of tiles for analysis. If ``"auto"``, automatically determines size. + detect_outliers + Whether to detect outlier tiles based on sharpness scores. + detect_tissue + Whether to detect tissue regions for context-aware outlier detection. + outlier_method + Method for outlier detection. Options: ``"pvalue"``, ``"iqr"``, ``"zscore"``, ``"tenengrad_tissue"``. + outlier_cutoff + Threshold for outlier detection. + progress + Whether to show progress bar during computation. + tissue_mask_key + Key of the tissue mask in ``sdata.labels`` to use. If ``None``, the function will + check if ``"{image_key}_tissue"`` already exists in ``sdata.labels`` and reuse it. + If it doesn't exist, tissue detection will be performed and the mask will be added + to ``sdata.labels`` with key ``"{image_key}_tissue"``. If provided, the existing + mask at this key will be used. + + Returns + ------- + None + Results are stored in the following locations: + + - ``sdata.tables[f"qc_img_{image_key}_sharpness"]``: AnnData object with sharpness scores + - ``sdata.shapes[f"qc_img_{image_key}_sharpness_grid"]``: GeoDataFrame with tile geometries + - ``sdata.tables[...].uns["qc_sharpness"]``: Metadata about the analysis + + Notes + ----- + This function performs tile-based sharpness analysis on images, computing + various sharpness metrics and optionally detecting outlier tiles. + """ + # Parameter validation + if image_key not in sdata.images: + raise KeyError(f"Image key '{image_key}' not found in sdata.images") + + if metrics is None: + metrics = [SharpnessMetric.TENENGRAD, SharpnessMetric.VAR_OF_LAPLACIAN] + elif isinstance(metrics, SharpnessMetric): + metrics = [metrics] + + if not isinstance(metrics, list) or not all(isinstance(m, SharpnessMetric) for m in metrics): + raise TypeError("metrics must be SharpnessMetric or list of SharpnessMetric") + + if isinstance(metrics, list) and not all(isinstance(m, SharpnessMetric) for m in metrics): + available = ", ".join(m.value for m in SharpnessMetric) + raise TypeError(f"Metrics must be one of: {available}") + + if outlier_method not in ["pvalue", "iqr", "zscore", "tenengrad_tissue"]: + raise ValueError( + f"Unknown outlier_method '{outlier_method}'. Must be one of: pvalue, iqr, zscore, tenengrad_tissue" + ) + + # Compute sharpness metrics + img_node = sdata.images[image_key] + img_da = _get_element_data(img_node, scale, "image", image_key) + img_yxc = _ensure_dim_order(img_da, "yxc") + gray = _to_gray_dask_yx(img_yxc) + H, W = int(gray.shape[0]), int(gray.shape[1]) + + tg = TileGrid(H, W, tile_size) + tile_indices = tg.indices() + obs_names = tg.names() + pixel_bounds = tg.bounds() + + logger.info("Quantifying image sharpness.") + logger.info(f"- Input image (x, y): ({W}, {H})") + logger.info(f"- Tile size (x, y): ({tg.tx}, {tg.ty})") + logger.info(f"- Number of tiles (n_x, n_y): ({tg.tiles_x}, {tg.tiles_y})") + + metrics_list = metrics if isinstance(metrics, list) else [metrics] + metric_names = [(m.value if isinstance(m, SharpnessMetric) else str(m)) for m in metrics_list] + + all_scores: dict[str, np.ndarray] = {} + for name in metric_names: + gray_re = gray.rechunk((tg.ty, tg.tx)) + metric_func = _get_sharpness_metric_function(name) + field = da.map_overlap(metric_func, gray_re, depth=0, boundary="reflect", dtype=np.float32) + + padded = tg.rechunk_and_pad(field) + + if name == "tenengrad": + tiles_da = tg.coarsen(padded, "sum") / float(tg.ty * tg.tx) + else: + tiles_da = tg.coarsen(padded, "mean") + + logger.info(f"- Calculating metric: '{name}'") + if progress: + with ProgressBar(): + all_scores[name] = tiles_da.compute() + else: + all_scores[name] = tiles_da.compute() + + # build AnnData + first = next(iter(all_scores.values())) + cents, polys = tg.centroids_and_polygons() + n_tiles = first.size + X = np.column_stack([all_scores[n].ravel() for n in metric_names]) + var_names = [f"sharpness_{n}" for n in metric_names] + + adata = AnnData(X=X) + adata.var_names = var_names + adata.obs_names = obs_names + adata.obs["centroid_y"] = cents[:, 0] + adata.obs["centroid_x"] = cents[:, 1] + adata.obsm["spatial"] = cents + + # defaults to avoid NameError when skipping tissue/outliers + tissue = np.zeros(n_tiles, dtype=bool) + back = ~tissue + t_sim = np.zeros(n_tiles, np.float32) + b_sim = np.zeros(n_tiles, np.float32) + outlier_labels = np.ones(n_tiles, dtype=int) + unfocus_scores: np.ndarray | None = None + + if detect_outliers: + if detect_tissue: + tissue, back, t_sim, b_sim = _detect_tissue_from_mask( + sdata, image_key, tile_indices, tg.ty, tg.tx, scale, tissue_mask_key=tissue_mask_key + ) + logger.info(f"- Classified tiles: background: {back.sum()}, tissue: {tissue.sum()}.") + + if detect_tissue and tissue.any(): + if outlier_method == "pvalue": + labels, pvals = _detect_outliers_pvalue(X, tissue_mask=tissue, var_names=var_names) + scores = 1.0 - pvals + lo, hi = float(np.min(scores)), float(np.max(scores)) + scores = (scores - lo) / (hi - lo) if hi > lo else np.zeros_like(scores) + outlier_labels = np.where(scores >= outlier_cutoff, -1, 1) + unfocus_scores = scores + elif outlier_method == "tenengrad_tissue": + scores = _detect_tenengrad_tissue_outliers(X, tissue_mask=tissue, var_names=var_names) + outlier_labels = np.where(scores >= outlier_cutoff, -1, 1) + unfocus_scores = scores + else: + tX = X[tissue] + t_labels = _detect_sharpness_outliers(tX, method=outlier_method) + outlier_labels = np.ones(n_tiles, dtype=int) + outlier_labels[tissue] = t_labels + else: + method = "zscore" if outlier_method == "pvalue" else outlier_method + outlier_labels = _detect_sharpness_outliers(X, method=method) + + adata.obs["sharpness_outlier"] = pd.Categorical( + (outlier_labels == -1).astype(str), categories=["False", "True"] + ) + if detect_tissue: + adata.obs["is_tissue"] = pd.Categorical(tissue.astype(str), categories=["False", "True"]) + adata.obs["is_background"] = pd.Categorical(back.astype(str), categories=["False", "True"]) + adata.obs["tissue_similarity"] = t_sim + adata.obs["background_similarity"] = b_sim + if unfocus_scores is not None: + adata.obs["unfocus_score"] = unfocus_scores + + logger.info(f"- Detected {int((outlier_labels == -1).sum())} outlier tiles.") + + adata.uns["qc_sharpness"] = { + "metrics": list(all_scores.keys()), + "tile_size_y": tg.ty, + "tile_size_x": tg.tx, + "image_height": H, + "image_width": W, + "n_tiles_y": tg.tiles_y, + "n_tiles_x": tg.tiles_x, + "image_key": image_key, + "scale": scale, + "detect_tissue": detect_tissue, + "outlier_method": outlier_method, + "n_tissue_tiles": int(tissue.sum()), + "n_background_tiles": int(back.sum()), + "n_outlier_tiles": int((outlier_labels == -1).sum()), + } + + table_key = f"qc_img_{image_key}_sharpness" + shapes_key = f"qc_img_{image_key}_sharpness_grid" + + sdata.tables[table_key] = TableModel.parse(adata) + logger.info(f"- Saved sharpness scores as 'sdata.tables[\"{table_key}\"]'") + + tile_gdf = gpd.GeoDataFrame( + { + "tile_id": [f"tile_x{ix}_y{iy}" for iy, ix in tile_indices], + "tile_y": tile_indices[:, 0], + "tile_x": tile_indices[:, 1], + "pixel_y0": pixel_bounds[:, 0], + "pixel_x0": pixel_bounds[:, 1], + "pixel_y1": pixel_bounds[:, 2], + "pixel_x1": pixel_bounds[:, 3], + "geometry": polys, + }, + geometry="geometry", + ) + sdata.shapes[shapes_key] = ShapesModel.parse(tile_gdf) + + sdata.tables[table_key].uns["spatialdata_attrs"] = { + "region": shapes_key, + "region_key": "grid_name", + "instance_key": "tile_id", + } + sdata.tables[table_key].obs["grid_name"] = pd.Categorical([shapes_key] * len(sdata.tables[table_key])) + sdata.tables[table_key].obs["tile_id"] = sdata.shapes[shapes_key].index + logger.info(f"- Saved tile grid as 'sdata.shapes[\"{shapes_key}\"]'") + + +def _to_gray_dask_yx(img_yxc: xr.DataArray, weights: tuple[float, float, float] = (0.2126, 0.7152, 0.0722)) -> da.Array: + """ + Convert multi-channel image to grayscale using luminance weights. + + Parameters + ---------- + img_yxc + Input image array with shape (y, x, c). + weights + RGB weights for luminance conversion. + + Returns + ------- + Grayscale image as dask array with shape (y, x). + """ + arr = img_yxc.data + if arr.ndim != 3: + raise ValueError(f"Expected image with shape `(y, x, c)`, found `{arr.shape}`.") + c = arr.shape[2] + if c == 1: + return arr[..., 0].astype(np.float32, copy=False) + rgb = arr[..., :3].astype(np.float32, copy=False) + w = da.from_array(np.asarray(weights, dtype=np.float32), chunks=(3,)) + gray = da.tensordot(rgb, w, axes=([2], [0])) + return gray.astype(np.float32, copy=False) + + +def _get_mask_from_labels(sdata: SpatialData, mask_key: str, scale: str) -> np.ndarray: + """ + Extract mask array from sdata.labels at the specified key and scale. + + Parameters + ---------- + sdata + SpatialData object. + mask_key + Key of the mask in sdata.labels. + scale + Scale level for processing. + + Returns + ------- + Mask array as numpy array with shape (y, x). + + """ + label_node = sdata.labels[mask_key] + mask_da = _get_element_data(label_node, scale, "label", mask_key) + + # Convert to numpy array if needed + if hasattr(mask_da, "compute"): + mask = np.asarray(mask_da.compute()) + elif hasattr(mask_da, "values"): + mask = np.asarray(mask_da.values) + else: + mask = np.asarray(mask_da) + + # Ensure 2D (y, x) shape - squeeze out any singleton dimensions + if mask.ndim > 2: + mask = mask.squeeze() + if mask.ndim != 2: + raise ValueError(f"Expected 2D mask with shape (y, x), got shape {mask.shape}") + return mask + + +def _detect_tissue_from_mask( + sdata: SpatialData, + image_key: str, + tile_indices: np.ndarray, + ty: int, + tx: int, + scale: str = "scale0", + tissue_mask_key: str | None = None, +) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + """ + Detect tissue regions from mask and classify tiles. + + Parameters + ---------- + sdata + SpatialData object. + image_key + Image key in sdata.images. + tile_indices + Tile indices array. + ty + Tile height. + tx + Tile width. + scale + Scale level for processing. + tissue_mask_key + Key of the tissue mask in sdata.labels to use. If None, tissue detection + will be performed and the mask will be added to sdata.labels. + + Returns + ------- + Tuple of (tissue_mask, background_mask, tissue_similarity, background_similarity). + """ + n_tiles = len(tile_indices) + + # If tissue_mask_key is provided, use existing mask from sdata.labels + if tissue_mask_key is None: + # Check if default mask key already exists, otherwise perform tissue detection + mask_key = f"{image_key}_tissue" + if mask_key not in sdata.labels: + # Perform tissue detection and save to sdata.labels + detect_tissue(sdata=sdata, image_key=image_key, scale=scale, inplace=True, new_labels_key=mask_key) + logger.info(f"- Saved tissue mask as 'sdata.labels[\"{mask_key}\"]'") + + mask = _get_mask_from_labels(sdata, mask_key, scale) + elif tissue_mask_key not in sdata.labels: + raise KeyError(f"Tissue mask key '{tissue_mask_key}' not found in sdata.labels") + else: + mask = _get_mask_from_labels(sdata, tissue_mask_key, scale) + if mask is None: + logger.warning("Tissue mask missing. Marking all tiles as tissue.") + t = np.ones(n_tiles, dtype=bool) + b = ~t + return t, b, np.ones(n_tiles, np.float32), np.zeros(n_tiles, np.float32) + + # Get image dimensions from the mask + H, W = mask.shape + + tissue = np.zeros(n_tiles, dtype=bool) + back = np.zeros(n_tiles, dtype=bool) + t_sim = np.zeros(n_tiles, dtype=np.float32) + b_sim = np.zeros(n_tiles, dtype=np.float32) + + for i, (iy, ix) in enumerate(tile_indices): + y0, y1 = iy * ty, min((iy + 1) * ty, H) + x0, x1 = ix * tx, min((ix + 1) * tx, W) + frac = float(np.mean(mask[y0:y1, x0:x1] > 0.0)) if (y1 > y0 and x1 > x0) else 0.0 + is_t = frac > 0.5 + tissue[i] = is_t + back[i] = not is_t + t_sim[i] = 1.0 if is_t else 0.0 + b_sim[i] = 0.0 if is_t else 1.0 + + return tissue, back, t_sim, b_sim + + +def _clean_sharpness_data(X: np.ndarray) -> np.ndarray: + """ + Clean sharpness data by handling inf/nan values and clipping outliers. + + Parameters + ---------- + X + Input sharpness data array. + + Returns + ------- + Cleaned data array with inf/nan values replaced and outliers clipped. + """ + Xc: np.ndarray = X.copy() + Xc[np.isinf(Xc)] = np.nan + for i in range(Xc.shape[1]): + col = Xc[:, i] + if np.any(np.isnan(col)): + med = np.nanmedian(col) + Xc[np.isnan(col), i] = med + lo, hi = np.percentile(Xc[:, i], [0.1, 99.9]) + clipped = np.clip(Xc[:, i], lo, hi) + Xc[:, i] = clipped + return Xc + + +def _detect_outliers_iqr(X_scaled: np.ndarray) -> np.ndarray: + """ + Detect outliers using Interquartile Range (IQR) method. + + Parameters + ---------- + X_scaled + Scaled sharpness data. + + Returns + ------- + Array with -1 for outliers, 1 for normal tiles. + """ + Q1 = np.percentile(X_scaled, 25, axis=0) + Q3 = np.percentile(X_scaled, 75, axis=0) + IQR = Q3 - Q1 + lower = Q1 - 1.5 * IQR + mask = np.any(X_scaled < lower, axis=1) + return np.where(mask, -1, 1) + + +def _detect_outliers_zscore(X_scaled: np.ndarray, threshold: float = 3.0) -> np.ndarray: + """ + Detect outliers using Z-score method. + + Parameters + ---------- + X_scaled + Scaled sharpness data. + threshold + Z-score threshold for outlier detection. + + Returns + ------- + Array with -1 for outliers, 1 for normal tiles. + """ + mask = np.any(X_scaled < -threshold, axis=1) + return np.where(mask, -1, 1) + + +def _detect_sharpness_outliers( + X: np.ndarray, method: str = "iqr", tissue_mask: np.ndarray | None = None, var_names: list[str] | None = None +) -> np.ndarray: + """ + Detect sharpness outliers using various methods. + + Parameters + ---------- + X + Sharpness data array. + method + Outlier detection method. + tissue_mask + Optional tissue mask for context-aware detection. + var_names + Variable names for metric identification. + + Returns + ------- + Array with -1 for outliers, 1 for normal tiles. + """ + Xc = _clean_sharpness_data(X) + if method == "tenengrad_tissue": + return _detect_tenengrad_tissue_outliers(Xc, tissue_mask, var_names) + if method == "pvalue": + return _detect_outliers_pvalue(Xc, tissue_mask, var_names)[0] + scaler = StandardScaler() + Xs = scaler.fit_transform(Xc) + if method == "iqr": + return _detect_outliers_iqr(Xs) + if method == "zscore": + return _detect_outliers_zscore(Xs) + raise ValueError(f"Unknown method '{method}'. Use 'iqr', 'zscore', 'tenengrad_tissue', or 'pvalue'.") + + +def _detect_tenengrad_tissue_outliers( + X: np.ndarray, tissue_mask: np.ndarray | None = None, var_names: list[str] | None = None +) -> np.ndarray: + """ + Detect outliers using Tenengrad metric with tissue context. + + Parameters + ---------- + X + Sharpness data array. + tissue_mask + Tissue mask for context-aware detection. + var_names + Variable names for metric identification. + + Returns + ------- + Array with outlier scores. + """ + if tissue_mask is None: + scaler = StandardScaler() + Xs = scaler.fit_transform(_clean_sharpness_data(X)) + return _detect_outliers_zscore(Xs) + + bg_mask = ~tissue_mask + if bg_mask.sum() == 0: + return np.zeros(len(X)) + ten_idx = None + if var_names is not None: + for i, n in enumerate(var_names): + if "tenengrad" in n.lower(): + ten_idx = i + break + if ten_idx is None: + scaler = StandardScaler() + Xs = scaler.fit_transform(_clean_sharpness_data(X)) + return _detect_outliers_zscore(Xs) + + t = X[tissue_mask, ten_idx] + b = X[bg_mask, ten_idx] + bmin, bmax = float(np.min(b)), float(np.max(b)) + rng = bmax - bmin + if rng <= 0: + mu_t, sd_t = float(np.mean(t)), float(np.std(t)) + 1e-10 + scores = np.clip((mu_t - t) / sd_t, 0, 1) + else: + norm = np.clip((t - bmin) / rng, 0, 1) + scores = 1.0 - norm + out = np.zeros(len(X)) + out[np.where(tissue_mask)[0]] = scores + return out + + +def _detect_outliers_pvalue( + X: np.ndarray, tissue_mask: np.ndarray | None = None, var_names: list[str] | None = None, alpha: float = 0.05 +) -> tuple[np.ndarray, np.ndarray]: + """ + Detect outliers using p-value method. + + Parameters + ---------- + X + Sharpness data array. + tissue_mask + Tissue mask for context-aware detection. + var_names + Variable names for metric identification. + alpha + Significance level for p-value threshold. + + Returns + ------- + Tuple of (outlier_labels, p_values). + """ + from scipy import stats + + if tissue_mask is None: + tissue_mask = np.ones(len(X), dtype=bool) + if var_names is None: + var_names = [f"metric_{i}" for i in range(X.shape[1])] + tX = X[tissue_mask] + bX = X[~tissue_mask] + if len(bX) < 10: + return np.ones(len(X), dtype=int), np.ones(len(X)) + P = np.ones((len(tX), len(var_names))) + for i in range(len(var_names)): + bg = bX[:, i] + mu, sd = float(np.mean(bg)), float(np.std(bg)) + if sd < 1e-10: + continue + P[:, i] = stats.norm.cdf(tX[:, i], loc=mu, scale=sd) + minP = np.min(P, axis=1) + fullP = np.ones(len(X)) + fullP[np.where(tissue_mask)[0]] = minP + out = np.where(fullP < alpha, -1, 1) + return out, fullP diff --git a/src/squidpy/experimental/im/_sharpness_metrics.py b/src/squidpy/experimental/im/_sharpness_metrics.py new file mode 100644 index 000000000..9063af293 --- /dev/null +++ b/src/squidpy/experimental/im/_sharpness_metrics.py @@ -0,0 +1,188 @@ +from __future__ import annotations + +from collections.abc import Callable +from enum import Enum + +import numba +import numpy as np +from numba import njit +from scipy.fft import fft2, fftfreq + +# One thread to avoid clashes with Dask +numba.set_num_threads(1) + + +MetricFn = Callable[[np.ndarray], np.ndarray] + + +class SharpnessMetric(str, Enum): + TENENGRAD = "tenengrad" + VAR_OF_LAPLACIAN = "var_of_laplacian" + VARIANCE = "variance" + FFT_HIGH_FREQ_ENERGY = "fft_high_freq_energy" + HAAR_WAVELET_ENERGY = "haar_wavelet_energy" + + +def _ensure_f32_2d(x: np.ndarray) -> np.ndarray: + if x.ndim != 2: + raise ValueError("block must be 2D") + return np.ascontiguousarray(x.astype(np.float32, copy=False)) + + +@njit(cache=True, fastmath=True) +def _clamp(v: int, lo: int, hi: int) -> int: + if v < lo: + return lo + if v > hi: + return hi + return v + + +@njit(cache=True, fastmath=True) +def _tenengrad_mean(block: np.ndarray) -> np.ndarray: + """Mean Tenengrad energy using Sobel 3×3.""" + h, w = block.shape + gxk = np.array([[-1.0, 0.0, 1.0], [-2.0, 0.0, 2.0], [-1.0, 0.0, 1.0]], dtype=np.float32) + gyk = np.array([[1.0, 2.0, 1.0], [0.0, 0.0, 0.0], [-1.0, -2.0, -1.0]], dtype=np.float32) + s = 0.0 + for i in range(h): + for j in range(w): + gx = 0.0 + gy = 0.0 + for di in range(-1, 2): + for dj in range(-1, 2): + ii = _clamp(i + di, 0, h - 1) + jj = _clamp(j + dj, 0, w - 1) + v = block[ii, jj] + gx += gxk[di + 1, dj + 1] * v + gy += gyk[di + 1, dj + 1] * v + s += gx * gx + gy * gy + mean_val = s / (h * w) + return np.full_like(block, mean_val, dtype=np.float32) + + +@njit(cache=True, fastmath=True) +def _laplacian_variance(block: np.ndarray) -> np.ndarray: + """Population variance of Laplacian response.""" + h, w = block.shape + lk = np.array([[0.0, 1.0, 0.0], [1.0, -4.0, 1.0], [0.0, 1.0, 0.0]], dtype=np.float32) + n = h * w + s = 0.0 + s2 = 0.0 + for i in range(h): + for j in range(w): + y = 0.0 + for di in range(-1, 2): + for dj in range(-1, 2): + ii = _clamp(i + di, 0, h - 1) + jj = _clamp(j + dj, 0, w - 1) + y += lk[di + 1, dj + 1] * block[ii, jj] + s += y + s2 += y * y + mean = s / n + # var = E[y^2] - (E[y])^2 + var = (s2 / n) - (mean * mean) + var_val = var if var > 0.0 else 0.0 + return np.full_like(block, var_val, dtype=np.float32) + + +@njit(cache=True, fastmath=True) +def _pop_variance(block: np.ndarray) -> np.ndarray: + """Population variance of pixel intensities.""" + h, w = block.shape + n = h * w + s = 0.0 + s2 = 0.0 + for i in range(h): + for j in range(w): + v = block[i, j] + s += v + s2 += v * v + mean = s / n + var = (s2 / n) - (mean * mean) + var_val = var if var > 0.0 else 0.0 + return np.full_like(block, var_val, dtype=np.float32) + + +def _fft_high_freq_energy(block: np.ndarray) -> np.ndarray: + x = _ensure_f32_2d(block).astype(np.float64, copy=False) + m = float(x.mean()) + s = float(x.std()) + x = (x - m) / s if s > 0.0 else (x - m) + + F = fft2(x) + mag2 = (F.real * F.real) + (F.imag * F.imag) + + h, w = x.shape + fy = fftfreq(h) + fx = fftfreq(w) + ry, rx = np.meshgrid(fy, fx, indexing="ij") + r = np.hypot(ry, rx) + mask = r > 0.1 + + total = float(mag2.sum()) + if not np.isfinite(total) or total <= 1e-12: + ratio = 0.0 + else: + hi = float(mag2[mask].sum()) + ratio = hi / total if np.isfinite(hi) else 0.0 + if ratio < 0.0: + ratio = 0.0 + if ratio > 1.0: + ratio = 1.0 + return np.full_like(block, ratio, dtype=np.float32) + + +def _haar_wavelet_energy(block: np.ndarray) -> np.ndarray: + """Detail-band (LH+HL+HH) energy ratio of a single-level Haar transform.""" + x = _ensure_f32_2d(block).astype(np.float64, copy=False) + m = float(x.mean()) + s = float(x.std()) + x = (x - m) / s if s > 0.0 else (x - m) + + h, w = x.shape + if h % 2 == 1: + x = np.vstack([x, x[-1:, :]]) + h += 1 + if w % 2 == 1: + x = np.hstack([x, x[:, -1:]]) + w += 1 + + cA_h = (x[::2, :] + x[1::2, :]) / 2.0 + cH_h = (x[::2, :] - x[1::2, :]) / 2.0 + + cA = (cA_h[:, ::2] + cA_h[:, 1::2]) / 2.0 # LL + cH = (cA_h[:, ::2] - cA_h[:, 1::2]) / 2.0 # LH + cV = (cH_h[:, ::2] + cH_h[:, 1::2]) / 2.0 # HL + cD = (cH_h[:, ::2] - cH_h[:, 1::2]) / 2.0 # HH + + total = float((cA * cA).sum() + (cH * cH).sum() + (cV * cV).sum() + (cD * cD).sum()) + if not np.isfinite(total) or total <= 1e-12: + ratio = 0.0 + else: + detail = float((cH * cH).sum() + (cV * cV).sum() + (cD * cD).sum()) + ratio = detail / total if np.isfinite(detail) else 0.0 + if ratio < 0.0: + ratio = 0.0 + if ratio > 1.0: + ratio = 1.0 + return np.full_like(block, ratio, dtype=np.float32) + + +_METRICS: dict[SharpnessMetric, MetricFn] = { + SharpnessMetric.TENENGRAD: lambda a: _tenengrad_mean(_ensure_f32_2d(a)), + SharpnessMetric.VAR_OF_LAPLACIAN: lambda a: _laplacian_variance(_ensure_f32_2d(a)), + SharpnessMetric.VARIANCE: lambda a: _pop_variance(_ensure_f32_2d(a)), + SharpnessMetric.FFT_HIGH_FREQ_ENERGY: _fft_high_freq_energy, + SharpnessMetric.HAAR_WAVELET_ENERGY: _haar_wavelet_energy, +} + + +def _get_sharpness_metric_function(metric: str | SharpnessMetric) -> MetricFn: + if isinstance(metric, str): + try: + metric = SharpnessMetric(metric.lower()) + except ValueError as e: + avail = ", ".join(m.value for m in SharpnessMetric) + raise ValueError(f"Unknown metric '{metric}'. Available: {avail}") from e + return _METRICS[metric] diff --git a/src/squidpy/experimental/im/_utils.py b/src/squidpy/experimental/im/_utils.py index 8075ac3ca..71f1a41d7 100644 --- a/src/squidpy/experimental/im/_utils.py +++ b/src/squidpy/experimental/im/_utils.py @@ -1,67 +1,126 @@ from __future__ import annotations -from typing import Literal +from typing import Any, Literal +import dask.array as da +import numpy as np import spatialdata as sd import xarray as xr -from spatialdata._logging import logger as logg - - -def _get_image_data( - sdata: sd.SpatialData, - image_key: str, - scale: str, +from shapely.geometry import Polygon +from spatialdata._logging import logger + +from squidpy._utils import _ensure_dim_order + + +class TileGrid: + def __init__( + self, + H: int, + W: int, + tile_size: Literal["auto"] | tuple[int, int] = "auto", + target_tiles: int = 100, + ): + self.H = int(H) + self.W = int(W) + if tile_size == "auto": + size = max(min(self.H // target_tiles, self.W // target_tiles), 100) + self.ty = int(size) + self.tx = int(size) + else: + self.ty = int(tile_size[0]) + self.tx = int(tile_size[1]) + self.tiles_y = (self.H + self.ty - 1) // self.ty + self.tiles_x = (self.W + self.tx - 1) // self.tx + + def indices(self) -> np.ndarray: + return np.array([[iy, ix] for iy in range(self.tiles_y) for ix in range(self.tiles_x)], dtype=int) + + def names(self) -> list[str]: + return [f"tile_x{ix}_y{iy}" for iy in range(self.tiles_y) for ix in range(self.tiles_x)] + + def bounds(self) -> np.ndarray: + b: list[list[int]] = [] + for iy in range(self.tiles_y): + for ix in range(self.tiles_x): + y0, x0 = iy * self.ty, ix * self.tx + y1 = (iy + 1) * self.ty if iy < self.tiles_y - 1 else self.H + x1 = (ix + 1) * self.tx if ix < self.tiles_x - 1 else self.W + b.append([y0, x0, y1, x1]) + return np.array(b, dtype=int) + + def centroids_and_polygons(self) -> tuple[np.ndarray, list[Polygon]]: + cents: list[list[float]] = [] + polys: list[Polygon] = [] + for y0, x0, y1, x1 in self.bounds(): + cy = (y0 + y1) / 2 + cx = (x0 + x1) / 2 + cents.append([cy, cx]) + polys.append(Polygon([(x0, y0), (x1, y0), (x1, y1), (x0, y1), (x0, y0)])) + return np.array(cents, dtype=float), polys + + def rechunk_and_pad(self, arr_yx: da.Array) -> da.Array: + if arr_yx.ndim != 2: + raise ValueError("Expected a 2D array shaped (y, x).") + pad_y = self.tiles_y * self.ty - int(arr_yx.shape[0]) + pad_x = self.tiles_x * self.tx - int(arr_yx.shape[1]) + a = arr_yx.rechunk((self.ty, self.tx)) + return da.pad(a, ((0, pad_y), (0, pad_x)), mode="edge") if (pad_y > 0 or pad_x > 0) else a + + def coarsen(self, arr_yx: da.Array, reduce: Literal["mean", "sum"] = "mean") -> da.Array: + reducer = np.mean if reduce == "mean" else np.sum + return da.coarsen(reducer, arr_yx, {0: self.ty, 1: self.tx}, trim_excess=False) + + +def _get_element_data( + element_node: Any, + scale: str | Literal["auto"], + element_type: str = "element", + element_key: str = "", ) -> xr.DataArray: """ - Extract image data from SpatialData object, handling both datatree and direct DataArray images. + Extract data array from a spatialdata element (image or label) node. + Supports multiscale and single-scale elements. Parameters ---------- - sdata : SpatialData - SpatialData object - image_key : str - Key in sdata.images - scale : str - Multiscale level, e.g. "scale0", or "auto" for the smallest available scale + element_node + The element node from sdata.images[key] or sdata.labels[key] + scale + Scale level to use, or "auto" for images (picks coarsest). + element_type + Type of element for error messages (e.g., "image", "label"). + element_key + Key of the element for error messages. Returns ------- - xr.DataArray - Image data in (c, y, x) format + xr.DataArray of the element data. """ - img_node = sdata.images[image_key] - - # Check if the image is a datatree (has multiple scales) or a direct DataArray - if hasattr(img_node, "keys"): - available_scales = list(img_node.keys()) + if hasattr(element_node, "keys"): # multiscale + available = list(element_node.keys()) + if not available: + raise ValueError(f"No scales for {element_type} {element_key!r}") if scale == "auto": - scale = available_scales[-1] - elif scale not in available_scales: - print(scale) - print(available_scales) - scale = available_scales[-1] - logg.warning(f"Scale '{scale}' not found, using available scale. Available scales: {available_scales}") - - img_da = img_node[scale].image - else: - # It's a direct DataArray (no scales) - img_da = img_node.image if hasattr(img_node, "image") else img_node - return _ensure_cyx(img_da) + def _idx(k: str) -> int: + num = "".join(ch for ch in k if ch.isdigit()) + return int(num) if num else -1 + chosen = max(available, key=_idx) + elif scale not in available: + logger.warning(f"Scale {scale!r} not found. Available: {available}") + # Try scale0 as fallback, otherwise use first available + chosen = "scale0" if "scale0" in available else available[0] + logger.info(f"Using scale {chosen!r}") + else: + chosen = scale -def _ensure_cyx(img_da: xr.DataArray) -> xr.DataArray: - """Ensure dims are (c, y, x). Adds a length-1 "c" if missing.""" - dims = list(img_da.dims) - if "y" not in dims or "x" not in dims: - raise ValueError(f'Expected dims to include "y" and "x". Found dims={dims}') + data = element_node[chosen].image + else: # single-scale + data = element_node.image if hasattr(element_node, "image") else element_node - # Handle case where dims are (c, y, x) - keep as is - if "c" in dims: - return img_da if dims[0] == "c" else img_da.transpose("c", "y", "x") - # If no "c" dimension, add one - return img_da.expand_dims({"c": [0]}).transpose("c", "y", "x") + return data def _flatten_channels( @@ -104,21 +163,21 @@ def _flatten_channels( # If user explicitly specifies multichannel, always use mean if channel_format == "multichannel": - logg.info(f"Converting {n_channels}-channel image to greyscale using mean across all channels") + logger.info(f"Converting {n_channels}-channel image to greyscale using mean across all channels") return img.mean(dim="c") # Handle explicit RGB specification if channel_format == "rgb": if n_channels != 3: raise ValueError(f"Cannot treat {n_channels}-channel image as RGB (requires exactly 3 channels)") - logg.info("Converting RGB image to greyscale using luminance formula") + logger.info("Converting RGB image to greyscale using luminance formula") weights = xr.DataArray([0.299, 0.587, 0.114], dims=["c"], coords={"c": img.coords["c"]}) return (img * weights).sum(dim="c") elif channel_format == "rgba": if n_channels != 4: raise ValueError(f"Cannot treat {n_channels}-channel image as RGBA (requires exactly 4 channels)") - logg.info("Converting RGBA image to greyscale using luminance formula (ignoring alpha)") + logger.info("Converting RGBA image to greyscale using luminance formula (ignoring alpha)") weights = xr.DataArray([0.299, 0.587, 0.114, 0.0], dims=["c"], coords={"c": img.coords["c"]}) return (img * weights).sum(dim="c") diff --git a/src/squidpy/experimental/pl/__init__.py b/src/squidpy/experimental/pl/__init__.py new file mode 100644 index 000000000..6d51786fa --- /dev/null +++ b/src/squidpy/experimental/pl/__init__.py @@ -0,0 +1,5 @@ +from __future__ import annotations + +from ._qc_sharpness import qc_sharpness + +__all__ = ["qc_sharpness"] diff --git a/src/squidpy/experimental/pl/_qc_sharpness.py b/src/squidpy/experimental/pl/_qc_sharpness.py new file mode 100644 index 000000000..903668907 --- /dev/null +++ b/src/squidpy/experimental/pl/_qc_sharpness.py @@ -0,0 +1,204 @@ +from __future__ import annotations + +from typing import Any + +import matplotlib.pyplot as plt +import numpy as np +from scipy.stats import gaussian_kde +from spatialdata import SpatialData +from spatialdata._logging import logger as logg + +from squidpy.experimental.im._sharpness_metrics import SharpnessMetric + + +def qc_sharpness( + sdata: SpatialData, + image_key: str, + metrics: SharpnessMetric | list[SharpnessMetric] | None = None, + figsize: tuple[int, int] | None = None, + return_fig: bool = False, + **kwargs: Any, +) -> plt.Figure | None: + """ + Plot a summary view of raw sharpness metrics from qc_sharpness results. + + Automatically scans adata.uns for calculated metrics and plots the raw sharpness values. + Creates a multi-panel plot: one panel per calculated sharpness metric. + Each panel shows: spatial view, KDE plot, and statistics. + + Parameters + ---------- + sdata : SpatialData + SpatialData object containing QC results. + image_key : str + Image key used in qc_sharpness function. + metrics : SharpnessMetric or list of SharpnessMetric, optional + Specific metrics to plot. If None, plots all calculated sharpness metrics. + Use SharpnessMetric enum values. + figsize : tuple, optional + Figure size (width, height). Auto-calculated if None. + return_fig : bool + Whether to return the figure object. Default is False. + **kwargs + Additional arguments passed to render_shapes(). + + Returns + ------- + fig : matplotlib.Figure or None + The matplotlib figure object if return_fig=True, otherwise None. + """ + + # Expected keys + table_key = f"qc_img_{image_key}_sharpness" + shapes_key = f"qc_img_{image_key}_sharpness_grid" + + if table_key not in sdata.tables: + raise ValueError(f"No QC data found for image '{image_key}'. Run sq.exp.im.qc_sharpness() first.") + + adata = sdata.tables[table_key] + + # Check if qc_sharpness metadata exists + if "qc_sharpness" not in adata.uns: + raise ValueError("No qc_sharpness metadata found. Run sq.exp.im.qc_sharpness() first.") + + # Get calculated metrics from metadata + calculated_metrics = adata.uns["qc_sharpness"]["metrics"] + + if not calculated_metrics: + raise ValueError("No sharpness metrics found in metadata.") + + # Filter for specific metrics if requested + if metrics is not None: + # Convert metrics to list if single metric provided + metrics_list = metrics if isinstance(metrics, list) else [metrics] + # Convert enum to string names using the same logic as main function + metrics_to_plot = [] + for metric in metrics_list: + metric_name = metric.name.lower() if isinstance(metric, SharpnessMetric) else metric + if metric_name not in calculated_metrics: + raise ValueError(f"Metric '{metric_name}' not found. Available: {calculated_metrics}") + metrics_to_plot.append(metric_name) + else: + metrics_to_plot = calculated_metrics + + logg.info(f"Plotting {len(metrics_to_plot)} sharpness metrics: {metrics_to_plot}") + + # Create subplots: 3 columns, one row per metric + n_metrics = len(metrics_to_plot) + ncols = 3 # spatial, histogram, stats + nrows = n_metrics + + if figsize is None: + figsize = (12, 4 * nrows) # 12 width for 3 columns, 4 height per row + + fig, axes = plt.subplots(nrows, ncols, figsize=figsize) + + # Ensure axes is always 2D array for consistent indexing + if nrows == 1: + axes = axes.reshape(1, -1) + if ncols == 1: + axes = axes.reshape(-1, 1) + + # Plot each metric + for i, metric_name in enumerate(metrics_to_plot): + # Find the metric in adata.var_names and get raw values + var_name = f"sharpness_{metric_name}" + if var_name not in adata.var_names: + logg.warning(f"Variable '{var_name}' not found in adata.var_names. Skipping.") + continue + + # Get metric index and raw values + metric_idx = list(adata.var_names).index(var_name) + raw_values = adata.X[:, metric_idx] + + # Get axes for this metric (row i, columns 0, 1, 2) + ax_spatial = axes[i, 0] + ax_hist = axes[i, 1] + ax_stats = axes[i, 2] + + # Panel 1: Spatial plot + try: + ( + sdata.pl.render_shapes(shapes_key, color=var_name, **kwargs).pl.show( + ax=ax_spatial, title=f"{metric_name.replace('_', ' ').title()}" + ) + ) + except (ValueError, KeyError, AttributeError) as e: + logg.warning(f"Error plotting spatial view for {metric_name}: {e}") + ax_spatial.text( + 0.5, 0.5, f"Error plotting\n{metric_name}", ha="center", va="center", transform=ax_spatial.transAxes + ) + ax_spatial.set_title(f"{metric_name.replace('_', ' ').title()}") + + # Panel 2: KDE plot (overlaid if tissue/background classification available) + # Create x-axis range for KDE + x_min, x_max = float(np.min(raw_values)), float(np.max(raw_values)) + x_range = np.linspace(x_min, x_max, 200) + + if "is_tissue" in adata.obs: + # Convert categorical to boolean for filtering + is_tissue = adata.obs["is_tissue"].astype(str) == "True" + tissue_values = raw_values[is_tissue] + background_values = raw_values[~is_tissue] + + # Create KDE plots for both tissue and background + if len(background_values) > 1: + kde_background = gaussian_kde(background_values) + density_background = kde_background(x_range) + ax_hist.plot(x_range, density_background, label="Background", alpha=0.7) + ax_hist.fill_between(x_range, density_background, alpha=0.3) + + if len(tissue_values) > 1: + kde_tissue = gaussian_kde(tissue_values) + density_tissue = kde_tissue(x_range) + ax_hist.plot(x_range, density_tissue, label="Tissue", alpha=0.7) + ax_hist.fill_between(x_range, density_tissue, alpha=0.3) + + ax_hist.legend() + + elif len(raw_values) > 1: + kde = gaussian_kde(raw_values) + density = kde(x_range) + ax_hist.plot(x_range, density, alpha=0.7) + ax_hist.fill_between(x_range, density, alpha=0.3) + + ax_hist.set_xlabel(f"{metric_name.replace('_', ' ').title()}") + ax_hist.set_ylabel("Density") + ax_hist.set_title("Distribution") + ax_hist.grid(True, alpha=0.3) + + # Panel 3: Statistics + ax_stats.axis("off") + stats_text = f""" + Raw {metric_name.replace("_", " ").title()} Statistics: + + Count: {len(raw_values):,} + Mean: {np.mean(raw_values):.4f} + Std: {np.std(raw_values):.4f} + Min: {np.min(raw_values):.4f} + Max: {np.max(raw_values):.4f} + + Percentiles: + 5%: {np.percentile(raw_values, 5):.4f} + 25%: {np.percentile(raw_values, 25):.4f} + 50%: {np.percentile(raw_values, 50):.4f} + 75%: {np.percentile(raw_values, 75):.4f} + 95%: {np.percentile(raw_values, 95):.4f} + + Non-zero: {np.count_nonzero(raw_values):,} + Zero: {np.sum(raw_values == 0):,} + """ + + ax_stats.text( + 0.05, + 0.95, + stats_text.strip(), + transform=ax_stats.transAxes, + fontsize=9, + verticalalignment="top", + fontfamily="monospace", + ) + + plt.tight_layout() + + return fig if return_fig else None diff --git a/tests/_images/QCSharpness_calc_qc_sharpness.png b/tests/_images/QCSharpness_calc_qc_sharpness.png new file mode 100644 index 000000000..46971dccb Binary files /dev/null and b/tests/_images/QCSharpness_calc_qc_sharpness.png differ diff --git a/tests/_images/QCSharpness_plot_qc_sharpness.png b/tests/_images/QCSharpness_plot_qc_sharpness.png new file mode 100644 index 000000000..aa3edb35d Binary files /dev/null and b/tests/_images/QCSharpness_plot_qc_sharpness.png differ diff --git a/tests/experimental/test_detect_tissue.py b/tests/experimental/test_detect_tissue.py index 7a54d41c5..2410beba9 100644 --- a/tests/experimental/test_detect_tissue.py +++ b/tests/experimental/test_detect_tissue.py @@ -1,5 +1,3 @@ -"""Test for experimental tissue detection.""" - from __future__ import annotations import spatialdata_plot as sdp diff --git a/tests/experimental/test_qc_sharpness.py b/tests/experimental/test_qc_sharpness.py new file mode 100644 index 000000000..bdf56f896 --- /dev/null +++ b/tests/experimental/test_qc_sharpness.py @@ -0,0 +1,45 @@ +from __future__ import annotations + +import spatialdata_plot as sdp + +import squidpy as sq +from tests.conftest import PlotTester, PlotTesterMeta + +_ = sdp + + +class TestQCSharpness(PlotTester, metaclass=PlotTesterMeta): + def test_plot_calc_qc_sharpness(self): + """Test QC sharpness on Visium H&E dataset.""" + sdata = sq.datasets.visium_hne_sdata() + + sq.experimental.im.qc_sharpness( + sdata, + image_key="hne", + # Only one metric for speed + metrics=[sq.experimental.im.SharpnessMetric.TENENGRAD], + ) + + ( + sdata.pl.render_images() + .pl.render_shapes( + "qc_img_hne_sharpness_grid", color="sharpness_outlier", groups="True", palette="red", fill_alpha=1.0 + ) + .pl.show() + ) + + def test_plot_plot_qc_sharpness(self): + """Test QC sharpness on Visium H&E dataset.""" + sdata = sq.datasets.visium_hne_sdata() + + sq.experimental.im.qc_sharpness( + sdata, + image_key="hne", + # Only one metric for speed + metrics=[sq.experimental.im.SharpnessMetric.TENENGRAD], + ) + + sq.experimental.pl.qc_sharpness( + sdata, + image_key="hne", + )