|
1 | 1 | from __future__ import annotations |
2 | 2 |
|
| 3 | +from collections.abc import Sequence |
3 | 4 | from pathlib import Path |
4 | | -from typing import Any, Literal |
| 5 | +from typing import Any, Literal, TypeGuard |
5 | 6 |
|
6 | 7 | import dask.array as da |
7 | 8 | import numpy as np |
|
38 | 39 | ) |
39 | 40 |
|
40 | 41 |
|
| 42 | +def _is_flat_int_sequence(value: object) -> TypeGuard[Sequence[int]]: |
| 43 | + # e.g. "", "auto" or b"auto" |
| 44 | + if isinstance(value, str | bytes): |
| 45 | + return False |
| 46 | + if not isinstance(value, Sequence): |
| 47 | + return False |
| 48 | + return all(isinstance(v, int) for v in value) |
| 49 | + |
| 50 | + |
| 51 | +def _is_dask_chunk_grid(value: object) -> TypeGuard[Sequence[Sequence[int]]]: |
| 52 | + if isinstance(value, str | bytes): |
| 53 | + return False |
| 54 | + if not isinstance(value, Sequence): |
| 55 | + return False |
| 56 | + return len(value) > 0 and all(_is_flat_int_sequence(axis_chunks) for axis_chunks in value) |
| 57 | + |
| 58 | + |
| 59 | +def _is_regular_dask_chunk_grid(chunk_grid: Sequence[Sequence[int]]) -> bool: |
| 60 | + """Check whether a Dask chunk grid is regular (zarr-compatible). |
| 61 | +
|
| 62 | + A grid is regular when every axis has at most one unique chunk size among all but the last |
| 63 | + chunk, and the last chunk is not larger than the first. |
| 64 | +
|
| 65 | + Parameters |
| 66 | + ---------- |
| 67 | + chunk_grid |
| 68 | + Per-axis tuple of chunk sizes, for instance as returned by ``dask_array.chunks``. |
| 69 | +
|
| 70 | + Examples |
| 71 | + -------- |
| 72 | + Triggers ``continue`` on the first ``if`` (single or empty axis): |
| 73 | +
|
| 74 | + >>> _is_regular_dask_chunk_grid([(4,)]) # single chunk → True |
| 75 | + True |
| 76 | + >>> _is_regular_dask_chunk_grid([()]) # empty axis → True |
| 77 | + True |
| 78 | +
|
| 79 | + Triggers the first ``return False`` (non-uniform interior chunks): |
| 80 | +
|
| 81 | + >>> _is_regular_dask_chunk_grid([(4, 4, 3, 4)]) # interior sizes differ → False |
| 82 | + False |
| 83 | +
|
| 84 | + Triggers the second ``return False`` (last chunk larger than the first): |
| 85 | +
|
| 86 | + >>> _is_regular_dask_chunk_grid([(4, 4, 4, 5)]) # last > first → False |
| 87 | + False |
| 88 | +
|
| 89 | + Exits with ``return True``: |
| 90 | +
|
| 91 | + >>> _is_regular_dask_chunk_grid([(4, 4, 4, 4)]) # all equal → True |
| 92 | + True |
| 93 | + >>> _is_regular_dask_chunk_grid([(4, 4, 4, 1)]) # last < first → True |
| 94 | + True |
| 95 | +
|
| 96 | + Empty grid (loop never executes) → True: |
| 97 | +
|
| 98 | + >>> _is_regular_dask_chunk_grid([]) |
| 99 | + True |
| 100 | +
|
| 101 | + Multi-axis: all axes regular → True; one axis irregular → False: |
| 102 | +
|
| 103 | + >>> _is_regular_dask_chunk_grid([(4, 4, 4, 1), (3, 3, 2)]) |
| 104 | + True |
| 105 | + >>> _is_regular_dask_chunk_grid([(4, 4, 4, 1), (4, 4, 3, 4)]) |
| 106 | + False |
| 107 | + """ |
| 108 | + # Match Dask's private _check_regular_chunks() logic without depending on its internal API. |
| 109 | + for axis_chunks in chunk_grid: |
| 110 | + if len(axis_chunks) <= 1: |
| 111 | + continue |
| 112 | + if len(set(axis_chunks[:-1])) > 1: |
| 113 | + return False |
| 114 | + if axis_chunks[-1] > axis_chunks[0]: |
| 115 | + return False |
| 116 | + return True |
| 117 | + |
| 118 | + |
| 119 | +def _chunks_to_zarr_chunks(chunks: object) -> tuple[int, ...] | int | None: |
| 120 | + if isinstance(chunks, int): |
| 121 | + return chunks |
| 122 | + if _is_flat_int_sequence(chunks): |
| 123 | + return tuple(chunks) |
| 124 | + if _is_dask_chunk_grid(chunks): |
| 125 | + chunk_grid = tuple(tuple(axis_chunks) for axis_chunks in chunks) |
| 126 | + if _is_regular_dask_chunk_grid(chunk_grid): |
| 127 | + return tuple(axis_chunks[0] for axis_chunks in chunk_grid) |
| 128 | + return None |
| 129 | + return None |
| 130 | + |
| 131 | + |
| 132 | +def _normalize_explicit_chunks(chunks: object) -> tuple[int, ...] | int: |
| 133 | + normalized = _chunks_to_zarr_chunks(chunks) |
| 134 | + if normalized is None: |
| 135 | + raise ValueError( |
| 136 | + 'storage_options["chunks"] must resolve to a Zarr chunk shape or a regular Dask chunk grid. ' |
| 137 | + "The current raster has irregular Dask chunks, which cannot be written to Zarr. " |
| 138 | + "To fix this, rechunk before writing, for example by passing regular chunks=... " |
| 139 | + "to Image2DModel.parse(...) / Labels2DModel.parse(...)." |
| 140 | + ) |
| 141 | + return normalized |
| 142 | + |
| 143 | + |
| 144 | +def _prepare_storage_options( |
| 145 | + storage_options: JSONDict | list[JSONDict] | None, |
| 146 | +) -> JSONDict | list[JSONDict] | None: |
| 147 | + if storage_options is None: |
| 148 | + return None |
| 149 | + if isinstance(storage_options, dict): |
| 150 | + prepared = dict(storage_options) |
| 151 | + if "chunks" in prepared: |
| 152 | + prepared["chunks"] = _normalize_explicit_chunks(prepared["chunks"]) |
| 153 | + return prepared |
| 154 | + |
| 155 | + prepared_options = [dict(options) for options in storage_options] |
| 156 | + for options in prepared_options: |
| 157 | + if "chunks" in options: |
| 158 | + options["chunks"] = _normalize_explicit_chunks(options["chunks"]) |
| 159 | + return prepared_options |
| 160 | + |
| 161 | + |
41 | 162 | def _read_multiscale( |
42 | 163 | store: str | Path, raster_type: Literal["image", "labels"], reader_format: Format |
43 | 164 | ) -> DataArray | DataTree: |
@@ -251,20 +372,18 @@ def _write_raster_dataarray( |
251 | 372 | if transformations is None: |
252 | 373 | raise ValueError(f"{element_name} does not have any transformations and can therefore not be written.") |
253 | 374 | input_axes: tuple[str, ...] = tuple(raster_data.dims) |
254 | | - chunks = raster_data.chunks |
255 | 375 | parsed_axes = _get_valid_axes(axes=list(input_axes), fmt=raster_format) |
256 | | - if storage_options is not None: |
257 | | - if "chunks" not in storage_options and isinstance(storage_options, dict): |
258 | | - storage_options["chunks"] = chunks |
259 | | - else: |
260 | | - storage_options = {"chunks": chunks} |
261 | | - # Scaler needs to be None since we are passing the data already downscaled for the multiscale case. |
262 | | - # We need this because the argument of write_image_ngff is called image while the argument of |
| 376 | + storage_options = _prepare_storage_options(storage_options) |
| 377 | + # Explicitly disable pyramid generation for single-scale rasters. Recent ome-zarr versions default |
| 378 | + # write_image()/write_labels() to scale_factors=(2, 4, 8, 16), which would otherwise write s0, s1, ... |
| 379 | + # even when the input is a plain DataArray. |
| 380 | + # We need this because the argument of write_image_ngff is called image while the argument of |
263 | 381 | # write_labels_ngff is called label. |
264 | 382 | metadata[raster_type] = data |
265 | 383 | ome_zarr_format = get_ome_zarr_format(raster_format) |
266 | 384 | write_single_scale_ngff( |
267 | 385 | group=group, |
| 386 | + scale_factors=[], |
268 | 387 | scaler=None, |
269 | 388 | fmt=ome_zarr_format, |
270 | 389 | axes=parsed_axes, |
@@ -322,10 +441,9 @@ def _write_raster_datatree( |
322 | 441 | transformations = _get_transformations_xarray(xdata) |
323 | 442 | if transformations is None: |
324 | 443 | raise ValueError(f"{element_name} does not have any transformations and can therefore not be written.") |
325 | | - chunks = get_pyramid_levels(raster_data, "chunks") |
326 | 444 |
|
327 | 445 | parsed_axes = _get_valid_axes(axes=list(input_axes), fmt=raster_format) |
328 | | - storage_options = [{"chunks": chunk} for chunk in chunks] |
| 446 | + storage_options = _prepare_storage_options(storage_options) |
329 | 447 | ome_zarr_format = get_ome_zarr_format(raster_format) |
330 | 448 | dask_delayed = write_multi_scale_ngff( |
331 | 449 | pyramid=data, |
|
0 commit comments