diff --git a/src/spatialdata/_core/spatialdata.py b/src/spatialdata/_core/spatialdata.py index 48f6386ca..6e66d7054 100644 --- a/src/spatialdata/_core/spatialdata.py +++ b/src/spatialdata/_core/spatialdata.py @@ -1179,6 +1179,7 @@ def write( overwrite: bool = False, consolidate_metadata: bool = True, format: SpatialDataFormat | list[SpatialDataFormat] | None = None, + compressor: dict[Literal["lz4", "zstd"], int] | None = None, ) -> None: """ Write the `SpatialData` object to a Zarr store. @@ -1204,7 +1205,16 @@ def write( By default, the latest format is used for all elements, i.e. :class:`~spatialdata._io.format.CurrentRasterFormat`, :class:`~spatialdata._io.format.CurrentShapesFormat`, :class:`~spatialdata._io.format.CurrentPointsFormat`, :class:`~spatialdata._io.format.CurrentTablesFormat`. + compressor + A dictionary with as key the type of compression to use for images and labels and as value the compression + level which should be inclusive between 0 and 9. For compression, `lz4` and `zstd` are supported. If not + specified, the compression will be `lz4` with compression level 5. Bytes are automatically ordered for more + efficient compression. """ + from spatialdata._io._utils import _validate_compressor_args + + _validate_compressor_args(compressor) + if isinstance(file_path, str): file_path = Path(file_path) self._validate_can_safely_write_to_path(file_path, overwrite=overwrite) @@ -1223,6 +1233,7 @@ def write( element_name=element_name, overwrite=False, format=format, + compressor=compressor, ) if self.path != file_path: @@ -1241,6 +1252,7 @@ def _write_element( element_name: str, overwrite: bool, format: SpatialDataFormat | list[SpatialDataFormat] | None = None, + compressor: dict[Literal["lz4", "zstd"], int] | None = None, ) -> None: if not isinstance(zarr_container_path, Path): raise ValueError( @@ -1260,9 +1272,17 @@ def _write_element( parsed = _parse_formats(formats=format) if element_type == "images": - write_image(image=element, group=element_type_group, name=element_name, format=parsed["raster"]) + write_image( + image=element, + group=element_type_group, + name=element_name, + format=parsed["raster"], + compressor=compressor, + ) elif element_type == "labels": - write_labels(labels=element, group=root_group, name=element_name, format=parsed["raster"]) + write_labels( + labels=element, group=root_group, name=element_name, format=parsed["raster"], compressor=compressor + ) elif element_type == "points": write_points(points=element, group=element_type_group, name=element_name, format=parsed["points"]) elif element_type == "shapes": @@ -1277,6 +1297,7 @@ def write_element( element_name: str | list[str], overwrite: bool = False, format: SpatialDataFormat | list[SpatialDataFormat] | None = None, + compressor: dict[Literal["lz4", "zstd"], int] | None = None, ) -> None: """ Write a single element, or a list of elements, to the Zarr store used for backing. @@ -1292,6 +1313,11 @@ def write_element( format It is recommended to leave this parameter equal to `None`. See more details in the documentation of `SpatialData.write()`. + compressor + A dictionary with as key the type of compression to use for images and labels and as value the compression + level which should be inclusive between 0 and 9. For compression, `lz4` and `zstd` are supported. If not + specified, the compression will be `lz4` with compression level 5. Bytes are automatically ordered for more + efficient compression. Notes ----- @@ -1301,7 +1327,7 @@ def write_element( if isinstance(element_name, list): for name in element_name: assert isinstance(name, str) - self.write_element(name, overwrite=overwrite) + self.write_element(name, overwrite=overwrite, compressor=compressor) return check_valid_name(element_name) @@ -1335,6 +1361,7 @@ def write_element( element_name=element_name, overwrite=overwrite, format=format, + compressor=compressor, ) def delete_element_from_disk(self, element_name: str | list[str]) -> None: diff --git a/src/spatialdata/_io/_utils.py b/src/spatialdata/_io/_utils.py index 5e8eb832d..c49603d1d 100644 --- a/src/spatialdata/_io/_utils.py +++ b/src/spatialdata/_io/_utils.py @@ -437,3 +437,22 @@ def handle_read_errors( else: # on_bad_files == BadFileHandleMethod.ERROR # Let it raise exceptions yield + + +def _validate_compressor_args(compressor_dict: dict[Literal["lz4", "zstd"], int] | None) -> None: + if compressor_dict: + if not isinstance(compressor_dict, dict): + raise TypeError( + f"Expected a dictionary with as key the type of compression to use for images and labels and " + f"as value the compression level which should be inclusive between 1 and 9. " + f"Got type: {type(compressor_dict)}" + ) + if len(compressor_dict) != 1: + raise ValueError( + "Expected a dictionary with a single key indicating the type of compression, either 'lz4' or " + "'zstd' and an `int` inclusive between 1 and 9 as value representing the compression level." + ) + if (compression := list(compressor_dict.keys())[0]) not in ["lz4", "zstd"]: + raise ValueError(f"Compression must either be `lz4` or `zstd`, got: {compression}.") + if not isinstance(value := list(compressor_dict.values())[0], int) or not (0 <= value <= 9): + raise ValueError(f"The compression level must be an integer inclusive between 0 and 9. Got: {value}") diff --git a/src/spatialdata/_io/io_raster.py b/src/spatialdata/_io/io_raster.py index 541be3ead..fa5b16508 100644 --- a/src/spatialdata/_io/io_raster.py +++ b/src/spatialdata/_io/io_raster.py @@ -1,3 +1,4 @@ +from collections.abc import Callable from pathlib import Path from typing import Any, Literal @@ -116,6 +117,213 @@ def _read_multiscale(store: str | Path, raster_type: Literal["image", "labels"]) return compute_coordinates(si) +def _get_group_for_writing_transformations( + raster_type: Literal["image", "labels"], group: zarr.Group, name: str +) -> zarr.Group: + """Get the appropriate zarr group for writing transformations. + + Parameters + ---------- + raster_type + Type of raster data, either "image" or "labels" + group + Parent zarr group + name + Name of the element + + Returns + ------- + The zarr group where transformations should be written + """ + if raster_type == "image": + return group.require_group(name) + return group["labels"][name] + + +def _apply_compression( + storage_options: JSONDict | list[JSONDict], compressor: dict[Literal["lz4", "zstd"], int] | None +) -> JSONDict | list[JSONDict]: + """Apply compression settings to storage options. + + Parameters + ---------- + storage_options + Storage options for zarr arrays + compressor + Compression settings as a dictionary with a single key-value pair + + Returns + ------- + Updated storage options with compression settings + """ + from zarr.codecs import Blosc + + if not compressor: + return storage_options + + ((compression, compression_level),) = compressor.items() + + if isinstance(storage_options, dict): + storage_options["compressor"] = Blosc(cname=compression, clevel=compression_level, shuffle=1) + elif isinstance(storage_options, list): + for option in storage_options: + option["compressor"] = Blosc(cname=compression, clevel=compression_level, shuffle=1) + + return storage_options + + +def _write_data_array( + raster_type: Literal["image", "labels"], + raster_data: DataArray, + group_data: zarr.Group, + format: Format, + storage_options: JSONDict | None, + compressor: dict[Literal["lz4", "zstd"], int] | None, + metadata: dict[str, Any], + get_transformations_group: Callable[[], zarr.Group], +) -> None: + """Write a DataArray to a zarr group. + + Parameters + ---------- + raster_type + Type of raster data, either "image" or "labels" + raster_data + The DataArray to write + group_data + The zarr group to write to + format + The spatialdata raster format to use for writing + storage_options + Storage options for zarr arrays (to be passed to ome-zarr) + compressor + Compression settings as a dictionary with a single key-value (compression, compression level) pair + metadata + Additional metadata + get_transformations_group + Function that returns the group for writing transformations + """ + data = raster_data.data + transformations = _get_transformations(raster_data) + input_axes: tuple[str, ...] = tuple(raster_data.dims) + chunks = raster_data.chunks + parsed_axes = _get_valid_axes(axes=list(input_axes), fmt=format) + + # Set up storage options with chunks + if storage_options is not None: + if "chunks" not in storage_options and isinstance(storage_options, dict): + storage_options["chunks"] = chunks + else: + storage_options = {"chunks": chunks} + + # Apply compression if specified + storage_options = _apply_compression(storage_options, compressor) + + # Determine which write function to use + write_single_scale_ngff = write_image_ngff if raster_type == "image" else write_labels_ngff + + # Scaler needs to be None since we are passing the data already downscaled for the multiscale case. + # We need this because the argument of write_image_ngff is called image while the argument of + # write_labels_ngff is called label. + metadata[raster_type] = data + + # Write the data + write_single_scale_ngff( + group=group_data, + scaler=None, + fmt=format, + axes=parsed_axes, + coordinate_transformations=None, + storage_options=storage_options, + **metadata, + ) + + # Write transformations + assert transformations is not None + overwrite_coordinate_transformations_raster( + group=get_transformations_group(), transformations=transformations, axes=input_axes + ) + + +def _write_data_tree( + raster_type: Literal["image", "labels"], + raster_data: DataTree, + group_data: zarr.Group, + format: Format, + storage_options: JSONDict | list[JSONDict] | None, + compressor: dict[Literal["lz4", "zstd"], int] | None, + metadata: dict[str, Any], + get_transformations_group: Callable[[], zarr.Group], +) -> None: + """Write a DataTree to a zarr group. + + Parameters + ---------- + raster_type + Type of raster data, either "image" or "labels" + raster_data + The DataTree to write + group_data + The zarr group to write to + format + The SpatialData raster format to use for writing + storage_options + Storage options for zarr arrays (to be passed to ome-zarr) + compressor + Compression settings as a dictionary with a single key-value (compression, compression level) pair + metadata + Additional metadata + get_transformations_group + Function that returns the group for writing transformations + """ + data = get_pyramid_levels(raster_data, attr="data") + list_of_input_axes: list[Any] = get_pyramid_levels(raster_data, attr="dims") + assert len(set(list_of_input_axes)) == 1 + input_axes = list_of_input_axes[0] + + # Saving only the transformations of the first scale + d = dict(raster_data["scale0"]) + assert len(d) == 1 + xdata = d.values().__iter__().__next__() + transformations = _get_transformations_xarray(xdata) + assert transformations is not None + assert len(transformations) > 0 + + chunks = get_pyramid_levels(raster_data, "chunks") + parsed_axes = _get_valid_axes(axes=list(input_axes), fmt=format) + + # Set up storage options with chunks + if storage_options is None: + storage_options = [{"chunks": chunk} for chunk in chunks] + + # Apply compression if specified + storage_options = _apply_compression(storage_options, compressor) + + # Determine which write function to use + write_multi_scale_ngff = write_multiscale_ngff if raster_type == "image" else write_multiscale_labels_ngff + + # Write the data + dask_delayed = write_multi_scale_ngff( + pyramid=data, + group=group_data, + fmt=format, + axes=parsed_axes, + coordinate_transformations=None, + storage_options=storage_options, + **metadata, + compute=False, + ) + + # Compute all pyramid levels at once to allow Dask to optimize the computational graph. + da.compute(*dask_delayed) + + # Write transformations + assert transformations is not None + overwrite_coordinate_transformations_raster( + group=get_transformations_group(), transformations=transformations, axes=tuple(input_axes) + ) + + def _write_raster( raster_type: Literal["image", "labels"], raster_data: DataArray | DataTree, @@ -123,111 +331,102 @@ def _write_raster( name: str, format: Format = CurrentRasterFormat(), storage_options: JSONDict | list[JSONDict] | None = None, + compressor: dict[Literal["lz4", "zstd"], int] | None = None, label_metadata: JSONDict | None = None, **metadata: str | JSONDict | list[JSONDict], ) -> None: - assert raster_type in ["image", "labels"] - # the argument "name" and "label_metadata" are only used for labels (to be precise, name is used in + """Write raster data to a zarr group. + + This function handles writing both image and label data, in both single-scale (DataArray) + and multi-scale (DataTree) formats. + + Parameters + ---------- + raster_type + Type of raster data, either "image" or "labels" + raster_data + The data to write, either a DataArray (single-scale) or DataTree (multi-scale) + group + The zarr group to write to + name + Name of the element + format + The raster format to use for writing + storage_options + Storage options for zarr arrays (to be passed to ome-zarr) + compressor + Compression settings as a dictionary with a single key-value (compression, compression level) pair + label_metadata + Metadata specific to labels + **metadata + Additional metadata + """ + # Validate inputs + if raster_type not in ["image", "labels"]: + raise TypeError(f"Writing raster data is only supported for 'image' and 'labels'. Got: {raster_type}") + + # The argument "name" and "label_metadata" are only used for labels (to be precise, name is used in # write_multiscale_ngff() when writing metadata, but name is not used in write_image_ngff(). Maybe this is bug of # ome-zarr-py. In any case, we don't need that metadata and we use the argument name so that when we write labels # the correct group is created by the ome-zarr-py APIs. For images we do it manually in the function # _get_group_for_writing_data() - if raster_type == "image": - assert label_metadata is None - else: + if raster_type == "image" and label_metadata is not None: + raise ValueError("If the rastertype is 'image', 'label_metadata' should be None.") + + if raster_type == "labels": metadata["name"] = name metadata["label_metadata"] = label_metadata - write_single_scale_ngff = write_image_ngff if raster_type == "image" else write_labels_ngff - write_multi_scale_ngff = write_multiscale_ngff if raster_type == "image" else write_multiscale_labels_ngff - + # Prepare the group for writing data group_data = group.require_group(name) if raster_type == "image" else group - def _get_group_for_writing_transformations() -> zarr.Group: - if raster_type == "image": - return group.require_group(name) - return group["labels"][name] + # Create a function to get the transformations group + get_transformations_group = lambda: _get_group_for_writing_transformations(raster_type, group, name) - # convert channel names to channel metadata in omero + # Convert channel names to channel metadata in omero for images if raster_type == "image": metadata["metadata"] = {"omero": {"channels": []}} channels = get_channel_names(raster_data) for c in channels: metadata["metadata"]["omero"]["channels"].append({"label": c}) # type: ignore[union-attr, index, call-overload] + # Write the data based on its type if isinstance(raster_data, DataArray): - data = raster_data.data - transformations = _get_transformations(raster_data) - input_axes: tuple[str, ...] = tuple(raster_data.dims) - chunks = raster_data.chunks - parsed_axes = _get_valid_axes(axes=list(input_axes), fmt=format) - if storage_options is not None: - if "chunks" not in storage_options and isinstance(storage_options, dict): - storage_options["chunks"] = chunks - else: - storage_options = {"chunks": chunks} - # Scaler needs to be None since we are passing the data already downscaled for the multiscale case. - # We need this because the argument of write_image_ngff is called image while the argument of - # write_labels_ngff is called label. - metadata[raster_type] = data - write_single_scale_ngff( - group=group_data, - scaler=None, - fmt=format, - axes=parsed_axes, - coordinate_transformations=None, + _write_data_array( + raster_type=raster_type, + raster_data=raster_data, + group_data=group_data, + format=format, storage_options=storage_options, - **metadata, - ) - assert transformations is not None - overwrite_coordinate_transformations_raster( - group=_get_group_for_writing_transformations(), transformations=transformations, axes=input_axes + compressor=compressor, + metadata=metadata, + get_transformations_group=get_transformations_group, ) elif isinstance(raster_data, DataTree): - data = get_pyramid_levels(raster_data, attr="data") - list_of_input_axes: list[Any] = get_pyramid_levels(raster_data, attr="dims") - assert len(set(list_of_input_axes)) == 1 - input_axes = list_of_input_axes[0] - # saving only the transformations of the first scale - d = dict(raster_data["scale0"]) - assert len(d) == 1 - xdata = d.values().__iter__().__next__() - transformations = _get_transformations_xarray(xdata) - assert transformations is not None - assert len(transformations) > 0 - chunks = get_pyramid_levels(raster_data, "chunks") - # coords = iterate_pyramid_levels(raster_data, "coords") - parsed_axes = _get_valid_axes(axes=list(input_axes), fmt=format) - storage_options = [{"chunks": chunk} for chunk in chunks] - dask_delayed = write_multi_scale_ngff( - pyramid=data, - group=group_data, - fmt=format, - axes=parsed_axes, - coordinate_transformations=None, + _write_data_tree( + raster_type=raster_type, + raster_data=raster_data, + group_data=group_data, + format=format, storage_options=storage_options, - **metadata, - compute=False, - ) - # Compute all pyramid levels at once to allow Dask to optimize the computational graph. - da.compute(*dask_delayed) - assert transformations is not None - overwrite_coordinate_transformations_raster( - group=_get_group_for_writing_transformations(), transformations=transformations, axes=tuple(input_axes) + compressor=compressor, + metadata=metadata, + get_transformations_group=get_transformations_group, ) else: raise ValueError("Not a valid labels object") - # as explained in a comment in format.py, since coordinate transformations are not part of NGFF yet, we need to have + # Write format version metadata + # As explained in a comment in format.py, since coordinate transformations are not part of NGFF yet, we need to have # our spatialdata extension also for raster type (eventually it will be dropped in favor of pure NGFF). Until then, # saving the NGFF version (i.e. 0.4) is not enough, and we need to also record which version of the spatialdata # format we are using for raster types - group = _get_group_for_writing_transformations() + group = get_transformations_group() if ATTRS_KEY not in group.attrs: group.attrs[ATTRS_KEY] = {} attrs = group.attrs[ATTRS_KEY] attrs["version"] = format.spatialdata_format_version - # triggers the write operation + # Triggers the write operation group.attrs[ATTRS_KEY] = attrs @@ -237,6 +436,7 @@ def write_image( name: str, format: Format = CurrentRasterFormat(), storage_options: JSONDict | list[JSONDict] | None = None, + compressor: dict[Literal["lz4", "zstd"], int] | None = None, **metadata: str | JSONDict | list[JSONDict], ) -> None: _write_raster( @@ -246,6 +446,7 @@ def write_image( name=name, format=format, storage_options=storage_options, + compressor=compressor, **metadata, ) @@ -257,6 +458,7 @@ def write_labels( format: Format = CurrentRasterFormat(), storage_options: JSONDict | list[JSONDict] | None = None, label_metadata: JSONDict | None = None, + compressor: dict[Literal["lz4", "zstd"], int] | None = None, **metadata: JSONDict, ) -> None: _write_raster( @@ -266,6 +468,7 @@ def write_labels( name=name, format=format, storage_options=storage_options, + compressor=compressor, label_metadata=label_metadata, **metadata, ) diff --git a/tests/io/test_readwrite.py b/tests/io/test_readwrite.py index ad8c66b4c..1e13d7704 100644 --- a/tests/io/test_readwrite.py +++ b/tests/io/test_readwrite.py @@ -2,11 +2,12 @@ import tempfile from collections.abc import Callable from pathlib import Path -from typing import Any +from typing import Any, Literal import dask.dataframe as dd import numpy as np import pytest +import zarr from anndata import AnnData from numpy.random import default_rng @@ -82,6 +83,30 @@ def test_multiple_tables(self, tmp_path: str, tables: list[AnnData]) -> None: sdata_tables = SpatialData(tables={str(i): tables[i] for i in range(len(tables))}) self._test_table(tmp_path, sdata_tables) + def test_compression_roundtrip(self, tmp_path: str, full_sdata: SpatialData): + tmpdir = Path(tmp_path) / "tmp.zarr" + with pytest.raises(TypeError, match="Expected a dictionary with as"): + full_sdata.write(tmpdir, compressor="faulty") + with pytest.raises(ValueError, match="Expected a dictionary with a single"): + full_sdata.write(tmpdir, compressor={"zstd": 8, "other_item": 4}) + with pytest.raises(ValueError, match="Compression must either"): + full_sdata.write(tmpdir, compressor={"faulty": 8}) + with pytest.raises(ValueError, match="Compression must either"): + full_sdata.write(tmpdir, compressor={"The compression level": 10}) + + full_sdata.write(tmpdir, compressor={"zstd": 8}) + + # sourcery skip: no-loop-in-tests + for element in ["image2d", "image2d_multiscale"]: + compressor = zarr.open_group(tmpdir / "images", mode="r")[element]["0"].compressor + assert compressor.cname == "zstd" + assert compressor.clevel == 8 + + for element in ["labels2d", "labels2d_multiscale"]: + compressor = zarr.open_group(tmpdir / "labels", mode="r")[element]["0"].compressor + assert compressor.cname == "zstd" + assert compressor.clevel == 8 + def test_roundtrip( self, tmp_path: str, @@ -252,6 +277,22 @@ def test_incremental_io_on_disk( sdata.delete_element_from_disk(name) sdata.write_element(name) + @pytest.mark.parametrize("compressor", [{"lz4": 3}, {"zstd": 7}]) + @pytest.mark.parametrize("element", [("images", "image2d"), ("labels", "labels2d")]) + def test_write_element_compression( + self, tmp_path: str, full_sdata: SpatialData, compressor: dict[Literal["lz4", "zstd"], int], element: str + ): + tmpdir = Path(tmp_path) / "compression.zarr" + sdata = SpatialData() + sdata.write(tmpdir) + + sdata["element"] = full_sdata[element[1]] + sdata.write_element("element", compressor=compressor) + + compression = zarr.open_group(tmpdir / element[0], mode="r")["element"]["0"].compressor + assert compression.cname == list(compressor.keys())[0] + assert compression.clevel == list(compressor.values())[0] + def test_incremental_io_table_legacy(self, table_single_annotation: SpatialData) -> None: s = table_single_annotation t = s["table"][:10, :].copy()