Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 30 additions & 3 deletions src/spatialdata/_core/spatialdata.py
Original file line number Diff line number Diff line change
Expand Up @@ -1179,6 +1179,7 @@ def write(
overwrite: bool = False,
consolidate_metadata: bool = True,
format: SpatialDataFormat | list[SpatialDataFormat] | None = None,
compressor: dict[Literal["lz4", "zstd"], int] | None = None,
) -> None:
"""
Write the `SpatialData` object to a Zarr store.
Expand All @@ -1204,7 +1205,16 @@ def write(
By default, the latest format is used for all elements, i.e.
:class:`~spatialdata._io.format.CurrentRasterFormat`, :class:`~spatialdata._io.format.CurrentShapesFormat`,
:class:`~spatialdata._io.format.CurrentPointsFormat`, :class:`~spatialdata._io.format.CurrentTablesFormat`.
compressor
A dictionary with as key the type of compression to use for images and labels and as value the compression
level which should be inclusive between 0 and 9. For compression, `lz4` and `zstd` are supported. If not
specified, the compression will be `lz4` with compression level 5. Bytes are automatically ordered for more
efficient compression.
"""
from spatialdata._io._utils import _validate_compressor_args

_validate_compressor_args(compressor)

if isinstance(file_path, str):
file_path = Path(file_path)
self._validate_can_safely_write_to_path(file_path, overwrite=overwrite)
Expand All @@ -1223,6 +1233,7 @@ def write(
element_name=element_name,
overwrite=False,
format=format,
compressor=compressor,
)

if self.path != file_path:
Expand All @@ -1241,6 +1252,7 @@ def _write_element(
element_name: str,
overwrite: bool,
format: SpatialDataFormat | list[SpatialDataFormat] | None = None,
compressor: dict[Literal["lz4", "zstd"], int] | None = None,
) -> None:
if not isinstance(zarr_container_path, Path):
raise ValueError(
Expand All @@ -1260,9 +1272,17 @@ def _write_element(
parsed = _parse_formats(formats=format)

if element_type == "images":
write_image(image=element, group=element_type_group, name=element_name, format=parsed["raster"])
write_image(
image=element,
group=element_type_group,
name=element_name,
format=parsed["raster"],
compressor=compressor,
)
elif element_type == "labels":
write_labels(labels=element, group=root_group, name=element_name, format=parsed["raster"])
write_labels(
labels=element, group=root_group, name=element_name, format=parsed["raster"], compressor=compressor
)
elif element_type == "points":
write_points(points=element, group=element_type_group, name=element_name, format=parsed["points"])
elif element_type == "shapes":
Expand All @@ -1277,6 +1297,7 @@ def write_element(
element_name: str | list[str],
overwrite: bool = False,
format: SpatialDataFormat | list[SpatialDataFormat] | None = None,
compressor: dict[Literal["lz4", "zstd"], int] | None = None,
) -> None:
"""
Write a single element, or a list of elements, to the Zarr store used for backing.
Expand All @@ -1292,6 +1313,11 @@ def write_element(
format
It is recommended to leave this parameter equal to `None`. See more details in the documentation of
`SpatialData.write()`.
compressor
A dictionary with as key the type of compression to use for images and labels and as value the compression
level which should be inclusive between 0 and 9. For compression, `lz4` and `zstd` are supported. If not
specified, the compression will be `lz4` with compression level 5. Bytes are automatically ordered for more
efficient compression.

Notes
-----
Expand All @@ -1301,7 +1327,7 @@ def write_element(
if isinstance(element_name, list):
for name in element_name:
assert isinstance(name, str)
self.write_element(name, overwrite=overwrite)
self.write_element(name, overwrite=overwrite, compressor=compressor)
return

check_valid_name(element_name)
Expand Down Expand Up @@ -1335,6 +1361,7 @@ def write_element(
element_name=element_name,
overwrite=overwrite,
format=format,
compressor=compressor,
)

def delete_element_from_disk(self, element_name: str | list[str]) -> None:
Expand Down
19 changes: 19 additions & 0 deletions src/spatialdata/_io/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,3 +437,22 @@ def handle_read_errors(
else: # on_bad_files == BadFileHandleMethod.ERROR
# Let it raise exceptions
yield


def _validate_compressor_args(compressor_dict: dict[Literal["lz4", "zstd"], int] | None) -> None:
if compressor_dict:
if not isinstance(compressor_dict, dict):
raise TypeError(
f"Expected a dictionary with as key the type of compression to use for images and labels and "
f"as value the compression level which should be inclusive between 1 and 9. "
f"Got type: {type(compressor_dict)}"
)
if len(compressor_dict) != 1:
raise ValueError(
"Expected a dictionary with a single key indicating the type of compression, either 'lz4' or "
"'zstd' and an `int` inclusive between 1 and 9 as value representing the compression level."
)
if (compression := list(compressor_dict.keys())[0]) not in ["lz4", "zstd"]:
raise ValueError(f"Compression must either be `lz4` or `zstd`, got: {compression}.")
if not isinstance(value := list(compressor_dict.values())[0], int) or not (0 <= value <= 9):
raise ValueError(f"The compression level must be an integer inclusive between 0 and 9. Got: {value}")
Loading
Loading