Skip to content

Commit 78f75ef

Browse files
ome zarr chunks (#1092)
* ome zarr chunks * set scale factors to emtpy list + fix unit tests * mypy * lowercase to fix unit test linux * bump ome zarr in pyproject toml * dask accessor is now always loaded * deduplicate storage option util; use chunks from data when not specified in storage options * simplify, document and test the chunk helper functions * guard against storage_options["chunks"]="" + Change ValueError * remove data argument from _prepare_storage_options() * remove data argument from _prepare_storage_options() --------- Co-authored-by: Luca Marconato <m.lucalmer@gmail.com>
1 parent 6f65caf commit 78f75ef

File tree

5 files changed

+262
-29
lines changed

5 files changed

+262
-29
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ dependencies = [
3535
"networkx",
3636
"numba>=0.55.0",
3737
"numpy",
38-
"ome_zarr>=0.12.2",
38+
"ome_zarr>=0.14.0",
3939
"pandas",
4040
"pooch",
4141
"pyarrow",

src/spatialdata/__init__.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
from importlib.metadata import version
55
from typing import TYPE_CHECKING, Any
66

7+
import spatialdata.models._accessor # noqa: F401
8+
79
__version__ = version("spatialdata")
810

911
_submodules = {
@@ -129,15 +131,8 @@
129131
"settings",
130132
]
131133

132-
_accessor_loaded = False
133-
134134

135135
def __getattr__(name: str) -> Any:
136-
global _accessor_loaded
137-
if not _accessor_loaded:
138-
_accessor_loaded = True
139-
import spatialdata.models._accessor # noqa: F401
140-
141136
if name in _submodules:
142137
return importlib.import_module(f"spatialdata.{name}")
143138
if name in _LAZY_IMPORTS:

src/spatialdata/_io/io_raster.py

Lines changed: 129 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
from __future__ import annotations
22

3+
from collections.abc import Sequence
34
from pathlib import Path
4-
from typing import Any, Literal
5+
from typing import Any, Literal, TypeGuard
56

67
import dask.array as da
78
import numpy as np
@@ -38,6 +39,126 @@
3839
)
3940

4041

42+
def _is_flat_int_sequence(value: object) -> TypeGuard[Sequence[int]]:
43+
# e.g. "", "auto" or b"auto"
44+
if isinstance(value, str | bytes):
45+
return False
46+
if not isinstance(value, Sequence):
47+
return False
48+
return all(isinstance(v, int) for v in value)
49+
50+
51+
def _is_dask_chunk_grid(value: object) -> TypeGuard[Sequence[Sequence[int]]]:
52+
if isinstance(value, str | bytes):
53+
return False
54+
if not isinstance(value, Sequence):
55+
return False
56+
return len(value) > 0 and all(_is_flat_int_sequence(axis_chunks) for axis_chunks in value)
57+
58+
59+
def _is_regular_dask_chunk_grid(chunk_grid: Sequence[Sequence[int]]) -> bool:
60+
"""Check whether a Dask chunk grid is regular (zarr-compatible).
61+
62+
A grid is regular when every axis has at most one unique chunk size among all but the last
63+
chunk, and the last chunk is not larger than the first.
64+
65+
Parameters
66+
----------
67+
chunk_grid
68+
Per-axis tuple of chunk sizes, for instance as returned by ``dask_array.chunks``.
69+
70+
Examples
71+
--------
72+
Triggers ``continue`` on the first ``if`` (single or empty axis):
73+
74+
>>> _is_regular_dask_chunk_grid([(4,)]) # single chunk → True
75+
True
76+
>>> _is_regular_dask_chunk_grid([()]) # empty axis → True
77+
True
78+
79+
Triggers the first ``return False`` (non-uniform interior chunks):
80+
81+
>>> _is_regular_dask_chunk_grid([(4, 4, 3, 4)]) # interior sizes differ → False
82+
False
83+
84+
Triggers the second ``return False`` (last chunk larger than the first):
85+
86+
>>> _is_regular_dask_chunk_grid([(4, 4, 4, 5)]) # last > first → False
87+
False
88+
89+
Exits with ``return True``:
90+
91+
>>> _is_regular_dask_chunk_grid([(4, 4, 4, 4)]) # all equal → True
92+
True
93+
>>> _is_regular_dask_chunk_grid([(4, 4, 4, 1)]) # last < first → True
94+
True
95+
96+
Empty grid (loop never executes) → True:
97+
98+
>>> _is_regular_dask_chunk_grid([])
99+
True
100+
101+
Multi-axis: all axes regular → True; one axis irregular → False:
102+
103+
>>> _is_regular_dask_chunk_grid([(4, 4, 4, 1), (3, 3, 2)])
104+
True
105+
>>> _is_regular_dask_chunk_grid([(4, 4, 4, 1), (4, 4, 3, 4)])
106+
False
107+
"""
108+
# Match Dask's private _check_regular_chunks() logic without depending on its internal API.
109+
for axis_chunks in chunk_grid:
110+
if len(axis_chunks) <= 1:
111+
continue
112+
if len(set(axis_chunks[:-1])) > 1:
113+
return False
114+
if axis_chunks[-1] > axis_chunks[0]:
115+
return False
116+
return True
117+
118+
119+
def _chunks_to_zarr_chunks(chunks: object) -> tuple[int, ...] | int | None:
120+
if isinstance(chunks, int):
121+
return chunks
122+
if _is_flat_int_sequence(chunks):
123+
return tuple(chunks)
124+
if _is_dask_chunk_grid(chunks):
125+
chunk_grid = tuple(tuple(axis_chunks) for axis_chunks in chunks)
126+
if _is_regular_dask_chunk_grid(chunk_grid):
127+
return tuple(axis_chunks[0] for axis_chunks in chunk_grid)
128+
return None
129+
return None
130+
131+
132+
def _normalize_explicit_chunks(chunks: object) -> tuple[int, ...] | int:
133+
normalized = _chunks_to_zarr_chunks(chunks)
134+
if normalized is None:
135+
raise ValueError(
136+
'storage_options["chunks"] must resolve to a Zarr chunk shape or a regular Dask chunk grid. '
137+
"The current raster has irregular Dask chunks, which cannot be written to Zarr. "
138+
"To fix this, rechunk before writing, for example by passing regular chunks=... "
139+
"to Image2DModel.parse(...) / Labels2DModel.parse(...)."
140+
)
141+
return normalized
142+
143+
144+
def _prepare_storage_options(
145+
storage_options: JSONDict | list[JSONDict] | None,
146+
) -> JSONDict | list[JSONDict] | None:
147+
if storage_options is None:
148+
return None
149+
if isinstance(storage_options, dict):
150+
prepared = dict(storage_options)
151+
if "chunks" in prepared:
152+
prepared["chunks"] = _normalize_explicit_chunks(prepared["chunks"])
153+
return prepared
154+
155+
prepared_options = [dict(options) for options in storage_options]
156+
for options in prepared_options:
157+
if "chunks" in options:
158+
options["chunks"] = _normalize_explicit_chunks(options["chunks"])
159+
return prepared_options
160+
161+
41162
def _read_multiscale(
42163
store: str | Path, raster_type: Literal["image", "labels"], reader_format: Format
43164
) -> DataArray | DataTree:
@@ -251,20 +372,18 @@ def _write_raster_dataarray(
251372
if transformations is None:
252373
raise ValueError(f"{element_name} does not have any transformations and can therefore not be written.")
253374
input_axes: tuple[str, ...] = tuple(raster_data.dims)
254-
chunks = raster_data.chunks
255375
parsed_axes = _get_valid_axes(axes=list(input_axes), fmt=raster_format)
256-
if storage_options is not None:
257-
if "chunks" not in storage_options and isinstance(storage_options, dict):
258-
storage_options["chunks"] = chunks
259-
else:
260-
storage_options = {"chunks": chunks}
261-
# Scaler needs to be None since we are passing the data already downscaled for the multiscale case.
262-
# We need this because the argument of write_image_ngff is called image while the argument of
376+
storage_options = _prepare_storage_options(storage_options)
377+
# Explicitly disable pyramid generation for single-scale rasters. Recent ome-zarr versions default
378+
# write_image()/write_labels() to scale_factors=(2, 4, 8, 16), which would otherwise write s0, s1, ...
379+
# even when the input is a plain DataArray.
380+
# We need this because the argument of write_image_ngff is called image while the argument of
263381
# write_labels_ngff is called label.
264382
metadata[raster_type] = data
265383
ome_zarr_format = get_ome_zarr_format(raster_format)
266384
write_single_scale_ngff(
267385
group=group,
386+
scale_factors=[],
268387
scaler=None,
269388
fmt=ome_zarr_format,
270389
axes=parsed_axes,
@@ -322,10 +441,9 @@ def _write_raster_datatree(
322441
transformations = _get_transformations_xarray(xdata)
323442
if transformations is None:
324443
raise ValueError(f"{element_name} does not have any transformations and can therefore not be written.")
325-
chunks = get_pyramid_levels(raster_data, "chunks")
326444

327445
parsed_axes = _get_valid_axes(axes=list(input_axes), fmt=raster_format)
328-
storage_options = [{"chunks": chunk} for chunk in chunks]
446+
storage_options = _prepare_storage_options(storage_options)
329447
ome_zarr_format = get_ome_zarr_format(raster_format)
330448
dask_delayed = write_multi_scale_ngff(
331449
pyramid=data,

tests/io/test_partial_read.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -184,9 +184,9 @@ def sdata_with_corrupted_image_chunks_zarrv3(session_tmp_path: Path) -> PartialR
184184
sdata.write(sdata_path)
185185

186186
corrupted = "blobs_image"
187-
os.unlink(sdata_path / "images" / corrupted / "0" / "zarr.json") # it will hide the "0" array from the Zarr reader
188-
os.rename(sdata_path / "images" / corrupted / "0", sdata_path / "images" / corrupted / "0_corrupted")
189-
(sdata_path / "images" / corrupted / "0").touch()
187+
os.unlink(sdata_path / "images" / corrupted / "s0" / "zarr.json") # it will hide the "0" array from the Zarr reader
188+
os.rename(sdata_path / "images" / corrupted / "s0", sdata_path / "images" / corrupted / "s0_corrupted")
189+
(sdata_path / "images" / corrupted / "s0").touch()
190190

191191
not_corrupted = [name for _, name, _ in sdata.gen_elements() if name != corrupted]
192192

@@ -206,9 +206,9 @@ def sdata_with_corrupted_image_chunks_zarrv2(session_tmp_path: Path) -> PartialR
206206
sdata.write(sdata_path, sdata_formats=SpatialDataContainerFormatV01())
207207

208208
corrupted = "blobs_image"
209-
os.unlink(sdata_path / "images" / corrupted / "0" / ".zarray") # it will hide the "0" array from the Zarr reader
210-
os.rename(sdata_path / "images" / corrupted / "0", sdata_path / "images" / corrupted / "0_corrupted")
211-
(sdata_path / "images" / corrupted / "0").touch()
209+
os.unlink(sdata_path / "images" / corrupted / "s0" / ".zarray") # it will hide the "0" array from the Zarr reader
210+
os.rename(sdata_path / "images" / corrupted / "s0", sdata_path / "images" / corrupted / "s0_corrupted")
211+
(sdata_path / "images" / corrupted / "s0").touch()
212212
not_corrupted = [name for _, name, _ in sdata.gen_elements() if name != corrupted]
213213

214214
return PartialReadTestCase(
@@ -315,8 +315,8 @@ def sdata_with_missing_image_chunks_zarrv3(
315315
sdata.write(sdata_path)
316316

317317
corrupted = "blobs_image"
318-
os.unlink(sdata_path / "images" / corrupted / "0" / "zarr.json")
319-
os.rename(sdata_path / "images" / corrupted / "0", sdata_path / "images" / corrupted / "0_corrupted")
318+
os.unlink(sdata_path / "images" / corrupted / "s0" / "zarr.json")
319+
os.rename(sdata_path / "images" / corrupted / "s0", sdata_path / "images" / corrupted / "s0_corrupted")
320320

321321
not_corrupted = [name for _, name, _ in sdata.gen_elements() if name != corrupted]
322322

@@ -339,8 +339,8 @@ def sdata_with_missing_image_chunks_zarrv2(
339339
sdata.write(sdata_path, sdata_formats=SpatialDataContainerFormatV01())
340340

341341
corrupted = "blobs_image"
342-
os.unlink(sdata_path / "images" / corrupted / "0" / ".zarray")
343-
os.rename(sdata_path / "images" / corrupted / "0", sdata_path / "images" / corrupted / "0_corrupted")
342+
os.unlink(sdata_path / "images" / corrupted / "s0" / ".zarray")
343+
os.rename(sdata_path / "images" / corrupted / "s0", sdata_path / "images" / corrupted / "s0_corrupted")
344344

345345
not_corrupted = [name for _, name, _ in sdata.gen_elements() if name != corrupted]
346346

0 commit comments

Comments
 (0)