Skip to content
Merged
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ dependencies = [
"networkx",
"numba>=0.55.0",
"numpy",
"ome_zarr>=0.12.2",
"ome_zarr>=0.14.0",
"pandas",
"pooch",
"pyarrow",
Expand Down
102 changes: 91 additions & 11 deletions src/spatialdata/_io/io_raster.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from __future__ import annotations

from collections.abc import Sequence
from pathlib import Path
from typing import Any, Literal
from typing import Any, Literal, TypeGuard

import dask.array as da
import numpy as np
Expand Down Expand Up @@ -38,6 +39,88 @@
)


def _is_flat_int_sequence(value: object) -> TypeGuard[Sequence[int]]:
if isinstance(value, str | bytes):
return False
Comment on lines +44 to +45
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need this? Please either remove or document.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed here: 2450bd4

if not isinstance(value, Sequence):
return False
return all(isinstance(v, int) for v in value)


def _is_dask_chunk_grid(value: object) -> TypeGuard[Sequence[Sequence[int]]]:
if isinstance(value, str | bytes):
return False
Comment on lines +52 to +53
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same for this check.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed here: 2450bd4

if not isinstance(value, Sequence):
return False
return len(value) > 0 and all(_is_flat_int_sequence(axis_chunks) for axis_chunks in value)


def _is_regular_dask_chunk_grid(chunk_grid: Sequence[Sequence[int]]) -> bool:
# Match Dask's private _check_regular_chunks() logic without depending on its internal API.
for axis_chunks in chunk_grid:
if len(axis_chunks) <= 1:
continue
if len(set(axis_chunks[:-1])) > 1:
return False
if axis_chunks[-1] > axis_chunks[0]:
return False
return True
Comment on lines +59 to +116
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd add a docstring with examples (or examples in-line with the code) to show what fails and what not.

I would add the following:
triggers the continue in the first if:

  • [(4,)]
  • [()]

triggers the first return False

  • [(4, 4, 3, 4)]

triggers the second return False

  • [(4, 4, 4, 5)]

exits with the last return True

  • [(4, 4, 4, 4)], succeeds, all chunks equal
  • [(4, 4, 4, 1)], succeeds, final chunk is < of the initial one

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also: I would add all the examples above in a test, for the function _is_regular_dask_chunk_grid().

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added docstring and tests here: 2450bd4



def _chunks_to_zarr_chunks(chunks: object) -> tuple[int, ...] | int | None:
if isinstance(chunks, int):
return chunks
if _is_flat_int_sequence(chunks):
return tuple(chunks)
if _is_dask_chunk_grid(chunks):
chunk_grid = tuple(tuple(axis_chunks) for axis_chunks in chunks)
if _is_regular_dask_chunk_grid(chunk_grid):
return tuple(axis_chunks[0] for axis_chunks in chunk_grid)
return None
return None


def _normalize_explicit_chunks(chunks: object) -> tuple[int, ...] | int:
normalized = _chunks_to_zarr_chunks(chunks)
if normalized is None:
raise ValueError(
"storage_options['chunks'] must be a Zarr chunk shape or a regular Dask chunk grid. "
"Irregular Dask chunk grids must be rechunked before writing or omitted."
)
return normalized


def _prepare_single_scale_storage_options(
storage_options: JSONDict | list[JSONDict] | None,
) -> JSONDict | list[JSONDict] | None:
if storage_options is None:
return None
if isinstance(storage_options, dict):
prepared = dict(storage_options)
if "chunks" in prepared:
prepared["chunks"] = _normalize_explicit_chunks(prepared["chunks"])
return prepared
return [dict(options) for options in storage_options]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function behaves like _prepare_multiscale_storage_options(), without normalizing the list of storage options case. Can we remove it and just use _prepare_multiscale_storage_options() (after renaming it to _prepare_storage_options()?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I unified the two functions in 20696b5. Happy to hear what you think (in case we need two, we can revert, but I think we can proceed with one function).



def _prepare_multiscale_storage_options(
storage_options: JSONDict | list[JSONDict] | None,
) -> JSONDict | list[JSONDict] | None:
if storage_options is None:
return None
if isinstance(storage_options, dict):
prepared = dict(storage_options)
if "chunks" in prepared:
prepared["chunks"] = _normalize_explicit_chunks(prepared["chunks"])
return prepared

prepared_options = [dict(options) for options in storage_options]
for options in prepared_options:
if "chunks" in options:
options["chunks"] = _normalize_explicit_chunks(options["chunks"])
return prepared_options


def _read_multiscale(
store: str | Path, raster_type: Literal["image", "labels"], reader_format: Format
) -> DataArray | DataTree:
Expand Down Expand Up @@ -251,20 +334,18 @@ def _write_raster_dataarray(
if transformations is None:
raise ValueError(f"{element_name} does not have any transformations and can therefore not be written.")
input_axes: tuple[str, ...] = tuple(raster_data.dims)
chunks = raster_data.chunks
parsed_axes = _get_valid_axes(axes=list(input_axes), fmt=raster_format)
if storage_options is not None:
if "chunks" not in storage_options and isinstance(storage_options, dict):
storage_options["chunks"] = chunks
else:
storage_options = {"chunks": chunks}
# Scaler needs to be None since we are passing the data already downscaled for the multiscale case.
# We need this because the argument of write_image_ngff is called image while the argument of
storage_options = _prepare_single_scale_storage_options(storage_options)
# Explicitly disable pyramid generation for single-scale rasters. Recent ome-zarr versions default
# write_image()/write_labels() to scale_factors=(2, 4, 8, 16), which would otherwise write s0, s1, ...
# even when the input is a plain DataArray.
# We need this because the argument of write_image_ngff is called image while the argument of
# write_labels_ngff is called label.
metadata[raster_type] = data
ome_zarr_format = get_ome_zarr_format(raster_format)
write_single_scale_ngff(
group=group,
scale_factors=[],
scaler=None,
fmt=ome_zarr_format,
axes=parsed_axes,
Expand Down Expand Up @@ -322,10 +403,9 @@ def _write_raster_datatree(
transformations = _get_transformations_xarray(xdata)
if transformations is None:
raise ValueError(f"{element_name} does not have any transformations and can therefore not be written.")
chunks = get_pyramid_levels(raster_data, "chunks")

parsed_axes = _get_valid_axes(axes=list(input_axes), fmt=raster_format)
storage_options = [{"chunks": chunk} for chunk in chunks]
storage_options = _prepare_multiscale_storage_options(storage_options)
ome_zarr_format = get_ome_zarr_format(raster_format)
dask_delayed = write_multi_scale_ngff(
pyramid=data,
Expand Down
20 changes: 10 additions & 10 deletions tests/io/test_partial_read.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,9 +184,9 @@ def sdata_with_corrupted_image_chunks_zarrv3(session_tmp_path: Path) -> PartialR
sdata.write(sdata_path)

corrupted = "blobs_image"
os.unlink(sdata_path / "images" / corrupted / "0" / "zarr.json") # it will hide the "0" array from the Zarr reader
os.rename(sdata_path / "images" / corrupted / "0", sdata_path / "images" / corrupted / "0_corrupted")
(sdata_path / "images" / corrupted / "0").touch()
os.unlink(sdata_path / "images" / corrupted / "s0" / "zarr.json") # it will hide the "0" array from the Zarr reader
os.rename(sdata_path / "images" / corrupted / "s0", sdata_path / "images" / corrupted / "s0_corrupted")
(sdata_path / "images" / corrupted / "s0").touch()

not_corrupted = [name for _, name, _ in sdata.gen_elements() if name != corrupted]

Expand All @@ -206,9 +206,9 @@ def sdata_with_corrupted_image_chunks_zarrv2(session_tmp_path: Path) -> PartialR
sdata.write(sdata_path, sdata_formats=SpatialDataContainerFormatV01())

corrupted = "blobs_image"
os.unlink(sdata_path / "images" / corrupted / "0" / ".zarray") # it will hide the "0" array from the Zarr reader
os.rename(sdata_path / "images" / corrupted / "0", sdata_path / "images" / corrupted / "0_corrupted")
(sdata_path / "images" / corrupted / "0").touch()
os.unlink(sdata_path / "images" / corrupted / "s0" / ".zarray") # it will hide the "0" array from the Zarr reader
os.rename(sdata_path / "images" / corrupted / "s0", sdata_path / "images" / corrupted / "s0_corrupted")
(sdata_path / "images" / corrupted / "s0").touch()
not_corrupted = [name for _, name, _ in sdata.gen_elements() if name != corrupted]

return PartialReadTestCase(
Expand Down Expand Up @@ -315,8 +315,8 @@ def sdata_with_missing_image_chunks_zarrv3(
sdata.write(sdata_path)

corrupted = "blobs_image"
os.unlink(sdata_path / "images" / corrupted / "0" / "zarr.json")
os.rename(sdata_path / "images" / corrupted / "0", sdata_path / "images" / corrupted / "0_corrupted")
os.unlink(sdata_path / "images" / corrupted / "s0" / "zarr.json")
os.rename(sdata_path / "images" / corrupted / "s0", sdata_path / "images" / corrupted / "s0_corrupted")

not_corrupted = [name for _, name, _ in sdata.gen_elements() if name != corrupted]

Expand All @@ -339,8 +339,8 @@ def sdata_with_missing_image_chunks_zarrv2(
sdata.write(sdata_path, sdata_formats=SpatialDataContainerFormatV01())

corrupted = "blobs_image"
os.unlink(sdata_path / "images" / corrupted / "0" / ".zarray")
os.rename(sdata_path / "images" / corrupted / "0", sdata_path / "images" / corrupted / "0_corrupted")
os.unlink(sdata_path / "images" / corrupted / "s0" / ".zarray")
os.rename(sdata_path / "images" / corrupted / "s0", sdata_path / "images" / corrupted / "s0_corrupted")

not_corrupted = [name for _, name, _ in sdata.gen_elements() if name != corrupted]

Expand Down
46 changes: 46 additions & 0 deletions tests/io/test_readwrite.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from pathlib import Path
from typing import Any, Literal

import dask.array as da
import dask.dataframe as dd
import numpy as np
import pandas as pd
Expand All @@ -18,6 +19,7 @@
from packaging.version import Version
from shapely import MultiPolygon, Polygon
from upath import UPath
from xarray import DataArray
from zarr.errors import GroupNotFoundError

import spatialdata.config
Expand All @@ -30,6 +32,7 @@
SpatialDataContainerFormatType,
SpatialDataContainerFormatV01,
)
from spatialdata._io.io_raster import write_image
from spatialdata.datasets import blobs
from spatialdata.models import Image2DModel
from spatialdata.models._utils import get_channel_names
Expand Down Expand Up @@ -623,6 +626,49 @@ def test_bug_rechunking_after_queried_raster():
queried.write(f)


def test_write_irregular_dask_chunks_without_explicit_storage_options(tmp_path: Path) -> None:
data = da.from_array(RNG.random((3, 800, 1000)), chunks=((3,), (300, 200, 300), (512, 488)))
image = Image2DModel.parse(data, dims=("c", "y", "x"))
sdata = SpatialData(images={"image": image})

sdata.write(tmp_path / "data.zarr")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would find it more natural if this test was failing. Now that chunks = raster_data.chunks has been removed it, writing doesn't fail, but it ignores the chunks in the data. I think a natural behavior is that if storage option specifies chunks, these are used, otherwise the ones from the data (and if the `chunks from the data are irregular and no storage options are specified, an error would be raised).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I now implemented this in 20696b5 changing the test so that it expects to fail.



def test_write_image_normalizes_explicit_regular_dask_chunk_grid(tmp_path: Path) -> None:
data = da.from_array(RNG.random((3, 800, 1000)), chunks=((3,), (300, 300, 200), (512, 488)))
image = Image2DModel.parse(data, dims=("c", "y", "x"))
group = zarr.open_group(tmp_path / "image.zarr", mode="w")

write_image(image, group, "image", storage_options={"chunks": image.data.chunks})

assert group["s0"].chunks == (3, 300, 512)


def test_write_image_rejects_explicit_irregular_dask_chunk_grid(tmp_path: Path) -> None:
data = da.from_array(RNG.random((3, 800, 1000)), chunks=((3,), (300, 200, 300), (512, 488)))
image = Image2DModel.parse(data, dims=("c", "y", "x"))
group = zarr.open_group(tmp_path / "image.zarr", mode="w")

with pytest.raises(
ValueError,
match="storage_options\\['chunks'\\] must be a Zarr chunk shape or a regular Dask chunk grid",
):
write_image(image, group, "image", storage_options={"chunks": image.data.chunks})


def test_single_scale_image_roundtrip_stays_dataarray(tmp_path: Path) -> None:
image = Image2DModel.parse(RNG.random((3, 64, 64)), dims=("c", "y", "x"))
sdata = SpatialData(images={"image": image})
path = tmp_path / "data.zarr"

sdata.write(path)
sdata_back = read_zarr(path)

assert isinstance(sdata_back["image"], DataArray)
image_group = zarr.open_group(path / "images" / "image", mode="r")
assert list(image_group.keys()) == ["s0"]


@pytest.mark.parametrize("sdata_container_format", SDATA_FORMATS)
def test_self_contained(full_sdata: SpatialData, sdata_container_format: SpatialDataContainerFormatType) -> None:
# data only in-memory, so the SpatialData object and all its elements are self-contained
Expand Down
Loading