Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 126 additions & 0 deletions src/spatialdata_io/readers/_utils/_image.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
from collections.abc import Callable
from typing import Any

import dask.array as da
import numpy as np
from dask import delayed
from numpy.typing import NDArray


def _compute_chunk_sizes_positions(size: int, chunk: int, min_coord: int) -> tuple[NDArray[np.int_], NDArray[np.int_]]:
"""Calculate chunk sizes and positions for a given dimension and chunk size."""
# All chunks have the same size except for the last one
positions = np.arange(min_coord, min_coord + size, chunk)
lengths = np.minimum(chunk, min_coord + size - positions)

return positions, lengths


def _compute_chunks(
shape: tuple[int, int],
chunk_size: tuple[int, int],
min_coordinates: tuple[int, int] = (0, 0),
) -> NDArray[np.int_]:
"""Create all chunk specs for a given image and chunk size.

Creates specifications (x, y, width, height) with (x, y) being the upper left corner
of chunks of size chunk_size. Chunks at the edges correspond to the remainder of
chunk size and dimensions

Parameters
----------
shape : tuple[int, int]
Size of the image in (width, height).
chunk_size : tuple[int, int]
Size of individual tiles in (width, height).
min_coordinates : tuple[int, int], optional
Minimum coordinates (x, y) in the image, defaults to (0, 0).

Returns
-------
np.ndarray
Array of shape (n_tiles_x, n_tiles_y, 4). Each entry defines a tile
as (x, y, width, height).
"""
x_positions, widths = _compute_chunk_sizes_positions(shape[1], chunk_size[1], min_coord=min_coordinates[1])
y_positions, heights = _compute_chunk_sizes_positions(shape[0], chunk_size[0], min_coord=min_coordinates[0])

# Generate the tiles
tiles = np.array(
[
[[x, y, w, h] for x, w in zip(x_positions, widths, strict=True)]
for y, h in zip(y_positions, heights, strict=True)
],
dtype=int,
)
return tiles


def _read_chunks(
func: Callable[..., NDArray[np.int_]],
slide: Any,
coords: NDArray[np.int_],
n_channel: int,
dtype: np.dtype[Any],
**func_kwargs: Any,
) -> list[list[da.Array]]:
"""Abstract method to tile a large microscopy image.

Parameters
----------
func
Function to retrieve a rectangular tile from the slide image. Must take the
arguments:

- slide Full slide image
- x0: x (col) coordinate of upper left corner of chunk
- y0: y (row) coordinate of upper left corner of chunk
- width: Width of chunk
- height: Height of chunk

and should return the chunk as numpy array of shape (c, y, x)
slide
Slide image in lazyly loaded format compatible with func
coords
Coordinates of the upper left corner of the image in format (n_row_x, n_row_y, 4)
where the last dimension defines the rectangular tile in format (x, y, width, height).
n_row_x represents the number of chunks in x dimension and n_row_y the number of chunks
in y dimension.
n_channel
Number of channels in array
dtype
Data type of image
func_kwargs
Additional keyword arguments passed to func

Returns
-------
list[list[da.array]]
List (length: n_row_x) of lists (length: n_row_y) of chunks.
Represents all chunks of the full image.
"""
func_kwargs = func_kwargs if func_kwargs else {}

# Collect each delayed chunk as item in list of list
# Inner list becomes dim=-1 (cols/x)
# Outer list becomes dim=-2 (rows/y)
# see dask.array.block
chunks = [
[
da.from_delayed(
delayed(func)(
slide,
x0=coords[y, x, 0],
y0=coords[y, x, 1],
width=coords[y, x, 2],
height=coords[y, x, 3],
**func_kwargs,
),
dtype=dtype,
shape=(n_channel, *coords[y, x, [3, 2]]),
)
for x in range(coords.shape[1])
]
for y in range(coords.shape[0])
]
return chunks
82 changes: 77 additions & 5 deletions src/spatialdata_io/readers/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,13 @@

import warnings
from pathlib import Path
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Protocol, TypeVar

import dask.array as da
import numpy as np
import tifffile
from dask_image.imread import imread
from geopandas import GeoDataFrame
from spatialdata._docs import docstring_parameter
from spatialdata.models import Image2DModel, ShapesModel
from spatialdata.models._utils import DEFAULT_COORDINATE_SYSTEM
Expand All @@ -15,13 +18,23 @@
from collections.abc import Sequence

from geopandas import GeoDataFrame
from numpy.typing import NDArray
from xarray import DataArray

from ._utils._image import _compute_chunks, _read_chunks

VALID_IMAGE_TYPES = [".tif", ".tiff", ".png", ".jpg", ".jpeg"]
VALID_SHAPE_TYPES = [".geojson"]
DEFAULT_CHUNKSIZE = (1000, 1000)

__all__ = ["generic", "geojson", "image", "VALID_IMAGE_TYPES", "VALID_SHAPE_TYPES"]

T = TypeVar("T", bound=np.generic) # Restrict to NumPy scalar types


class DaskArray(Protocol[T]):
dtype: np.dtype[T]


@docstring_parameter(
valid_image_types=", ".join(VALID_IMAGE_TYPES),
Expand Down Expand Up @@ -73,11 +86,70 @@ def geojson(input: Path, coordinate_system: str) -> GeoDataFrame:
return ShapesModel.parse(input, transformations={coordinate_system: Identity()})


def image(input: Path, data_axes: Sequence[str], coordinate_system: str) -> DataArray:
"""Reads an image file and returns a parsed Image2D spatial element."""
# this function is just a draft, the more general one will be available when
# https://github.com/scverse/spatialdata-io/pull/234 is merged
def _tiff_to_chunks(input: Path, axes_dim_mapping: dict[str, int]) -> list[list[DaskArray[np.int_]]]:
"""Chunkwise reader for tiff files.

Parameters
----------
input
Path to image
axes_dim_mapping
Mapping between dimension name (x, y, c) and index

Returns
-------
list[list[DaskArray]]
"""
# Lazy file reader
slide = tifffile.memmap(input)

# Transpose to cyx order
slide = np.transpose(slide, (axes_dim_mapping["c"], axes_dim_mapping["y"], axes_dim_mapping["x"]))

# Get dimensions in (x, y)
slide_dimensions = slide.shape[2], slide.shape[1]

# Get number of channels (c)
n_channel = slide.shape[0]

# Compute chunk coords
chunk_coords = _compute_chunks(slide_dimensions, chunk_size=DEFAULT_CHUNKSIZE, min_coordinates=(0, 0))

# Define reader func
def _reader_func(slide: NDArray[np.int_], x0: int, y0: int, width: int, height: int) -> NDArray[np.int_]:
return np.array(slide[:, y0 : y0 + height, x0 : x0 + width])

return _read_chunks(_reader_func, slide, coords=chunk_coords, n_channel=n_channel, dtype=slide.dtype)


def _dask_image_imread(input: Path, data_axes: Sequence[str]) -> da.Array:
image = imread(input)
if len(image.shape) == len(data_axes) + 1 and image.shape[0] == 1:
image = np.squeeze(image, axis=0)
return image


def image(input: Path, data_axes: Sequence[str], coordinate_system: str) -> DataArray:
"""Reads an image file and returns a parsed Image2D spatial element."""
# Map passed data axes to position of dimension
axes_dim_mapping = {axes: ndim for ndim, axes in enumerate(data_axes)}

if input.suffix in [".tiff", ".tif"]:
try:
chunks = _tiff_to_chunks(input, axes_dim_mapping=axes_dim_mapping)
image = da.block(chunks, allow_unknown_chunksizes=True)

# Edge case: Compressed images are not memory-mappable
except ValueError as e:
warnings.warn(
f"Exception occurred: {str(e)}\nPossible troubleshooting: image data are not memory-mappable, potentially due to compression. Trying to load the image into memory at once",
stacklevel=2,
)
image = _dask_image_imread(input=input, data_axes=data_axes)

elif input.suffix in [".png", ".jpg", ".jpeg"]:
image = _dask_image_imread(input=input, data_axes=data_axes)
else:
raise NotImplementedError(f"File format {input.suffix} not implemented")

return Image2DModel.parse(image, dims=data_axes, transformations={coordinate_system: Identity()})
65 changes: 65 additions & 0 deletions tests/readers/test_utils_image.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import numpy as np
import pytest
from numpy.typing import NDArray

from spatialdata_io.readers._utils._image import (
_compute_chunk_sizes_positions,
_compute_chunks,
)


@pytest.mark.parametrize(
("size", "chunk", "min_coordinate", "positions", "lengths"),
[
(300, 100, 0, np.array([0, 100, 200]), np.array([100, 100, 100])),
(300, 200, 0, np.array([0, 200]), np.array([200, 100])),
(300, 100, -100, np.array([-100, 0, 100]), np.array([100, 100, 100])),
(300, 200, -100, np.array([-100, 100]), np.array([200, 100])),
],
)
def test_compute_chunk_sizes_positions(
size: int,
chunk: int,
min_coordinate: int,
positions: NDArray[np.number],
lengths: NDArray[np.number],
) -> None:
computed_positions, computed_lengths = _compute_chunk_sizes_positions(size, chunk, min_coordinate)
assert (positions == computed_positions).all()
assert (lengths == computed_lengths).all()


@pytest.mark.parametrize(
("dimensions", "chunk_size", "min_coordinates", "result"),
[
# Regular grid 2x2
(
(2, 2),
(1, 1),
(0, 0),
np.array([[[0, 0, 1, 1], [1, 0, 1, 1]], [[0, 1, 1, 1], [1, 1, 1, 1]]]),
),
# Different tile sizes
(
(3, 3),
(2, 2),
(0, 0),
np.array([[[0, 0, 2, 2], [2, 0, 1, 2]], [[0, 2, 2, 1], [2, 2, 1, 1]]]),
),
(
(2, 2),
(1, 1),
(-1, 0),
np.array([[[0, -1, 1, 1], [1, -1, 1, 1]], [[0, 0, 1, 1], [1, 0, 1, 1]]]),
),
],
)
def test_compute_chunks(
dimensions: tuple[int, int],
chunk_size: tuple[int, int],
min_coordinates: tuple[int, int],
result: NDArray[np.number],
) -> None:
tiles = _compute_chunks(dimensions, chunk_size, min_coordinates)

assert (tiles == result).all()
39 changes: 39 additions & 0 deletions tests/test_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,12 @@
from PIL import Image
from spatialdata import SpatialData
from spatialdata.datasets import blobs
from tifffile import imread as tiffread
from tifffile import imwrite as tiffwrite

from spatialdata_io.__main__ import read_generic_wrapper
from spatialdata_io.converters.generic_to_zarr import generic_to_zarr
from spatialdata_io.readers.generic import image


@contextmanager
Expand All @@ -33,6 +36,42 @@ def save_temp_files() -> Generator[tuple[Path, Path, Path], None, None]:
yield jpg_path, geojson_path, Path(tmpdir)


@pytest.fixture(
scope="module",
params=[
{"axes": ("c", "y", "x"), "compression": None},
{"axes": ("x", "y", "c"), "compression": None},
{"axes": ("c", "y", "x"), "compression": "lzw"},
{"axes": ("x", "y", "c"), "compression": "lzw"},
],
)
def save_tiff_files(
request: pytest.FixtureRequest,
) -> Generator[tuple[Path, tuple[str], Path], None, None]:
with tempfile.TemporaryDirectory() as tmpdir:
axes = request.param["axes"]
compression = request.param["compression"]

sdata = blobs()
# save the image as tiff
x = sdata["blobs_image"].transpose(*axes).data.compute()
img = np.clip(x * 255, 0, 255).astype(np.uint8)

tiff_path = Path(tmpdir) / "blobs_image.tiff"
tiffwrite(tiff_path, img, compression=compression)

yield tiff_path, axes, Path(tmpdir)


def test_read_tiff(save_tiff_files: tuple[Path, tuple[str], Path]) -> None:
tiff_path, axes, _ = save_tiff_files
img = image(tiff_path, data_axes=axes, coordinate_system="global")

reference = tiffread(tiff_path)

assert (img.compute() == reference).all()


@pytest.mark.parametrize("cli", [True, False])
@pytest.mark.parametrize("element_name", [None, "test_element"])
def test_read_generic_image(runner: CliRunner, cli: bool, element_name: str | None) -> None:
Expand Down
Loading