-
Notifications
You must be signed in to change notification settings - Fork 86
ome zarr chunks #1092
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ome zarr chunks #1092
Changes from 5 commits
e9ea783
c4e8608
da9eef3
1c060b3
85148d7
930922c
20696b5
2450bd4
3dd6121
8c85587
385dd2e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,7 +1,8 @@ | ||
| from __future__ import annotations | ||
|
|
||
| from collections.abc import Sequence | ||
| from pathlib import Path | ||
| from typing import Any, Literal | ||
| from typing import Any, Literal, TypeGuard | ||
|
|
||
| import dask.array as da | ||
| import numpy as np | ||
|
|
@@ -38,6 +39,88 @@ | |
| ) | ||
|
|
||
|
|
||
| def _is_flat_int_sequence(value: object) -> TypeGuard[Sequence[int]]: | ||
| if isinstance(value, str | bytes): | ||
| return False | ||
| if not isinstance(value, Sequence): | ||
| return False | ||
| return all(isinstance(v, int) for v in value) | ||
|
|
||
|
|
||
| def _is_dask_chunk_grid(value: object) -> TypeGuard[Sequence[Sequence[int]]]: | ||
| if isinstance(value, str | bytes): | ||
| return False | ||
|
Comment on lines
+52
to
+53
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same for this check.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Removed here: 2450bd4 |
||
| if not isinstance(value, Sequence): | ||
| return False | ||
| return len(value) > 0 and all(_is_flat_int_sequence(axis_chunks) for axis_chunks in value) | ||
|
|
||
|
|
||
| def _is_regular_dask_chunk_grid(chunk_grid: Sequence[Sequence[int]]) -> bool: | ||
| # Match Dask's private _check_regular_chunks() logic without depending on its internal API. | ||
| for axis_chunks in chunk_grid: | ||
| if len(axis_chunks) <= 1: | ||
| continue | ||
| if len(set(axis_chunks[:-1])) > 1: | ||
| return False | ||
| if axis_chunks[-1] > axis_chunks[0]: | ||
| return False | ||
| return True | ||
|
Comment on lines
+59
to
+116
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'd add a docstring with examples (or examples in-line with the code) to show what fails and what not. I would add the following:
triggers the first return False
triggers the second return False
exits with the last return True
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also: I would add all the examples above in a test, for the function
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added docstring and tests here: 2450bd4 |
||
|
|
||
|
|
||
| def _chunks_to_zarr_chunks(chunks: object) -> tuple[int, ...] | int | None: | ||
| if isinstance(chunks, int): | ||
| return chunks | ||
| if _is_flat_int_sequence(chunks): | ||
| return tuple(chunks) | ||
| if _is_dask_chunk_grid(chunks): | ||
| chunk_grid = tuple(tuple(axis_chunks) for axis_chunks in chunks) | ||
| if _is_regular_dask_chunk_grid(chunk_grid): | ||
| return tuple(axis_chunks[0] for axis_chunks in chunk_grid) | ||
| return None | ||
| return None | ||
|
|
||
|
|
||
| def _normalize_explicit_chunks(chunks: object) -> tuple[int, ...] | int: | ||
| normalized = _chunks_to_zarr_chunks(chunks) | ||
| if normalized is None: | ||
| raise ValueError( | ||
| "storage_options['chunks'] must be a Zarr chunk shape or a regular Dask chunk grid. " | ||
| "Irregular Dask chunk grids must be rechunked before writing or omitted." | ||
| ) | ||
| return normalized | ||
|
|
||
|
|
||
| def _prepare_single_scale_storage_options( | ||
| storage_options: JSONDict | list[JSONDict] | None, | ||
| ) -> JSONDict | list[JSONDict] | None: | ||
| if storage_options is None: | ||
| return None | ||
| if isinstance(storage_options, dict): | ||
| prepared = dict(storage_options) | ||
| if "chunks" in prepared: | ||
| prepared["chunks"] = _normalize_explicit_chunks(prepared["chunks"]) | ||
| return prepared | ||
| return [dict(options) for options in storage_options] | ||
|
||
|
|
||
|
|
||
| def _prepare_multiscale_storage_options( | ||
| storage_options: JSONDict | list[JSONDict] | None, | ||
| ) -> JSONDict | list[JSONDict] | None: | ||
| if storage_options is None: | ||
| return None | ||
| if isinstance(storage_options, dict): | ||
| prepared = dict(storage_options) | ||
| if "chunks" in prepared: | ||
| prepared["chunks"] = _normalize_explicit_chunks(prepared["chunks"]) | ||
| return prepared | ||
|
|
||
| prepared_options = [dict(options) for options in storage_options] | ||
| for options in prepared_options: | ||
| if "chunks" in options: | ||
| options["chunks"] = _normalize_explicit_chunks(options["chunks"]) | ||
| return prepared_options | ||
|
|
||
|
|
||
| def _read_multiscale( | ||
| store: str | Path, raster_type: Literal["image", "labels"], reader_format: Format | ||
| ) -> DataArray | DataTree: | ||
|
|
@@ -251,20 +334,18 @@ def _write_raster_dataarray( | |
| if transformations is None: | ||
| raise ValueError(f"{element_name} does not have any transformations and can therefore not be written.") | ||
| input_axes: tuple[str, ...] = tuple(raster_data.dims) | ||
| chunks = raster_data.chunks | ||
| parsed_axes = _get_valid_axes(axes=list(input_axes), fmt=raster_format) | ||
| if storage_options is not None: | ||
| if "chunks" not in storage_options and isinstance(storage_options, dict): | ||
| storage_options["chunks"] = chunks | ||
| else: | ||
| storage_options = {"chunks": chunks} | ||
| # Scaler needs to be None since we are passing the data already downscaled for the multiscale case. | ||
| # We need this because the argument of write_image_ngff is called image while the argument of | ||
| storage_options = _prepare_single_scale_storage_options(storage_options) | ||
| # Explicitly disable pyramid generation for single-scale rasters. Recent ome-zarr versions default | ||
| # write_image()/write_labels() to scale_factors=(2, 4, 8, 16), which would otherwise write s0, s1, ... | ||
| # even when the input is a plain DataArray. | ||
| # We need this because the argument of write_image_ngff is called image while the argument of | ||
| # write_labels_ngff is called label. | ||
| metadata[raster_type] = data | ||
| ome_zarr_format = get_ome_zarr_format(raster_format) | ||
| write_single_scale_ngff( | ||
| group=group, | ||
| scale_factors=[], | ||
| scaler=None, | ||
| fmt=ome_zarr_format, | ||
| axes=parsed_axes, | ||
|
|
@@ -322,10 +403,9 @@ def _write_raster_datatree( | |
| transformations = _get_transformations_xarray(xdata) | ||
| if transformations is None: | ||
| raise ValueError(f"{element_name} does not have any transformations and can therefore not be written.") | ||
| chunks = get_pyramid_levels(raster_data, "chunks") | ||
|
|
||
| parsed_axes = _get_valid_axes(axes=list(input_axes), fmt=raster_format) | ||
| storage_options = [{"chunks": chunk} for chunk in chunks] | ||
| storage_options = _prepare_multiscale_storage_options(storage_options) | ||
| ome_zarr_format = get_ome_zarr_format(raster_format) | ||
| dask_delayed = write_multi_scale_ngff( | ||
| pyramid=data, | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -7,6 +7,7 @@ | |
| from pathlib import Path | ||
| from typing import Any, Literal | ||
|
|
||
| import dask.array as da | ||
| import dask.dataframe as dd | ||
| import numpy as np | ||
| import pandas as pd | ||
|
|
@@ -18,6 +19,7 @@ | |
| from packaging.version import Version | ||
| from shapely import MultiPolygon, Polygon | ||
| from upath import UPath | ||
| from xarray import DataArray | ||
| from zarr.errors import GroupNotFoundError | ||
|
|
||
| import spatialdata.config | ||
|
|
@@ -30,6 +32,7 @@ | |
| SpatialDataContainerFormatType, | ||
| SpatialDataContainerFormatV01, | ||
| ) | ||
| from spatialdata._io.io_raster import write_image | ||
| from spatialdata.datasets import blobs | ||
| from spatialdata.models import Image2DModel | ||
| from spatialdata.models._utils import get_channel_names | ||
|
|
@@ -623,6 +626,49 @@ def test_bug_rechunking_after_queried_raster(): | |
| queried.write(f) | ||
|
|
||
|
|
||
| def test_write_irregular_dask_chunks_without_explicit_storage_options(tmp_path: Path) -> None: | ||
| data = da.from_array(RNG.random((3, 800, 1000)), chunks=((3,), (300, 200, 300), (512, 488))) | ||
| image = Image2DModel.parse(data, dims=("c", "y", "x")) | ||
| sdata = SpatialData(images={"image": image}) | ||
|
|
||
| sdata.write(tmp_path / "data.zarr") | ||
|
||
|
|
||
|
|
||
| def test_write_image_normalizes_explicit_regular_dask_chunk_grid(tmp_path: Path) -> None: | ||
| data = da.from_array(RNG.random((3, 800, 1000)), chunks=((3,), (300, 300, 200), (512, 488))) | ||
| image = Image2DModel.parse(data, dims=("c", "y", "x")) | ||
| group = zarr.open_group(tmp_path / "image.zarr", mode="w") | ||
|
|
||
| write_image(image, group, "image", storage_options={"chunks": image.data.chunks}) | ||
|
|
||
| assert group["s0"].chunks == (3, 300, 512) | ||
|
|
||
|
|
||
| def test_write_image_rejects_explicit_irregular_dask_chunk_grid(tmp_path: Path) -> None: | ||
| data = da.from_array(RNG.random((3, 800, 1000)), chunks=((3,), (300, 200, 300), (512, 488))) | ||
| image = Image2DModel.parse(data, dims=("c", "y", "x")) | ||
| group = zarr.open_group(tmp_path / "image.zarr", mode="w") | ||
|
|
||
| with pytest.raises( | ||
| ValueError, | ||
| match="storage_options\\['chunks'\\] must be a Zarr chunk shape or a regular Dask chunk grid", | ||
| ): | ||
| write_image(image, group, "image", storage_options={"chunks": image.data.chunks}) | ||
|
|
||
|
|
||
| def test_single_scale_image_roundtrip_stays_dataarray(tmp_path: Path) -> None: | ||
| image = Image2DModel.parse(RNG.random((3, 64, 64)), dims=("c", "y", "x")) | ||
| sdata = SpatialData(images={"image": image}) | ||
| path = tmp_path / "data.zarr" | ||
|
|
||
| sdata.write(path) | ||
| sdata_back = read_zarr(path) | ||
|
|
||
| assert isinstance(sdata_back["image"], DataArray) | ||
| image_group = zarr.open_group(path / "images" / "image", mode="r") | ||
| assert list(image_group.keys()) == ["s0"] | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("sdata_container_format", SDATA_FORMATS) | ||
| def test_self_contained(full_sdata: SpatialData, sdata_container_format: SpatialDataContainerFormatType) -> None: | ||
| # data only in-memory, so the SpatialData object and all its elements are self-contained | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need this? Please either remove or document.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed here: 2450bd4