diff --git a/pyproject.toml b/pyproject.toml index e8968966..34eb9cfa 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,16 +37,20 @@ authors = [ {name = "Giovanni Palla"}, {name = "Michal Klein"}, {name = "Hannah Spitzer"}, + {name = "Tim Treis"}, + {name = "Laurens Lehner"}, + {name = "Selman Ozleyen"}, ] maintainers = [ - {name = "Giovanni Palla", email = "giovanni.palla@helmholtz-muenchen.de"}, - {name = "Michal Klein", email = "michal.klein@helmholtz-muenchen.de"}, - {name = "Tim Treis", email = "tim.treis@helmholtz-muenchen.de"} + {name = "Tim Treis", email = "tim.treis@helmholtz-munich.de"}, + {name = "Selman Ozleyen", email = "selman.ozleyen@helmholtz-munich.de"} ] dependencies = [ "aiohttp>=3.8.1", "anndata>=0.9", + "spatialdata>=0.2.5", + "spatialdata-plot", "cycler>=0.11.0", "dask-image>=0.5.0", "dask[array]>=2021.02.0,<=2024.11.2", @@ -61,7 +65,7 @@ dependencies = [ "pandas>=2.1.0", "Pillow>=8.0.0", "scanpy>=1.9.3", - "scikit-image>=0.20", + "scikit-image>=0.25", # due to https://github.com/scikit-image/scikit-image/issues/6850 breaks rescale ufunc "scikit-learn>=0.24.0", "statsmodels>=0.12.0", @@ -70,21 +74,27 @@ dependencies = [ "tqdm>=4.50.2", "validators>=0.18.2", "xarray>=2024.10.0", - "zarr>=2.6.1,<3.0.0", - "spatialdata>=0.2.5", + "zarr>=2.6.1,<3.0.0", "imagecodecs>=2025.8.2,<2026", ] [project.optional-dependencies] dev = [ "pre-commit>=3.0.0", "hatch>=1.9.0", + "jupyterlab", + "notebook", + "ipykernel", + "ipywidgets", + "jupytext", + "pytest", + "pytest-cov", + "ruff", ] test = [ "scanpy[leiden]", "pytest>=7", "pytest-xdist>=3", "pytest-mock>=3.5.0", - # Just for VS Code "pytest-cov>=4", "coverage[toml]>=7", "pytest-timeout>=2.1.0", @@ -281,4 +291,41 @@ exclude_lines = [ show_missing = true precision = 2 skip_empty = true -sort = "Miss" \ No newline at end of file +sort = "Miss" + +[tool.pixi.workspace] +channels = ["conda-forge"] +platforms = ["osx-arm64", "linux-64"] + +[tool.pixi.dependencies] +python = ">=3.11" + +[tool.pixi.pypi-dependencies] +squidpy = { path = ".", editable = true } + +# for gh-actions +[tool.pixi.feature.py311.dependencies] +python = "3.11.*" + +[tool.pixi.feature.py313.dependencies] +python = "3.13.*" + +[tool.pixi.environments] +# 3.11 lane (for gh-actions) +dev-py311 = { features = ["dev", "test", "py311"], solve-group = "py311" } +docs-py311 = { features = ["docs", "py311"], solve-group = "py311" } + +# 3.12 lane +default = { features = ["py313"], solve-group = "py313" } +dev-py313 = { features = ["dev", "test", "py313"], solve-group = "py313" } +docs-py313 = { features = ["docs", "py313"], solve-group = "py313" } +test-py313 = { features = ["test", "py313"], solve-group = "py313" } + +[tool.pixi.tasks] +lab = "jupyter lab" +kernel-install = "python -m ipykernel install --user --name pixi-dev --display-name \"sdata-plot (dev)\"" +test = "pytest -v --color=yes --tb=short --durations=10" +lint = "ruff check ." +format = "ruff format ." +pre-commit-install = "pre-commit install" +pre-commit = "pre-commit run" \ No newline at end of file diff --git a/src/squidpy/__init__.py b/src/squidpy/__init__.py index 5fb2b848..1a56e7a6 100644 --- a/src/squidpy/__init__.py +++ b/src/squidpy/__init__.py @@ -3,7 +3,7 @@ from importlib import metadata from importlib.metadata import PackageMetadata -from squidpy import datasets, gr, im, pl, read, tl +from squidpy import datasets, exp, gr, im, pl, read, tl try: md: PackageMetadata = metadata.metadata(__name__) @@ -14,3 +14,5 @@ md = None # type: ignore[assignment] del metadata, md + +__all__ = ["datasets", "exp", "gr", "im", "pl", "read", "tl"] diff --git a/src/squidpy/_utils.py b/src/squidpy/_utils.py index 3d1f26b8..99f1b134 100644 --- a/src/squidpy/_utils.py +++ b/src/squidpy/_utils.py @@ -16,6 +16,8 @@ import joblib as jl import numba import numpy as np +import spatialdata as sd +from spatialdata.models import Image2DModel, Labels2DModel __all__ = ["singledispatchmethod", "Signal", "SigQueue", "NDArray", "NDArrayA"] @@ -347,3 +349,27 @@ def new_func2(*args: Any, **kwargs: Any) -> Any: else: raise TypeError(repr(type(reason))) + + +def _get_scale_factors( + element: Image2DModel | Labels2DModel, +) -> list[float]: + """ + Get the scale factors of an image or labels. + """ + if not hasattr(element, "keys"): + return [] # element isn't a datatree -> single scale + + shapes = [_yx_from_shape(element[scale].image.shape) for scale in element.keys()] + + factors: list[float] = [(y0 / y1 + x0 / x1) / 2 for (y0, x0), (y1, x1) in zip(shapes, shapes[1:], strict=False)] + return [int(f) for f in factors] + + +def _yx_from_shape(shape: tuple[int, ...]) -> tuple[int, int]: + if len(shape) == 2: # (y, x) + return shape[0], shape[1] + if len(shape) == 3: # (c, y, x) + return shape[1], shape[2] + + raise ValueError(f"Unsupported shape {shape}. Expected (y, x) or (c, y, x).") diff --git a/src/squidpy/exp/__init__.py b/src/squidpy/exp/__init__.py new file mode 100644 index 00000000..d2e7760a --- /dev/null +++ b/src/squidpy/exp/__init__.py @@ -0,0 +1,11 @@ +"""Experimental module for Squidpy. + +This module contains experimental features that are still under development. +These features may change or be removed in future releases. +""" + +from __future__ import annotations + +from .im._detect_tissue import detect_tissue + +__all__ = ["detect_tissue"] diff --git a/src/squidpy/exp/im/__init__.py b/src/squidpy/exp/im/__init__.py new file mode 100644 index 00000000..52e3164f --- /dev/null +++ b/src/squidpy/exp/im/__init__.py @@ -0,0 +1,5 @@ +from __future__ import annotations + +from ._detect_tissue import detect_tissue + +__all__ = ["detect_tissue"] diff --git a/src/squidpy/exp/im/_detect_tissue.py b/src/squidpy/exp/im/_detect_tissue.py new file mode 100644 index 00000000..c6d769b2 --- /dev/null +++ b/src/squidpy/exp/im/_detect_tissue.py @@ -0,0 +1,562 @@ +from __future__ import annotations + +import enum +from dataclasses import dataclass +from typing import Any + +import numpy as np +import spatialdata as sd +import xarray as xr +from dask_image.ndinterp import affine_transform as da_affine +from scipy import ndimage +from skimage import filters, measure +from skimage.filters import gaussian, threshold_otsu +from skimage.morphology import binary_closing, disk, remove_small_holes +from skimage.segmentation import felzenszwalb +from skimage.util import img_as_float +from spatialdata._logging import logger as logg +from spatialdata.models import Labels2DModel +from spatialdata.transformations import get_transformation + +from squidpy._utils import _get_scale_factors, _yx_from_shape + +from ._utils import _flatten_channels, _get_image_data + + +class DETECT_TISSUE_METHOD(enum.Enum): + OTSU = enum.auto() + FELZENSZWALB = enum.auto() + + +@dataclass(slots=True) +class BackgroundDetectionParams: + """ + Which corners are background, and how large the corner boxes should be. + If no corners are flagged True, the orientation falls back to 'bright background'. + """ + + ymin_xmin_is_bg: bool = True + ymax_xmin_is_bg: bool = True + ymin_xmax_is_bg: bool = True + ymax_xmax_is_bg: bool = True + corner_size_pct: float = 0.01 # size of each corner box as fraction of img height/width + + @property + def any_corner(self) -> bool: + return any((self.ymin_xmin_is_bg, self.ymax_xmin_is_bg, self.ymin_xmax_is_bg, self.ymax_xmax_is_bg)) + + +@dataclass(slots=True) +class FelzenszwalbParams: + """ + Size-aware superpixel defaults for felzenszwalb segmentation. + """ + + grid_rows: int = 100 # ~desired grid, only for scale heuristics + grid_cols: int = 100 + sigma_frac: float = 0.008 # blur = this * short side (clipped to [1,5] px) + scale_coef: float = 0.25 # scale = coef * target_area + min_size_coef: float = 0.20 # min_size = coef * target_area + + +def detect_tissue( + sdata: sd.SpatialData, + image_key: str, + scale: str = "auto", + method: DETECT_TISSUE_METHOD | str = DETECT_TISSUE_METHOD.OTSU, + corners_are_background: bool = True, + close_holes_smaller_than_frac: float = 0.0001, + mask_smoothing_cycles: int = 0, + new_labels_key: str | None = None, + inplace: bool = True, + **kwargs: Any, +) -> xr.DataArray: + """ + Detect tissue and return a boolean mask (y,x) where True = specimen. + + Parameters + ---------- + sdata : sd.SpatialData + SpatialData object containing the image + image_key : str + Key of the image in sdata.images + scale : str, default "auto" + - If a specific scale key is provided (e.g., "scale0", "scale1", ...), + that image scale is used verbatim. + - If "auto": pick the smallest available scale. If that smallest scale + exceeds the pixel threshold, it is further downscaled to be under the + threshold for calculations. + method : str or DETECT_TISSUE_METHOD, default OTSU + Method to use ("otsu" or "felzenszwalb") + corners_are_background : bool, default True + Whether all corners should be treated as background + close_holes_smaller_than_frac : float, default 0.0001 + Holes smaller than this fraction of the image area will be closed. + mask_smoothing_cycles : int, default 0 + Number of cycles of (2D) morphological closing to apply to the mask + new_labels_key : str | None, default None + Key to store the new labels in the SpatialData object + inplace: bool, default True + Whether to store the new labels in the SpatialData object or return the mask. + If the mask is saved to the SpatialData object, it will inherit the scale_factors + of the image, if present. + **kwargs + Optional keyword arguments: + + channel_format : {"infer", "rgb", "rgba", "multichannel"}, default "infer" + How to interpret image channels for grey conversion: + - "infer": Auto-detect (3 ch → RGB luminance, others → mean) + - "rgb": Force RGB treatment (requires exactly 3 channels) + - "rgba": Force RGBA treatment (requires exactly 4 channels) + - "multichannel": Force mean across all channels + + background_detection_params : BackgroundDetectionParams, optional + Custom background detection configuration. If provided, overrides + the `corners_are_background` parameter. Default creates config + based on `corners_are_background`. + + felzenszwalb_params : FelzenszwalbParams, optional + Felzenszwalb segmentation parameters (only used when method="FELZENSZWALB"). + Default: FelzenszwalbParams(grid_rows=50, grid_cols=50, ...) + + min_specimen_area_frac : float, default 0.1 + Minimum area of a specimen as fraction of image area. + + n_samples : int | None, default None + Number of specimens to keep. If provided, the n_samples largest components will be kept. + If not provided, the specimens will be filtered by area using Otsu thresholding on log10(area). + + auto_max_pixels : int, default 2_000_000 + Target maximum number of pixels (H*W) for the image when scale="auto". + + + Returns + ------- + xr.DataArray + Boolean mask with shape (y, x) where True indicates tissue + + Notes + ----- + This function uses a simple pipeline: + - OTSU: Global Otsu thresholding (with background orientation from corners) + - FELZENSZWALB: Felzenszwalb superpixels -> per-superpixel Otsu (with background orientation) + Both methods are followed by morphology to solidify, area-based filtering to keep only real specimens, + and optional mask smoothing to refine boundaries. + """ + + # Set up background detection + background_detection_params = kwargs.get("background_detection_params", None) + if background_detection_params is None: + background_detection_params = BackgroundDetectionParams( + ymin_xmin_is_bg=corners_are_background, + ymax_xmin_is_bg=corners_are_background, + ymin_xmax_is_bg=corners_are_background, + ymax_xmax_is_bg=corners_are_background, + ) + + # Convert string method to enum + if isinstance(method, str): + try: + method = DETECT_TISSUE_METHOD[method.upper()] + except KeyError as e: + raise ValueError("method must be 'otsu' or 'felzenszwalb'") from e + + manual_scale = scale.lower() != "auto" + + if manual_scale: + # Respect the user's scale verbatim + img_src = _get_image_data(sdata, image_key, scale=scale) + else: + img_src = _get_image_data(sdata, image_key, scale="auto") + + img_src_h, img_src_w = _yx_from_shape(img_src.shape) + n_source_pixels = img_src_h * img_src_w + + # 1) deal with channel dimension + img_grey: xr.DataArray = _flatten_channels(img=img_src, channel_format=kwargs.get("channel_format", "infer")) + + # decide working resolution + auto_max_pixels = kwargs.get("auto_max_pixels", 5_000_000) + need_downscale = (not manual_scale) and (n_source_pixels > auto_max_pixels) + + if need_downscale: + # Compute the array via Dask (if dask-backed) and show a progress bar + logg.info("Downscaling for faster computation.") + img_grey = _downscale_with_dask(img_grey=img_grey, target_pixels=auto_max_pixels) + else: + # No additional downscaling; use the smallest scale (or manual scale) as-is + img_grey = img_grey.values # may trigger compute without explicit progress bar + + # 2) first-pass foreground + if method == DETECT_TISSUE_METHOD.OTSU: + img_fg_mask = _segment_otsu(img_grey=img_grey, params=background_detection_params) + elif method == DETECT_TISSUE_METHOD.FELZENSZWALB: + labels = _segment_felzenszwalb( + img_grey=img_grey, + params=kwargs.get("felzenszwalb_params", FelzenszwalbParams()), + ) + img_fg_mask = _mask_from_labels_via_corners( + img_grey=img_grey, labels=labels, params=background_detection_params + ) + else: + raise ValueError(f"Method {method} not implemented") + + # 3) solidify + if close_holes_smaller_than_frac > 0: + img_fg_mask = _make_solid(img_fg_mask, close_holes_smaller_than_frac) + + # 4) keep only specimen-sized components (Otsu on areas) + img_fg_mask = _filter_by_area( + mask=img_fg_mask, + min_specimen_area_frac=kwargs.get("min_specimen_area_frac", 0.01), + n_samples=kwargs.get("n_samples", None), + ) + + # 5) smooth mask boundaries (optional) + img_fg_mask = _smooth_mask(img_fg_mask, mask_smoothing_cycles) + + # 6) Upscale to full resolution of the source image + target_shape = _get_target_upscale_shape(sdata, image_key) + scale_matrix = _get_scaling_matrix(img_fg_mask.shape, target_shape) + img_fg_mask_upscaled = da_affine( + img_fg_mask, + matrix=scale_matrix, + offset=(0.0, 0.0), + output_shape=target_shape, + order=0, + mode="constant", + cval=0, + output=np.int32, + ) + + if inplace: + if new_labels_key is None: + new_labels_key = f"{image_key}_tissue" + + source_scale_factors = _get_scale_factors(sdata.images[image_key]) + + sdata.labels[new_labels_key] = Labels2DModel.parse( + data=img_fg_mask_upscaled, + dims=("y", "x"), + transformations=get_transformation(sdata.images[image_key], get_all=True), + scale_factors=source_scale_factors, + ) + + return None + + return np.array(img_fg_mask_upscaled) + + +def _get_scaling_matrix(current_shape: tuple[int, int], target_shape: tuple[int, int]) -> np.ndarray: + """ + Get the scaling matrix for upscaling the mask back to the original image size. + """ + scale_y = 1 / (target_shape[0] / current_shape[0]) + scale_x = 1 / (target_shape[1] / current_shape[1]) + return np.array([[scale_y, 0.0], [0.0, scale_x]], dtype=float) + + +def _get_target_upscale_shape( + sdata: sd.SpatialData, + image_key: str, +) -> tuple[int, int]: + """ + Get the target shape for upscaling the mask back to the original image size. + """ + if not hasattr(sdata.images[image_key], "keys"): + return _yx_from_shape(sdata.images[image_key].shape) + + target_scale = list(sdata.images[image_key].keys())[0] + + return _yx_from_shape(sdata.images[image_key][target_scale].image.shape) + + +def _downscale_with_dask(img_grey: xr.DataArray, target_pixels: int) -> np.ndarray: + """ + Downscale (y,x) with Dask-backed xarray.coarsen (mean) until H*W <= target_pixels. + Returns a NumPy array of the *downscaled* image and its shape. Shows a Dask ProgressBar. + """ + img_grey_h, img_grey_w = img_grey.shape + n_source_pixels = img_grey_h * img_grey_w + if n_source_pixels <= target_pixels: + # Nothing to do; still compute lazily with progress bar + return _dask_compute(_ensure_dask(img_grey)) + + # Desired continuous scale + scale = float(np.sqrt(target_pixels / float(n_source_pixels))) # 0 < s < 1 + target_h = max(1, int(img_grey_h * scale)) + target_w = max(1, int(img_grey_w * scale)) + + # Integer coarsen factors (mean-pooling); ensure we don't exceed target + coarsen_factor_y = max(1, int(np.ceil(img_grey_h / target_h))) + coarsen_factor_x = max(1, int(np.ceil(img_grey_w / target_w))) + + # Ensure Dask backing (if not already) + img_grey_small_da = ( + _ensure_dask(img_grey) + .coarsen(y=coarsen_factor_y, x=coarsen_factor_x, boundary="trim") + .mean() # anti-aliased downscale + ) + + # Compute the *downscaled* array only + img_grey_small = _dask_compute(img_grey_small_da) + + return np.asarray(img_grey_small) + + +def _ensure_dask(da: xr.DataArray) -> xr.DataArray: + """ + Ensure the DataArray is Dask-backed (chunked). If it's already Dask, return as-is. + """ + try: + import dask.array as dask_array + + if isinstance(da.data, dask_array.Array): + return da + # Chunk to reasonable tiles; adjust if you have known tile sizes + return da.chunk({"y": 2048, "x": 2048}) + except ImportError: + # Dask not available; just return original (compute will be eager) + return da + + +def _dask_compute(img_grey_da: xr.DataArray) -> np.ndarray: + """ + Compute an xarray DataArray (possibly Dask-backed) to a NumPy array with a Dask ProgressBar if available. + """ + result: np.ndarray + try: + import dask.array as dask_array + from dask.diagnostics import ProgressBar + + if isinstance(img_grey_da.data, dask_array.Array): + with ProgressBar(): + result = img_grey_da.data.compute() + result = img_grey_da.values + except ImportError: + result = img_grey_da.values + return result + + +def _segment_otsu(img_grey: np.ndarray, params: BackgroundDetectionParams) -> np.ndarray: + """ + Otsu binarization with orientation from background corners: + - If corners (flagged) are brighter than global median -> foreground is darker (I <= t) + - Else -> foreground is brighter (I >= t) + If no corners are flagged, assume bright background (common for brightfield). + """ + img_f = img_as_float(img_grey) + t = threshold_otsu(img_f) + bright_bg = _background_is_bright(img_f, params) + return np.array((img_f <= t) if bright_bg else (img_f >= t)) + + +def _segment_felzenszwalb(img_grey: np.ndarray, params: FelzenszwalbParams) -> np.ndarray: + img_grey_h, img_grey_w = img_grey.shape + + # Parameters computed on the image resolution + short = min(img_grey_h, img_grey_w) + sigma = float(np.clip(params.sigma_frac * short, 1.0, 5.0)) + img_s = img_as_float(gaussian(img_grey, sigma=sigma)) + + target_regions = max(1, params.grid_rows * params.grid_cols) + target_area = (img_grey_h * img_grey_w) / target_regions + scale = float(max(1.0, params.scale_coef * target_area)) + min_size = int(max(1, params.min_size_coef * target_area)) + + return np.array( + felzenszwalb( + img_s, + scale=scale, + sigma=sigma, + min_size=min_size, + channel_axis=None, + ).astype(np.int32) + ) + + +def _mask_from_labels_via_corners( + img_grey: np.ndarray, labels: np.ndarray, params: BackgroundDetectionParams +) -> np.ndarray: + """ + Turn superpixels into a mask via Otsu on per-label mean intensity, oriented by corners. + """ + labels = labels.astype(np.int32) + n_labels = int(labels.max()) + if n_labels == 0: + return np.zeros_like(img_grey, bool) + + labels_flat = labels.ravel() + img_grey_flat = img_as_float(img_grey).ravel() + + counts = np.bincount(labels_flat, minlength=n_labels + 1).astype(np.float64) + sums = np.bincount(labels_flat, weights=img_grey_flat, minlength=n_labels + 1) + means = np.zeros(n_labels + 1, dtype=np.float64) + nz = counts > 0 + means[nz] = sums[nz] / counts[nz] + + valid_means = means[1:][means[1:] > 0] + thr = threshold_otsu(valid_means) if valid_means.size > 1 else float(valid_means.min()) - 1.0 + + bright_bg = _background_is_bright(img_as_float(img_grey), params) + keep = (means <= thr) if bright_bg else (means >= thr) + keep[0] = False + return np.array(keep[labels]) + + +def _background_is_bright(img_grey: np.ndarray, params: BackgroundDetectionParams) -> bool: + """ + Decide if background is bright using only the corners flagged True in `bg`. + If none are flagged, return True (bright background). + """ + H, W = img_grey.shape + ch = max(1, int(params.corner_size_pct * H)) + cw = max(1, int(params.corner_size_pct * W)) + + if not params.any_corner: + return True + + corner_mask = np.zeros((H, W), bool) + if params.ymin_xmin_is_bg: + corner_mask[:ch, :cw] = True + if params.ymin_xmax_is_bg: + corner_mask[:ch, -cw:] = True + if params.ymax_xmin_is_bg: + corner_mask[-ch:, :cw] = True + if params.ymax_xmax_is_bg: + corner_mask[-ch:, -cw:] = True + + if not corner_mask.any(): + return True + corner_mean = float(img_grey[corner_mask].mean()) + global_median = float(np.median(img_grey)) + return corner_mean >= global_median + + +def _make_solid(mask: np.ndarray, close_holes_smaller_than_frac: float = 0.01) -> np.ndarray: + """ + Make mask solid by connecting nearby regions and filling enclosed holes. + + Parameters + ---------- + mask : np.ndarray + Binary mask to process + close_holes_smaller_than_frac : float, default 0.01 + Maximum hole area as fraction of image area. Holes larger than this + will not be filled. + """ + if mask.dtype != bool: + mask = mask.astype(bool) + + # Calculate maximum hole area in pixels + max_hole_area = int(close_holes_smaller_than_frac * mask.size) + + # Fill holes smaller than the threshold + return np.array(remove_small_holes(mask, area_threshold=max_hole_area)) + + +def _smooth_mask(mask: np.ndarray, cycles: int) -> np.ndarray: + """ + Apply morphological closing cycles to smooth the mask boundaries. + + Parameters + ---------- + mask : np.ndarray + Integer mask to smooth (0 = background, >0 = specimen labels) + cycles : int + Number of closing cycles to apply (0 = no smoothing) + + Returns + ------- + np.ndarray + Smoothed integer mask + """ + if cycles <= 0: + return mask + + # Convert to boolean for morphological operations + binary_mask = mask > 0 + + # Calculate adaptive radius based on image size + H, W = mask.shape + min_dim = min(H, W) + # Use 1-5 pixels radius depending on image size for more noticeable smoothing + radius = max(1, min(5, min_dim // 100)) + + # Apply smoothing with progressive radius increase + smoothed_binary = binary_mask.copy() + for i in range(cycles): + # Slightly increase radius with each cycle for more effective smoothing + current_radius = radius + i + smoothed_binary = binary_closing(smoothed_binary, disk(current_radius)) + + # Convert back to integer labels, preserving the original label values + result = np.zeros_like(mask, dtype=mask.dtype) + for label_id in np.unique(mask[mask > 0]): + label_mask = mask == label_id + # Apply smoothing to this specific label + smoothed_label = binary_closing(label_mask, disk(radius)) + result[smoothed_label] = label_id + + return result + + +def _filter_by_area( + mask: np.ndarray, + min_specimen_area_frac: float, + n_samples: int | None = None, +) -> np.ndarray: + """ + Keep only specimen-sized components, returning integer labels for multiple specimens. + + If n_samples is provided: + - Remove tiny artifacts (relative min-area). + - Keep the n_samples largest components (or all if fewer are present). + + Else: + - Remove tiny artifacts. + - Apply Otsu on log10(areas) to separate specimen-sized from small artifacts. + """ + labels = measure.label(mask, connectivity=2) + n = labels.max() + if n == 0: + return np.zeros_like(mask, dtype=np.int32) + + areas = np.bincount(labels.ravel(), minlength=n + 1)[1:].astype(np.int64) + ids = np.arange(1, n + 1) + + # Remove very small components (likely noise/artifacts) + H, W = mask.shape + min_area = max(1, int(min_specimen_area_frac * H * W)) + big_enough = areas >= min_area + + if not np.any(big_enough): + return np.zeros_like(mask, dtype=np.int32) + + areas_big = areas[big_enough] + ids_big = ids[big_enough] + + if n_samples is not None: + # Keep the n_samples largest components + order = np.argsort(areas_big)[::-1] + keep = ids_big[order[:n_samples]] + # Create a mapping from old labels to new labels + result = np.zeros_like(labels, dtype=np.int32) + for new_id, old_id in enumerate(keep, 1): + result[labels == old_id] = new_id + return result + + # Otsu on log(area) if no explicit sample count + la = np.log10(areas_big + 1e-9) + thr = filters.threshold_otsu(la) if la.size > 1 else la.min() - 1.0 + keep_ids = ids_big[la > thr] + + if keep_ids.size == 0: + return np.zeros_like(mask, dtype=np.int32) + + # Create a mapping from old labels to new labels + result = np.zeros_like(labels, dtype=np.int32) + for new_id, old_id in enumerate(keep_ids, 1): + result[labels == old_id] = new_id + + return result diff --git a/src/squidpy/exp/im/_utils.py b/src/squidpy/exp/im/_utils.py new file mode 100644 index 00000000..a0b537c5 --- /dev/null +++ b/src/squidpy/exp/im/_utils.py @@ -0,0 +1,140 @@ +from __future__ import annotations + +from typing import Literal + +import spatialdata as sd +import xarray as xr +from spatialdata._logging import logger as logg + + +def _get_image_data( + sdata: sd.SpatialData, + image_key: str, + scale: str, +) -> xr.DataArray: + """ + Extract image data from SpatialData object, handling both datatree and direct DataArray images. + + Parameters + ---------- + sdata : SpatialData + SpatialData object + image_key : str + Key in sdata.images + scale : str + Multiscale level, e.g. "scale0", or "auto" for the smallest available scale + + Returns + ------- + xr.DataArray + Image data in (c, y, x) format + """ + img_node = sdata.images[image_key] + + # Check if the image is a datatree (has multiple scales) or a direct DataArray + if hasattr(img_node, "keys"): + available_scales = list(img_node.keys()) + + if scale == "auto": + scale = available_scales[-1] + elif scale not in available_scales: + print(scale) + print(available_scales) + scale = available_scales[-1] + logg.warning(f"Scale '{scale}' not found, using available scale. Available scales: {available_scales}") + + img_da = img_node[scale].image + else: + # It's a direct DataArray (no scales) + img_da = img_node.image if hasattr(img_node, "image") else img_node + + return _ensure_cyx(img_da) + + +def _ensure_cyx(img_da: xr.DataArray) -> xr.DataArray: + """Ensure dims are (c, y, x). Adds a length-1 "c" if missing.""" + dims = list(img_da.dims) + if "y" not in dims or "x" not in dims: + raise ValueError(f'Expected dims to include "y" and "x". Found dims={dims}') + + # Handle case where dims are (c, y, x) - keep as is + if "c" in dims: + return img_da if dims[0] == "c" else img_da.transpose("c", "y", "x") + # If no "c" dimension, add one + return img_da.expand_dims({"c": [0]}).transpose("c", "y", "x") + + +def _flatten_channels( + img: xr.DataArray, + channel_format: Literal["infer", "rgb", "rgba", "multichannel"] = "infer", +) -> xr.DataArray: + """ + Takes an image of shape (c, y, x) and returns a 2D image of shape (y, x). + + Conversion logic: + - 1 channel: Returns greyscale (removes channel dimension) + - 3 channels + "rgb"/"infer": Uses RGB luminance formula + - 4 channels + "rgba": Uses RGB luminance formula (ignores alpha) + - 2 channels or 4+ channels + "infer": Automatically treated as multichannel + - "multichannel": Always uses mean across all channels + + The function is silent unless the channel_format is not "infer". + + Parameters + ---------- + img : xr.DataArray + Input image with shape (c, y, x) + channel_format : Literal["infer", "rgb", "rgba", "multichannel"] + How to interpret the channels: + - "infer": Automatically infer format based on number of channels + - "rgb": Force RGB treatment (requires exactly 3 channels) + - "rgba": Force RGBA treatment (requires exactly 4 channels) + - "multichannel": Force multichannel treatment (mean across all channels) + + Returns + ------- + xr.DataArray + Greyscale image with shape (y, x) + """ + n_channels = img.sizes["c"] + + # 1 channel: always return greyscale + if n_channels == 1: + return img.squeeze("c") + + # If user explicitly specifies multichannel, always use mean + if channel_format == "multichannel": + logg.info(f"Converting {n_channels}-channel image to greyscale using mean across all channels") + return img.mean(dim="c") + + # Handle explicit RGB specification + if channel_format == "rgb": + if n_channels != 3: + raise ValueError(f"Cannot treat {n_channels}-channel image as RGB (requires exactly 3 channels)") + logg.info("Converting RGB image to greyscale using luminance formula") + weights = xr.DataArray([0.299, 0.587, 0.114], dims=["c"], coords={"c": img.coords["c"]}) + return (img * weights).sum(dim="c") + + # Handle explicit RGBA specification + elif channel_format == "rgba": + if n_channels != 4: + raise ValueError(f"Cannot treat {n_channels}-channel image as RGBA (requires exactly 4 channels)") + logg.info("Converting RGBA image to greyscale using luminance formula (ignoring alpha)") + weights = xr.DataArray([0.299, 0.587, 0.114, 0.0], dims=["c"], coords={"c": img.coords["c"]}) + return (img * weights).sum(dim="c") + + # Infer mode - automatic detection based on channel count + elif channel_format == "infer": + if n_channels == 3: + # 3 channels + infer -> RGB luminance formula + weights = xr.DataArray([0.299, 0.587, 0.114], dims=["c"], coords={"c": img.coords["c"]}) + return (img * weights).sum(dim="c") + + else: + # 2 channels or 4+ channels + infer -> multichannel + return img.mean(dim="c") + + else: + raise ValueError( + f"Invalid channel_format: {channel_format}. Must be one of 'infer', 'rgb', 'rgba', 'multichannel'." + )