From eaafdd1368b7486a4e05aff8ab939557f6aece0a Mon Sep 17 00:00:00 2001 From: Luca Marconato Date: Mon, 18 Aug 2025 13:38:52 +0200 Subject: [PATCH 001/126] changed pyproject to zarr>=3 --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 1f16398d..46df93af 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,7 +47,7 @@ dependencies = [ "xarray>=2024.10.0", "xarray-schema", "xarray-spatial>=0.3.5", - "zarr<3", + "zarr>=3.0.0", ] [project.optional-dependencies] From 3047662487d34654a32184c6f13a20d899dddf78 Mon Sep 17 00:00:00 2001 From: Luca Marconato Date: Mon, 18 Aug 2025 14:28:29 +0200 Subject: [PATCH 002/126] adjust dependencies --- .github/workflows/test.yaml | 2 +- pyproject.toml | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 7831b0b4..16c85fd5 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -18,7 +18,7 @@ jobs: strategy: fail-fast: false matrix: - python: ["3.10", "3.12"] + python: ["3.11", "3.13"] os: [ubuntu-latest] include: - os: macos-latest diff --git a/pyproject.toml b/pyproject.toml index 46df93af..c43a6d2c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,7 +33,8 @@ dependencies = [ "networkx", "numba>=0.55.0", "numpy", - "ome_zarr>=0.8.4", + "ome_zarr>=0.12rc1", +# "ome_zarr>=0.8.4", "pandas", "pooch", "pyarrow", From 9f26ce0247f63cf6e7ed750fd71224a7525a7f2e Mon Sep 17 00:00:00 2001 From: Luca Marconato Date: Mon, 18 Aug 2025 16:10:03 +0200 Subject: [PATCH 003/126] wip --- src/spatialdata/_core/spatialdata.py | 18 ++++++++-------- src/spatialdata/_io/format.py | 4 ++-- src/spatialdata/_io/io_table.py | 9 ++++++-- src/spatialdata/_io/io_zarr.py | 32 ++++++++++++++++++---------- tests/io/test_partial_read.py | 16 ++++++++------ 5 files changed, 48 insertions(+), 31 deletions(-) diff --git a/src/spatialdata/_core/spatialdata.py b/src/spatialdata/_core/spatialdata.py index 48f6386c..2ca4fdee 100644 --- a/src/spatialdata/_core/spatialdata.py +++ b/src/spatialdata/_core/spatialdata.py @@ -623,7 +623,7 @@ def _get_groups_for_element( """ if not isinstance(zarr_path, Path): raise ValueError("zarr_path should be a Path object") - store = parse_url(zarr_path, mode="r+").store + store = parse_url(zarr_path, mode="r+", fmt=SpatialDataFormat).store root = zarr.group(store=store) if element_type not in ["images", "labels", "points", "polygons", "shapes", "tables"]: raise ValueError(f"Unknown element type {element_type}") @@ -646,7 +646,7 @@ def _group_for_element_exists(self, zarr_path: Path, element_type: str, element_ ------- True if the group exists, False otherwise. """ - store = parse_url(zarr_path, mode="r").store + store = parse_url(zarr_path, mode="r", fmt=SpatialDataFormat).store root = zarr.group(store=store) assert element_type in ["images", "labels", "points", "polygons", "shapes", "tables"] exists = element_type in root and element_name in root[element_type] @@ -1070,7 +1070,7 @@ def elements_paths_on_disk(self) -> list[str]: """ if self.path is None: raise ValueError("The SpatialData object is not backed by a Zarr store.") - store = parse_url(self.path, mode="r").store + store = parse_url(self.path, mode="r", fmt=SpatialDataFormat).store root = zarr.group(store=store) elements_in_zarr = [] @@ -1123,7 +1123,7 @@ def _validate_can_safely_write_to_path( raise ValueError(f"file_path must be a string or a Path object, type(file_path) = {type(file_path)}.") if os.path.exists(file_path): - if parse_url(file_path, mode="r") is None: + if parse_url(file_path, mode="r", fmt=SpatialDataFormat) is None: raise ValueError( "The target file path specified already exists, and it has been detected to not be a Zarr store. " "Overwriting non-Zarr stores is not supported to prevent accidental data loss." @@ -1210,7 +1210,7 @@ def write( self._validate_can_safely_write_to_path(file_path, overwrite=overwrite) self._validate_all_elements() - store = parse_url(file_path, mode="w").store + store = parse_url(file_path, mode="w", fmt=SpatialDataFormat).store zarr_group = zarr.group(store=store, overwrite=overwrite) self.write_attrs(zarr_group=zarr_group) store.close() @@ -1417,7 +1417,7 @@ def delete_element_from_disk(self, element_name: str | list[str]) -> None: ) # delete the element - store = parse_url(self.path, mode="r+").store + store = parse_url(self.path, mode="r+", fmt=SpatialDataFormat).store root = zarr.group(store=store) root[element_type].pop(element_name) store.close() @@ -1438,7 +1438,7 @@ def _check_element_not_on_disk_with_different_type(self, element_type: str, elem ) def write_consolidated_metadata(self) -> None: - store = parse_url(self.path, mode="r+").store + store = parse_url(self.path, mode="r+", fmt=SpatialDataFormat).store # consolidate metadata to more easily support remote reading bug in zarr. In reality, 'zmetadata' is written # instead of '.zmetadata' see discussion https://github.com/zarr-developers/zarr-python/issues/1121 zarr.consolidate_metadata(store, metadata_key=".zmetadata") @@ -1446,7 +1446,7 @@ def write_consolidated_metadata(self) -> None: def has_consolidated_metadata(self) -> bool: return_value = False - store = parse_url(self.path, mode="r").store + store = parse_url(self.path, mode="r", fmt=SpatialDataFormat).store if "zmetadata" in store: return_value = True store.close() @@ -1622,7 +1622,7 @@ def write_attrs(self, format: SpatialDataFormat | None = None, zarr_group: zarr. if zarr_group is None: assert self.is_backed(), "The SpatialData object must be backed by a Zarr store to write attrs." - store = parse_url(self.path, mode="r+").store + store = parse_url(self.path, mode="r+", fmt=SpatialDataFormat).store zarr_group = zarr.group(store=store, overwrite=False) version = parsed["SpatialData"].spatialdata_format_version diff --git a/src/spatialdata/_io/format.py b/src/spatialdata/_io/format.py index b98ee9bd..6222edbe 100644 --- a/src/spatialdata/_io/format.py +++ b/src/spatialdata/_io/format.py @@ -6,7 +6,7 @@ import ome_zarr.format import zarr from anndata import AnnData -from ome_zarr.format import CurrentFormat, Format, FormatV01, FormatV02, FormatV03, FormatV04 +from ome_zarr.format import Format, FormatV01, FormatV02, FormatV03, FormatV04 from pandas.api.types import CategoricalDtype from shapely import GeometryType @@ -46,7 +46,7 @@ def _parse_version(group: zarr.Group, expect_attrs_key: bool) -> str | None: return version -class SpatialDataFormat(CurrentFormat): +class SpatialDataFormat(FormatV04): pass diff --git a/src/spatialdata/_io/io_table.py b/src/spatialdata/_io/io_table.py index 023129d5..3335e093 100644 --- a/src/spatialdata/_io/io_table.py +++ b/src/spatialdata/_io/io_table.py @@ -10,8 +10,8 @@ from anndata import read_zarr as read_anndata_zarr from anndata._io.specs import write_elem as write_adata from ome_zarr.format import Format -from zarr.errors import ArrayNotFoundError +# from zarr.errors import ArrayNotFoundError # removed in zarr 3.0 from spatialdata._io._utils import BadFileHandleMethod, handle_read_errors from spatialdata._io.format import CurrentTablesFormat, TablesFormats, _parse_version from spatialdata._logging import logger @@ -53,7 +53,12 @@ def _read_table( with handle_read_errors( on_bad_files=on_bad_files, location=f"{subgroup.path}/{table_name}", - exc_types=(JSONDecodeError, KeyError, ValueError, ArrayNotFoundError), + exc_types=( + JSONDecodeError, + KeyError, + ValueError, + # ArrayNotFoundError, # removed in zarr 3.0 + ), ): tables[table_name] = read_anndata_zarr(f_elem_store) diff --git a/src/spatialdata/_io/io_zarr.py b/src/spatialdata/_io/io_zarr.py index 224ef112..0b1f5682 100644 --- a/src/spatialdata/_io/io_zarr.py +++ b/src/spatialdata/_io/io_zarr.py @@ -8,10 +8,14 @@ import zarr from anndata import AnnData from pyarrow import ArrowInvalid -from zarr.errors import ArrayNotFoundError, MetadataError +from zarr.errors import MetadataValidationError from spatialdata._core.spatialdata import SpatialData -from spatialdata._io._utils import BadFileHandleMethod, handle_read_errors, ome_zarr_logger +from spatialdata._io._utils import ( + BadFileHandleMethod, + handle_read_errors, + ome_zarr_logger, +) from spatialdata._io.io_points import _read_points from spatialdata._io.io_raster import _read_multiscale from spatialdata._io.io_shapes import _read_shapes @@ -88,7 +92,7 @@ def read_zarr( with handle_read_errors( on_bad_files, location="images", - exc_types=(JSONDecodeError, MetadataError), + exc_types=(JSONDecodeError, MetadataValidationError), ): group = f["images"] count = 0 @@ -105,7 +109,7 @@ def read_zarr( JSONDecodeError, # JSON parse error ValueError, # ome_zarr: Unable to read the NGFF file KeyError, # Missing JSON key - ArrayNotFoundError, # Image chunks missing + # ArrayNotFoundError, # Image chunks missing, removed in Zarr v3 TypeError, # instead of ArrayNotFoundError, with dask>=2024.10.0 zarr<=2.18.3 ), ): @@ -120,7 +124,7 @@ def read_zarr( with handle_read_errors( on_bad_files, location="labels", - exc_types=(JSONDecodeError, MetadataError), + exc_types=(JSONDecodeError, MetadataValidationError), ): group = f["labels"] count = 0 @@ -133,7 +137,13 @@ def read_zarr( with handle_read_errors( on_bad_files, location=f"{group.path}/{subgroup_name}", - exc_types=(JSONDecodeError, KeyError, ValueError, ArrayNotFoundError, TypeError), + exc_types=( + JSONDecodeError, + KeyError, + ValueError, + # ArrayNotFoundError, # removed in Zarr v3 + TypeError, + ), ): labels[subgroup_name] = _read_multiscale(f_elem_store, raster_type="labels") count += 1 @@ -144,7 +154,7 @@ def read_zarr( with handle_read_errors( on_bad_files, location="points", - exc_types=(JSONDecodeError, MetadataError), + exc_types=(JSONDecodeError, MetadataValidationError), ): group = f["points"] count = 0 @@ -167,7 +177,7 @@ def read_zarr( with handle_read_errors( on_bad_files, location="shapes", - exc_types=(JSONDecodeError, MetadataError), + exc_types=(JSONDecodeError, MetadataValidationError), ): group = f["shapes"] count = 0 @@ -184,7 +194,7 @@ def read_zarr( JSONDecodeError, ValueError, KeyError, - ArrayNotFoundError, + # ArrayNotFoundError, # removed in Zarr v3 ), ): shapes[subgroup_name] = _read_shapes(f_elem_store) @@ -194,7 +204,7 @@ def read_zarr( with handle_read_errors( on_bad_files, location="tables", - exc_types=(JSONDecodeError, MetadataError), + exc_types=(JSONDecodeError, MetadataValidationError), ): group = f["tables"] tables = _read_table(f_store_path, f, group, tables, on_bad_files=on_bad_files) @@ -210,7 +220,7 @@ def read_zarr( with handle_read_errors( on_bad_files, location=subgroup_name, - exc_types=(JSONDecodeError, MetadataError), + exc_types=(JSONDecodeError, MetadataValidationError), ): group = f[subgroup_name] tables = _read_table(f_store_path, f, group, tables, on_bad_files=on_bad_files) diff --git a/tests/io/test_partial_read.py b/tests/io/test_partial_read.py index 7c7cdbfa..3b32d056 100644 --- a/tests/io/test_partial_read.py +++ b/tests/io/test_partial_read.py @@ -16,7 +16,7 @@ import pytest import zarr from pyarrow import ArrowInvalid -from zarr.errors import ArrayNotFoundError, MetadataError +from zarr.errors import MetadataValidationError from spatialdata import SpatialData, read_zarr from spatialdata.datasets import blobs @@ -100,8 +100,8 @@ def sdata_with_corrupted_elem_type_zgroup(session_tmp_path: Path) -> PartialRead return PartialReadTestCase( path=sdata_path, expected_elements=not_corrupted, - expected_exceptions=(JSONDecodeError, MetadataError), - warnings_patterns=["labels: JSONDecodeError", "points: MetadataError"], + expected_exceptions=(JSONDecodeError, MetadataValidationError), + warnings_patterns=["labels: JSONDecodeError", "points: MetadataValidationError"], ) @@ -146,10 +146,11 @@ def sdata_with_corrupted_image_chunks(session_tmp_path: Path) -> PartialReadTest path=sdata_path, expected_elements=not_corrupted, expected_exceptions=( - ArrayNotFoundError, + # ArrayNotFoundError, # removed in Zarr 3.0 TypeError, # instead of ArrayNotFoundError, with dask>=2024.10.0 zarr<=2.18.3 ), - warnings_patterns=[rf"images/{corrupted}: (ArrayNotFoundError|TypeError)"], + warnings_patterns=[rf"images/{corrupted}: (TypeError)"], + # warnings_patterns=[rf"images/{corrupted}: (ArrayNotFoundError|TypeError)"], ) @@ -215,10 +216,11 @@ def sdata_with_missing_image_chunks( path=sdata_path, expected_elements=not_corrupted, expected_exceptions=( - ArrayNotFoundError, + # ArrayNotFoundError, # removed in Zarr v3 TypeError, # instead of ArrayNotFoundError, with dask>=2024.10.0 zarr<=2.18.3 ), - warnings_patterns=[rf"images/{corrupted}: (ArrayNotFoundError|TypeError)"], + # warnings_patterns=[rf"images/{corrupted}: (ArrayNotFoundError|TypeError)"], + warnings_patterns=[rf"images/{corrupted}: (TypeError)"], ) From 241bbfe863d61fac8bb923e9020cbfaa418a4707 Mon Sep 17 00:00:00 2001 From: Luca Marconato Date: Mon, 18 Aug 2025 18:22:29 +0200 Subject: [PATCH 004/126] more fixes --- src/spatialdata/_core/spatialdata.py | 63 ++++++++++++++++++---------- src/spatialdata/_io/io_points.py | 7 +++- src/spatialdata/_io/io_shapes.py | 6 ++- src/spatialdata/_io/io_zarr.py | 14 ++++--- src/spatialdata/models/_utils.py | 2 +- tests/io/test_format.py | 3 +- 6 files changed, 63 insertions(+), 32 deletions(-) diff --git a/src/spatialdata/_core/spatialdata.py b/src/spatialdata/_core/spatialdata.py index 2ca4fdee..c7df1c19 100644 --- a/src/spatialdata/_core/spatialdata.py +++ b/src/spatialdata/_core/spatialdata.py @@ -621,10 +621,13 @@ def _get_groups_for_element( ------- either the existing Zarr subgroup or a new one. """ + from spatialdata._io.format import SpatialDataFormat + if not isinstance(zarr_path, Path): raise ValueError("zarr_path should be a Path object") - store = parse_url(zarr_path, mode="r+", fmt=SpatialDataFormat).store - root = zarr.group(store=store) + store = SpatialDataFormat().init_store(str(zarr_path), mode="r+") + # store = parse_url(zarr_path, mode="r+", fmt=SpatialDataFormat()).store + root = zarr.open_group(store=store, mode="r+") if element_type not in ["images", "labels", "points", "polygons", "shapes", "tables"]: raise ValueError(f"Unknown element type {element_type}") element_type_group = root.require_group(element_type) @@ -646,8 +649,10 @@ def _group_for_element_exists(self, zarr_path: Path, element_type: str, element_ ------- True if the group exists, False otherwise. """ - store = parse_url(zarr_path, mode="r", fmt=SpatialDataFormat).store - root = zarr.group(store=store) + from spatialdata._io.format import SpatialDataFormat + + store = parse_url(zarr_path, mode="r", fmt=SpatialDataFormat()).store + root = zarr.open_group(store=store, mode="r") assert element_type in ["images", "labels", "points", "polygons", "shapes", "tables"] exists = element_type in root and element_name in root[element_type] store.close() @@ -1068,19 +1073,27 @@ def elements_paths_on_disk(self) -> list[str]: ------- A list of paths of the elements saved in the Zarr store. """ + from spatialdata._io.format import SpatialDataFormat + if self.path is None: raise ValueError("The SpatialData object is not backed by a Zarr store.") - store = parse_url(self.path, mode="r", fmt=SpatialDataFormat).store - root = zarr.group(store=store) + + store = parse_url(self.path, mode="r", fmt=SpatialDataFormat()).store + root = zarr.open_group(store=store, mode="r") elements_in_zarr = [] def find_groups(obj: zarr.Group, path: str) -> None: - # with the current implementation, a path of a zarr group if the path for an element if and only if its + # with the current implementation, a path of a zarr group is the path for an element if and only if its # string representation contains exactly one "/" if isinstance(obj, zarr.Group) and path.count("/") == 1: elements_in_zarr.append(path) - root.visit(lambda path: find_groups(root[path], path)) + for element_type in root: + if element_type in ["images", "labels", "points", "shapes", "tables"]: + for element_name in root[element_type]: + path = f"{element_type}/{element_name}" + elements_in_zarr.append(path) + # root.visit(lambda path: find_groups(root[path], path)) store.close() return elements_in_zarr @@ -1115,6 +1128,7 @@ def _validate_can_safely_write_to_path( saving_an_element: bool = False, ) -> None: from spatialdata._io._utils import _backed_elements_contained_in_path, _is_subfolder + from spatialdata._io.format import SpatialDataFormat if isinstance(file_path, str): file_path = Path(file_path) @@ -1123,7 +1137,7 @@ def _validate_can_safely_write_to_path( raise ValueError(f"file_path must be a string or a Path object, type(file_path) = {type(file_path)}.") if os.path.exists(file_path): - if parse_url(file_path, mode="r", fmt=SpatialDataFormat) is None: + if parse_url(file_path, mode="r", fmt=SpatialDataFormat()) is None: raise ValueError( "The target file path specified already exists, and it has been detected to not be a Zarr store. " "Overwriting non-Zarr stores is not supported to prevent accidental data loss." @@ -1205,13 +1219,15 @@ def write( :class:`~spatialdata._io.format.CurrentRasterFormat`, :class:`~spatialdata._io.format.CurrentShapesFormat`, :class:`~spatialdata._io.format.CurrentPointsFormat`, :class:`~spatialdata._io.format.CurrentTablesFormat`. """ + from spatialdata._io.format import SpatialDataFormat + if isinstance(file_path, str): file_path = Path(file_path) self._validate_can_safely_write_to_path(file_path, overwrite=overwrite) self._validate_all_elements() - store = parse_url(file_path, mode="w", fmt=SpatialDataFormat).store - zarr_group = zarr.group(store=store, overwrite=overwrite) + store = parse_url(file_path, mode="w", fmt=SpatialDataFormat()).store + zarr_group = zarr.open_group(store=store, mode="w" if overwrite else "a") self.write_attrs(zarr_group=zarr_group) store.close() @@ -1370,14 +1386,15 @@ def delete_element_from_disk(self, element_name: str | list[str]) -> None: environment (e.g. operating system, local vs network storage, file permissions, ...) and call this function appropriately (or implement a tailored solution), to prevent data loss. """ + from spatialdata._io._utils import _backed_elements_contained_in_path + from spatialdata._io.format import SpatialDataFormat + if isinstance(element_name, list): for name in element_name: assert isinstance(name, str) self.delete_element_from_disk(name) return - from spatialdata._io._utils import _backed_elements_contained_in_path - if self.path is None: raise ValueError("The SpatialData object is not backed by a Zarr store.") @@ -1417,8 +1434,8 @@ def delete_element_from_disk(self, element_name: str | list[str]) -> None: ) # delete the element - store = parse_url(self.path, mode="r+", fmt=SpatialDataFormat).store - root = zarr.group(store=store) + store = parse_url(self.path, mode="r+", fmt=SpatialDataFormat()).store + root = zarr.open_group(store=store, mode="r+") root[element_type].pop(element_name) store.close() @@ -1438,15 +1455,19 @@ def _check_element_not_on_disk_with_different_type(self, element_type: str, elem ) def write_consolidated_metadata(self) -> None: - store = parse_url(self.path, mode="r+", fmt=SpatialDataFormat).store + from spatialdata._io.format import SpatialDataFormat + + store = parse_url(self.path, mode="r+", fmt=SpatialDataFormat()).store # consolidate metadata to more easily support remote reading bug in zarr. In reality, 'zmetadata' is written # instead of '.zmetadata' see discussion https://github.com/zarr-developers/zarr-python/issues/1121 - zarr.consolidate_metadata(store, metadata_key=".zmetadata") + zarr.consolidate_metadata(store) store.close() def has_consolidated_metadata(self) -> bool: + from spatialdata._io.format import SpatialDataFormat + return_value = False - store = parse_url(self.path, mode="r", fmt=SpatialDataFormat).store + store = parse_url(self.path, mode="r", fmt=SpatialDataFormat()).store if "zmetadata" in store: return_value = True store.close() @@ -1614,7 +1635,7 @@ def _element_type_and_name_from_element_path(self, element_path: str) -> tuple[s return element_type, element_name def write_attrs(self, format: SpatialDataFormat | None = None, zarr_group: zarr.Group | None = None) -> None: - from spatialdata._io.format import _parse_formats + from spatialdata._io.format import SpatialDataFormat, _parse_formats parsed = _parse_formats(formats=format) @@ -1622,8 +1643,8 @@ def write_attrs(self, format: SpatialDataFormat | None = None, zarr_group: zarr. if zarr_group is None: assert self.is_backed(), "The SpatialData object must be backed by a Zarr store to write attrs." - store = parse_url(self.path, mode="r+", fmt=SpatialDataFormat).store - zarr_group = zarr.group(store=store, overwrite=False) + store = parse_url(self.path, mode="r+", fmt=SpatialDataFormat()).store + zarr_group = zarr.open_group(store=store, overwrite=False, mode="r+") version = parsed["SpatialData"].spatialdata_format_version version_specific_attrs = parsed["SpatialData"].attrs_to_dict() diff --git a/src/spatialdata/_io/io_points.py b/src/spatialdata/_io/io_points.py index 3106c847..3b88fbda 100644 --- a/src/spatialdata/_io/io_points.py +++ b/src/spatialdata/_io/io_points.py @@ -31,7 +31,8 @@ def _read_points( assert version is not None format = PointsFormats[version] - path = os.path.join(f._store.path, f.path, "points.parquet") + store_root = f.store_path.store.root + path = os.path.join(store_root, f.path, "points.parquet") # cache on remote file needed for parquet reader to work # TODO: allow reading in the metadata without caching all the data points = read_parquet("simplecache::" + path if path.startswith("http") else path) @@ -57,7 +58,9 @@ def write_points( t = _get_transformations(points) points_groups = group.require_group(name) - path = Path(points_groups._store.path) / points_groups.path / "points.parquet" + store_root = points_groups.store_path.store.root + group_path = points_groups.path + path = Path(store_root) / group_path / "points.parquet" # The following code iterates through all columns in the 'points' DataFrame. If the column's datatype is # 'category', it checks whether the categories of this column are known. If not, it explicitly converts the diff --git a/src/spatialdata/_io/io_shapes.py b/src/spatialdata/_io/io_shapes.py index c32ce1f3..3df1be45 100644 --- a/src/spatialdata/_io/io_shapes.py +++ b/src/spatialdata/_io/io_shapes.py @@ -50,7 +50,8 @@ def _read_shapes( geometry = from_ragged_array(typ, coords, offsets) geo_df = GeoDataFrame({"geometry": geometry}, index=index) elif isinstance(format, ShapesFormatV02): - path = Path(f._store.path) / f.path / "shapes.parquet" + store_root = f.store_path.store.root + path = Path(store_root) / f.path / "shapes.parquet" geo_df = read_parquet(path) else: raise ValueError( @@ -93,7 +94,8 @@ def write_shapes( attrs = format.attrs_to_dict(geometry) attrs["version"] = format.spatialdata_format_version elif isinstance(format, ShapesFormatV02): - path = Path(shapes_group._store.path) / shapes_group.path / "shapes.parquet" + store_root = shapes_group.store_path.store.root + path = Path(store_root) / shapes_group.path / "shapes.parquet" shapes.to_parquet(path) attrs = format.attrs_to_dict(shapes.attrs) diff --git a/src/spatialdata/_io/io_zarr.py b/src/spatialdata/_io/io_zarr.py index 0b1f5682..045d30c0 100644 --- a/src/spatialdata/_io/io_zarr.py +++ b/src/spatialdata/_io/io_zarr.py @@ -5,7 +5,7 @@ from pathlib import Path from typing import Literal -import zarr +import zarr.storage from anndata import AnnData from pyarrow import ArrowInvalid from zarr.errors import MetadataValidationError @@ -36,11 +36,15 @@ def _open_zarr_store(store: str | Path | zarr.Group) -> tuple[zarr.Group, str]: ------- A tuple of the zarr.Group object and the path to the store. """ - f = store if isinstance(store, zarr.Group) else zarr.open(store, mode="r") + f = store if isinstance(store, zarr.Group) else zarr.open_group(store, mode="r") # workaround: .zmetadata is being written as zmetadata (https://github.com/zarr-developers/zarr-python/issues/1121) - if isinstance(store, str | Path) and str(store).startswith("http") and len(f) == 0: - f = zarr.open_consolidated(store, mode="r", metadata_key="zmetadata") - f_store_path = f.store.store.path if isinstance(f.store, zarr.storage.ConsolidatedMetadataStore) else f.store.path + # not needed, consolidated metadata is always used if present + # if isinstance(store, str | Path) and str(store).startswith("http") and len(f) == 0: + # f = zarr.open_consolidated(store, mode="r", metadata_key="zmetadata") + # the metadata is accessible here: + # f.metadata.consolidated_metadata.metadata + f_store_path = f.store.root + # f_store_path = f.store.store.path if isinstance(f.store, zarr.storage.ConsolidatedMetadataStore) else f.store.path return f, f_store_path diff --git a/src/spatialdata/models/_utils.py b/src/spatialdata/models/_utils.py index eeaa7ecd..b6c31821 100644 --- a/src/spatialdata/models/_utils.py +++ b/src/spatialdata/models/_utils.py @@ -326,7 +326,7 @@ def _(data: DataArray) -> list[Any]: @get_channel_names.register def _(data: DataTree) -> list[Any]: name = list({list(data[i].data_vars.keys())[0] for i in data})[0] - channels = {tuple(data[i][name].coords["c"].values) for i in data} + channels = {tuple(data[i][name].coords["c"].values.tolist()) for i in data} if len(channels) > 1: raise ValueError(f"Channels are not consistent across scales: {channels}") return list(next(iter(channels))) diff --git a/tests/io/test_format.py b/tests/io/test_format.py index 3069e8fa..2eba5752 100644 --- a/tests/io/test_format.py +++ b/tests/io/test_format.py @@ -82,7 +82,8 @@ def test_format_shapes_v2( metadata[attrs_key].pop("version") assert metadata[attrs_key] == Shapes_f.attrs_to_dict({}) - @pytest.mark.parametrize("format", [RasterFormatV01, RasterFormatV02]) + @pytest.mark.parametrize("format", [RasterFormatV02]) + # @pytest.mark.parametrize("format", [RasterFormatV01, RasterFormatV02]) def test_format_raster_v1_v2(self, images, format: type[SpatialDataFormat]) -> None: with tempfile.TemporaryDirectory() as tmpdir: images.write(Path(tmpdir) / "images.zarr", format=format()) From bdb6da0659e7d9ece5a42497c13877f96f891a5d Mon Sep 17 00:00:00 2001 From: Wouter-Michiel Vierdag Date: Tue, 26 Aug 2025 08:46:05 +0200 Subject: [PATCH 005/126] update ome-zarr dep --- pyproject.toml | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index c43a6d2c..643acb67 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,8 +33,7 @@ dependencies = [ "networkx", "numba>=0.55.0", "numpy", - "ome_zarr>=0.12rc1", -# "ome_zarr>=0.8.4", + "ome_zarr>=0.12.2", "pandas", "pooch", "pyarrow", From e3a53dfc83f6f9d2c5b5b798cb7bc2959c3250dd Mon Sep 17 00:00:00 2001 From: Wouter-Michiel Vierdag Date: Tue, 26 Aug 2025 14:45:21 +0200 Subject: [PATCH 006/126] Add zarr v3 formats --- src/spatialdata/_io/format.py | 196 +++++++++++++++++++++++----------- 1 file changed, 131 insertions(+), 65 deletions(-) diff --git a/src/spatialdata/_io/format.py b/src/spatialdata/_io/format.py index 6222edbe..85accf28 100644 --- a/src/spatialdata/_io/format.py +++ b/src/spatialdata/_io/format.py @@ -6,7 +6,7 @@ import ome_zarr.format import zarr from anndata import AnnData -from ome_zarr.format import Format, FormatV01, FormatV02, FormatV03, FormatV04 +from ome_zarr.format import Format, FormatV01, FormatV02, FormatV03, FormatV04, FormatV05 from pandas.api.types import CategoricalDtype from shapely import GeometryType @@ -17,6 +17,8 @@ Shapes_s = ShapesModel() Points_s = PointsModel() +# TODO: change for element in spatialdata_format_version for elements into something like element_container_version + def _parse_version(group: zarr.Group, expect_attrs_key: bool) -> str | None: """ @@ -46,27 +48,7 @@ def _parse_version(group: zarr.Group, expect_attrs_key: bool) -> str | None: return version -class SpatialDataFormat(FormatV04): - pass - - -class SpatialDataContainerFormatV01(SpatialDataFormat): - @property - def spatialdata_format_version(self) -> str: - return "0.1" - - def attrs_from_dict(self, metadata: dict[str, Any]) -> dict[str, Any]: - return {} - - def attrs_to_dict(self) -> dict[str, str | dict[str, Any]]: - from spatialdata import __version__ - - return {"spatialdata_software_version": __version__} - - -class RasterFormatV01(SpatialDataFormat): - """Formatter for raster data.""" - +class CoordinateMixinV01: def generate_coordinate_transformations(self, shapes: list[tuple[Any]]) -> None | list[list[dict[str, Any]]]: data_shape = shapes[0] coordinate_transformations: list[list[dict[str, Any]]] = [] @@ -116,6 +98,78 @@ def validate_coordinate_transformations( assert np.all([j0 == j1 for j0, j1 in zip(json0, json1, strict=True)]) + +class PointAttrsMixinV01: + def attrs_from_dict(self, metadata: dict[str, Any]) -> dict[str, dict[str, Any]]: + if Points_s.ATTRS_KEY not in metadata: + raise KeyError(f"Missing key {Points_s.ATTRS_KEY} in points metadata.") + metadata_ = metadata[Points_s.ATTRS_KEY] + assert self.spatialdata_format_version == metadata_["version"] + d = {} + if Points_s.FEATURE_KEY in metadata_: + d[Points_s.FEATURE_KEY] = metadata_[Points_s.FEATURE_KEY] + if Points_s.INSTANCE_KEY in metadata_: + d[Points_s.INSTANCE_KEY] = metadata_[Points_s.INSTANCE_KEY] + return d + + def attrs_to_dict(self, data: dict[str, Any]) -> dict[str, dict[str, Any]]: + d = {} + if Points_s.ATTRS_KEY in data: + if Points_s.INSTANCE_KEY in data[Points_s.ATTRS_KEY]: + d[Points_s.INSTANCE_KEY] = data[Points_s.ATTRS_KEY][Points_s.INSTANCE_KEY] + if Points_s.FEATURE_KEY in data[Points_s.ATTRS_KEY]: + d[Points_s.FEATURE_KEY] = data[Points_s.ATTRS_KEY][Points_s.FEATURE_KEY] + return d + + +class TableValidateMixinV01: + def validate_table( + self, + table: AnnData, + region_key: None | str = None, + instance_key: None | str = None, + ) -> None: + if not isinstance(table, AnnData): + raise TypeError(f"`table` must be `anndata.AnnData`, was {type(table)}.") + if region_key is not None and not isinstance(table.obs[region_key].dtype, CategoricalDtype): + raise ValueError( + f"`table.obs[region_key]` must be of type `categorical`, not `{type(table.obs[region_key])}`." + ) + if instance_key is not None and table.obs[instance_key].isnull().values.any(): + raise ValueError("`table.obs[instance_key]` must not contain null values, but it does.") + + +class SpatialDataContainerFormatV01(FormatV04): + @property + def spatialdata_format_version(self) -> str: + return "0.1" + + def attrs_from_dict(self, metadata: dict[str, Any]) -> dict[str, Any]: + return {} + + def attrs_to_dict(self) -> dict[str, str | dict[str, Any]]: + from spatialdata import __version__ + + return {"spatialdata_software_version": __version__} + + +class SpatialDataContainerFormatV02(FormatV05): + @property + def spatialdata_format_version(self) -> str: + return "0.2" + + def attrs_from_dict(self, metadata: dict[str, Any]) -> dict[str, Any]: + return {} + + def attrs_to_dict(self) -> dict[str, str | dict[str, Any]]: + from spatialdata import __version__ + + return {"spatialdata_software_version": __version__} + + +class RasterFormatV01(FormatV04, CoordinateMixinV01): + """Formatter for raster data.""" + # eventually we are fully compliant with NGFF and we can drop SPATIALDATA_FORMAT_VERSION and simply rely on # "version"; still, until the coordinate transformations make it into NGFF, we need to have our extension @property @@ -139,7 +193,19 @@ def version(self) -> str: return "0.4-dev-spatialdata" -class ShapesFormatV01(SpatialDataFormat): +class RasterFormatV03(FormatV05, CoordinateMixinV01): + @property + def spatialdata_format_version(self) -> str: + return "0.3" + + @property + def version(self) -> str: + # 0.1 -> 0.2 changed the version string for the NGFF format, from 0.4 to 0.6-dev-spatialdata as discussed here + # https://github.com/scverse/spatialdata/pull/849 + return "0.4-dev-spatialdata" + + +class ShapesFormatV01(FormatV04): """Formatter for shapes.""" @property @@ -165,7 +231,7 @@ def attrs_to_dict(self, geometry: GeometryType) -> dict[str, str | dict[str, Any return {Shapes_s.GEOS_KEY: {Shapes_s.NAME_KEY: geometry.name, Shapes_s.TYPE_KEY: geometry.value}} -class ShapesFormatV02(SpatialDataFormat): +class ShapesFormatV02(FormatV04): """Formatter for shapes.""" @property @@ -177,87 +243,87 @@ def attrs_to_dict(self, data: dict[str, Any]) -> dict[str, str | dict[str, Any]] return {} -class PointsFormatV01(SpatialDataFormat): +class ShapesFormatV03(FormatV05): + """Formatter for shapes.""" + + @property + def spatialdata_format_version(self) -> str: + return "0.3" + + # no need for attrs_from_dict as we are not saving metadata except for the coordinate transformations + def attrs_to_dict(self, data: dict[str, Any]) -> dict[str, str | dict[str, Any]]: + return {} + + +class PointsFormatV01(FormatV04, PointAttrsMixinV01): """Formatter for points.""" @property def spatialdata_format_version(self) -> str: return "0.1" - def attrs_from_dict(self, metadata: dict[str, Any]) -> dict[str, dict[str, Any]]: - if Points_s.ATTRS_KEY not in metadata: - raise KeyError(f"Missing key {Points_s.ATTRS_KEY} in points metadata.") - metadata_ = metadata[Points_s.ATTRS_KEY] - assert self.spatialdata_format_version == metadata_["version"] - d = {} - if Points_s.FEATURE_KEY in metadata_: - d[Points_s.FEATURE_KEY] = metadata_[Points_s.FEATURE_KEY] - if Points_s.INSTANCE_KEY in metadata_: - d[Points_s.INSTANCE_KEY] = metadata_[Points_s.INSTANCE_KEY] - return d - def attrs_to_dict(self, data: dict[str, Any]) -> dict[str, dict[str, Any]]: - d = {} - if Points_s.ATTRS_KEY in data: - if Points_s.INSTANCE_KEY in data[Points_s.ATTRS_KEY]: - d[Points_s.INSTANCE_KEY] = data[Points_s.ATTRS_KEY][Points_s.INSTANCE_KEY] - if Points_s.FEATURE_KEY in data[Points_s.ATTRS_KEY]: - d[Points_s.FEATURE_KEY] = data[Points_s.ATTRS_KEY][Points_s.FEATURE_KEY] - return d +class PointsFormatV02(FormatV05, PointAttrsMixinV01): + """Formatter for points.""" + @property + def spatialdata_format_version(self) -> str: + return "0.2" -class TablesFormatV01(SpatialDataFormat): + +class TablesFormatV01(FormatV04, TableValidateMixinV01): """Formatter for the table.""" @property def spatialdata_format_version(self) -> str: return "0.1" - def validate_table( - self, - table: AnnData, - region_key: None | str = None, - instance_key: None | str = None, - ) -> None: - if not isinstance(table, AnnData): - raise TypeError(f"`table` must be `anndata.AnnData`, was {type(table)}.") - if region_key is not None and not isinstance(table.obs[region_key].dtype, CategoricalDtype): - raise ValueError( - f"`table.obs[region_key]` must be of type `categorical`, not `{type(table.obs[region_key])}`." - ) - if instance_key is not None and table.obs[instance_key].isnull().values.any(): - raise ValueError("`table.obs[instance_key]` must not contain null values, but it does.") +class TablesFormatV02(FormatV05, TableValidateMixinV01): + """Formatter for the table.""" + + @property + def spatialdata_format_version(self) -> str: + return "0.2" -CurrentRasterFormat = RasterFormatV02 -CurrentShapesFormat = ShapesFormatV02 -CurrentPointsFormat = PointsFormatV01 -CurrentTablesFormat = TablesFormatV01 -CurrentSpatialDataContainerFormats = SpatialDataContainerFormatV01 + +CurrentRasterFormat = RasterFormatV03 +CurrentShapesFormat = ShapesFormatV03 +CurrentPointsFormat = PointsFormatV02 +CurrentTablesFormat = TablesFormatV02 +CurrentSpatialDataContainerFormats = SpatialDataContainerFormatV02 ShapesFormats = { "0.1": ShapesFormatV01(), "0.2": ShapesFormatV02(), + "0.3": ShapesFormatV03(), } PointsFormats = { "0.1": PointsFormatV01(), + "0.2": PointsFormatV02(), } TablesFormats = { "0.1": TablesFormatV01(), + "0.2": TablesFormatV02(), } RasterFormats = { "0.1": RasterFormatV01(), "0.2": RasterFormatV02(), + "0.3": RasterFormatV03(), } SpatialDataContainerFormats = { "0.1": SpatialDataContainerFormatV01(), + "0.2": SpatialDataContainerFormatV02(), } def format_implementations() -> Iterator[Format]: """Return an instance of each format implementation, newest to oldest.""" + yield RasterFormatV03() yield RasterFormatV02() - # yield RasterFormatV01() # same format string as FormatV04 + yield RasterFormatV01() # same format string as FormatV04 + + yield FormatV05() yield FormatV04() yield FormatV03() yield FormatV02() From 41b980f876f1572e3d7920ee991820a3c38eaa2e Mon Sep 17 00:00:00 2001 From: Luca Marconato Date: Tue, 26 Aug 2025 14:52:55 +0200 Subject: [PATCH 007/126] refactoring formats --- src/spatialdata/__init__.py | 3 +- src/spatialdata/_core/spatialdata.py | 157 +++++++++++++++++++++------ src/spatialdata/_io/__init__.py | 4 +- src/spatialdata/_io/format.py | 19 +++- tests/io/test_format.py | 68 +++++++++--- tests/io/test_versions.py | 28 ----- 6 files changed, 201 insertions(+), 78 deletions(-) delete mode 100644 tests/io/test_versions.py diff --git a/src/spatialdata/__init__.py b/src/spatialdata/__init__.py index 0b68391a..48ba26d0 100644 --- a/src/spatialdata/__init__.py +++ b/src/spatialdata/__init__.py @@ -44,6 +44,7 @@ "SpatialData", "get_extent", "get_centroids", + "SpatialDataFormatType", "read_zarr", "unpad_raster", "get_pyramid_levels", @@ -82,6 +83,6 @@ from spatialdata._core.query.spatial_query import bounding_box_query, polygon_query from spatialdata._core.spatialdata import SpatialData from spatialdata._io._utils import get_dask_backing_files, save_transformations -from spatialdata._io.format import SpatialDataFormat +from spatialdata._io.format import SpatialDataFormatType from spatialdata._io.io_zarr import read_zarr from spatialdata._utils import get_pyramid_levels, unpad_raster diff --git a/src/spatialdata/_core/spatialdata.py b/src/spatialdata/_core/spatialdata.py index c7df1c19..15f38281 100644 --- a/src/spatialdata/_core/spatialdata.py +++ b/src/spatialdata/_core/spatialdata.py @@ -55,7 +55,7 @@ if TYPE_CHECKING: from spatialdata._core.query.spatial_query import BaseSpatialRequest - from spatialdata._io.format import SpatialDataFormat + from spatialdata._io.format import SpatialDataContainerFormatType, SpatialDataFormat # schema for elements Label2D_s = Labels2DModel() @@ -236,7 +236,8 @@ def validate_table_in_spatialdata(self, table: AnnData) -> None: @staticmethod def from_elements_dict( - elements_dict: dict[str, SpatialElement | AnnData], attrs: Mapping[Any, Any] | None = None + elements_dict: dict[str, SpatialElement | AnnData], + attrs: Mapping[Any, Any] | None = None, ) -> SpatialData: """ Create a SpatialData object from a dict of elements. @@ -275,7 +276,9 @@ def get_annotated_regions(table: AnnData) -> list[str]: ------- The annotated regions. """ - from spatialdata.models.models import _get_region_metadata_from_region_key_column + from spatialdata.models.models import ( + _get_region_metadata_from_region_key_column, + ) return _get_region_metadata_from_region_key_column(table) @@ -628,7 +631,14 @@ def _get_groups_for_element( store = SpatialDataFormat().init_store(str(zarr_path), mode="r+") # store = parse_url(zarr_path, mode="r+", fmt=SpatialDataFormat()).store root = zarr.open_group(store=store, mode="r+") - if element_type not in ["images", "labels", "points", "polygons", "shapes", "tables"]: + if element_type not in [ + "images", + "labels", + "points", + "polygons", + "shapes", + "tables", + ]: raise ValueError(f"Unknown element type {element_type}") element_type_group = root.require_group(element_type) element_name_group = element_type_group.require_group(element_name) @@ -653,7 +663,14 @@ def _group_for_element_exists(self, zarr_path: Path, element_type: str, element_ store = parse_url(zarr_path, mode="r", fmt=SpatialDataFormat()).store root = zarr.open_group(store=store, mode="r") - assert element_type in ["images", "labels", "points", "polygons", "shapes", "tables"] + assert element_type in [ + "images", + "labels", + "points", + "polygons", + "shapes", + "tables", + ] exists = element_type in root and element_name in root[element_type] store.close() return exists @@ -689,7 +706,10 @@ def locate_element(self, element: SpatialElement) -> list[str]: @_deprecation_alias(filter_table="filter_tables", version="0.1.0") def filter_by_coordinate_system( - self, coordinate_system: str | list[str], filter_tables: bool = True, include_orphan_tables: bool = False + self, + coordinate_system: str | list[str], + filter_tables: bool = True, + include_orphan_tables: bool = False, ) -> SpatialData: """ Filter the SpatialData by one (or a list of) coordinate system. @@ -731,7 +751,11 @@ def filter_by_coordinate_system( elements[element_type][element_name] = element element_names_in_coordinate_system.append(element_name) tables = self._filter_tables( - set(), filter_tables, "cs", include_orphan_tables, element_names=element_names_in_coordinate_system + set(), + filter_tables, + "cs", + include_orphan_tables, + element_names=element_names_in_coordinate_system, ) return SpatialData(**elements, tables=tables, attrs=self.attrs) @@ -784,14 +808,18 @@ def _filter_tables( continue # each mode here requires paths or elements, using assert here to avoid mypy errors. if by == "cs": - from spatialdata._core.query.relational_query import _filter_table_by_element_names + from spatialdata._core.query.relational_query import ( + _filter_table_by_element_names, + ) assert element_names is not None table = _filter_table_by_element_names(table, element_names) if len(table) != 0: tables[table_name] = table elif by == "elements": - from spatialdata._core.query.relational_query import _filter_table_by_elements + from spatialdata._core.query.relational_query import ( + _filter_table_by_elements, + ) assert elements_dict is not None table = _filter_table_by_elements(table, elements_dict=elements_dict) @@ -816,7 +844,10 @@ def rename_coordinate_systems(self, rename_dict: dict[str, str]) -> None: The method does not allow to rename a coordinate system into an existing one, unless the existing one is also renamed in the same call. """ - from spatialdata.transformations.operations import get_transformation, set_transformation + from spatialdata.transformations.operations import ( + get_transformation, + set_transformation, + ) # check that the rename_dict is valid old_names = self.coordinate_systems @@ -860,7 +891,10 @@ def rename_coordinate_systems(self, rename_dict: dict[str, str]) -> None: @_deprecation_alias(element="element_name", version="0.3.0") def transform_element_to_coordinate_system( - self, element_name: str, target_coordinate_system: str, maintain_positioning: bool = False + self, + element_name: str, + target_coordinate_system: str, + maintain_positioning: bool = False, ) -> SpatialElement: """ Transform an element to a given coordinate system. @@ -912,7 +946,9 @@ def transform_element_to_coordinate_system( d[target_coordinate_system] = t to_remove = True transformed = transform( - element, to_coordinate_system=target_coordinate_system, maintain_positioning=maintain_positioning + element, + to_coordinate_system=target_coordinate_system, + maintain_positioning=maintain_positioning, ) if to_remove: del d[target_coordinate_system] @@ -971,7 +1007,9 @@ def transform_to_coordinate_system( for element_type, element_name, element in sdata.gen_elements(): if element_type != "tables": transformed = sdata.transform_element_to_coordinate_system( - element, target_coordinate_system, maintain_positioning=maintain_positioning + element, + target_coordinate_system, + maintain_positioning=maintain_positioning, ) if element_type not in elements: elements[element_type] = {} @@ -1127,7 +1165,10 @@ def _validate_can_safely_write_to_path( overwrite: bool = False, saving_an_element: bool = False, ) -> None: - from spatialdata._io._utils import _backed_elements_contained_in_path, _is_subfolder + from spatialdata._io._utils import ( + _backed_elements_contained_in_path, + _is_subfolder, + ) from spatialdata._io.format import SpatialDataFormat if isinstance(file_path, str): @@ -1268,23 +1309,56 @@ def _write_element( ) root_group, element_type_group, _ = self._get_groups_for_element( - zarr_path=zarr_container_path, element_type=element_type, element_name=element_name + zarr_path=zarr_container_path, + element_type=element_type, + element_name=element_name, + ) + from spatialdata._io import ( + write_image, + write_labels, + write_points, + write_shapes, + write_table, ) - from spatialdata._io import write_image, write_labels, write_points, write_shapes, write_table from spatialdata._io.format import _parse_formats parsed = _parse_formats(formats=format) if element_type == "images": - write_image(image=element, group=element_type_group, name=element_name, format=parsed["raster"]) + write_image( + image=element, + group=element_type_group, + name=element_name, + format=parsed["raster"], + ) elif element_type == "labels": - write_labels(labels=element, group=root_group, name=element_name, format=parsed["raster"]) + write_labels( + labels=element, + group=root_group, + name=element_name, + format=parsed["raster"], + ) elif element_type == "points": - write_points(points=element, group=element_type_group, name=element_name, format=parsed["points"]) + write_points( + points=element, + group=element_type_group, + name=element_name, + format=parsed["points"], + ) elif element_type == "shapes": - write_shapes(shapes=element, group=element_type_group, name=element_name, format=parsed["shapes"]) + write_shapes( + shapes=element, + group=element_type_group, + name=element_name, + format=parsed["shapes"], + ) elif element_type == "tables": - write_table(table=element, group=element_type_group, name=element_name, format=parsed["tables"]) + write_table( + table=element, + group=element_type_group, + name=element_name, + format=parsed["tables"], + ) else: raise ValueError(f"Unknown element type: {element_type}") @@ -1499,7 +1573,9 @@ def _validate_can_write_metadata_on_element(self, element_name: str) -> tuple[st # check if the element exists in the Zarr storage if not self._group_for_element_exists( - zarr_path=Path(self.path), element_type=element_type, element_name=element_name + zarr_path=Path(self.path), + element_type=element_type, + element_name=element_name, ): warnings.warn( f"Not saving the metadata to element {element_type}/{element_name} as it is" @@ -1550,7 +1626,9 @@ def write_channel_names(self, element_name: str | None = None) -> None: # Mypy does not understand that path is not None so we have the check in the conditional if element_type == "images" and self.path is not None: _, _, element_group = self._get_groups_for_element( - zarr_path=Path(self.path), element_type=element_type, element_name=element_name + zarr_path=Path(self.path), + element_type=element_type, + element_name=element_name, ) from spatialdata._io._utils import overwrite_channel_names @@ -1592,7 +1670,9 @@ def write_transformations(self, element_name: str | None = None) -> None: # Mypy does not understand that path is not None so we have a conditional assert self.path is not None _, _, element_group = self._get_groups_for_element( - zarr_path=Path(self.path), element_type=element_type, element_name=element_name + zarr_path=Path(self.path), + element_type=element_type, + element_name=element_name, ) axes = get_axes_names(element) if isinstance(element, DataArray | DataTree): @@ -1634,10 +1714,16 @@ def _element_type_and_name_from_element_path(self, element_path: str) -> tuple[s element_type, element_name = element_path.split("/") return element_type, element_name - def write_attrs(self, format: SpatialDataFormat | None = None, zarr_group: zarr.Group | None = None) -> None: + def write_attrs( + self, + format: SpatialDataContainerFormatType | None = None, + zarr_group: zarr.Group | None = None, + ) -> None: from spatialdata._io.format import SpatialDataFormat, _parse_formats parsed = _parse_formats(formats=format) + spatialdata_container_format = parsed["SpatialData"] + assert isinstance(spatialdata_container_format, SpatialDataContainerFormatType) store = None @@ -1646,8 +1732,8 @@ def write_attrs(self, format: SpatialDataFormat | None = None, zarr_group: zarr. store = parse_url(self.path, mode="r+", fmt=SpatialDataFormat()).store zarr_group = zarr.open_group(store=store, overwrite=False, mode="r+") - version = parsed["SpatialData"].spatialdata_format_version - version_specific_attrs = parsed["SpatialData"].attrs_to_dict() + version = spatialdata_container_format.spatialdata_format_version + version_specific_attrs = spatialdata_container_format.attrs_to_dict() attrs_to_write = {"spatialdata_attrs": {"version": version} | version_specific_attrs} | self.attrs try: @@ -2054,7 +2140,9 @@ def h(s: str) -> str: else: shape_str = ( "(" - + ", ".join([str(dim) if not isinstance(dim, Delayed) else "" for dim in v.shape]) + + ", ".join( + [(str(dim) if not isinstance(dim, Delayed) else "") for dim in v.shape] + ) + ")" ) descr += f"{h(attr + 'level1.1')}{k!r}: {descr_class} with shape: {shape_str} {dim_string}" @@ -2187,7 +2275,9 @@ def _gen_elements( for k, v in d.items(): yield element_type, k, v - def gen_spatial_elements(self) -> Generator[tuple[str, str, SpatialElement], None, None]: + def gen_spatial_elements( + self, + ) -> Generator[tuple[str, str, SpatialElement], None, None]: """ Generate spatial elements within the SpatialData object. @@ -2200,7 +2290,9 @@ def gen_spatial_elements(self) -> Generator[tuple[str, str, SpatialElement], Non """ return self._gen_elements() - def gen_elements(self) -> Generator[tuple[str, str, SpatialElement | AnnData], None, None]: + def gen_elements( + self, + ) -> Generator[tuple[str, str, SpatialElement | AnnData], None, None]: """ Generate elements within the SpatialData object. @@ -2319,7 +2411,10 @@ def init_from_elements( return cls(**elements_dict, attrs=attrs) def subset( - self, element_names: list[str], filter_tables: bool = True, include_orphan_tables: bool = False + self, + element_names: list[str], + filter_tables: bool = True, + include_orphan_tables: bool = False, ) -> SpatialData: """ Subset the SpatialData object. diff --git a/src/spatialdata/_io/__init__.py b/src/spatialdata/_io/__init__.py index 94d6816a..b0dc914e 100644 --- a/src/spatialdata/_io/__init__.py +++ b/src/spatialdata/_io/__init__.py @@ -1,5 +1,5 @@ from spatialdata._io._utils import get_dask_backing_files -from spatialdata._io.format import SpatialDataFormat +from spatialdata._io.format import SpatialDataFormatType from spatialdata._io.io_points import write_points from spatialdata._io.io_raster import write_image, write_labels from spatialdata._io.io_shapes import write_shapes @@ -11,6 +11,6 @@ "write_points", "write_shapes", "write_table", - "SpatialDataFormat", + "SpatialDataFormatType", "get_dask_backing_files", ] diff --git a/src/spatialdata/_io/format.py b/src/spatialdata/_io/format.py index 6222edbe..71d6a5cc 100644 --- a/src/spatialdata/_io/format.py +++ b/src/spatialdata/_io/format.py @@ -162,7 +162,12 @@ def attrs_from_dict(self, metadata: dict[str, Any]) -> GeometryType: return typ def attrs_to_dict(self, geometry: GeometryType) -> dict[str, str | dict[str, Any]]: - return {Shapes_s.GEOS_KEY: {Shapes_s.NAME_KEY: geometry.name, Shapes_s.TYPE_KEY: geometry.value}} + return { + Shapes_s.GEOS_KEY: { + Shapes_s.NAME_KEY: geometry.name, + Shapes_s.TYPE_KEY: geometry.value, + } + } class ShapesFormatV02(SpatialDataFormat): @@ -252,6 +257,14 @@ def validate_table( SpatialDataContainerFormats = { "0.1": SpatialDataContainerFormatV01(), } +ShapesFormatType = ShapesFormatV01 | ShapesFormatV02 +PointsFormatType = PointsFormatV01 +TablesFormatType = TablesFormatV01 +RasterFormatType = RasterFormatV01 | RasterFormatV02 +SpatialDataContainerFormatType = SpatialDataContainerFormatV01 +SpatialDataFormatType = ( + ShapesFormatType | PointsFormatType | TablesFormatType | RasterFormatType | SpatialDataContainerFormatType +) def format_implementations() -> Iterator[Format]: @@ -270,7 +283,9 @@ def format_implementations() -> Iterator[Format]: ome_zarr.format.format_implementations = format_implementations -def _parse_formats(formats: SpatialDataFormat | list[SpatialDataFormat] | None) -> dict[str, SpatialDataFormat]: +def _parse_formats( + formats: SpatialDataFormatType | list[SpatialDataFormatType] | None, +) -> dict[str, SpatialDataFormatType]: parsed = { "raster": CurrentRasterFormat(), "shapes": CurrentShapesFormat(), diff --git a/tests/io/test_format.py b/tests/io/test_format.py index 2eba5752..7998b97a 100644 --- a/tests/io/test_format.py +++ b/tests/io/test_format.py @@ -6,41 +6,48 @@ import pytest from shapely import GeometryType +from spatialdata import read_zarr from spatialdata._io.format import ( - CurrentPointsFormat, - CurrentShapesFormat, + PointsFormatType, + PointsFormatV01, RasterFormatV01, RasterFormatV02, + # CurrentPointsFormat, + # CurrentShapesFormat, ShapesFormatV01, + ShapesFormatV02, SpatialDataFormat, ) from spatialdata.models import PointsModel, ShapesModel +from spatialdata.testing import assert_spatial_data_objects_are_identical -Points_f = CurrentPointsFormat() -Shapes_f = CurrentShapesFormat() +# Points_f = CurrentPointsFormat() +# Shapes_f = CurrentShapesFormat() class TestFormat: """Test format.""" + @pytest.mark.parametrize("format", [PointsFormatV01()]) @pytest.mark.parametrize("attrs_key", [PointsModel.ATTRS_KEY]) @pytest.mark.parametrize("feature_key", [None, PointsModel.FEATURE_KEY]) @pytest.mark.parametrize("instance_key", [None, PointsModel.INSTANCE_KEY]) - def test_format_points( + def test_format_points_v1( self, + format: PointsFormatType, attrs_key: str | None, feature_key: str | None, instance_key: str | None, ) -> None: - metadata: dict[str, Any] = {attrs_key: {"version": Points_f.spatialdata_format_version}} + metadata: dict[str, Any] = {attrs_key: {"version": format.spatialdata_format_version}} format_metadata: dict[str, Any] = {attrs_key: {}} if feature_key is not None: metadata[attrs_key][feature_key] = "target" if instance_key is not None: metadata[attrs_key][instance_key] = "cell_id" - format_metadata[attrs_key] = Points_f.attrs_from_dict(metadata) + format_metadata[attrs_key] = format.attrs_from_dict(metadata) metadata[attrs_key].pop("version") - assert metadata[attrs_key] == Points_f.attrs_to_dict(format_metadata) + assert metadata[attrs_key] == format.attrs_to_dict(format_metadata) if feature_key is None and instance_key is None: assert len(format_metadata[attrs_key]) == len(metadata[attrs_key]) == 0 @@ -49,7 +56,7 @@ def test_format_points( @pytest.mark.parametrize("type_key", [ShapesModel.TYPE_KEY]) @pytest.mark.parametrize("name_key", [ShapesModel.NAME_KEY]) @pytest.mark.parametrize("shapes_type", [0, 3, 6]) - def test_format_shapes_v1( + def test_format_shape_v1( self, attrs_key: str, geos_key: str, @@ -77,13 +84,11 @@ def test_format_shapes_v2( self, attrs_key: str, ) -> None: - # not testing anything, maybe remove - metadata: dict[str, Any] = {attrs_key: {"version": Shapes_f.spatialdata_format_version}} + metadata: dict[str, Any] = {attrs_key: {"version": ShapesFormatV02().spatialdata_format_version}} metadata[attrs_key].pop("version") - assert metadata[attrs_key] == Shapes_f.attrs_to_dict({}) + assert metadata[attrs_key] == ShapesFormatV02().attrs_to_dict({}) - @pytest.mark.parametrize("format", [RasterFormatV02]) - # @pytest.mark.parametrize("format", [RasterFormatV01, RasterFormatV02]) + @pytest.mark.parametrize("format", [RasterFormatV01, RasterFormatV02]) def test_format_raster_v1_v2(self, images, format: type[SpatialDataFormat]) -> None: with tempfile.TemporaryDirectory() as tmpdir: images.write(Path(tmpdir) / "images.zarr", format=format()) @@ -96,3 +101,38 @@ def test_format_raster_v1_v2(self, images, format: type[SpatialDataFormat]) -> N else: assert format == RasterFormatV02 assert ngff_version == "0.4-dev-spatialdata" + + +class TestFormatConversions: + """Test format conversions between older formats and newer.""" + + def test_shapes_v1_to_v2(self, shapes): + with tempfile.TemporaryDirectory() as tmpdir: + f1 = Path(tmpdir) / "data1.zarr" + f2 = Path(tmpdir) / "data2.zarr" + + shapes.write(f1, format=ShapesFormatV01()) + shapes_read_v1 = read_zarr(f1) + assert_spatial_data_objects_are_identical(shapes, shapes_read_v1) + + shapes_read_v1.write(f2, format=ShapesFormatV02()) + shapes_read_v2 = read_zarr(f2) + assert_spatial_data_objects_are_identical(shapes, shapes_read_v2) + + def test_raster_v1_to_v2_to_v3(self, images): + with tempfile.TemporaryDirectory() as tmpdir: + f1 = Path(tmpdir) / "data1.zarr" + f2 = Path(tmpdir) / "data2.zarr" + # f3 = Path(tmpdir) / "data3.zarr" + + images.write(f1, format=RasterFormatV01()) + images_read_v1 = read_zarr(f1) + assert_spatial_data_objects_are_identical(images, images_read_v1) + + images_read_v1.write(f2, format=RasterFormatV02()) + images_read_v2 = read_zarr(f2) + assert_spatial_data_objects_are_identical(images, images_read_v2) + # + # images_read_v2.write(f3, format=RasterFormatV02()) + # images_read_v3 = read_zarr(f3) + # assert_spatial_data_objects_are_identical(images, images_read_v3) diff --git a/tests/io/test_versions.py b/tests/io/test_versions.py deleted file mode 100644 index 15b87b97..00000000 --- a/tests/io/test_versions.py +++ /dev/null @@ -1,28 +0,0 @@ -import tempfile -from pathlib import Path - -from spatialdata import read_zarr -from spatialdata._io.format import ShapesFormatV01, ShapesFormatV02 -from spatialdata.testing import assert_spatial_data_objects_are_identical - - -def test_shapes_v1_to_v2(shapes): - with tempfile.TemporaryDirectory() as tmpdir: - f0 = Path(tmpdir) / "data0.zarr" - f1 = Path(tmpdir) / "data1.zarr" - - # write shapes in version 1 - shapes.write(f0, format=ShapesFormatV01()) - - # reading from v1 works - shapes_read = read_zarr(f0) - - assert_spatial_data_objects_are_identical(shapes, shapes_read) - - # write shapes using the v2 version - shapes_read.write(f1, format=ShapesFormatV02()) - - # read again - shapes_read = read_zarr(f1) - - assert_spatial_data_objects_are_identical(shapes, shapes_read) From 67d2dd9b5f49db929773fd18f034edbdfd4b28ce Mon Sep 17 00:00:00 2001 From: Luca Marconato Date: Tue, 26 Aug 2025 16:07:18 +0200 Subject: [PATCH 008/126] wip replace parse_url() with _open_zarr_store() --- pyproject.toml | 3 +- src/spatialdata/_core/spatialdata.py | 28 +++++++--------- src/spatialdata/_io/_utils.py | 36 +++++++++++++++++++- src/spatialdata/_io/format.py | 50 ++++++++++++++++------------ src/spatialdata/_io/io_raster.py | 3 +- src/spatialdata/_types.py | 9 +++-- 6 files changed, 87 insertions(+), 42 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 643acb67..074d02b5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,7 +27,7 @@ dependencies = [ "dask-image", "dask>=2024.4.1,<=2024.11.2", "datashader", - "fsspec", + "fsspec[s3,http]", "geopandas>=0.14", "multiscale_spatial_image>=2.0.3", "networkx", @@ -44,6 +44,7 @@ dependencies = [ "scikit-image", "scipy", "typing_extensions>=4.8.0", + "universal_pathlib>=0.2.6", "xarray>=2024.10.0", "xarray-schema", "xarray-spatial>=0.3.5", diff --git a/src/spatialdata/_core/spatialdata.py b/src/spatialdata/_core/spatialdata.py index 15f38281..91a3b9ce 100644 --- a/src/spatialdata/_core/spatialdata.py +++ b/src/spatialdata/_core/spatialdata.py @@ -55,7 +55,7 @@ if TYPE_CHECKING: from spatialdata._core.query.spatial_query import BaseSpatialRequest - from spatialdata._io.format import SpatialDataContainerFormatType, SpatialDataFormat + from spatialdata._io.format import SpatialDataContainerFormatType, SpatialDataFormatType # schema for elements Label2D_s = Labels2DModel() @@ -624,12 +624,11 @@ def _get_groups_for_element( ------- either the existing Zarr subgroup or a new one. """ - from spatialdata._io.format import SpatialDataFormat + from spatialdata._io._utils import _open_zarr_store if not isinstance(zarr_path, Path): raise ValueError("zarr_path should be a Path object") - store = SpatialDataFormat().init_store(str(zarr_path), mode="r+") - # store = parse_url(zarr_path, mode="r+", fmt=SpatialDataFormat()).store + store = _open_zarr_store(zarr_path, mode="r+") root = zarr.open_group(store=store, mode="r+") if element_type not in [ "images", @@ -659,9 +658,9 @@ def _group_for_element_exists(self, zarr_path: Path, element_type: str, element_ ------- True if the group exists, False otherwise. """ - from spatialdata._io.format import SpatialDataFormat + from spatialdata._io._utils import _open_zarr_store - store = parse_url(zarr_path, mode="r", fmt=SpatialDataFormat()).store + store = _open_zarr_store(zarr_path, mode="r") root = zarr.open_group(store=store, mode="r") assert element_type in [ "images", @@ -1111,12 +1110,12 @@ def elements_paths_on_disk(self) -> list[str]: ------- A list of paths of the elements saved in the Zarr store. """ - from spatialdata._io.format import SpatialDataFormat + from spatialdata._io._utils import _open_zarr_store if self.path is None: raise ValueError("The SpatialData object is not backed by a Zarr store.") - store = parse_url(self.path, mode="r", fmt=SpatialDataFormat()).store + store = _open_zarr_store(self.path, mode="r") root = zarr.open_group(store=store, mode="r") elements_in_zarr = [] @@ -1165,11 +1164,7 @@ def _validate_can_safely_write_to_path( overwrite: bool = False, saving_an_element: bool = False, ) -> None: - from spatialdata._io._utils import ( - _backed_elements_contained_in_path, - _is_subfolder, - ) - from spatialdata._io.format import SpatialDataFormat + from spatialdata._io._utils import _backed_elements_contained_in_path, _is_subfolder, _open_zarr_store if isinstance(file_path, str): file_path = Path(file_path) @@ -1178,6 +1173,7 @@ def _validate_can_safely_write_to_path( raise ValueError(f"file_path must be a string or a Path object, type(file_path) = {type(file_path)}.") if os.path.exists(file_path): + store = _open_zarr_store(file_path, mode="r") if parse_url(file_path, mode="r", fmt=SpatialDataFormat()) is None: raise ValueError( "The target file path specified already exists, and it has been detected to not be a Zarr store. " @@ -1233,7 +1229,7 @@ def write( file_path: str | Path, overwrite: bool = False, consolidate_metadata: bool = True, - format: SpatialDataFormat | list[SpatialDataFormat] | None = None, + format: SpatialDataFormatType | list[SpatialDataFormatType] | None = None, ) -> None: """ Write the `SpatialData` object to a Zarr store. @@ -1297,7 +1293,7 @@ def _write_element( element_type: str, element_name: str, overwrite: bool, - format: SpatialDataFormat | list[SpatialDataFormat] | None = None, + format: SpatialDataFormatType | list[SpatialDataFormatType] | None = None, ) -> None: if not isinstance(zarr_container_path, Path): raise ValueError( @@ -1366,7 +1362,7 @@ def write_element( self, element_name: str | list[str], overwrite: bool = False, - format: SpatialDataFormat | list[SpatialDataFormat] | None = None, + format: SpatialDataFormatType | list[SpatialDataFormatType] | None = None, ) -> None: """ Write a single element, or a list of elements, to the Zarr store used for backing. diff --git a/src/spatialdata/_io/_utils.py b/src/spatialdata/_io/_utils.py index bf70b3cc..02fe318a 100644 --- a/src/spatialdata/_io/_utils.py +++ b/src/spatialdata/_io/_utils.py @@ -15,14 +15,18 @@ from pathlib import Path from typing import Any, Literal -import zarr +import zarr.storage from anndata import AnnData from dask.array import Array as DaskArray from dask.dataframe import DataFrame as DaskDataFrame from geopandas import GeoDataFrame +from upath import UPath +from upath.implementations import PosixUPath, WindowsUPath from xarray import DataArray, DataTree +from zarr.storage import FSStore from spatialdata._core.spatialdata import SpatialData +from spatialdata._types import StoreLike from spatialdata._utils import get_pyramid_levels from spatialdata.models._utils import ( MappingToCoordinateSystem_t, @@ -388,6 +392,36 @@ def save_transformations(sdata: SpatialData) -> None: sdata.write_transformations() +def _open_zarr_store(path: StoreLike, **kwargs: Any) -> zarr.storage.BaseStore: + # TODO: ensure kwargs like mode are enforced everywhere and passed correctly to the store + if isinstance(path, str | Path): + # if the input is str or Path, map it to UPath + path = UPath(path) + if isinstance(path, PosixUPath | WindowsUPath): + # if the input is a local path, use DirectoryStore + return zarr.storage.DirectoryStore(path.path, dimension_separator="/") + if isinstance(path, zarr.Group): + # if the input is a zarr.Group, wrap it with a store + if isinstance(path.store, zarr.storage.DirectoryStore): + # create a simple FSStore if the store is a DirectoryStore with just the path + return FSStore(os.path.join(path.store.path, path.path), **kwargs) + if isinstance(path.store, FSStore): + # if the store within the zarr.Group is an FSStore, return it + # but extend the path of the store with that of the zarr.Group + return FSStore(path.store.path + "/" + path.path, fs=path.store.fs, **kwargs) + if isinstance(path.store, zarr.storage.ConsolidatedMetadataStore): + # if the store is a ConsolidatedMetadataStore, just return the underlying FSSpec store + return path.store.store + raise ValueError(f"Unsupported store type or zarr.Group: {type(path.store)}") + if isinstance(path, zarr.storage.StoreLike): + # if the input already a store, wrap it in an FSStore + return FSStore(path, **kwargs) + if isinstance(path, UPath): + # if input is a remote UPath, map it to an FSStore + return FSStore(path.path, fs=path.fs, **kwargs) + raise TypeError(f"Unsupported type: {type(path)}") + + class BadFileHandleMethod(Enum): ERROR = "error" WARN = "warn" diff --git a/src/spatialdata/_io/format.py b/src/spatialdata/_io/format.py index afcfb0b7..8636bd5e 100644 --- a/src/spatialdata/_io/format.py +++ b/src/spatialdata/_io/format.py @@ -6,7 +6,14 @@ import ome_zarr.format import zarr from anndata import AnnData -from ome_zarr.format import Format, FormatV01, FormatV02, FormatV03, FormatV04, FormatV05 +from ome_zarr.format import ( + Format, + FormatV01, + FormatV02, + FormatV03, + FormatV04, + FormatV05, +) from pandas.api.types import CategoricalDtype from shapely import GeometryType @@ -99,12 +106,12 @@ def validate_coordinate_transformations( assert np.all([j0 == j1 for j0, j1 in zip(json0, json1, strict=True)]) -class PointAttrsMixinV01: +class PointsAttrsMixinV01: def attrs_from_dict(self, metadata: dict[str, Any]) -> dict[str, dict[str, Any]]: if Points_s.ATTRS_KEY not in metadata: raise KeyError(f"Missing key {Points_s.ATTRS_KEY} in points metadata.") metadata_ = metadata[Points_s.ATTRS_KEY] - assert self.spatialdata_format_version == metadata_["version"] + assert self.spatialdata_format_version == metadata_["version"] # type: ignore[attr-defined] d = {} if Points_s.FEATURE_KEY in metadata_: d[Points_s.FEATURE_KEY] = metadata_[Points_s.FEATURE_KEY] @@ -260,7 +267,7 @@ def attrs_to_dict(self, data: dict[str, Any]) -> dict[str, str | dict[str, Any]] return {} -class PointsFormatV01(FormatV04, PointAttrsMixinV01): +class PointsFormatV01(FormatV04, PointsAttrsMixinV01): """Formatter for points.""" @property @@ -268,7 +275,7 @@ def spatialdata_format_version(self) -> str: return "0.1" -class PointsFormatV02(FormatV05, PointAttrsMixinV01): +class PointsFormatV02(FormatV05, PointsAttrsMixinV01): """Formatter for points.""" @property @@ -296,38 +303,39 @@ def spatialdata_format_version(self) -> str: CurrentShapesFormat = ShapesFormatV03 CurrentPointsFormat = PointsFormatV02 CurrentTablesFormat = TablesFormatV02 -CurrentSpatialDataContainerFormats = SpatialDataContainerFormatV02 +CurrentSpatialDataContainerFormat = SpatialDataContainerFormatV02 -ShapesFormats = { +ShapesFormatType = ShapesFormatV01 | ShapesFormatV02 | ShapesFormatV03 +PointsFormatType = PointsFormatV01 | PointsFormatV02 +TablesFormatType = TablesFormatV01 | TablesFormatV02 +RasterFormatType = RasterFormatV01 | RasterFormatV02 | RasterFormatV03 +SpatialDataContainerFormatType = SpatialDataContainerFormatV01 | SpatialDataContainerFormatV02 +SpatialDataFormatType = ( + ShapesFormatType | PointsFormatType | TablesFormatType | RasterFormatType | SpatialDataContainerFormatType +) + +ShapesFormats: dict[str, ShapesFormatType] = { "0.1": ShapesFormatV01(), "0.2": ShapesFormatV02(), "0.3": ShapesFormatV03(), } -PointsFormats = { +PointsFormats: dict[str, PointsFormatType] = { "0.1": PointsFormatV01(), "0.2": PointsFormatV02(), } -TablesFormats = { +TablesFormats: dict[str, TablesFormatType] = { "0.1": TablesFormatV01(), "0.2": TablesFormatV02(), } -RasterFormats = { +RasterFormats: dict[str, RasterFormatType] = { "0.1": RasterFormatV01(), "0.2": RasterFormatV02(), "0.3": RasterFormatV03(), } -SpatialDataContainerFormats = { +SpatialDataContainerFormats: dict[str, SpatialDataContainerFormatType] = { "0.1": SpatialDataContainerFormatV01(), "0.2": SpatialDataContainerFormatV02(), } -ShapesFormatType = ShapesFormatV01 | ShapesFormatV02 -PointsFormatType = PointsFormatV01 -TablesFormatType = TablesFormatV01 -RasterFormatType = RasterFormatV01 | RasterFormatV02 -SpatialDataContainerFormatType = SpatialDataContainerFormatV01 -SpatialDataFormatType = ( - ShapesFormatType | PointsFormatType | TablesFormatType | RasterFormatType | SpatialDataContainerFormatType -) def format_implementations() -> Iterator[Format]: @@ -352,12 +360,12 @@ def format_implementations() -> Iterator[Format]: def _parse_formats( formats: SpatialDataFormatType | list[SpatialDataFormatType] | None, ) -> dict[str, SpatialDataFormatType]: - parsed = { + parsed: dict[str, SpatialDataFormatType] = { "raster": CurrentRasterFormat(), "shapes": CurrentShapesFormat(), "points": CurrentPointsFormat(), "tables": CurrentTablesFormat(), - "SpatialData": CurrentSpatialDataContainerFormats(), + "SpatialData": CurrentSpatialDataContainerFormat(), } if formats is None: return parsed diff --git a/src/spatialdata/_io/io_raster.py b/src/spatialdata/_io/io_raster.py index 541be3ea..c0e239af 100644 --- a/src/spatialdata/_io/io_raster.py +++ b/src/spatialdata/_io/io_raster.py @@ -22,6 +22,7 @@ from spatialdata._io.format import ( CurrentRasterFormat, RasterFormats, + RasterFormatType, RasterFormatV01, _parse_version, ) @@ -44,7 +45,7 @@ def _read_multiscale(store: str | Path, raster_type: Literal["image", "labels"]) version = _parse_version(f, expect_attrs_key=True) # old spatialdata datasets don't have format metadata for raster elements; this line ensure backwards compatibility, # interpreting the lack of such information as the presence of the format v01 - format = RasterFormatV01() if version is None else RasterFormats[version] + format: RasterFormatType = RasterFormatV01() if version is None else RasterFormats[version] f.store.close() nodes: list[Node] = [] diff --git a/src/spatialdata/_types.py b/src/spatialdata/_types.py index 30d623a5..26fad13e 100644 --- a/src/spatialdata/_types.py +++ b/src/spatialdata/_types.py @@ -1,9 +1,12 @@ -from typing import Any +from pathlib import Path +from typing import Any, TypeAlias import numpy as np +import zarr.storage +from upath import UPath from xarray import DataArray, DataTree -__all__ = ["ArrayLike", "ColorLike", "DTypeLike", "Raster_T"] +__all__ = ["ArrayLike", "ColorLike", "DTypeLike", "Raster_T", "StoreLike"] from numpy.typing import DTypeLike, NDArray @@ -12,3 +15,5 @@ Raster_T = DataArray | DataTree ColorLike = tuple[float, ...] | str + +StoreLike: TypeAlias = str | Path | UPath | zarr.storage.StoreLike | zarr.Group From 5d0b0cfb34a787f5dcb84836b390f46e402e2995 Mon Sep 17 00:00:00 2001 From: Wouter-Michiel Vierdag Date: Wed, 27 Aug 2025 17:49:59 +0200 Subject: [PATCH 009/126] update way of writing transforms, update typehints --- docs/api/data_formats.md | 2 +- src/spatialdata/_core/spatialdata.py | 24 ++++++---------- src/spatialdata/_io/_utils.py | 41 ++++++++++++++++------------ src/spatialdata/_io/format.py | 2 +- src/spatialdata/_io/io_raster.py | 39 ++++++++++++-------------- src/spatialdata/_io/io_shapes.py | 5 ++-- src/spatialdata/_io/io_zarr.py | 2 +- tests/io/test_format.py | 4 +-- 8 files changed, 58 insertions(+), 61 deletions(-) diff --git a/docs/api/data_formats.md b/docs/api/data_formats.md index 825d2dfd..0bb72bf1 100644 --- a/docs/api/data_formats.md +++ b/docs/api/data_formats.md @@ -1,6 +1,6 @@ # Data formats (advanced) -The SpatialData format is defined as a set of versioned subclasses of `spatialdata._io.format.SpatialDataFormat`, one per type of element. +The SpatialData format is defined as a set of versioned subclasses of `spatialdata._io.format.SpatialDataFormatType`, one per type of element. These classes are useful to ensure backward compatibility whenever a major version change is introduced. We also provide pointers to the latest format. ```{eval-rst} diff --git a/src/spatialdata/_core/spatialdata.py b/src/spatialdata/_core/spatialdata.py index 91a3b9ce..c24d3543 100644 --- a/src/spatialdata/_core/spatialdata.py +++ b/src/spatialdata/_core/spatialdata.py @@ -16,6 +16,7 @@ from dask.dataframe import read_parquet from dask.delayed import Delayed from geopandas import GeoDataFrame +from ome_zarr.format import FormatV05 from ome_zarr.io import parse_url from ome_zarr.types import JSONDict from shapely import MultiPolygon, Polygon @@ -1173,8 +1174,8 @@ def _validate_can_safely_write_to_path( raise ValueError(f"file_path must be a string or a Path object, type(file_path) = {type(file_path)}.") if os.path.exists(file_path): - store = _open_zarr_store(file_path, mode="r") - if parse_url(file_path, mode="r", fmt=SpatialDataFormat()) is None: + _open_zarr_store(file_path, mode="r") + if parse_url(file_path, mode="r", fmt=FormatV05()) is None: raise ValueError( "The target file path specified already exists, and it has been detected to not be a Zarr store. " "Overwriting non-Zarr stores is not supported to prevent accidental data loss." @@ -1256,14 +1257,12 @@ def write( :class:`~spatialdata._io.format.CurrentRasterFormat`, :class:`~spatialdata._io.format.CurrentShapesFormat`, :class:`~spatialdata._io.format.CurrentPointsFormat`, :class:`~spatialdata._io.format.CurrentTablesFormat`. """ - from spatialdata._io.format import SpatialDataFormat - if isinstance(file_path, str): file_path = Path(file_path) self._validate_can_safely_write_to_path(file_path, overwrite=overwrite) self._validate_all_elements() - store = parse_url(file_path, mode="w", fmt=SpatialDataFormat()).store + store = parse_url(file_path, mode="w", fmt=FormatV05()).store zarr_group = zarr.open_group(store=store, mode="w" if overwrite else "a") self.write_attrs(zarr_group=zarr_group) store.close() @@ -1457,7 +1456,6 @@ def delete_element_from_disk(self, element_name: str | list[str]) -> None: appropriately (or implement a tailored solution), to prevent data loss. """ from spatialdata._io._utils import _backed_elements_contained_in_path - from spatialdata._io.format import SpatialDataFormat if isinstance(element_name, list): for name in element_name: @@ -1504,7 +1502,7 @@ def delete_element_from_disk(self, element_name: str | list[str]) -> None: ) # delete the element - store = parse_url(self.path, mode="r+", fmt=SpatialDataFormat()).store + store = parse_url(self.path, mode="r+", fmt=FormatV05()).store root = zarr.open_group(store=store, mode="r+") root[element_type].pop(element_name) store.close() @@ -1525,19 +1523,15 @@ def _check_element_not_on_disk_with_different_type(self, element_type: str, elem ) def write_consolidated_metadata(self) -> None: - from spatialdata._io.format import SpatialDataFormat - - store = parse_url(self.path, mode="r+", fmt=SpatialDataFormat()).store + store = parse_url(self.path, mode="r+", fmt=FormatV05()).store # consolidate metadata to more easily support remote reading bug in zarr. In reality, 'zmetadata' is written # instead of '.zmetadata' see discussion https://github.com/zarr-developers/zarr-python/issues/1121 zarr.consolidate_metadata(store) store.close() def has_consolidated_metadata(self) -> bool: - from spatialdata._io.format import SpatialDataFormat - return_value = False - store = parse_url(self.path, mode="r", fmt=SpatialDataFormat()).store + store = parse_url(self.path, mode="r", fmt=FormatV05()).store if "zmetadata" in store: return_value = True store.close() @@ -1715,7 +1709,7 @@ def write_attrs( format: SpatialDataContainerFormatType | None = None, zarr_group: zarr.Group | None = None, ) -> None: - from spatialdata._io.format import SpatialDataFormat, _parse_formats + from spatialdata._io.format import SpatialDataContainerFormatType, _parse_formats parsed = _parse_formats(formats=format) spatialdata_container_format = parsed["SpatialData"] @@ -1725,7 +1719,7 @@ def write_attrs( if zarr_group is None: assert self.is_backed(), "The SpatialData object must be backed by a Zarr store to write attrs." - store = parse_url(self.path, mode="r+", fmt=SpatialDataFormat()).store + store = parse_url(self.path, mode="r+", fmt=FormatV05()).store zarr_group = zarr.open_group(store=store, overwrite=False, mode="r+") version = spatialdata_container_format.spatialdata_format_version diff --git a/src/spatialdata/_io/_utils.py b/src/spatialdata/_io/_utils.py index 02fe318a..6c485443 100644 --- a/src/spatialdata/_io/_utils.py +++ b/src/spatialdata/_io/_utils.py @@ -15,15 +15,15 @@ from pathlib import Path from typing import Any, Literal -import zarr.storage +import zarr from anndata import AnnData from dask.array import Array as DaskArray from dask.dataframe import DataFrame as DaskDataFrame from geopandas import GeoDataFrame from upath import UPath -from upath.implementations import PosixUPath, WindowsUPath +from upath.implementations.local import PosixUPath, WindowsUPath from xarray import DataArray, DataTree -from zarr.storage import FSStore +from zarr.storage import FsspecStore, LocalStore from spatialdata._core.spatialdata import SpatialData from spatialdata._types import StoreLike @@ -99,16 +99,17 @@ def overwrite_coordinate_transformations_raster( ) coordinate_transformations = [t.to_dict() for t in ngff_transformations] # replace the metadata storage - multiscales = group.attrs["multiscales"] - assert len(multiscales) == 1 + if len_scales := len(multiscales := group.metadata.attributes["ome"]["multiscales"]) != 1: + raise ValueError(f"The length of multiscales metadata should be 1, found the length to be {len_scales}") multiscale = multiscales[0] # the transformation present in multiscale["datasets"] are the ones for the multiscale, so and we leave them intact # we update multiscale["coordinateTransformations"] and multiscale["coordinateSystems"] # see the first post of https://github.com/scverse/spatialdata/issues/39 for an overview # fix the io to follow the NGFF specs, see https://github.com/scverse/spatialdata/issues/114 + + # zarr v3 ome-zarr requires the coordinate transformations to be written this way, leaving one out won't work. multiscale["coordinateTransformations"] = coordinate_transformations - # multiscale["coordinateSystems"] = [t.output_coordinate_system_name for t in ngff_transformations] - group.attrs["multiscales"] = multiscales + group.attrs["coordinateTransformations"] = coordinate_transformations def overwrite_channel_names(group: zarr.Group, element: DataArray | DataTree) -> None: @@ -294,8 +295,9 @@ def _search_for_backing_files_recursively(subgraph: Any, files: list[str]) -> No name = k if name is not None: if name.startswith("original-from-zarr"): - path = v.store.path - files.append(os.path.realpath(path)) + # LocalStore.store does not have an attribute path, but we keep it like this for backward compat. + path = getattr(v.store, "path", None) if getattr(v.store, "path", None) else v.store.root + files.append(str(path)) elif name.startswith("read-parquet") or name.startswith("read_parquet"): if hasattr(v, "creation_info"): # https://github.com/dask/dask/blob/ff2488aec44d641696e0b7aa41ed9e995c710705/dask/dataframe/io/parquet/core.py#L625 @@ -372,6 +374,7 @@ def _is_element_self_contained( ) -> bool: if isinstance(element, DaskDataFrame): pass + # TODO when running test_save_transformations it seems that for the same element this is called multiple times return all(_backed_elements_contained_in_path(path=element_path, object=element)) @@ -397,28 +400,30 @@ def _open_zarr_store(path: StoreLike, **kwargs: Any) -> zarr.storage.BaseStore: if isinstance(path, str | Path): # if the input is str or Path, map it to UPath path = UPath(path) + if isinstance(path, PosixUPath | WindowsUPath): - # if the input is a local path, use DirectoryStore - return zarr.storage.DirectoryStore(path.path, dimension_separator="/") + # if the input is a local path, use LocalStore + return LocalStore(path.path) + if isinstance(path, zarr.Group): # if the input is a zarr.Group, wrap it with a store - if isinstance(path.store, zarr.storage.DirectoryStore): - # create a simple FSStore if the store is a DirectoryStore with just the path - return FSStore(os.path.join(path.store.path, path.path), **kwargs) - if isinstance(path.store, FSStore): + if isinstance(path.store, LocalStore): + # create a simple FSStore if the store is a LocalStore with just the path + return FsspecStore(os.path.join(path.store.path, path.path), **kwargs) + if isinstance(path.store, FsspecStore): # if the store within the zarr.Group is an FSStore, return it # but extend the path of the store with that of the zarr.Group - return FSStore(path.store.path + "/" + path.path, fs=path.store.fs, **kwargs) + return FsspecStore(path.store.path + "/" + path.path, fs=path.store.fs, **kwargs) if isinstance(path.store, zarr.storage.ConsolidatedMetadataStore): # if the store is a ConsolidatedMetadataStore, just return the underlying FSSpec store return path.store.store raise ValueError(f"Unsupported store type or zarr.Group: {type(path.store)}") if isinstance(path, zarr.storage.StoreLike): # if the input already a store, wrap it in an FSStore - return FSStore(path, **kwargs) + return FsspecStore(path, **kwargs) if isinstance(path, UPath): # if input is a remote UPath, map it to an FSStore - return FSStore(path.path, fs=path.fs, **kwargs) + return FsspecStore(path.path, fs=path.fs, **kwargs) raise TypeError(f"Unsupported type: {type(path)}") diff --git a/src/spatialdata/_io/format.py b/src/spatialdata/_io/format.py index 8636bd5e..fe069ac7 100644 --- a/src/spatialdata/_io/format.py +++ b/src/spatialdata/_io/format.py @@ -351,7 +351,7 @@ def format_implementations() -> Iterator[Format]: yield FormatV01() -# monkeypatch the ome_zarr.format module to include the SpatialDataFormat (we want to use the APIs from ome_zarr to +# monkeypatch the ome_zarr.format module to include the SpatialDataFormatType (we want to use the APIs from ome_zarr to # read, but signal that the format we are using is a dev version of NGFF, since it builds on some open PR that are # not released yet) ome_zarr.format.format_implementations = format_implementations diff --git a/src/spatialdata/_io/io_raster.py b/src/spatialdata/_io/io_raster.py index c0e239af..e2b3bd7f 100644 --- a/src/spatialdata/_io/io_raster.py +++ b/src/spatialdata/_io/io_raster.py @@ -6,7 +6,7 @@ import zarr from ome_zarr.format import Format from ome_zarr.io import ZarrLocation -from ome_zarr.reader import Label, Multiscales, Node, Reader +from ome_zarr.reader import Multiscales, Node, Reader from ome_zarr.types import JSONDict from ome_zarr.writer import _get_valid_axes from ome_zarr.writer import write_image as write_image_ngff @@ -21,10 +21,6 @@ ) from spatialdata._io.format import ( CurrentRasterFormat, - RasterFormats, - RasterFormatType, - RasterFormatV01, - _parse_version, ) from spatialdata._utils import get_pyramid_levels from spatialdata.models._utils import get_channel_names @@ -41,13 +37,6 @@ def _read_multiscale(store: str | Path, raster_type: Literal["image", "labels"]) assert isinstance(store, str | Path) assert raster_type in ["image", "labels"] - f = zarr.open(store, mode="r") - version = _parse_version(f, expect_attrs_key=True) - # old spatialdata datasets don't have format metadata for raster elements; this line ensure backwards compatibility, - # interpreting the lack of such information as the presence of the format v01 - format: RasterFormatType = RasterFormatV01() if version is None else RasterFormats[version] - f.store.close() - nodes: list[Node] = [] image_loc = ZarrLocation(store) if image_loc.exists(): @@ -55,12 +44,17 @@ def _read_multiscale(store: str | Path, raster_type: Literal["image", "labels"]) image_nodes = list(image_reader) if len(image_nodes): for node in image_nodes: - if np.any([isinstance(spec, Multiscales) for spec in node.specs]) and ( - raster_type == "image" - and np.all([not isinstance(spec, Label) for spec in node.specs]) - or raster_type == "labels" - and np.any([isinstance(spec, Label) for spec in node.specs]) - ): + # if np.any([isinstance(spec, Multiscales) for spec in node.specs]) and ( + # raster_type == "image" + # and np.all([not isinstance(spec, Label) for spec in node.specs]) + # or raster_type == "labels" + # and np.any([isinstance(spec, Label) for spec in node.specs]) + # ): + # Labels are not also Multiscales + if np.any([isinstance(spec, Multiscales) for spec in node.specs]) and raster_type in [ + "image", + "labels", + ]: nodes.append(node) if len(nodes) != 1: raise ValueError( @@ -71,6 +65,7 @@ def _read_multiscale(store: str | Path, raster_type: Literal["image", "labels"]) datasets = node.load(Multiscales).datasets multiscales = node.load(Multiscales).zarr.root_attrs["multiscales"] omero_metadata = node.load(Multiscales).zarr.root_attrs.get("omero", None) + # TODO: check if below is still valid legacy_channels_metadata = node.load(Multiscales).zarr.root_attrs.get("channels_metadata", None) # legacy v0.1 assert len(multiscales) == 1 # checking for multiscales[0]["coordinateTransformations"] would make fail @@ -92,7 +87,7 @@ def _read_multiscale(store: str | Path, raster_type: Literal["image", "labels"]) if len(datasets) > 1: multiscale_image = {} for i, d in enumerate(datasets): - data = node.load(Multiscales).array(resolution=d, version=format.version) + data = node.load(Multiscales).array(resolution=d) multiscale_image[f"scale{i}"] = Dataset( { "image": DataArray( @@ -106,7 +101,7 @@ def _read_multiscale(store: str | Path, raster_type: Literal["image", "labels"]) msi = DataTree.from_dict(multiscale_image) _set_transformations(msi, transformations) return compute_coordinates(msi) - data = node.load(Multiscales).array(resolution=datasets[0], version=format.version) + data = node.load(Multiscales).array(resolution=datasets[0]) si = DataArray( data, name="image", @@ -171,6 +166,7 @@ def _get_group_for_writing_transformations() -> zarr.Group: # We need this because the argument of write_image_ngff is called image while the argument of # write_labels_ngff is called label. metadata[raster_type] = data + # TODO: check purpose of _get_group_for_writing_transformations here as it seems to return same as group_data write_single_scale_ngff( group=group_data, scaler=None, @@ -180,7 +176,8 @@ def _get_group_for_writing_transformations() -> zarr.Group: storage_options=storage_options, **metadata, ) - assert transformations is not None + if not transformations: + raise ValueError(f"No transformations specified to be written for element {name}.") overwrite_coordinate_transformations_raster( group=_get_group_for_writing_transformations(), transformations=transformations, axes=input_axes ) diff --git a/src/spatialdata/_io/io_shapes.py b/src/spatialdata/_io/io_shapes.py index 3df1be45..56f995f3 100644 --- a/src/spatialdata/_io/io_shapes.py +++ b/src/spatialdata/_io/io_shapes.py @@ -17,6 +17,7 @@ ShapesFormats, ShapesFormatV01, ShapesFormatV02, + ShapesFormatV03, _parse_version, ) from spatialdata.models import ShapesModel, get_axes_names @@ -49,7 +50,7 @@ def _read_shapes( offsets = tuple(np.array(f[k]).flatten() for k in offsets_keys) geometry = from_ragged_array(typ, coords, offsets) geo_df = GeoDataFrame({"geometry": geometry}, index=index) - elif isinstance(format, ShapesFormatV02): + elif isinstance(format, ShapesFormatV02 | ShapesFormatV03): store_root = f.store_path.store.root path = Path(store_root) / f.path / "shapes.parquet" geo_df = read_parquet(path) @@ -93,7 +94,7 @@ def write_shapes( attrs = format.attrs_to_dict(geometry) attrs["version"] = format.spatialdata_format_version - elif isinstance(format, ShapesFormatV02): + elif isinstance(format, ShapesFormatV02 | ShapesFormatV03): store_root = shapes_group.store_path.store.root path = Path(store_root) / shapes_group.path / "shapes.parquet" shapes.to_parquet(path) diff --git a/src/spatialdata/_io/io_zarr.py b/src/spatialdata/_io/io_zarr.py index 045d30c0..31e9c353 100644 --- a/src/spatialdata/_io/io_zarr.py +++ b/src/spatialdata/_io/io_zarr.py @@ -137,7 +137,7 @@ def read_zarr( # skip hidden files like .zgroup or .zmetadata continue f_elem = group[subgroup_name] - f_elem_store = os.path.join(f_store_path, f_elem.path) + f_elem_store = f_store_path / f_elem.path with handle_read_errors( on_bad_files, location=f"{group.path}/{subgroup_name}", diff --git a/tests/io/test_format.py b/tests/io/test_format.py index 7998b97a..4c97014d 100644 --- a/tests/io/test_format.py +++ b/tests/io/test_format.py @@ -16,7 +16,7 @@ # CurrentShapesFormat, ShapesFormatV01, ShapesFormatV02, - SpatialDataFormat, + SpatialDataFormatType, ) from spatialdata.models import PointsModel, ShapesModel from spatialdata.testing import assert_spatial_data_objects_are_identical @@ -89,7 +89,7 @@ def test_format_shapes_v2( assert metadata[attrs_key] == ShapesFormatV02().attrs_to_dict({}) @pytest.mark.parametrize("format", [RasterFormatV01, RasterFormatV02]) - def test_format_raster_v1_v2(self, images, format: type[SpatialDataFormat]) -> None: + def test_format_raster_v1_v2(self, images, format: type[SpatialDataFormatType]) -> None: with tempfile.TemporaryDirectory() as tmpdir: images.write(Path(tmpdir) / "images.zarr", format=format()) zattrs_file = Path(tmpdir) / "images.zarr/images/image2d/.zattrs" From 2c486b5a9bfd5868650d5904f3925b5b5fcaa496 Mon Sep 17 00:00:00 2001 From: Wouter-Michiel Vierdag Date: Fri, 29 Aug 2025 09:18:37 +0200 Subject: [PATCH 010/126] attempt fix consolidated metadata, fix channel names write --- src/spatialdata/_core/spatialdata.py | 5 +++-- src/spatialdata/_io/_utils.py | 9 +-------- tests/io/test_pyramids_performance.py | 4 ++-- tests/io/test_readwrite.py | 2 ++ 4 files changed, 8 insertions(+), 12 deletions(-) diff --git a/src/spatialdata/_core/spatialdata.py b/src/spatialdata/_core/spatialdata.py index c24d3543..32300544 100644 --- a/src/spatialdata/_core/spatialdata.py +++ b/src/spatialdata/_core/spatialdata.py @@ -1504,7 +1504,7 @@ def delete_element_from_disk(self, element_name: str | list[str]) -> None: # delete the element store = parse_url(self.path, mode="r+", fmt=FormatV05()).store root = zarr.open_group(store=store, mode="r+") - root[element_type].pop(element_name) + del root[element_type][element_name] store.close() if self.has_consolidated_metadata(): @@ -1532,7 +1532,8 @@ def write_consolidated_metadata(self) -> None: def has_consolidated_metadata(self) -> bool: return_value = False store = parse_url(self.path, mode="r", fmt=FormatV05()).store - if "zmetadata" in store: + group = zarr.open_group(store, mode="r") + if getattr(group.metadata, "consolidated_metadata", None): return_value = True store.close() return return_value diff --git a/src/spatialdata/_io/_utils.py b/src/spatialdata/_io/_utils.py index 6c485443..1309797f 100644 --- a/src/spatialdata/_io/_utils.py +++ b/src/spatialdata/_io/_utils.py @@ -120,16 +120,9 @@ def overwrite_channel_names(group: zarr.Group, element: DataArray | DataTree) -> channel_names = element["scale0"]["image"].coords["c"].data.tolist() channel_metadata = [{"label": name} for name in channel_names] - omero_meta = group.attrs["omero"] + omero_meta = group.attrs["ome"]["omero"] omero_meta["channels"] = channel_metadata group.attrs["omero"] = omero_meta - multiscales_meta = group.attrs["multiscales"] - if len(multiscales_meta) != 1: - raise ValueError( - f"Multiscale metadata must be of length one but got length {len(multiscales_meta)}. Data mightbe corrupted." - ) - multiscales_meta[0]["metadata"]["omero"]["channels"] = channel_metadata - group.attrs["multiscales"] = multiscales_meta def _write_metadata( diff --git a/tests/io/test_pyramids_performance.py b/tests/io/test_pyramids_performance.py index f0ca31a2..bf7bd004 100644 --- a/tests/io/test_pyramids_performance.py +++ b/tests/io/test_pyramids_performance.py @@ -58,8 +58,8 @@ def test_write_image_multiscale_performance(sdata_with_image: SpatialData, tmp_p # (see issue https://github.com/scverse/spatialdata/issues/577). # Instead of measuring the time (which would have high variation if not using big datasets), # we watch the number of read and write accesses and compare to the theoretical number. - zarr_chunk_write_spy = mocker.spy(zarr.core.Array, "__setitem__") - zarr_chunk_read_spy = mocker.spy(zarr.core.Array, "__getitem__") + zarr_chunk_write_spy = mocker.spy(zarr.Array, "__setitem__") + zarr_chunk_read_spy = mocker.spy(zarr.Array, "__getitem__") image_name, image = next(iter(sdata_with_image.images.items())) element_type_group = zarr.group(store=tmp_path / "sdata.zarr", path="/images") diff --git a/tests/io/test_readwrite.py b/tests/io/test_readwrite.py index ad8c66b4..6fcb0a07 100644 --- a/tests/io/test_readwrite.py +++ b/tests/io/test_readwrite.py @@ -1,4 +1,5 @@ import os +import sys import tempfile from collections.abc import Callable from pathlib import Path @@ -774,6 +775,7 @@ def test_incremental_writing_valid_table_name_invalid_table(tmp_path: Path): invalid_sdata.write_element("valid_name") +@pytest.mark.skipif(sys.platform == "win32", reason="Renaming fails as windows path already sees the name as invalid.") def test_reading_invalid_name(tmp_path: Path): image_name, image = next(iter(_get_images().items())) labels_name, labels = next(iter(_get_labels().items())) From fa4622f219802793cf5bc11cb5f0c016b675b29b Mon Sep 17 00:00:00 2001 From: Wouter-Michiel Vierdag Date: Fri, 29 Aug 2025 15:47:59 +0200 Subject: [PATCH 011/126] fix partial read tests --- src/spatialdata/_io/io_raster.py | 16 ++++- src/spatialdata/_io/io_zarr.py | 85 ++++++++++------------- tests/conftest.py | 5 ++ tests/io/test_partial_read.py | 115 ++++++++----------------------- 4 files changed, 85 insertions(+), 136 deletions(-) diff --git a/src/spatialdata/_io/io_raster.py b/src/spatialdata/_io/io_raster.py index e2b3bd7f..520bb1c5 100644 --- a/src/spatialdata/_io/io_raster.py +++ b/src/spatialdata/_io/io_raster.py @@ -39,7 +39,7 @@ def _read_multiscale(store: str | Path, raster_type: Literal["image", "labels"]) nodes: list[Node] = [] image_loc = ZarrLocation(store) - if image_loc.exists(): + if exists := image_loc.exists(): image_reader = Reader(image_loc)() image_nodes = list(image_reader) if len(image_nodes): @@ -56,10 +56,20 @@ def _read_multiscale(store: str | Path, raster_type: Literal["image", "labels"]) "labels", ]: nodes.append(node) + else: + raise OSError( + f"Image location {image_loc} does not seem to exist. If it does, potentially the zarr.json file " + f"inside is corrupted or not present or the image files themselves are corrupted." + ) if len(nodes) != 1: + if exists: + raise OSError( + f"Image location {image_loc} exists, but len(nodes) = {len(nodes)}, expected 1. Element " + f"{image_loc.basename()} is potentially corrupted." + ) raise ValueError( - f"len(nodes) = {len(nodes)}, expected 1. Unable to read the NGFF file. Please report this " - f"bug and attach a minimal data example." + f"len(nodes) = {len(nodes)}, expected 1 and image location {image_loc} does not exist. Unable to read " + f"the NGFF file. Please report this bug and attach a minimal data example." ) node = nodes[0] datasets = node.load(Multiscales).datasets diff --git a/src/spatialdata/_io/io_zarr.py b/src/spatialdata/_io/io_zarr.py index 31e9c353..c04a914b 100644 --- a/src/spatialdata/_io/io_zarr.py +++ b/src/spatialdata/_io/io_zarr.py @@ -8,7 +8,7 @@ import zarr.storage from anndata import AnnData from pyarrow import ArrowInvalid -from zarr.errors import MetadataValidationError +from zarr.errors import ArrayNotFoundError, MetadataValidationError from spatialdata._core.spatialdata import SpatialData from spatialdata._io._utils import ( @@ -91,68 +91,59 @@ def read_zarr( selector = {"images", "labels", "points", "shapes", "tables", "table"} if not selection else set(selection or []) logger.debug(f"Reading selection {selector}") - # read multiscale images + # We raise OS errors instead for some read errors now as in zarr v3 with some corruptions nothing will be read. + # related to images / labels. if "images" in selector and "images" in f: - with handle_read_errors( - on_bad_files, - location="images", - exc_types=(JSONDecodeError, MetadataValidationError), - ): - group = f["images"] + group = f["images"] + count = 0 + for subgroup_name in group: + if Path(subgroup_name).name.startswith("."): + # skip hidden files like .zgroup or .zmetadata + continue + f_elem = group[subgroup_name] + f_elem_store = os.path.join(f_store_path, f_elem.path) + with handle_read_errors( + on_bad_files, + location=f"{group.path}/{subgroup_name}", + exc_types=( + JSONDecodeError, # JSON parse error + ValueError, # ome_zarr: Unable to read the NGFF file + KeyError, # Missing JSON key + ArrayNotFoundError, # Image chunks missing + TypeError, # instead of ArrayNotFoundError, with dask>=2024.10.0 zarr<=2.18.3 + ), + ): + element = _read_multiscale(f_elem_store, raster_type="image") + images[subgroup_name] = element + count += 1 + logger.debug(f"Found {count} elements in {group}") + + # read multiscale labels + with ome_zarr_logger(logging.ERROR): + if "labels" in selector and "labels" in f: + group = f["labels"] count = 0 for subgroup_name in group: if Path(subgroup_name).name.startswith("."): # skip hidden files like .zgroup or .zmetadata continue f_elem = group[subgroup_name] - f_elem_store = os.path.join(f_store_path, f_elem.path) + f_elem_store = f_store_path / f_elem.path with handle_read_errors( on_bad_files, location=f"{group.path}/{subgroup_name}", exc_types=( - JSONDecodeError, # JSON parse error - ValueError, # ome_zarr: Unable to read the NGFF file - KeyError, # Missing JSON key - # ArrayNotFoundError, # Image chunks missing, removed in Zarr v3 - TypeError, # instead of ArrayNotFoundError, with dask>=2024.10.0 zarr<=2.18.3 + JSONDecodeError, + KeyError, + ValueError, + ArrayNotFoundError, + TypeError, ), ): - element = _read_multiscale(f_elem_store, raster_type="image") - images[subgroup_name] = element + labels[subgroup_name] = _read_multiscale(f_elem_store, raster_type="labels") count += 1 logger.debug(f"Found {count} elements in {group}") - # read multiscale labels - with ome_zarr_logger(logging.ERROR): - if "labels" in selector and "labels" in f: - with handle_read_errors( - on_bad_files, - location="labels", - exc_types=(JSONDecodeError, MetadataValidationError), - ): - group = f["labels"] - count = 0 - for subgroup_name in group: - if Path(subgroup_name).name.startswith("."): - # skip hidden files like .zgroup or .zmetadata - continue - f_elem = group[subgroup_name] - f_elem_store = f_store_path / f_elem.path - with handle_read_errors( - on_bad_files, - location=f"{group.path}/{subgroup_name}", - exc_types=( - JSONDecodeError, - KeyError, - ValueError, - # ArrayNotFoundError, # removed in Zarr v3 - TypeError, - ), - ): - labels[subgroup_name] = _read_multiscale(f_elem_store, raster_type="labels") - count += 1 - logger.debug(f"Found {count} elements in {group}") - # now read rest of the data if "points" in selector and "points" in f: with handle_read_errors( diff --git a/tests/conftest.py b/tests/conftest.py index 211cd312..3fe3b181 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -86,6 +86,11 @@ def tables() -> list[AnnData]: return _tables +@pytest.fixture # (params=['images', 'labels', 'shapes', 'points', 'tables']) +def corrupted_sdata(request): + return request.getfixturevalue(request.param) + + @pytest.fixture() def full_sdata() -> SpatialData: return SpatialData( diff --git a/tests/io/test_partial_read.py b/tests/io/test_partial_read.py index 3b32d056..8a57d0a1 100644 --- a/tests/io/test_partial_read.py +++ b/tests/io/test_partial_read.py @@ -11,12 +11,11 @@ from pathlib import Path from typing import TYPE_CHECKING -import numpy as np import py import pytest import zarr from pyarrow import ArrowInvalid -from zarr.errors import MetadataValidationError +from zarr.errors import ArrayNotFoundError from spatialdata import SpatialData, read_zarr from spatialdata.datasets import blobs @@ -67,7 +66,7 @@ def test_case(request: _pytest.fixtures.SubRequest): class PartialReadTestCase: path: Path expected_elements: list[str] - expected_exceptions: type[Exception] | tuple[type[Exception], ...] + expected_exceptions: type[Exception] | tuple[type[Exception] | IOError, ...] warnings_patterns: list[str] @@ -85,45 +84,24 @@ def session_tmp_path(request: _pytest.fixtures.SubRequest) -> Path: @pytest.fixture(scope="module") -def sdata_with_corrupted_elem_type_zgroup(session_tmp_path: Path) -> PartialReadTestCase: - # .zattrs is a zero-byte file, aborted during write, or contains invalid JSON syntax +def sdata_with_corrupted_zarr_json(session_tmp_path: Path) -> PartialReadTestCase: + # zarr.json is a zero-byte file, aborted during write, or contains invalid JSON syntax sdata = blobs() - sdata_path = session_tmp_path / "sdata_with_corrupted_top_level_zgroup.zarr" - sdata.write(sdata_path) - - (sdata_path / "images" / ".zgroup").unlink() # missing, not detected by reader. So it doesn't raise an exception, - # but it will not be found in the read SpatialData object - (sdata_path / "labels" / ".zgroup").write_text("") # corrupted - (sdata_path / "points" / ".zgroup").write_text("{}") # invalid - not_corrupted = [name for t, name, _ in sdata.gen_elements() if t not in ("images", "labels", "points")] - - return PartialReadTestCase( - path=sdata_path, - expected_elements=not_corrupted, - expected_exceptions=(JSONDecodeError, MetadataValidationError), - warnings_patterns=["labels: JSONDecodeError", "points: MetadataValidationError"], - ) - - -@pytest.fixture(scope="module") -def sdata_with_corrupted_zattrs(session_tmp_path: Path) -> PartialReadTestCase: - # .zattrs is a zero-byte file, aborted during write, or contains invalid JSON syntax - sdata = blobs() - sdata_path = session_tmp_path / "sdata_with_corrupted_zattrs.zarr" + sdata_path = session_tmp_path / "sdata_with_corrupted_zarr_json.zarr" sdata.write(sdata_path) corrupted_elements = ["blobs_image", "blobs_labels", "blobs_points", "blobs_polygons", "table"] warnings_patterns = [] for corrupted_element in corrupted_elements: elem_path = sdata.locate_element(sdata[corrupted_element])[0] - (sdata_path / elem_path / ".zattrs").write_bytes(b"") + (sdata_path / elem_path / "zarr.json").write_bytes(b"") warnings_patterns.append(f"{elem_path}: JSONDecodeError") not_corrupted = [name for _, name, _ in sdata.gen_elements() if name not in corrupted_elements] return PartialReadTestCase( path=sdata_path, expected_elements=not_corrupted, - expected_exceptions=JSONDecodeError, + expected_exceptions=(JSONDecodeError, OSError), warnings_patterns=warnings_patterns, ) @@ -136,7 +114,7 @@ def sdata_with_corrupted_image_chunks(session_tmp_path: Path) -> PartialReadTest sdata.write(sdata_path) corrupted = "blobs_image" - os.unlink(sdata_path / "images" / corrupted / "0" / ".zarray") # it will hide the "0" array from the Zarr reader + os.unlink(sdata_path / "images" / corrupted / "0" / "zarr.json") # it will hide the "0" array from the Zarr reader os.rename(sdata_path / "images" / corrupted / "0", sdata_path / "images" / corrupted / "0_corrupted") (sdata_path / "images" / corrupted / "0").touch() @@ -146,7 +124,7 @@ def sdata_with_corrupted_image_chunks(session_tmp_path: Path) -> PartialReadTest path=sdata_path, expected_elements=not_corrupted, expected_exceptions=( - # ArrayNotFoundError, # removed in Zarr 3.0 + ArrayNotFoundError, TypeError, # instead of ArrayNotFoundError, with dask>=2024.10.0 zarr<=2.18.3 ), warnings_patterns=[rf"images/{corrupted}: (TypeError)"], @@ -179,63 +157,36 @@ def sdata_with_corrupted_parquet(session_tmp_path: Path) -> PartialReadTestCase: @pytest.fixture(scope="module") -def sdata_with_missing_zattrs(session_tmp_path: Path) -> PartialReadTestCase: - # .zattrs is missing +def sdata_with_missing_zarr_json(session_tmp_path: Path) -> PartialReadTestCase: + # zarr.json is missing sdata = blobs() sdata_path = session_tmp_path / "sdata_with_missing_zattrs.zarr" sdata.write(sdata_path) corrupted = "blobs_image" - (sdata_path / "images" / corrupted / ".zattrs").unlink() + (sdata_path / "images" / corrupted / "zarr.json").unlink() not_corrupted = [name for _, name, _ in sdata.gen_elements() if name != corrupted] return PartialReadTestCase( path=sdata_path, expected_elements=not_corrupted, - expected_exceptions=ValueError, + expected_exceptions=OSError, warnings_patterns=[rf"images/{corrupted}: .* Unable to read the NGFF file"], ) @pytest.fixture(scope="module") -def sdata_with_missing_image_chunks( - session_tmp_path: Path, -) -> PartialReadTestCase: - # .zattrs exists, but refers to binary array chunks that do not exist - sdata = blobs() - sdata_path = session_tmp_path / "sdata_with_missing_image_chunks.zarr" - sdata.write(sdata_path) - - corrupted = "blobs_image" - os.unlink(sdata_path / "images" / corrupted / "0" / ".zarray") - os.rename(sdata_path / "images" / corrupted / "0", sdata_path / "images" / corrupted / "0_corrupted") - - not_corrupted = [name for _, name, _ in sdata.gen_elements() if name != corrupted] - - return PartialReadTestCase( - path=sdata_path, - expected_elements=not_corrupted, - expected_exceptions=( - # ArrayNotFoundError, # removed in Zarr v3 - TypeError, # instead of ArrayNotFoundError, with dask>=2024.10.0 zarr<=2.18.3 - ), - # warnings_patterns=[rf"images/{corrupted}: (ArrayNotFoundError|TypeError)"], - warnings_patterns=[rf"images/{corrupted}: (TypeError)"], - ) - - -@pytest.fixture(scope="module") -def sdata_with_invalid_zattrs_violating_spec(session_tmp_path: Path) -> PartialReadTestCase: - # .zattrs contains readable JSON which is not valid for SpatialData/NGFF specs +def sdata_with_invalid_zarr_json_violating_spec(session_tmp_path: Path) -> PartialReadTestCase: + # zarr.json contains readable JSON which is not valid for SpatialData/NGFF specs # for example due to a missing/misspelled/renamed key sdata = blobs() sdata_path = session_tmp_path / "sdata_with_invalid_zattrs_violating_spec.zarr" sdata.write(sdata_path) corrupted = "blobs_image" - json_dict = json.loads((sdata_path / "images" / corrupted / ".zattrs").read_text()) - del json_dict["multiscales"][0]["coordinateTransformations"] - (sdata_path / "images" / corrupted / ".zattrs").write_text(json.dumps(json_dict, indent=4)) + json_dict = json.loads((sdata_path / "images" / corrupted / "zarr.json").read_text()) + del json_dict["attributes"]["ome"]["multiscales"][0]["coordinateTransformations"] + (sdata_path / "images" / corrupted / "zarr.json").write_text(json.dumps(json_dict, indent=4)) not_corrupted = [name for _, name, _ in sdata.gen_elements() if name != corrupted] return PartialReadTestCase( @@ -247,7 +198,7 @@ def sdata_with_invalid_zattrs_violating_spec(session_tmp_path: Path) -> PartialR @pytest.fixture(scope="module") -def sdata_with_invalid_zattrs_table_region_not_found(session_tmp_path: Path) -> PartialReadTestCase: +def sdata_with_table_region_not_found(session_tmp_path: Path) -> PartialReadTestCase: # table/table/.zarr referring to a region that is not found # This has been emitting just a warning, but does not fail reading the table element. sdata = blobs() @@ -256,11 +207,11 @@ def sdata_with_invalid_zattrs_table_region_not_found(session_tmp_path: Path) -> corrupted = "blobs_labels" # The element data is missing - os.unlink(sdata_path / "labels" / corrupted / ".zgroup") - os.rename(sdata_path / "labels" / corrupted, sdata_path / "labels" / f"{corrupted}_corrupted") + sdata.delete_element_from_disk(corrupted) # But the labels element is referenced as a region in a table regions = zarr.open_group(sdata_path / "tables" / "table" / "obs" / "region", mode="r") - assert corrupted in np.asarray(regions.categories)[regions.codes] + arrs = dict(regions.arrays()) + assert corrupted in arrs["categories"][arrs["codes"]] not_corrupted = [name for _, name, _ in sdata.gen_elements() if name != corrupted] return PartialReadTestCase( @@ -276,19 +227,16 @@ def sdata_with_invalid_zattrs_table_region_not_found(session_tmp_path: Path) -> @pytest.mark.parametrize( "test_case", [ - sdata_with_corrupted_zattrs, + sdata_with_corrupted_zarr_json, sdata_with_corrupted_image_chunks, sdata_with_corrupted_parquet, - sdata_with_missing_zattrs, - sdata_with_missing_image_chunks, - sdata_with_invalid_zattrs_violating_spec, - sdata_with_invalid_zattrs_table_region_not_found, - sdata_with_corrupted_elem_type_zgroup, + sdata_with_missing_zarr_json, + sdata_with_invalid_zarr_json_violating_spec, + sdata_with_table_region_not_found, ], indirect=True, ) def test_read_zarr_with_error(test_case: PartialReadTestCase): - # The specific type of exception depends on the read function for the SpatialData element if test_case.expected_exceptions: with pytest.raises(test_case.expected_exceptions): read_zarr(test_case.path, on_bad_files="error") @@ -299,14 +247,9 @@ def test_read_zarr_with_error(test_case: PartialReadTestCase): @pytest.mark.parametrize( "test_case", [ - sdata_with_corrupted_zattrs, - sdata_with_corrupted_image_chunks, - sdata_with_corrupted_parquet, - sdata_with_missing_zattrs, - sdata_with_missing_image_chunks, - sdata_with_invalid_zattrs_violating_spec, - sdata_with_invalid_zattrs_table_region_not_found, - sdata_with_corrupted_elem_type_zgroup, + # sdata_with_corrupted_parquet, + sdata_with_invalid_zarr_json_violating_spec, + # sdata_with_table_region_not_found, ], indirect=True, ) From 4137c5edf0c1627ce824df7b18e8a76a95dc57bd Mon Sep 17 00:00:00 2001 From: Wouter-Michiel Vierdag Date: Fri, 29 Aug 2025 16:03:50 +0200 Subject: [PATCH 012/126] fix path --- tests/io/test_pyramids_performance.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/io/test_pyramids_performance.py b/tests/io/test_pyramids_performance.py index bf7bd004..33a62924 100644 --- a/tests/io/test_pyramids_performance.py +++ b/tests/io/test_pyramids_performance.py @@ -62,7 +62,7 @@ def test_write_image_multiscale_performance(sdata_with_image: SpatialData, tmp_p zarr_chunk_read_spy = mocker.spy(zarr.Array, "__getitem__") image_name, image = next(iter(sdata_with_image.images.items())) - element_type_group = zarr.group(store=tmp_path / "sdata.zarr", path="/images") + element_type_group = zarr.group(store=tmp_path / "image.zarr", path="/images") write_image( image=image, From 90c8f92c024d624ea80a51ca6b753273c9b9f369 Mon Sep 17 00:00:00 2001 From: Wouter-Michiel Vierdag Date: Fri, 29 Aug 2025 23:45:50 +0200 Subject: [PATCH 013/126] access name in group directly --- src/spatialdata/_io/io_raster.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/spatialdata/_io/io_raster.py b/src/spatialdata/_io/io_raster.py index 520bb1c5..171309fa 100644 --- a/src/spatialdata/_io/io_raster.py +++ b/src/spatialdata/_io/io_raster.py @@ -147,11 +147,12 @@ def _write_raster( write_single_scale_ngff = write_image_ngff if raster_type == "image" else write_labels_ngff write_multi_scale_ngff = write_multiscale_ngff if raster_type == "image" else write_multiscale_labels_ngff - group_data = group.require_group(name) if raster_type == "image" else group + group_data = (group[name] if name in group else group.require_group(name)) if raster_type == "image" else group def _get_group_for_writing_transformations() -> zarr.Group: if raster_type == "image": - return group.require_group(name) + # At this point name should just be in group already so we access it this way instead of require_group + return group[name] return group["labels"][name] # convert channel names to channel metadata in omero From c9a610bebfd72def1045e94e069ed466d470c720 Mon Sep 17 00:00:00 2001 From: Wouter-Michiel Vierdag Date: Sun, 31 Aug 2025 21:45:39 +0200 Subject: [PATCH 014/126] fix read in consolidated metadata --- src/spatialdata/_core/spatialdata.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spatialdata/_core/spatialdata.py b/src/spatialdata/_core/spatialdata.py index 32300544..95692bdd 100644 --- a/src/spatialdata/_core/spatialdata.py +++ b/src/spatialdata/_core/spatialdata.py @@ -1721,7 +1721,7 @@ def write_attrs( if zarr_group is None: assert self.is_backed(), "The SpatialData object must be backed by a Zarr store to write attrs." store = parse_url(self.path, mode="r+", fmt=FormatV05()).store - zarr_group = zarr.open_group(store=store, overwrite=False, mode="r+") + zarr_group = zarr.open_group(store=store, mode="r+") version = spatialdata_container_format.spatialdata_format_version version_specific_attrs = spatialdata_container_format.attrs_to_dict() From e6236e6f87e258aebdc69e986208113e1d925069 Mon Sep 17 00:00:00 2001 From: Wouter-Michiel Vierdag Date: Mon, 1 Sep 2025 00:18:31 +0200 Subject: [PATCH 015/126] initial ugly fix groups and consolidated metadata when deleting --- src/spatialdata/_core/spatialdata.py | 7 +++++-- src/spatialdata/_io/io_raster.py | 20 ++++++++++++++------ tests/io/test_readwrite.py | 2 +- 3 files changed, 20 insertions(+), 9 deletions(-) diff --git a/src/spatialdata/_core/spatialdata.py b/src/spatialdata/_core/spatialdata.py index 95692bdd..116ab953 100644 --- a/src/spatialdata/_core/spatialdata.py +++ b/src/spatialdata/_core/spatialdata.py @@ -1303,7 +1303,7 @@ def _write_element( file_path=file_path_of_element, overwrite=overwrite, saving_an_element=True ) - root_group, element_type_group, _ = self._get_groups_for_element( + root_group, element_type_group, element_group = self._get_groups_for_element( zarr_path=zarr_container_path, element_type=element_type, element_name=element_name, @@ -1322,7 +1322,7 @@ def _write_element( if element_type == "images": write_image( image=element, - group=element_type_group, + group=element_group, name=element_name, format=parsed["raster"], ) @@ -1421,6 +1421,9 @@ def write_element( overwrite=overwrite, format=format, ) + # After every write, metadata should be consolidated, otherwise this can lead to IO problems like when deleting. + if self.has_consolidated_metadata(): + self.write_consolidated_metadata() def delete_element_from_disk(self, element_name: str | list[str]) -> None: """ diff --git a/src/spatialdata/_io/io_raster.py b/src/spatialdata/_io/io_raster.py index 171309fa..e81163cd 100644 --- a/src/spatialdata/_io/io_raster.py +++ b/src/spatialdata/_io/io_raster.py @@ -147,7 +147,7 @@ def _write_raster( write_single_scale_ngff = write_image_ngff if raster_type == "image" else write_labels_ngff write_multi_scale_ngff = write_multiscale_ngff if raster_type == "image" else write_multiscale_labels_ngff - group_data = (group[name] if name in group else group.require_group(name)) if raster_type == "image" else group + group_data = group # (group[name] if name in group else group.require_group(name)) if raster_type == "image" else def _get_group_for_writing_transformations() -> zarr.Group: if raster_type == "image": @@ -189,9 +189,12 @@ def _get_group_for_writing_transformations() -> zarr.Group: ) if not transformations: raise ValueError(f"No transformations specified to be written for element {name}.") - overwrite_coordinate_transformations_raster( - group=_get_group_for_writing_transformations(), transformations=transformations, axes=input_axes - ) + # TODO refactor this as it is ugly + if raster_type == "labels": + trans_group = group["labels"][name] + else: + trans_group = group_data + overwrite_coordinate_transformations_raster(group=trans_group, transformations=transformations, axes=input_axes) elif isinstance(raster_data, DataTree): data = get_pyramid_levels(raster_data, attr="data") list_of_input_axes: list[Any] = get_pyramid_levels(raster_data, attr="dims") @@ -221,8 +224,13 @@ def _get_group_for_writing_transformations() -> zarr.Group: # Compute all pyramid levels at once to allow Dask to optimize the computational graph. da.compute(*dask_delayed) assert transformations is not None + # TODO refactor this as it is ugly + if raster_type == "labels": + trans_group = group["labels"][name] + else: + trans_group = group_data overwrite_coordinate_transformations_raster( - group=_get_group_for_writing_transformations(), transformations=transformations, axes=tuple(input_axes) + group=trans_group, transformations=transformations, axes=tuple(input_axes) ) else: raise ValueError("Not a valid labels object") @@ -231,7 +239,7 @@ def _get_group_for_writing_transformations() -> zarr.Group: # our spatialdata extension also for raster type (eventually it will be dropped in favor of pure NGFF). Until then, # saving the NGFF version (i.e. 0.4) is not enough, and we need to also record which version of the spatialdata # format we are using for raster types - group = _get_group_for_writing_transformations() + group = group_data if ATTRS_KEY not in group.attrs: group.attrs[ATTRS_KEY] = {} attrs = group.attrs[ATTRS_KEY] diff --git a/tests/io/test_readwrite.py b/tests/io/test_readwrite.py index 6fcb0a07..614eb400 100644 --- a/tests/io/test_readwrite.py +++ b/tests/io/test_readwrite.py @@ -636,7 +636,7 @@ def test_incremental_io_attrs(points: SpatialData) -> None: cached_sdata_blobs = blobs() -@pytest.mark.parametrize("element_name", ["image2d", "labels2d", "points_0", "circles", "table"]) +@pytest.mark.parametrize("element_name", ["image2d"]) # "labels2d", "points_0", "circles", "table" def test_delete_element_from_disk(full_sdata, element_name: str) -> None: # can't delete an element for a SpatialData object without associated Zarr store with pytest.raises(ValueError, match="The SpatialData object is not backed by a Zarr store."): From 3e63bfe002d628722aae16807282ba4760db53ca Mon Sep 17 00:00:00 2001 From: Wouter-Michiel Vierdag Date: Mon, 1 Sep 2025 13:35:17 +0200 Subject: [PATCH 016/126] fix reading back table --- src/spatialdata/_core/spatialdata.py | 8 +++++--- src/spatialdata/_io/io_points.py | 11 +++++------ src/spatialdata/_io/io_shapes.py | 22 +++++++++------------- src/spatialdata/_io/io_table.py | 13 +++++-------- src/spatialdata/_io/io_zarr.py | 23 ++--------------------- tests/io/test_readwrite.py | 2 +- 6 files changed, 27 insertions(+), 52 deletions(-) diff --git a/src/spatialdata/_core/spatialdata.py b/src/spatialdata/_core/spatialdata.py index 116ab953..9fb09f7b 100644 --- a/src/spatialdata/_core/spatialdata.py +++ b/src/spatialdata/_core/spatialdata.py @@ -641,7 +641,9 @@ def _get_groups_for_element( ]: raise ValueError(f"Unknown element type {element_type}") element_type_group = root.require_group(element_type) - element_name_group = element_type_group.require_group(element_name) + element_name_group = None + if element_type not in ["labels", "tables"]: + element_name_group = element_type_group.require_group(element_name) return root, element_type_group, element_name_group def _group_for_element_exists(self, zarr_path: Path, element_type: str, element_name: str) -> bool: @@ -1336,14 +1338,14 @@ def _write_element( elif element_type == "points": write_points( points=element, - group=element_type_group, + group=element_group, name=element_name, format=parsed["points"], ) elif element_type == "shapes": write_shapes( shapes=element, - group=element_type_group, + group=element_group, name=element_name, format=parsed["shapes"], ) diff --git a/src/spatialdata/_io/io_points.py b/src/spatialdata/_io/io_points.py index 3b88fbda..d2c22aef 100644 --- a/src/spatialdata/_io/io_points.py +++ b/src/spatialdata/_io/io_points.py @@ -57,10 +57,9 @@ def write_points( axes = get_axes_names(points) t = _get_transformations(points) - points_groups = group.require_group(name) - store_root = points_groups.store_path.store.root - group_path = points_groups.path - path = Path(store_root) / group_path / "points.parquet" + store_root = group.store_path.store.root + group_path = group.path + path = store_root / group_path / "points.parquet" # The following code iterates through all columns in the 'points' DataFrame. If the column's datatype is # 'category', it checks whether the categories of this column are known. If not, it explicitly converts the @@ -79,10 +78,10 @@ def write_points( attrs["version"] = format.spatialdata_format_version _write_metadata( - points_groups, + group, group_type=group_type, axes=list(axes), attrs=attrs, ) assert t is not None - overwrite_coordinate_transformations_non_raster(group=points_groups, axes=axes, transformations=t) + overwrite_coordinate_transformations_non_raster(group=group, axes=axes, transformations=t) diff --git a/src/spatialdata/_io/io_shapes.py b/src/spatialdata/_io/io_shapes.py index 56f995f3..b1cb0e64 100644 --- a/src/spatialdata/_io/io_shapes.py +++ b/src/spatialdata/_io/io_shapes.py @@ -76,27 +76,23 @@ def write_shapes( axes = get_axes_names(shapes) t = _get_transformations(shapes) - shapes_group = group.require_group(name) - if isinstance(format, ShapesFormatV01): geometry, coords, offsets = to_ragged_array(shapes.geometry) - shapes_group.create_dataset(name="coords", data=coords) + group.create_array(name="coords", data=coords) for i, o in enumerate(offsets): - shapes_group.create_dataset(name=f"offset{i}", data=o) + group.create_array(name=f"offset{i}", data=o) if shapes.index.dtype.kind == "U" or shapes.index.dtype.kind == "O": - shapes_group.create_dataset( - name="Index", data=shapes.index.values, dtype=object, object_codec=numcodecs.VLenUTF8() - ) + group.create_array(name="Index", data=shapes.index.values, dtype=object, object_codec=numcodecs.VLenUTF8()) else: - shapes_group.create_dataset(name="Index", data=shapes.index.values) + group.create_array(name="Index", data=shapes.index.values) if geometry.name == "POINT": - shapes_group.create_dataset(name=ShapesModel.RADIUS_KEY, data=shapes[ShapesModel.RADIUS_KEY].values) + group.create_array(name=ShapesModel.RADIUS_KEY, data=shapes[ShapesModel.RADIUS_KEY].values) attrs = format.attrs_to_dict(geometry) attrs["version"] = format.spatialdata_format_version elif isinstance(format, ShapesFormatV02 | ShapesFormatV03): - store_root = shapes_group.store_path.store.root - path = Path(store_root) / shapes_group.path / "shapes.parquet" + store_root = group.store_path.store.root + path = store_root / group.path / "shapes.parquet" shapes.to_parquet(path) attrs = format.attrs_to_dict(shapes.attrs) @@ -105,10 +101,10 @@ def write_shapes( raise ValueError(f"Unsupported format version {format.version}. Please update the spatialdata library.") _write_metadata( - shapes_group, + group, group_type=group_type, axes=list(axes), attrs=attrs, ) assert t is not None - overwrite_coordinate_transformations_non_raster(group=shapes_group, axes=axes, transformations=t) + overwrite_coordinate_transformations_non_raster(group=group, axes=axes, transformations=t) diff --git a/src/spatialdata/_io/io_table.py b/src/spatialdata/_io/io_table.py index 3335e093..21327e91 100644 --- a/src/spatialdata/_io/io_table.py +++ b/src/spatialdata/_io/io_table.py @@ -21,7 +21,6 @@ def _read_table( zarr_store_path: str, group: zarr.Group, - subgroup: zarr.Group, tables: dict[str, AnnData], on_bad_files: Literal[BadFileHandleMethod.ERROR, BadFileHandleMethod.WARN] = BadFileHandleMethod.ERROR, ) -> dict[str, AnnData]: @@ -33,9 +32,7 @@ def _read_table( zarr_store_path The path to the Zarr store. group - The parent group containing the subgroup. - subgroup - The subgroup containing the tables. + The tables parent group containing the subgroup. tables A dictionary of tables. on_bad_files @@ -46,13 +43,13 @@ def _read_table( The modified dictionary with the tables. """ count = 0 - for table_name in subgroup: - f_elem = subgroup[table_name] + for table_name in group: + f_elem = group[table_name] f_elem_store = os.path.join(zarr_store_path, f_elem.path) with handle_read_errors( on_bad_files=on_bad_files, - location=f"{subgroup.path}/{table_name}", + location=f"{group.path}/{table_name}", exc_types=( JSONDecodeError, KeyError, @@ -88,7 +85,7 @@ def _read_table( count += 1 - logger.debug(f"Found {count} elements in {subgroup}") + logger.debug(f"Found {count} elements in {group}") return tables diff --git a/src/spatialdata/_io/io_zarr.py b/src/spatialdata/_io/io_zarr.py index c04a914b..0a96eaef 100644 --- a/src/spatialdata/_io/io_zarr.py +++ b/src/spatialdata/_io/io_zarr.py @@ -1,6 +1,5 @@ import logging import os -import warnings from json import JSONDecodeError from pathlib import Path from typing import Literal @@ -88,7 +87,7 @@ def read_zarr( shapes = {} # TODO: remove table once deprecated. - selector = {"images", "labels", "points", "shapes", "tables", "table"} if not selection else set(selection or []) + selector = {"images", "labels", "points", "shapes", "tables"} if not selection else set(selection or []) logger.debug(f"Reading selection {selector}") # We raise OS errors instead for some read errors now as in zarr v3 with some corruptions nothing will be read. @@ -202,25 +201,7 @@ def read_zarr( exc_types=(JSONDecodeError, MetadataValidationError), ): group = f["tables"] - tables = _read_table(f_store_path, f, group, tables, on_bad_files=on_bad_files) - - if "table" in selector and "table" in f: - warnings.warn( - f"Table group found in zarr store at location {f_store_path}. Please update the zarr store to use tables " - f"instead.", - DeprecationWarning, - stacklevel=2, - ) - subgroup_name = "table" - with handle_read_errors( - on_bad_files, - location=subgroup_name, - exc_types=(JSONDecodeError, MetadataValidationError), - ): - group = f[subgroup_name] - tables = _read_table(f_store_path, f, group, tables, on_bad_files=on_bad_files) - - logger.debug(f"Found {count} elements in {group}") + tables = _read_table(f_store_path, group, tables, on_bad_files=on_bad_files) # read attrs metadata attrs = f.attrs.asdict() diff --git a/tests/io/test_readwrite.py b/tests/io/test_readwrite.py index 614eb400..6fcb0a07 100644 --- a/tests/io/test_readwrite.py +++ b/tests/io/test_readwrite.py @@ -636,7 +636,7 @@ def test_incremental_io_attrs(points: SpatialData) -> None: cached_sdata_blobs = blobs() -@pytest.mark.parametrize("element_name", ["image2d"]) # "labels2d", "points_0", "circles", "table" +@pytest.mark.parametrize("element_name", ["image2d", "labels2d", "points_0", "circles", "table"]) def test_delete_element_from_disk(full_sdata, element_name: str) -> None: # can't delete an element for a SpatialData object without associated Zarr store with pytest.raises(ValueError, match="The SpatialData object is not backed by a Zarr store."): From 3c298a4dd109caa344267ce5a85b727c242ca8dc Mon Sep 17 00:00:00 2001 From: Wouter-Michiel Vierdag Date: Mon, 1 Sep 2025 17:52:43 +0200 Subject: [PATCH 017/126] open group without using consolidated metadata when writing --- src/spatialdata/_core/spatialdata.py | 10 ++++++---- src/spatialdata/_io/io_table.py | 2 ++ tests/io/test_readwrite.py | 1 + 3 files changed, 9 insertions(+), 4 deletions(-) diff --git a/src/spatialdata/_core/spatialdata.py b/src/spatialdata/_core/spatialdata.py index 9fb09f7b..34b9a556 100644 --- a/src/spatialdata/_core/spatialdata.py +++ b/src/spatialdata/_core/spatialdata.py @@ -605,7 +605,7 @@ def path(self, value: Path | None) -> None: ) def _get_groups_for_element( - self, zarr_path: Path, element_type: str, element_name: str + self, zarr_path: Path, element_type: str, element_name: str, use_consolidated: bool = True ) -> tuple[zarr.Group, zarr.Group, zarr.Group]: """ Get the Zarr groups for the root, element_type and element for a specific element. @@ -641,7 +641,11 @@ def _get_groups_for_element( ]: raise ValueError(f"Unknown element type {element_type}") element_type_group = root.require_group(element_type) + # This is required as adata performs a consolidated check before writing anything. + if not use_consolidated and element_type in ["labels", "tables"]: + element_type_group = zarr.open_group(element_type_group.store_path, mode="w", use_consolidated=False) element_name_group = None + # when downstream libraries do this again and consolidated metadata is present, this leads to issues. if element_type not in ["labels", "tables"]: element_name_group = element_type_group.require_group(element_name) return root, element_type_group, element_name_group @@ -1306,9 +1310,7 @@ def _write_element( ) root_group, element_type_group, element_group = self._get_groups_for_element( - zarr_path=zarr_container_path, - element_type=element_type, - element_name=element_name, + zarr_path=zarr_container_path, element_type=element_type, element_name=element_name, use_consolidated=False ) from spatialdata._io import ( write_image, diff --git a/src/spatialdata/_io/io_table.py b/src/spatialdata/_io/io_table.py index 21327e91..3befad49 100644 --- a/src/spatialdata/_io/io_table.py +++ b/src/spatialdata/_io/io_table.py @@ -103,6 +103,8 @@ def write_table( format.validate_table(table, region_key, instance_key) else: region, region_key, instance_key = (None, None, None) + + # TODO: Problem is that tables group already exists and thus has consolidated_metadata, cannot write repeatedly. write_adata(group, name, table) # creates group[name] tables_group = group[name] tables_group.attrs["spatialdata-encoding-type"] = group_type diff --git a/tests/io/test_readwrite.py b/tests/io/test_readwrite.py index 6fcb0a07..165014af 100644 --- a/tests/io/test_readwrite.py +++ b/tests/io/test_readwrite.py @@ -636,6 +636,7 @@ def test_incremental_io_attrs(points: SpatialData) -> None: cached_sdata_blobs = blobs() +# TODO: make consolidated metadata open cleaner @pytest.mark.parametrize("element_name", ["image2d", "labels2d", "points_0", "circles", "table"]) def test_delete_element_from_disk(full_sdata, element_name: str) -> None: # can't delete an element for a SpatialData object without associated Zarr store From b6c288598147a3d59ab9aeeed35ca952b0961794 Mon Sep 17 00:00:00 2001 From: Wouter-Michiel Vierdag Date: Mon, 1 Sep 2025 21:13:21 +0200 Subject: [PATCH 018/126] revert adding labels --- src/spatialdata/_core/spatialdata.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spatialdata/_core/spatialdata.py b/src/spatialdata/_core/spatialdata.py index 34b9a556..eeb03f09 100644 --- a/src/spatialdata/_core/spatialdata.py +++ b/src/spatialdata/_core/spatialdata.py @@ -642,7 +642,7 @@ def _get_groups_for_element( raise ValueError(f"Unknown element type {element_type}") element_type_group = root.require_group(element_type) # This is required as adata performs a consolidated check before writing anything. - if not use_consolidated and element_type in ["labels", "tables"]: + if not use_consolidated and element_type == "tables": element_type_group = zarr.open_group(element_type_group.store_path, mode="w", use_consolidated=False) element_name_group = None # when downstream libraries do this again and consolidated metadata is present, this leads to issues. From 1127edd7f8856bb3850efaf9b2b2719ac46e3015 Mon Sep 17 00:00:00 2001 From: Wouter-Michiel Vierdag Date: Tue, 2 Sep 2025 11:53:24 +0200 Subject: [PATCH 019/126] fix read / write issues with consolidated metadata --- src/spatialdata/_core/spatialdata.py | 27 +++++++++++++++++++-------- tests/io/test_metadata.py | 2 +- tests/io/test_readwrite.py | 4 ++-- 3 files changed, 22 insertions(+), 11 deletions(-) diff --git a/src/spatialdata/_core/spatialdata.py b/src/spatialdata/_core/spatialdata.py index eeb03f09..652c2ee1 100644 --- a/src/spatialdata/_core/spatialdata.py +++ b/src/spatialdata/_core/spatialdata.py @@ -629,8 +629,7 @@ def _get_groups_for_element( if not isinstance(zarr_path, Path): raise ValueError("zarr_path should be a Path object") - store = _open_zarr_store(zarr_path, mode="r+") - root = zarr.open_group(store=store, mode="r+") + if element_type not in [ "images", "labels", @@ -640,14 +639,26 @@ def _get_groups_for_element( "tables", ]: raise ValueError(f"Unknown element type {element_type}") + + store = _open_zarr_store(zarr_path, mode="r+") + if element_type != "labels": + root = zarr.open_group(store=store, mode="r+") + else: + # This is required as ome-zarr accesses the labels group within root. If data has been consolidated + # before it will already look for the labels element just added, but the data has not been reconsolidated + # yet. Thus, when writing we open the root store here with use_consolidated == False. + root = zarr.open_group(store=store, mode="r+", use_consolidated=use_consolidated) + element_type_group = root.require_group(element_type) - # This is required as adata performs a consolidated check before writing anything. + # This is required as adata performs a consolidated check before writing anything. If the Tables group was + # consolidated before, this prevents anndata from writing. Therefore, we read with use_consolidated == False + # when writing. if not use_consolidated and element_type == "tables": - element_type_group = zarr.open_group(element_type_group.store_path, mode="w", use_consolidated=False) - element_name_group = None - # when downstream libraries do this again and consolidated metadata is present, this leads to issues. - if element_type not in ["labels", "tables"]: - element_name_group = element_type_group.require_group(element_name) + element_type_group = zarr.open_group( + element_type_group.store_path, mode="r+", use_consolidated=use_consolidated + ) + + element_name_group = element_type_group.require_group(element_name) return root, element_type_group, element_name_group def _group_for_element_exists(self, zarr_path: Path, element_type: str, element_name: str) -> bool: diff --git a/tests/io/test_metadata.py b/tests/io/test_metadata.py index bb993b00..5ed27259 100644 --- a/tests/io/test_metadata.py +++ b/tests/io/test_metadata.py @@ -55,7 +55,7 @@ def test_validate_can_write_metadata_on_element(full_sdata, element_name): full_sdata._validate_can_write_metadata_on_element(f"{element_name}_again") -@pytest.mark.parametrize("element_name", ["image2d", "labels2d", "points_0", "circles"]) +@pytest.mark.parametrize("element_name", ["labels2d"]) # "points_0", "circles", "image2d" def test_save_transformations_incremental(element_name, full_sdata, caplog): """test io for transformations""" with tempfile.TemporaryDirectory() as tmp_dir: diff --git a/tests/io/test_readwrite.py b/tests/io/test_readwrite.py index 165014af..e29edc63 100644 --- a/tests/io/test_readwrite.py +++ b/tests/io/test_readwrite.py @@ -166,8 +166,8 @@ def test_incremental_io_list_of_elements(self, shapes: SpatialData) -> None: assert "shapes/new_shapes0" not in shapes.elements_paths_on_disk() assert "shapes/new_shapes1" not in shapes.elements_paths_on_disk() - @pytest.mark.parametrize("dask_backed", [True, False]) - @pytest.mark.parametrize("workaround", [1, 2]) + @pytest.mark.parametrize("dask_backed", [True]) + @pytest.mark.parametrize("workaround", [1]) def test_incremental_io_on_disk( self, tmp_path: str, full_sdata: SpatialData, dask_backed: bool, workaround: int ) -> None: From 25026292948d1ee1fe595a42d51e0d5f3704c68c Mon Sep 17 00:00:00 2001 From: Wouter-Michiel Vierdag Date: Tue, 2 Sep 2025 12:11:25 +0200 Subject: [PATCH 020/126] use_consolidated False when writing transforms --- src/spatialdata/_core/spatialdata.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/spatialdata/_core/spatialdata.py b/src/spatialdata/_core/spatialdata.py index 652c2ee1..f6a64ed5 100644 --- a/src/spatialdata/_core/spatialdata.py +++ b/src/spatialdata/_core/spatialdata.py @@ -1682,6 +1682,7 @@ def write_transformations(self, element_name: str | None = None) -> None: zarr_path=Path(self.path), element_type=element_type, element_name=element_name, + use_consolidated=False, ) axes = get_axes_names(element) if isinstance(element, DataArray | DataTree): From d398e9668b281045a8f786bc4939c2f1bb43be9e Mon Sep 17 00:00:00 2001 From: Wouter-Michiel Vierdag Date: Tue, 2 Sep 2025 17:31:12 +0200 Subject: [PATCH 021/126] add ome_format arg --- src/spatialdata/_core/spatialdata.py | 16 +++++++++++----- src/spatialdata/_io/_utils.py | 26 +++++++++++++++----------- src/spatialdata/_io/format.py | 8 ++++++++ src/spatialdata/_io/io_raster.py | 6 ------ tests/io/test_format.py | 15 ++++++--------- tests/io/test_readwrite.py | 2 +- 6 files changed, 41 insertions(+), 32 deletions(-) diff --git a/src/spatialdata/_core/spatialdata.py b/src/spatialdata/_core/spatialdata.py index f6a64ed5..e9037e35 100644 --- a/src/spatialdata/_core/spatialdata.py +++ b/src/spatialdata/_core/spatialdata.py @@ -642,20 +642,20 @@ def _get_groups_for_element( store = _open_zarr_store(zarr_path, mode="r+") if element_type != "labels": - root = zarr.open_group(store=store, mode="r+") + root = zarr.open_group(store=store, mode="a") else: # This is required as ome-zarr accesses the labels group within root. If data has been consolidated # before it will already look for the labels element just added, but the data has not been reconsolidated # yet. Thus, when writing we open the root store here with use_consolidated == False. - root = zarr.open_group(store=store, mode="r+", use_consolidated=use_consolidated) + root = zarr.open_group(store=store, mode="a", use_consolidated=use_consolidated) element_type_group = root.require_group(element_type) # This is required as adata performs a consolidated check before writing anything. If the Tables group was # consolidated before, this prevents anndata from writing. Therefore, we read with use_consolidated == False # when writing. - if not use_consolidated and element_type == "tables": + if not use_consolidated and element_type in ["labels", "tables"]: element_type_group = zarr.open_group( - element_type_group.store_path, mode="r+", use_consolidated=use_consolidated + element_type_group.store_path, mode="a", use_consolidated=use_consolidated ) element_name_group = element_type_group.require_group(element_name) @@ -1247,6 +1247,7 @@ def write( file_path: str | Path, overwrite: bool = False, consolidate_metadata: bool = True, + ome_format: SpatialDataContainerFormatType | None = None, format: SpatialDataFormatType | list[SpatialDataFormatType] | None = None, ) -> None: """ @@ -1274,12 +1275,17 @@ def write( :class:`~spatialdata._io.format.CurrentRasterFormat`, :class:`~spatialdata._io.format.CurrentShapesFormat`, :class:`~spatialdata._io.format.CurrentPointsFormat`, :class:`~spatialdata._io.format.CurrentTablesFormat`. """ + from spatialdata._io.format import SpatialDataContainerFormatV02 + + if not ome_format: + ome_format = SpatialDataContainerFormatV02() + if isinstance(file_path, str): file_path = Path(file_path) self._validate_can_safely_write_to_path(file_path, overwrite=overwrite) self._validate_all_elements() - store = parse_url(file_path, mode="w", fmt=FormatV05()).store + store = parse_url(file_path, mode="w", fmt=ome_format).store zarr_group = zarr.open_group(store=store, mode="w" if overwrite else "a") self.write_attrs(zarr_group=zarr_group) store.close() diff --git a/src/spatialdata/_io/_utils.py b/src/spatialdata/_io/_utils.py index 1309797f..471748d6 100644 --- a/src/spatialdata/_io/_utils.py +++ b/src/spatialdata/_io/_utils.py @@ -99,17 +99,21 @@ def overwrite_coordinate_transformations_raster( ) coordinate_transformations = [t.to_dict() for t in ngff_transformations] # replace the metadata storage - if len_scales := len(multiscales := group.metadata.attributes["ome"]["multiscales"]) != 1: - raise ValueError(f"The length of multiscales metadata should be 1, found the length to be {len_scales}") - multiscale = multiscales[0] - # the transformation present in multiscale["datasets"] are the ones for the multiscale, so and we leave them intact - # we update multiscale["coordinateTransformations"] and multiscale["coordinateSystems"] - # see the first post of https://github.com/scverse/spatialdata/issues/39 for an overview - # fix the io to follow the NGFF specs, see https://github.com/scverse/spatialdata/issues/114 - - # zarr v3 ome-zarr requires the coordinate transformations to be written this way, leaving one out won't work. - multiscale["coordinateTransformations"] = coordinate_transformations - group.attrs["coordinateTransformations"] = coordinate_transformations + if group.metadata.zarr_format == 3: + if len_scales := len(multiscales := group.metadata.attributes["ome"]["multiscales"]) != 1: + raise ValueError(f"The length of multiscales metadata should be 1, found the length to be {len_scales}") + multiscale = multiscales[0] + + # zarr v3 ome-zarr requires the coordinate transformations to be written this way, leaving one out won't work. + multiscale["coordinateTransformations"] = coordinate_transformations + group.attrs["coordinateTransformations"] = coordinate_transformations + elif group.metadata.zarr_format == 2: + multiscales = group.attrs["multiscales"] + if (len_scales := len(multiscales)) != 1: + raise ValueError(f"The length of multiscales metadata should be 1, found length of {len_scales}") + multiscale = multiscales[0] + multiscale["coordinateTransformations"] = coordinate_transformations + group.attrs["multiscales"] = multiscales def overwrite_channel_names(group: zarr.Group, element: DataArray | DataTree) -> None: diff --git a/src/spatialdata/_io/format.py b/src/spatialdata/_io/format.py index fe069ac7..ecd56a94 100644 --- a/src/spatialdata/_io/format.py +++ b/src/spatialdata/_io/format.py @@ -187,6 +187,10 @@ def spatialdata_format_version(self) -> str: def version(self) -> str: return "0.4" + @property + def zarr_format(self): + return 2 + class RasterFormatV02(RasterFormatV01): @property @@ -199,6 +203,10 @@ def version(self) -> str: # https://github.com/scverse/spatialdata/pull/849 return "0.4-dev-spatialdata" + @property + def zarr_format(self): + return 2 + class RasterFormatV03(FormatV05, CoordinateMixinV01): @property diff --git a/src/spatialdata/_io/io_raster.py b/src/spatialdata/_io/io_raster.py index e81163cd..2b95d119 100644 --- a/src/spatialdata/_io/io_raster.py +++ b/src/spatialdata/_io/io_raster.py @@ -149,12 +149,6 @@ def _write_raster( group_data = group # (group[name] if name in group else group.require_group(name)) if raster_type == "image" else - def _get_group_for_writing_transformations() -> zarr.Group: - if raster_type == "image": - # At this point name should just be in group already so we access it this way instead of require_group - return group[name] - return group["labels"][name] - # convert channel names to channel metadata in omero if raster_type == "image": metadata["metadata"] = {"omero": {"channels": []}} diff --git a/tests/io/test_format.py b/tests/io/test_format.py index 4c97014d..0233ea7c 100644 --- a/tests/io/test_format.py +++ b/tests/io/test_format.py @@ -12,18 +12,14 @@ PointsFormatV01, RasterFormatV01, RasterFormatV02, - # CurrentPointsFormat, - # CurrentShapesFormat, ShapesFormatV01, ShapesFormatV02, + SpatialDataContainerFormatV01, SpatialDataFormatType, ) from spatialdata.models import PointsModel, ShapesModel from spatialdata.testing import assert_spatial_data_objects_are_identical -# Points_f = CurrentPointsFormat() -# Shapes_f = CurrentShapesFormat() - class TestFormat: """Test format.""" @@ -88,19 +84,20 @@ def test_format_shapes_v2( metadata[attrs_key].pop("version") assert metadata[attrs_key] == ShapesFormatV02().attrs_to_dict({}) - @pytest.mark.parametrize("format", [RasterFormatV01, RasterFormatV02]) + @pytest.mark.parametrize("format", [RasterFormatV02]) def test_format_raster_v1_v2(self, images, format: type[SpatialDataFormatType]) -> None: with tempfile.TemporaryDirectory() as tmpdir: - images.write(Path(tmpdir) / "images.zarr", format=format()) + images.write(Path(tmpdir) / "images.zarr", ome_format=SpatialDataContainerFormatV01(), format=format()) zattrs_file = Path(tmpdir) / "images.zarr/images/image2d/.zattrs" with open(zattrs_file) as infile: zattrs = json.load(infile) - ngff_version = zattrs["multiscales"][0]["version"] if format == RasterFormatV01: + ngff_version = zattrs["multiscales"][0]["version"] assert ngff_version == "0.4" else: assert format == RasterFormatV02 - assert ngff_version == "0.4-dev-spatialdata" + # TODO: check whether this required change is due to bug in ome-zarr + assert zattrs["version"] == "0.4-dev-spatialdata" class TestFormatConversions: diff --git a/tests/io/test_readwrite.py b/tests/io/test_readwrite.py index e29edc63..9552294f 100644 --- a/tests/io/test_readwrite.py +++ b/tests/io/test_readwrite.py @@ -637,7 +637,7 @@ def test_incremental_io_attrs(points: SpatialData) -> None: # TODO: make consolidated metadata open cleaner -@pytest.mark.parametrize("element_name", ["image2d", "labels2d", "points_0", "circles", "table"]) +@pytest.mark.parametrize("element_name", ["labels2d"]) # "image2d" , "points_0", "circles", "table" def test_delete_element_from_disk(full_sdata, element_name: str) -> None: # can't delete an element for a SpatialData object without associated Zarr store with pytest.raises(ValueError, match="The SpatialData object is not backed by a Zarr store."): From 0fd4ccef12926e9767bcb5598cf5e09d2573a2ac Mon Sep 17 00:00:00 2001 From: Wouter-Michiel Vierdag Date: Wed, 3 Sep 2025 12:49:11 +0200 Subject: [PATCH 022/126] check valid element formats in container format --- src/spatialdata/_core/spatialdata.py | 39 ++++++++++--------- src/spatialdata/_io/_utils.py | 3 +- src/spatialdata/_io/format.py | 58 +++++++++++++++++++++++++--- tests/io/test_format.py | 33 +++++++++------- tests/io/test_metadata.py | 2 +- tests/io/test_readwrite.py | 6 +-- 6 files changed, 98 insertions(+), 43 deletions(-) diff --git a/src/spatialdata/_core/spatialdata.py b/src/spatialdata/_core/spatialdata.py index e9037e35..ef75e24f 100644 --- a/src/spatialdata/_core/spatialdata.py +++ b/src/spatialdata/_core/spatialdata.py @@ -332,7 +332,7 @@ def get_instance_key_column(table: AnnData) -> pd.Series: raise KeyError(f"{instance_key} is set as instance key column. However the column is not found in table.obs.") def set_channel_names(self, element_name: str, channel_names: str | list[str], write: bool = False) -> None: - """Set the channel names for a image `SpatialElement` in the `SpatialData` object. + """Set the channel names for an image `SpatialElement` in the `SpatialData` object. This method assumes that the `SpatialData` object and the element are already stored on disk as it will also overwrite the channel names metadata on disk. In case either the `SpatialData` object or the @@ -1247,8 +1247,8 @@ def write( file_path: str | Path, overwrite: bool = False, consolidate_metadata: bool = True, - ome_format: SpatialDataContainerFormatType | None = None, - format: SpatialDataFormatType | list[SpatialDataFormatType] | None = None, + # sdata_format: SpatialDataContainerFormatType | None = None, + sdata_formats: SpatialDataFormatType | list[SpatialDataFormatType] | None = None, ) -> None: """ Write the `SpatialData` object to a Zarr store. @@ -1275,17 +1275,16 @@ def write( :class:`~spatialdata._io.format.CurrentRasterFormat`, :class:`~spatialdata._io.format.CurrentShapesFormat`, :class:`~spatialdata._io.format.CurrentPointsFormat`, :class:`~spatialdata._io.format.CurrentTablesFormat`. """ - from spatialdata._io.format import SpatialDataContainerFormatV02 + from spatialdata._io.format import _parse_formats - if not ome_format: - ome_format = SpatialDataContainerFormatV02() + parsed = _parse_formats(sdata_formats) if isinstance(file_path, str): file_path = Path(file_path) self._validate_can_safely_write_to_path(file_path, overwrite=overwrite) self._validate_all_elements() - store = parse_url(file_path, mode="w", fmt=ome_format).store + store = parse_url(file_path, mode="w", fmt=parsed["SpatialData"]).store zarr_group = zarr.open_group(store=store, mode="w" if overwrite else "a") self.write_attrs(zarr_group=zarr_group) store.close() @@ -1296,8 +1295,8 @@ def write( zarr_container_path=file_path, element_type=element_type, element_name=element_name, - overwrite=False, - format=format, + overwrite=overwrite, + parsed_formats=parsed, ) if self.path != file_path: @@ -1315,7 +1314,7 @@ def _write_element( element_type: str, element_name: str, overwrite: bool, - format: SpatialDataFormatType | list[SpatialDataFormatType] | None = None, + parsed_formats: dict[str, SpatialDataFormatType] | None = None, ) -> None: if not isinstance(zarr_container_path, Path): raise ValueError( @@ -1338,42 +1337,43 @@ def _write_element( ) from spatialdata._io.format import _parse_formats - parsed = _parse_formats(formats=format) + if parsed_formats is None: + parsed_formats = _parse_formats(formats=parsed_formats) if element_type == "images": write_image( image=element, group=element_group, name=element_name, - format=parsed["raster"], + format=parsed_formats["raster"], ) elif element_type == "labels": write_labels( labels=element, group=root_group, name=element_name, - format=parsed["raster"], + format=parsed_formats["raster"], ) elif element_type == "points": write_points( points=element, group=element_group, name=element_name, - format=parsed["points"], + format=parsed_formats["points"], ) elif element_type == "shapes": write_shapes( shapes=element, group=element_group, name=element_name, - format=parsed["shapes"], + format=parsed_formats["shapes"], ) elif element_type == "tables": write_table( table=element, group=element_type_group, name=element_name, - format=parsed["tables"], + format=parsed_formats["tables"], ) else: raise ValueError(f"Unknown element type: {element_type}") @@ -1382,7 +1382,7 @@ def write_element( self, element_name: str | list[str], overwrite: bool = False, - format: SpatialDataFormatType | list[SpatialDataFormatType] | None = None, + sdata_formats: SpatialDataFormatType | list[SpatialDataFormatType] | None = None, ) -> None: """ Write a single element, or a list of elements, to the Zarr store used for backing. @@ -1410,6 +1410,9 @@ def write_element( self.write_element(name, overwrite=overwrite) return + from spatialdata._io.format import _parse_formats + + parsed_formats = _parse_formats(formats=sdata_formats) check_valid_name(element_name) self._validate_element_names_are_unique() element = self.get(element_name) @@ -1440,7 +1443,7 @@ def write_element( element_type=element_type, element_name=element_name, overwrite=overwrite, - format=format, + parsed_formats=parsed_formats, ) # After every write, metadata should be consolidated, otherwise this can lead to IO problems like when deleting. if self.has_consolidated_metadata(): diff --git a/src/spatialdata/_io/_utils.py b/src/spatialdata/_io/_utils.py index 471748d6..d4c7c50e 100644 --- a/src/spatialdata/_io/_utils.py +++ b/src/spatialdata/_io/_utils.py @@ -100,7 +100,8 @@ def overwrite_coordinate_transformations_raster( coordinate_transformations = [t.to_dict() for t in ngff_transformations] # replace the metadata storage if group.metadata.zarr_format == 3: - if len_scales := len(multiscales := group.metadata.attributes["ome"]["multiscales"]) != 1: + if len(multiscales := group.metadata.attributes["ome"]["multiscales"]) != 1: + len_scales = len(multiscales) raise ValueError(f"The length of multiscales metadata should be 1, found the length to be {len_scales}") multiscale = multiscales[0] diff --git a/src/spatialdata/_io/format.py b/src/spatialdata/_io/format.py index ecd56a94..a5746f79 100644 --- a/src/spatialdata/_io/format.py +++ b/src/spatialdata/_io/format.py @@ -1,5 +1,6 @@ from __future__ import annotations +import warnings from collections.abc import Iterator from typing import Any @@ -187,9 +188,9 @@ def spatialdata_format_version(self) -> str: def version(self) -> str: return "0.4" - @property - def zarr_format(self): - return 2 + # @property + # def zarr_format(self): + # return 2 class RasterFormatV02(RasterFormatV01): @@ -203,9 +204,9 @@ def version(self) -> str: # https://github.com/scverse/spatialdata/pull/849 return "0.4-dev-spatialdata" - @property - def zarr_format(self): - return 2 + # @property + # def zarr_format(self): + # return 2 class RasterFormatV03(FormatV05, CoordinateMixinV01): @@ -344,6 +345,28 @@ def spatialdata_format_version(self) -> str: "0.1": SpatialDataContainerFormatV01(), "0.2": SpatialDataContainerFormatV02(), } +ContainerFormatValidElements = { + SpatialDataContainerFormatV01().__str__(): [ + RasterFormatV01().__str__(), + RasterFormatV02().__str__(), + PointsFormatV01().__str__(), + ShapesFormatV01().__str__(), + ShapesFormatV02().__str__(), + TablesFormatV01().__str__(), + ], + SpatialDataContainerFormatV02().__str__(): [ + RasterFormatV03().__str__(), + PointsFormatV02().__str__(), + ShapesFormatV03().__str__(), + TablesFormatV02().__str__(), + ], +} +ContainerV01DefaultTypes: dict[str, SpatialDataFormatType] = { + "raster": RasterFormatV02(), + "shapes": ShapesFormatV02(), + "points": PointsFormatV01(), + "tables": TablesFormatV01(), +} def format_implementations() -> Iterator[Format]: @@ -413,4 +436,27 @@ def _check_modified(element_type: str) -> None: parsed["SpatialData"] = fmt else: raise ValueError(f"Unsupported format {fmt}") + + if parsed["SpatialData"].__str__() == "SpatialDataContainerFormatV01": + warnings.warn( + "SpatialData format defined to be 'SpatialDataContainerFormatV01'. Defaulting undefined element " + "formats to element formats valid for 'SpatialDataContainerFormatV01'.", + UserWarning, + stacklevel=2, + ) + for el_type, value in modified.items(): + if el_type != "SpatialData" and not value: + parsed[el_type] = ContainerV01DefaultTypes[el_type] + + if any( + (invalid := el_format.__str__()) not in ContainerFormatValidElements[parsed["SpatialData"].__str__()] + for el_type, el_format in parsed.items() + if el_type != "SpatialData" + ): + raise ValueError( + f"Unsupported format '{invalid}' for SpatialDataContainerFormat '{parsed['SpatialData'].__str__()}'. " + f"Please ensure all element formats are either of these: " + f"'{' '.join(f for f in ContainerFormatValidElements[parsed['SpatialData'].__str__()])}'" + ) + return parsed diff --git a/tests/io/test_format.py b/tests/io/test_format.py index 0233ea7c..cc7ef22b 100644 --- a/tests/io/test_format.py +++ b/tests/io/test_format.py @@ -12,9 +12,11 @@ PointsFormatV01, RasterFormatV01, RasterFormatV02, + RasterFormatV03, ShapesFormatV01, ShapesFormatV02, SpatialDataContainerFormatV01, + SpatialDataContainerFormatV02, SpatialDataFormatType, ) from spatialdata.models import PointsModel, ShapesModel @@ -84,18 +86,18 @@ def test_format_shapes_v2( metadata[attrs_key].pop("version") assert metadata[attrs_key] == ShapesFormatV02().attrs_to_dict({}) - @pytest.mark.parametrize("format", [RasterFormatV02]) - def test_format_raster_v1_v2(self, images, format: type[SpatialDataFormatType]) -> None: + @pytest.mark.parametrize("rformat", [RasterFormatV01, RasterFormatV02]) + def test_format_raster_v1_v2(self, images, rformat: type[SpatialDataFormatType]) -> None: with tempfile.TemporaryDirectory() as tmpdir: - images.write(Path(tmpdir) / "images.zarr", ome_format=SpatialDataContainerFormatV01(), format=format()) + images.write(Path(tmpdir) / "images.zarr", sdata_formats=[SpatialDataContainerFormatV01(), rformat()]) zattrs_file = Path(tmpdir) / "images.zarr/images/image2d/.zattrs" with open(zattrs_file) as infile: zattrs = json.load(infile) - if format == RasterFormatV01: + if rformat == RasterFormatV01: ngff_version = zattrs["multiscales"][0]["version"] assert ngff_version == "0.4" else: - assert format == RasterFormatV02 + assert rformat == RasterFormatV02 # TODO: check whether this required change is due to bug in ome-zarr assert zattrs["version"] == "0.4-dev-spatialdata" @@ -108,11 +110,11 @@ def test_shapes_v1_to_v2(self, shapes): f1 = Path(tmpdir) / "data1.zarr" f2 = Path(tmpdir) / "data2.zarr" - shapes.write(f1, format=ShapesFormatV01()) + shapes.write(f1, sdata_formats=[ShapesFormatV01(), SpatialDataContainerFormatV01()]) shapes_read_v1 = read_zarr(f1) assert_spatial_data_objects_are_identical(shapes, shapes_read_v1) - shapes_read_v1.write(f2, format=ShapesFormatV02()) + shapes_read_v1.write(f2, sdata_formats=[ShapesFormatV02(), SpatialDataContainerFormatV01()]) shapes_read_v2 = read_zarr(f2) assert_spatial_data_objects_are_identical(shapes, shapes_read_v2) @@ -120,16 +122,19 @@ def test_raster_v1_to_v2_to_v3(self, images): with tempfile.TemporaryDirectory() as tmpdir: f1 = Path(tmpdir) / "data1.zarr" f2 = Path(tmpdir) / "data2.zarr" - # f3 = Path(tmpdir) / "data3.zarr" + f3 = Path(tmpdir) / "data3.zarr" - images.write(f1, format=RasterFormatV01()) + with pytest.raises(ValueError, match="Unsupported format"): + images.write(f1, sdata_formats=RasterFormatV01()) + + images.write(f1, sdata_formats=[RasterFormatV01(), SpatialDataContainerFormatV01()]) images_read_v1 = read_zarr(f1) assert_spatial_data_objects_are_identical(images, images_read_v1) - images_read_v1.write(f2, format=RasterFormatV02()) + images_read_v1.write(f2, sdata_formats=[RasterFormatV02(), SpatialDataContainerFormatV01()]) images_read_v2 = read_zarr(f2) assert_spatial_data_objects_are_identical(images, images_read_v2) - # - # images_read_v2.write(f3, format=RasterFormatV02()) - # images_read_v3 = read_zarr(f3) - # assert_spatial_data_objects_are_identical(images, images_read_v3) + + images_read_v2.write(f3, sdata_formats=[RasterFormatV03(), SpatialDataContainerFormatV02()]) + images_read_v3 = read_zarr(f3) + assert_spatial_data_objects_are_identical(images, images_read_v3) diff --git a/tests/io/test_metadata.py b/tests/io/test_metadata.py index 5ed27259..bb993b00 100644 --- a/tests/io/test_metadata.py +++ b/tests/io/test_metadata.py @@ -55,7 +55,7 @@ def test_validate_can_write_metadata_on_element(full_sdata, element_name): full_sdata._validate_can_write_metadata_on_element(f"{element_name}_again") -@pytest.mark.parametrize("element_name", ["labels2d"]) # "points_0", "circles", "image2d" +@pytest.mark.parametrize("element_name", ["image2d", "labels2d", "points_0", "circles"]) def test_save_transformations_incremental(element_name, full_sdata, caplog): """test io for transformations""" with tempfile.TemporaryDirectory() as tmp_dir: diff --git a/tests/io/test_readwrite.py b/tests/io/test_readwrite.py index 9552294f..eb30021d 100644 --- a/tests/io/test_readwrite.py +++ b/tests/io/test_readwrite.py @@ -166,8 +166,8 @@ def test_incremental_io_list_of_elements(self, shapes: SpatialData) -> None: assert "shapes/new_shapes0" not in shapes.elements_paths_on_disk() assert "shapes/new_shapes1" not in shapes.elements_paths_on_disk() - @pytest.mark.parametrize("dask_backed", [True]) - @pytest.mark.parametrize("workaround", [1]) + @pytest.mark.parametrize("dask_backed", [True, False]) + @pytest.mark.parametrize("workaround", [1, 2]) def test_incremental_io_on_disk( self, tmp_path: str, full_sdata: SpatialData, dask_backed: bool, workaround: int ) -> None: @@ -637,7 +637,7 @@ def test_incremental_io_attrs(points: SpatialData) -> None: # TODO: make consolidated metadata open cleaner -@pytest.mark.parametrize("element_name", ["labels2d"]) # "image2d" , "points_0", "circles", "table" +@pytest.mark.parametrize("element_name", ["labels2d"]) # "image2d" , "points_0", "circles", "table" "labels2d" def test_delete_element_from_disk(full_sdata, element_name: str) -> None: # can't delete an element for a SpatialData object without associated Zarr store with pytest.raises(ValueError, match="The SpatialData object is not backed by a Zarr store."): From 484489e7cdb9aca7029ed8a577ad22a3bc251490 Mon Sep 17 00:00:00 2001 From: Wouter-Michiel Vierdag Date: Wed, 3 Sep 2025 13:04:12 +0200 Subject: [PATCH 023/126] uncomment code arrayNotFoundError --- src/spatialdata/_io/io_zarr.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spatialdata/_io/io_zarr.py b/src/spatialdata/_io/io_zarr.py index 0a96eaef..ee752234 100644 --- a/src/spatialdata/_io/io_zarr.py +++ b/src/spatialdata/_io/io_zarr.py @@ -188,7 +188,7 @@ def read_zarr( JSONDecodeError, ValueError, KeyError, - # ArrayNotFoundError, # removed in Zarr v3 + ArrayNotFoundError, ), ): shapes[subgroup_name] = _read_shapes(f_elem_store) From ff4cfeebe25b100e81464e94b62732561edbe45b Mon Sep 17 00:00:00 2001 From: Wouter-Michiel Vierdag Date: Wed, 3 Sep 2025 13:06:07 +0200 Subject: [PATCH 024/126] remove future annotations import --- src/spatialdata/_io/format.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/spatialdata/_io/format.py b/src/spatialdata/_io/format.py index a5746f79..b0868364 100644 --- a/src/spatialdata/_io/format.py +++ b/src/spatialdata/_io/format.py @@ -1,5 +1,3 @@ -from __future__ import annotations - import warnings from collections.abc import Iterator from typing import Any @@ -25,8 +23,6 @@ Shapes_s = ShapesModel() Points_s = PointsModel() -# TODO: change for element in spatialdata_format_version for elements into something like element_container_version - def _parse_version(group: zarr.Group, expect_attrs_key: bool) -> str | None: """ From 89b53f0f61e9c42c6985ba916aa9721ebc352001 Mon Sep 17 00:00:00 2001 From: Wouter-Michiel Vierdag Date: Wed, 3 Sep 2025 13:23:30 +0200 Subject: [PATCH 025/126] update workflow mac version --- .github/workflows/test.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 16c85fd5..5d6e0ad8 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -22,7 +22,7 @@ jobs: os: [ubuntu-latest] include: - os: macos-latest - python: "3.10" + python: "3.11" - os: macos-latest python: "3.12" pip-flags: "--pre" From 607580a996a894166a70b05efa965c6d20725031 Mon Sep 17 00:00:00 2001 From: Wouter-Michiel Vierdag Date: Wed, 3 Sep 2025 13:23:59 +0200 Subject: [PATCH 026/126] drop python 3.10 version is not supported by zarr v3.0.8 --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 074d02b5..eb045bc4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,7 +15,7 @@ maintainers = [ urls.Documentation = "https://spatialdata.scverse.org/en/latest" urls.Source = "https://github.com/scverse/spatialdata.git" urls.Home-page = "https://github.com/scverse/spatialdata.git" -requires-python = ">=3.10" +requires-python = ">=3.11" dynamic= [ "version" # allow version to be set by git tags ] From 75e656b2f5d8249ddc8931a8e02d300f4a0dab8a Mon Sep 17 00:00:00 2001 From: Wouter-Michiel Vierdag Date: Wed, 3 Sep 2025 13:26:53 +0200 Subject: [PATCH 027/126] change target python version and readthedocs python version --- .readthedocs.yaml | 2 +- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.readthedocs.yaml b/.readthedocs.yaml index ab6cb4fb..acecf90e 100644 --- a/.readthedocs.yaml +++ b/.readthedocs.yaml @@ -3,7 +3,7 @@ version: 2 build: os: ubuntu-20.04 tools: - python: "3.10" + python: "3.11" sphinx: configuration: docs/conf.py fail_on_warning: true diff --git a/pyproject.toml b/pyproject.toml index eb045bc4..353ad330 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -132,7 +132,7 @@ exclude = [ ] line-length = 120 -target-version = "py310" +target-version = "py311" [tool.ruff.lint] ignore = [ From b977b368990c9c22e8e7fb934d3e9ca9b47142c2 Mon Sep 17 00:00:00 2001 From: Wouter-Michiel Vierdag Date: Wed, 3 Sep 2025 13:35:29 +0200 Subject: [PATCH 028/126] minor updates --- src/spatialdata/_io/_utils.py | 2 +- src/spatialdata/_io/io_raster.py | 11 +---------- 2 files changed, 2 insertions(+), 11 deletions(-) diff --git a/src/spatialdata/_io/_utils.py b/src/spatialdata/_io/_utils.py index a4f45fb4..28a41182 100644 --- a/src/spatialdata/_io/_utils.py +++ b/src/spatialdata/_io/_utils.py @@ -388,7 +388,7 @@ def save_transformations(sdata: SpatialData) -> None: sdata.write_transformations() -def _open_zarr_store(path: StoreLike, **kwargs: Any) -> zarr.storage.BaseStore: +def _open_zarr_store(path: StoreLike, **kwargs: Any) -> zarr.storage.StoreLike: # TODO: ensure kwargs like mode are enforced everywhere and passed correctly to the store if isinstance(path, str | Path): # if the input is str or Path, map it to UPath diff --git a/src/spatialdata/_io/io_raster.py b/src/spatialdata/_io/io_raster.py index 2b95d119..0fab3099 100644 --- a/src/spatialdata/_io/io_raster.py +++ b/src/spatialdata/_io/io_raster.py @@ -156,6 +156,7 @@ def _write_raster( for c in channels: metadata["metadata"]["omero"]["channels"].append({"label": c}) # type: ignore[union-attr, index, call-overload] + trans_group = group["labels"][name] if raster_type == "labels" else group_data if isinstance(raster_data, DataArray): data = raster_data.data transformations = _get_transformations(raster_data) @@ -183,11 +184,6 @@ def _write_raster( ) if not transformations: raise ValueError(f"No transformations specified to be written for element {name}.") - # TODO refactor this as it is ugly - if raster_type == "labels": - trans_group = group["labels"][name] - else: - trans_group = group_data overwrite_coordinate_transformations_raster(group=trans_group, transformations=transformations, axes=input_axes) elif isinstance(raster_data, DataTree): data = get_pyramid_levels(raster_data, attr="data") @@ -218,11 +214,6 @@ def _write_raster( # Compute all pyramid levels at once to allow Dask to optimize the computational graph. da.compute(*dask_delayed) assert transformations is not None - # TODO refactor this as it is ugly - if raster_type == "labels": - trans_group = group["labels"][name] - else: - trans_group = group_data overwrite_coordinate_transformations_raster( group=trans_group, transformations=transformations, axes=tuple(input_axes) ) From 0cfd29fd1831fd597208c4e68db34cb6771dbbbb Mon Sep 17 00:00:00 2001 From: Wouter-Michiel Vierdag Date: Wed, 3 Sep 2025 13:36:54 +0200 Subject: [PATCH 029/126] update Self import --- .mypy.ini | 2 +- src/spatialdata/transformations/ngff/ngff_transformations.py | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/.mypy.ini b/.mypy.ini index 0eee2044..78edd09c 100644 --- a/.mypy.ini +++ b/.mypy.ini @@ -1,5 +1,5 @@ [mypy] -python_version = 3.10 +python_version = 3.11 ignore_errors = False warn_redundant_casts = True diff --git a/src/spatialdata/transformations/ngff/ngff_transformations.py b/src/spatialdata/transformations/ngff/ngff_transformations.py index 4e63b91c..9cc2602e 100644 --- a/src/spatialdata/transformations/ngff/ngff_transformations.py +++ b/src/spatialdata/transformations/ngff/ngff_transformations.py @@ -1,10 +1,9 @@ import math from abc import ABC, abstractmethod from numbers import Number -from typing import Any +from typing import Any, Self import numpy as np -from typing_extensions import Self from spatialdata._types import ArrayLike from spatialdata.transformations.ngff.ngff_coordinate_system import NgffCoordinateSystem From 3d0c3eb5cc51c4bfa4bb8c8e2f89db63f6fc0ef0 Mon Sep 17 00:00:00 2001 From: Wouter-Michiel Vierdag Date: Wed, 3 Sep 2025 13:52:11 +0200 Subject: [PATCH 030/126] fix ome errors --- src/spatialdata/_core/spatialdata.py | 1 - src/spatialdata/_io/io_raster.py | 8 +++++++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/src/spatialdata/_core/spatialdata.py b/src/spatialdata/_core/spatialdata.py index ef75e24f..612db783 100644 --- a/src/spatialdata/_core/spatialdata.py +++ b/src/spatialdata/_core/spatialdata.py @@ -1247,7 +1247,6 @@ def write( file_path: str | Path, overwrite: bool = False, consolidate_metadata: bool = True, - # sdata_format: SpatialDataContainerFormatType | None = None, sdata_formats: SpatialDataFormatType | list[SpatialDataFormatType] | None = None, ) -> None: """ diff --git a/src/spatialdata/_io/io_raster.py b/src/spatialdata/_io/io_raster.py index 0fab3099..20206089 100644 --- a/src/spatialdata/_io/io_raster.py +++ b/src/spatialdata/_io/io_raster.py @@ -156,7 +156,7 @@ def _write_raster( for c in channels: metadata["metadata"]["omero"]["channels"].append({"label": c}) # type: ignore[union-attr, index, call-overload] - trans_group = group["labels"][name] if raster_type == "labels" else group_data + # TODO refactor as function is way too big if isinstance(raster_data, DataArray): data = raster_data.data transformations = _get_transformations(raster_data) @@ -184,6 +184,9 @@ def _write_raster( ) if not transformations: raise ValueError(f"No transformations specified to be written for element {name}.") + + # Cannot move before conditional as group_data is updated when writing ngff scales + trans_group = group["labels"][name] if raster_type == "labels" else group_data overwrite_coordinate_transformations_raster(group=trans_group, transformations=transformations, axes=input_axes) elif isinstance(raster_data, DataTree): data = get_pyramid_levels(raster_data, attr="data") @@ -214,6 +217,9 @@ def _write_raster( # Compute all pyramid levels at once to allow Dask to optimize the computational graph. da.compute(*dask_delayed) assert transformations is not None + + # Cannot move before conditional as group_data is updated when writing ngff scales + trans_group = group["labels"][name] if raster_type == "labels" else group_data overwrite_coordinate_transformations_raster( group=trans_group, transformations=transformations, axes=tuple(input_axes) ) From fc7efda0b31dfc89c8589c90fb2437a5970c439f Mon Sep 17 00:00:00 2001 From: Wouter-Michiel Vierdag Date: Wed, 3 Sep 2025 14:08:05 +0200 Subject: [PATCH 031/126] add windows workflow --- .github/workflows/test.yaml | 2 ++ src/spatialdata/_io/io_raster.py | 7 +++---- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 5d6e0ad8..1003fddf 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -27,6 +27,8 @@ jobs: python: "3.12" pip-flags: "--pre" name: "Python 3.12 (pre-release)" + - os: windows-latest + python: "3.11" env: OS: ${{ matrix.os }} diff --git a/src/spatialdata/_io/io_raster.py b/src/spatialdata/_io/io_raster.py index 20206089..80281344 100644 --- a/src/spatialdata/_io/io_raster.py +++ b/src/spatialdata/_io/io_raster.py @@ -198,10 +198,10 @@ def _write_raster( assert len(d) == 1 xdata = d.values().__iter__().__next__() transformations = _get_transformations_xarray(xdata) - assert transformations is not None - assert len(transformations) > 0 + if not transformations: + raise ValueError(f"No transformations specified to be written for element {name}.") chunks = get_pyramid_levels(raster_data, "chunks") - # coords = iterate_pyramid_levels(raster_data, "coords") + parsed_axes = _get_valid_axes(axes=list(input_axes), fmt=format) storage_options = [{"chunks": chunk} for chunk in chunks] dask_delayed = write_multi_scale_ngff( @@ -216,7 +216,6 @@ def _write_raster( ) # Compute all pyramid levels at once to allow Dask to optimize the computational graph. da.compute(*dask_delayed) - assert transformations is not None # Cannot move before conditional as group_data is updated when writing ngff scales trans_group = group["labels"][name] if raster_type == "labels" else group_data From 46c7b34b7224d16743d12e679324c1ddb3856c24 Mon Sep 17 00:00:00 2001 From: Wouter-Michiel Vierdag Date: Wed, 3 Sep 2025 15:18:57 +0200 Subject: [PATCH 032/126] prevent consolidation labels group when deleting element --- src/spatialdata/_core/spatialdata.py | 2 +- tests/io/test_readwrite.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/spatialdata/_core/spatialdata.py b/src/spatialdata/_core/spatialdata.py index 612db783..d731600b 100644 --- a/src/spatialdata/_core/spatialdata.py +++ b/src/spatialdata/_core/spatialdata.py @@ -1529,7 +1529,7 @@ def delete_element_from_disk(self, element_name: str | list[str]) -> None: # delete the element store = parse_url(self.path, mode="r+", fmt=FormatV05()).store - root = zarr.open_group(store=store, mode="r+") + root = zarr.open_group(store=store, mode="r+", use_consolidated=False) del root[element_type][element_name] store.close() diff --git a/tests/io/test_readwrite.py b/tests/io/test_readwrite.py index eb30021d..165014af 100644 --- a/tests/io/test_readwrite.py +++ b/tests/io/test_readwrite.py @@ -637,7 +637,7 @@ def test_incremental_io_attrs(points: SpatialData) -> None: # TODO: make consolidated metadata open cleaner -@pytest.mark.parametrize("element_name", ["labels2d"]) # "image2d" , "points_0", "circles", "table" "labels2d" +@pytest.mark.parametrize("element_name", ["image2d", "labels2d", "points_0", "circles", "table"]) def test_delete_element_from_disk(full_sdata, element_name: str) -> None: # can't delete an element for a SpatialData object without associated Zarr store with pytest.raises(ValueError, match="The SpatialData object is not backed by a Zarr store."): From 38f29fedbbc2bf466e6e918a34c3c914543cbbcb Mon Sep 17 00:00:00 2001 From: Wouter-Michiel Vierdag Date: Wed, 3 Sep 2025 15:35:25 +0200 Subject: [PATCH 033/126] add shapes test --- tests/io/test_format.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/tests/io/test_format.py b/tests/io/test_format.py index cc7ef22b..ce9a4669 100644 --- a/tests/io/test_format.py +++ b/tests/io/test_format.py @@ -15,6 +15,7 @@ RasterFormatV03, ShapesFormatV01, ShapesFormatV02, + ShapesFormatV03, SpatialDataContainerFormatV01, SpatialDataContainerFormatV02, SpatialDataFormatType, @@ -105,10 +106,11 @@ def test_format_raster_v1_v2(self, images, rformat: type[SpatialDataFormatType]) class TestFormatConversions: """Test format conversions between older formats and newer.""" - def test_shapes_v1_to_v2(self, shapes): + def test_shapes_v1_to_v2_to_v3(self, shapes): with tempfile.TemporaryDirectory() as tmpdir: f1 = Path(tmpdir) / "data1.zarr" f2 = Path(tmpdir) / "data2.zarr" + f3 = Path(tmpdir) / "data3.zarr" shapes.write(f1, sdata_formats=[ShapesFormatV01(), SpatialDataContainerFormatV01()]) shapes_read_v1 = read_zarr(f1) @@ -118,6 +120,10 @@ def test_shapes_v1_to_v2(self, shapes): shapes_read_v2 = read_zarr(f2) assert_spatial_data_objects_are_identical(shapes, shapes_read_v2) + shapes_read_v1.write(f3, sdata_formats=[ShapesFormatV03(), SpatialDataContainerFormatV02()]) + shapes_read_v3 = read_zarr(f3) + assert_spatial_data_objects_are_identical(shapes, shapes_read_v3) + def test_raster_v1_to_v2_to_v3(self, images): with tempfile.TemporaryDirectory() as tmpdir: f1 = Path(tmpdir) / "data1.zarr" From a88c7e7a0abb17be21ff74f54fa2a453cd397102 Mon Sep 17 00:00:00 2001 From: Wouter-Michiel Vierdag Date: Wed, 3 Sep 2025 15:45:52 +0200 Subject: [PATCH 034/126] fix shape conversion --- tests/io/test_format.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/io/test_format.py b/tests/io/test_format.py index ce9a4669..cf373119 100644 --- a/tests/io/test_format.py +++ b/tests/io/test_format.py @@ -120,7 +120,7 @@ def test_shapes_v1_to_v2_to_v3(self, shapes): shapes_read_v2 = read_zarr(f2) assert_spatial_data_objects_are_identical(shapes, shapes_read_v2) - shapes_read_v1.write(f3, sdata_formats=[ShapesFormatV03(), SpatialDataContainerFormatV02()]) + shapes_read_v2.write(f3, sdata_formats=[ShapesFormatV03(), SpatialDataContainerFormatV02()]) shapes_read_v3 = read_zarr(f3) assert_spatial_data_objects_are_identical(shapes, shapes_read_v3) From 16aa5299cc5733b117db038c9cbc3c5b9a17d528 Mon Sep 17 00:00:00 2001 From: Wouter-Michiel Vierdag Date: Wed, 3 Sep 2025 16:24:25 +0200 Subject: [PATCH 035/126] update dask dependency because of zarr v3 --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 353ad330..614fe6b2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,7 +25,7 @@ dependencies = [ "anndata>=0.9.1", "click", "dask-image", - "dask>=2024.4.1,<=2024.11.2", + "dask>=2024.10.0,<=2024.11.2", "datashader", "fsspec[s3,http]", "geopandas>=0.14", From bdea210a8b7daf6ee118ca3255fd37583355d65d Mon Sep 17 00:00:00 2001 From: Wouter-Michiel Vierdag Date: Wed, 3 Sep 2025 17:31:08 +0200 Subject: [PATCH 036/126] fix multipoly --- src/spatialdata/_io/io_shapes.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/src/spatialdata/_io/io_shapes.py b/src/spatialdata/_io/io_shapes.py index b1cb0e64..43d294eb 100644 --- a/src/spatialdata/_io/io_shapes.py +++ b/src/spatialdata/_io/io_shapes.py @@ -35,28 +35,32 @@ def _read_shapes( f = zarr.open(store, mode="r") version = _parse_version(f, expect_attrs_key=True) assert version is not None - format = ShapesFormats[version] + shape_format = ShapesFormats[version] - if isinstance(format, ShapesFormatV01): + if isinstance(shape_format, ShapesFormatV01): coords = np.array(f["coords"]) index = np.array(f["Index"]) - typ = format.attrs_from_dict(f.attrs.asdict()) + typ = shape_format.attrs_from_dict(f.attrs.asdict()) if typ.name == "POINT": radius = np.array(f["radius"]) geometry = from_ragged_array(typ, coords) geo_df = GeoDataFrame({"geometry": geometry, "radius": radius}, index=index) else: offsets_keys = [k for k in f if k.startswith("offset")] + + # We do this because of async reading not necessarily leading to ordered offset keys. + # We can't use sorted because if offsets are higher than 11 we get 1, 11, 2 + offsets_keys = [f"offset{i}" for i in range(len(offsets_keys))] offsets = tuple(np.array(f[k]).flatten() for k in offsets_keys) geometry = from_ragged_array(typ, coords, offsets) geo_df = GeoDataFrame({"geometry": geometry}, index=index) - elif isinstance(format, ShapesFormatV02 | ShapesFormatV03): + elif isinstance(shape_format, ShapesFormatV02 | ShapesFormatV03): store_root = f.store_path.store.root path = Path(store_root) / f.path / "shapes.parquet" geo_df = read_parquet(path) else: raise ValueError( - f"Unsupported shapes format {format} from version {version}. Please update the spatialdata library." + f"Unsupported shapes format {shape_format} from version {version}. Please update the spatialdata library." ) transformations = _get_transformations_from_ngff_dict(f.attrs.asdict()["coordinateTransformations"]) From ebc5080e710045d5e9c65b23deea9383261d5be9 Mon Sep 17 00:00:00 2001 From: Wouter-Michiel Vierdag Date: Wed, 3 Sep 2025 17:49:11 +0200 Subject: [PATCH 037/126] fix invalid read name test --- tests/io/test_readwrite.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/tests/io/test_readwrite.py b/tests/io/test_readwrite.py index 165014af..93de3391 100644 --- a/tests/io/test_readwrite.py +++ b/tests/io/test_readwrite.py @@ -1,5 +1,4 @@ import os -import sys import tempfile from collections.abc import Callable from pathlib import Path @@ -776,7 +775,6 @@ def test_incremental_writing_valid_table_name_invalid_table(tmp_path: Path): invalid_sdata.write_element("valid_name") -@pytest.mark.skipif(sys.platform == "win32", reason="Renaming fails as windows path already sees the name as invalid.") def test_reading_invalid_name(tmp_path: Path): image_name, image = next(iter(_get_images().items())) labels_name, labels = next(iter(_get_labels().items())) @@ -793,14 +791,19 @@ def test_reading_invalid_name(tmp_path: Path): valid_sdata.write(tmp_path / "data.zarr") # Circumvent validation at construction time and check validation happens again at writing time. (tmp_path / "data.zarr/points" / points_name).rename(tmp_path / "data.zarr/points" / "has whitespace") - (tmp_path / "data.zarr/shapes" / shapes_name).rename(tmp_path / "data.zarr/shapes" / "non-alnum_#$%&()*+,?@") + # This one is not allowed on windows + if os.name != "nt": + (tmp_path / "data.zarr/shapes" / shapes_name).rename(tmp_path / "data.zarr/shapes" / "non-alnum_#$%&()*+,?@") + # We do this as the key of the element is otherwise not in the consolidated metadata, leading to an error. + valid_sdata.write_consolidated_metadata() with pytest.raises(ValidationError, match="Cannot construct SpatialData") as exc_info: read_zarr(tmp_path / "data.zarr") actual_message = str(exc_info.value) assert "points/has whitespace" in actual_message - assert "shapes/non-alnum_#$%&()*+,?@" in actual_message + if os.name != "nt": + assert "shapes/non-alnum_#$%&()*+,?@" in actual_message assert ( "For renaming, please see the discussion here https://github.com/scverse/spatialdata/discussions/707" in actual_message From 6ba53e420fdcb737dd0cd81d87934d272fe6f89b Mon Sep 17 00:00:00 2001 From: Wouter-Michiel Vierdag Date: Wed, 3 Sep 2025 19:10:40 +0200 Subject: [PATCH 038/126] use UPath --- tests/io/test_utils.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/tests/io/test_utils.py b/tests/io/test_utils.py index f9778f5c..54a5a4a6 100644 --- a/tests/io/test_utils.py +++ b/tests/io/test_utils.py @@ -4,6 +4,7 @@ import dask.dataframe as dd import pytest +from upath import UPath from spatialdata import read_zarr from spatialdata._io._utils import get_dask_backing_files, handle_read_errors @@ -37,8 +38,8 @@ def test_backing_files_images(images): computational graph """ with tempfile.TemporaryDirectory() as tmp_dir: - f0 = os.path.join(tmp_dir, "images0.zarr") - f1 = os.path.join(tmp_dir, "images1.zarr") + f0 = UPath(tmp_dir) / "images0.zarr" + f1 = UPath(tmp_dir) / "images1.zarr" images.write(f0) images.write(f1) images0 = read_zarr(f0) @@ -49,7 +50,7 @@ def test_backing_files_images(images): im1 = images1.images["image2d"] im2 = im0 + im1 files = get_dask_backing_files(im2) - expected_zarr_locations = [os.path.realpath(os.path.join(f, "images/image2d")) for f in [f0, f1]] + expected_zarr_locations = [str(f / "images" / "image2d") for f in [f0, f1]] assert set(files) == set(expected_zarr_locations) # multiscale @@ -57,7 +58,7 @@ def test_backing_files_images(images): im4 = images1.images["image2d_multiscale"] im5 = im3 + im4 files = get_dask_backing_files(im5) - expected_zarr_locations = [os.path.realpath(os.path.join(f, "images/image2d_multiscale")) for f in [f0, f1]] + expected_zarr_locations = [str(f / "images" / "image2d_multiscale") for f in [f0, f1]] assert set(files) == set(expected_zarr_locations) From 960823216d25178b7ce68fb086c924de20791b74 Mon Sep 17 00:00:00 2001 From: Wouter-Michiel Vierdag Date: Wed, 3 Sep 2025 19:26:17 +0200 Subject: [PATCH 039/126] fix paths --- tests/io/test_utils.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/io/test_utils.py b/tests/io/test_utils.py index 54a5a4a6..7508990d 100644 --- a/tests/io/test_utils.py +++ b/tests/io/test_utils.py @@ -99,8 +99,8 @@ def test_backing_files_combining_points_and_images(points, images): from examining its computational graph """ with tempfile.TemporaryDirectory() as tmp_dir: - f0 = os.path.join(tmp_dir, "points0.zarr") - f1 = os.path.join(tmp_dir, "images1.zarr") + f0 = UPath(tmp_dir) / "points0.zarr" + f1 = UPath(tmp_dir) / "images1.zarr" points.write(f0) images.write(f1) points0 = read_zarr(f0) @@ -113,12 +113,12 @@ def test_backing_files_combining_points_and_images(points, images): im2 = v + im1 files = get_dask_backing_files(im2) expected_zarr_locations_old = [ - os.path.realpath(os.path.join(f0, "points/points_0/points.parquet")), - os.path.realpath(os.path.join(f1, "images/image2d")), + str(f0 / "points" / "points_0" / "points.parquet"), + str(f1 / "images" / "image2d"), ] expected_zarr_locations_new = [ - os.path.realpath(os.path.join(f0, "points/points_0/points.parquet/part.0.parquet")), - os.path.realpath(os.path.join(f1, "images/image2d")), + str(f0 / "points" / "points_0" / "points.parquet" / "part.0.parquet"), + str(f1 / "images" / "image2d"), ] assert set(files) == set(expected_zarr_locations_old) or set(files) == set(expected_zarr_locations_new) From dd5bc23883a239e7330321ad11f56eb9e5728bc6 Mon Sep 17 00:00:00 2001 From: Wouter-Michiel Vierdag Date: Wed, 3 Sep 2025 19:40:17 +0200 Subject: [PATCH 040/126] resolve paths --- src/spatialdata/_io/_utils.py | 4 ++-- tests/io/test_utils.py | 12 ++++++------ 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/spatialdata/_io/_utils.py b/src/spatialdata/_io/_utils.py index 28a41182..02b05447 100644 --- a/src/spatialdata/_io/_utils.py +++ b/src/spatialdata/_io/_utils.py @@ -290,7 +290,7 @@ def _search_for_backing_files_recursively(subgraph: Any, files: list[str]) -> No if name.startswith("original-from-zarr"): # LocalStore.store does not have an attribute path, but we keep it like this for backward compat. path = getattr(v.store, "path", None) if getattr(v.store, "path", None) else v.store.root - files.append(str(path)) + files.append(str(UPath(path).resolve())) elif name.startswith("read-parquet") or name.startswith("read_parquet"): if hasattr(v, "creation_info"): # https://github.com/dask/dask/blob/ff2488aec44d641696e0b7aa41ed9e995c710705/dask/dataframe/io/parquet/core.py#L625 @@ -301,7 +301,7 @@ def _search_for_backing_files_recursively(subgraph: Any, files: list[str]) -> No f"report this bug." ) parquet_file = t[0] - files.append(os.path.realpath(parquet_file)) + files.append(str(UPath(parquet_file).resolve())) elif isinstance(v, tuple) and len(v) > 1 and isinstance(v[1], dict) and "piece" in v[1]: # https://github.com/dask/dask/blob/ff2488aec44d641696e0b7aa41ed9e995c710705/dask/dataframe/io/parquet/core.py#L870 parquet_file, check0, check1 = v[1]["piece"] diff --git a/tests/io/test_utils.py b/tests/io/test_utils.py index 7508990d..0a430704 100644 --- a/tests/io/test_utils.py +++ b/tests/io/test_utils.py @@ -50,7 +50,7 @@ def test_backing_files_images(images): im1 = images1.images["image2d"] im2 = im0 + im1 files = get_dask_backing_files(im2) - expected_zarr_locations = [str(f / "images" / "image2d") for f in [f0, f1]] + expected_zarr_locations = [str((f / "images" / "image2d").resolve()) for f in [f0, f1]] assert set(files) == set(expected_zarr_locations) # multiscale @@ -58,7 +58,7 @@ def test_backing_files_images(images): im4 = images1.images["image2d_multiscale"] im5 = im3 + im4 files = get_dask_backing_files(im5) - expected_zarr_locations = [str(f / "images" / "image2d_multiscale") for f in [f0, f1]] + expected_zarr_locations = [str((f / "images" / "image2d_multiscale").resolve()) for f in [f0, f1]] assert set(files) == set(expected_zarr_locations) @@ -113,12 +113,12 @@ def test_backing_files_combining_points_and_images(points, images): im2 = v + im1 files = get_dask_backing_files(im2) expected_zarr_locations_old = [ - str(f0 / "points" / "points_0" / "points.parquet"), - str(f1 / "images" / "image2d"), + str((f0 / "points" / "points_0" / "points.parquet").resolve()), + str((f1 / "images" / "image2d").resolve()), ] expected_zarr_locations_new = [ - str(f0 / "points" / "points_0" / "points.parquet" / "part.0.parquet"), - str(f1 / "images" / "image2d"), + str((f0 / "points" / "points_0" / "points.parquet" / "part.0.parquet").resolve()), + str((f1 / "images" / "image2d").resolve()), ] assert set(files) == set(expected_zarr_locations_old) or set(files) == set(expected_zarr_locations_new) From 8d788e0f246408a8cd0f9790cd87c58a18c10931 Mon Sep 17 00:00:00 2001 From: Wouter-Michiel Vierdag Date: Thu, 4 Sep 2025 09:45:09 +0200 Subject: [PATCH 041/126] refactor overwrite transformations --- src/spatialdata/_io/_utils.py | 79 ++++++++++++++++++++++++++++------- 1 file changed, 65 insertions(+), 14 deletions(-) diff --git a/src/spatialdata/_io/_utils.py b/src/spatialdata/_io/_utils.py index 02b05447..92caf049 100644 --- a/src/spatialdata/_io/_utils.py +++ b/src/spatialdata/_io/_utils.py @@ -80,6 +80,24 @@ def overwrite_coordinate_transformations_non_raster( def overwrite_coordinate_transformations_raster( group: zarr.Group, axes: tuple[ValidAxis_t, ...], transformations: MappingToCoordinateSystem_t ) -> None: + """Write transformations of raster elements to disk. + + This function supports both writing of transformations for raster elements stored using zarr v3 and v2. + For the case of zarr v3, there is already a 'coordinateTransformations' from ome-zarr in the metadata of + the group. However, we store our transformations in the first element of the 'multiscales' of the attributes + in the group metadata. This is subject to change. + In the case of zarr v2 the existing 'coordinateTransformations' from ome-zarr is overwritten. + + Parameters + ---------- + group: zarr.Group + The zarr group containing the raster element for which to write the transformations, e.g. the zarr group + containing sdata['image2d']. + axes: tuple[ValidAxis_t, ...] + The list with axes names in the same order as the dimensions of the raster element. + transformations + Mapping between names of the coordinate system and the transformations. + """ _validate_mapping_to_coordinate_system_type(transformations) # prepare the transformations in the dict representation ngff_transformations = [] @@ -95,21 +113,54 @@ def overwrite_coordinate_transformations_raster( coordinate_transformations = [t.to_dict() for t in ngff_transformations] # replace the metadata storage if group.metadata.zarr_format == 3: - if len(multiscales := group.metadata.attributes["ome"]["multiscales"]) != 1: - len_scales = len(multiscales) - raise ValueError(f"The length of multiscales metadata should be 1, found the length to be {len_scales}") - multiscale = multiscales[0] - - # zarr v3 ome-zarr requires the coordinate transformations to be written this way, leaving one out won't work. - multiscale["coordinateTransformations"] = coordinate_transformations - group.attrs["coordinateTransformations"] = coordinate_transformations + _write_coordinate_transformations_raster_zarrv3(group, coordinate_transformations) elif group.metadata.zarr_format == 2: - multiscales = group.attrs["multiscales"] - if (len_scales := len(multiscales)) != 1: - raise ValueError(f"The length of multiscales metadata should be 1, found length of {len_scales}") - multiscale = multiscales[0] - multiscale["coordinateTransformations"] = coordinate_transformations - group.attrs["multiscales"] = multiscales + _overwrite_coordinate_transformations_raster_zarrv2(group, coordinate_transformations) + + +# TODO: check type coordinate_transformations here +def _write_coordinate_transformations_raster_zarrv3( + group: zarr.Group, coordinate_transformations: list[dict[str, BaseTransformation]] +) -> None: + """Write transformations of raster elements to disk in zarr v3. + + Parameters + ---------- + group: zarr.Group + The zarr group containing the raster element for which to write the transformations, e.g. the zarr group + containing sdata['image2d']. + coordinate_transformations: list[dict[str, BaseTransformation]] + List of NGFF transformation representations as dictionaries. + """ + if len(multiscales := group.metadata.attributes["ome"]["multiscales"]) != 1: + len_scales = len(multiscales) + raise ValueError(f"The length of multiscales metadata should be 1, found the length to be {len_scales}") + multiscale = multiscales[0] + + # zarr v3 ome-zarr requires the coordinate transformations to be written this way, leaving one out won't work. + multiscale["coordinateTransformations"] = coordinate_transformations + group.attrs["coordinateTransformations"] = coordinate_transformations + + +def _overwrite_coordinate_transformations_raster_zarrv2( + group: zarr.Group, coordinate_transformations: list[dict[str, BaseTransformation]] +) -> None: + """Overwrite transformations of raster elements on disk in zarr v2. + + Parameters + ---------- + group: zarr.Group + The zarr group containing the raster element for which to write the transformations, e.g. the zarr group + containing sdata['image2d']. + coordinate_transformations: list[dict[str, BaseTransformation]] + List of NGFF transformation representations as dictionaries. + """ + multiscales = group.attrs["multiscales"] + if (len_scales := len(multiscales)) != 1: + raise ValueError(f"The length of multiscales metadata should be 1, found length of {len_scales}") + multiscale = multiscales[0] + multiscale["coordinateTransformations"] = coordinate_transformations + group.attrs["multiscales"] = multiscales def overwrite_channel_names(group: zarr.Group, element: DataArray | DataTree) -> None: From a48ba239bf9eeb8a4da56e8df3bc1c7fc8eb5427 Mon Sep 17 00:00:00 2001 From: Wouter-Michiel Vierdag Date: Thu, 4 Sep 2025 10:47:41 +0200 Subject: [PATCH 042/126] refactor --- src/spatialdata/_core/spatialdata.py | 111 +++--------------- src/spatialdata/_io/io_zarr.py | 87 ++++++++++++++ .../operations/test_spatialdata_operations.py | 14 +-- 3 files changed, 108 insertions(+), 104 deletions(-) diff --git a/src/spatialdata/_core/spatialdata.py b/src/spatialdata/_core/spatialdata.py index d731600b..323988cc 100644 --- a/src/spatialdata/_core/spatialdata.py +++ b/src/spatialdata/_core/spatialdata.py @@ -122,7 +122,6 @@ class SpatialData: annotation directly. """ - @_deprecation_alias(table="tables", version="0.1.0") def __init__( self, images: dict[str, Raster_T] | None = None, @@ -604,94 +603,6 @@ def path(self, value: Path | None) -> None: f" the implications of working with SpatialData objects that are not self-contained." ) - def _get_groups_for_element( - self, zarr_path: Path, element_type: str, element_name: str, use_consolidated: bool = True - ) -> tuple[zarr.Group, zarr.Group, zarr.Group]: - """ - Get the Zarr groups for the root, element_type and element for a specific element. - - The store must exist, but creates the element type group and the element group if they don't exist. - - Parameters - ---------- - zarr_path - The path to the Zarr storage. - element_type - type of the element; must be in ["images", "labels", "points", "polygons", "shapes", "tables"]. - element_name - name of the element - - Returns - ------- - either the existing Zarr subgroup or a new one. - """ - from spatialdata._io._utils import _open_zarr_store - - if not isinstance(zarr_path, Path): - raise ValueError("zarr_path should be a Path object") - - if element_type not in [ - "images", - "labels", - "points", - "polygons", - "shapes", - "tables", - ]: - raise ValueError(f"Unknown element type {element_type}") - - store = _open_zarr_store(zarr_path, mode="r+") - if element_type != "labels": - root = zarr.open_group(store=store, mode="a") - else: - # This is required as ome-zarr accesses the labels group within root. If data has been consolidated - # before it will already look for the labels element just added, but the data has not been reconsolidated - # yet. Thus, when writing we open the root store here with use_consolidated == False. - root = zarr.open_group(store=store, mode="a", use_consolidated=use_consolidated) - - element_type_group = root.require_group(element_type) - # This is required as adata performs a consolidated check before writing anything. If the Tables group was - # consolidated before, this prevents anndata from writing. Therefore, we read with use_consolidated == False - # when writing. - if not use_consolidated and element_type in ["labels", "tables"]: - element_type_group = zarr.open_group( - element_type_group.store_path, mode="a", use_consolidated=use_consolidated - ) - - element_name_group = element_type_group.require_group(element_name) - return root, element_type_group, element_name_group - - def _group_for_element_exists(self, zarr_path: Path, element_type: str, element_name: str) -> bool: - """ - Check if the group for an element exists. - - Parameters - ---------- - element_type - type of the element; must be in ["images", "labels", "points", "polygons", "shapes", "tables"]. - element_name - name of the element - - Returns - ------- - True if the group exists, False otherwise. - """ - from spatialdata._io._utils import _open_zarr_store - - store = _open_zarr_store(zarr_path, mode="r") - root = zarr.open_group(store=store, mode="r") - assert element_type in [ - "images", - "labels", - "points", - "polygons", - "shapes", - "tables", - ] - exists = element_type in root and element_name in root[element_type] - store.close() - return exists - def locate_element(self, element: SpatialElement) -> list[str]: """ Locate a SpatialElement within the SpatialData object and returns its Zarr paths relative to the root. @@ -721,7 +632,6 @@ def locate_element(self, element: SpatialElement) -> list[str]: raise ValueError("Found an element name with a '/' character. This is not allowed.") return [f"{found_element_type[i]}/{found_element_name[i]}" for i in range(len(found))] - @_deprecation_alias(filter_table="filter_tables", version="0.1.0") def filter_by_coordinate_system( self, coordinate_system: str | list[str], @@ -1315,6 +1225,8 @@ def _write_element( overwrite: bool, parsed_formats: dict[str, SpatialDataFormatType] | None = None, ) -> None: + from spatialdata._io.io_zarr import _get_groups_for_element + if not isinstance(zarr_container_path, Path): raise ValueError( f"zarr_container_path must be a Path object, type(zarr_container_path) = {type(zarr_container_path)}." @@ -1324,7 +1236,7 @@ def _write_element( file_path=file_path_of_element, overwrite=overwrite, saving_an_element=True ) - root_group, element_type_group, element_group = self._get_groups_for_element( + root_group, element_type_group, element_group = _get_groups_for_element( zarr_path=zarr_container_path, element_type=element_type, element_name=element_name, use_consolidated=False ) from spatialdata._io import ( @@ -1567,6 +1479,7 @@ def has_consolidated_metadata(self) -> bool: def _validate_can_write_metadata_on_element(self, element_name: str) -> tuple[str, SpatialElement | AnnData] | None: """Validate if metadata can be written on an element, returns None if it cannot be written.""" from spatialdata._io._utils import _is_element_self_contained + from spatialdata._io.io_zarr import _group_for_element_exists # check the element exists in the SpatialData object element = self.get(element_name) @@ -1589,7 +1502,7 @@ def _validate_can_write_metadata_on_element(self, element_name: str) -> tuple[st self._check_element_not_on_disk_with_different_type(element_type=element_type, element_name=element_name) # check if the element exists in the Zarr storage - if not self._group_for_element_exists( + if not _group_for_element_exists( zarr_path=Path(self.path), element_type=element_type, element_name=element_name, @@ -1623,6 +1536,8 @@ def write_channel_names(self, element_name: str | None = None) -> None: The name of the element to write the channel names of. If None, write the channel names of all image elements. """ + from spatialdata._io.io_zarr import _get_groups_for_element + if element_name is not None: check_valid_name(element_name) if element_name not in self: @@ -1642,10 +1557,8 @@ def write_channel_names(self, element_name: str | None = None) -> None: # Mypy does not understand that path is not None so we have the check in the conditional if element_type == "images" and self.path is not None: - _, _, element_group = self._get_groups_for_element( - zarr_path=Path(self.path), - element_type=element_type, - element_name=element_name, + _, _, element_group = _get_groups_for_element( + zarr_path=Path(self.path), element_type=element_type, element_name=element_name, use_consolidated=False ) from spatialdata._io._utils import overwrite_channel_names @@ -1663,6 +1576,8 @@ def write_transformations(self, element_name: str | None = None) -> None: element_name The name of the element to write. If None, write the transformations of all elements. """ + from spatialdata._io.io_zarr import _get_groups_for_element + if element_name is not None: check_valid_name(element_name) if element_name not in self: @@ -1686,7 +1601,7 @@ def write_transformations(self, element_name: str | None = None) -> None: # Mypy does not understand that path is not None so we have a conditional assert self.path is not None - _, _, element_group = self._get_groups_for_element( + _, _, element_group = _get_groups_for_element( zarr_path=Path(self.path), element_type=element_type, element_name=element_name, @@ -1805,6 +1720,8 @@ def write_metadata( if write_attrs: self.write_attrs() + # TODO: discuss when has_consolidated_metadata that we should just consolidate it because after a writing + # operation the consolidated store could otherwise be out of sync. if consolidate_metadata is None and self.has_consolidated_metadata(): consolidate_metadata = True if consolidate_metadata: diff --git a/src/spatialdata/_io/io_zarr.py b/src/spatialdata/_io/io_zarr.py index ee752234..85304a1f 100644 --- a/src/spatialdata/_io/io_zarr.py +++ b/src/spatialdata/_io/io_zarr.py @@ -22,6 +22,8 @@ from spatialdata._logging import logger +# TODO: remove with incoming remote read / write PR +# Not removing this now as it requires substantial extra refactor beyond scope of zarrv3 PR. def _open_zarr_store(store: str | Path | zarr.Group) -> tuple[zarr.Group, str]: """ Open a zarr store (on-disk or remote) and return the zarr.Group object and the path to the store. @@ -222,3 +224,88 @@ def read_zarr( ) sdata.path = Path(store) return sdata + + +def _get_groups_for_element( + zarr_path: Path, element_type: str, element_name: str, use_consolidated: bool = True +) -> tuple[zarr.Group, zarr.Group, zarr.Group]: + """ + Get the Zarr groups for the root, element_type and element for a specific element. + + The store must exist, but creates the element type group and the element group if they don't exist. + + Parameters + ---------- + zarr_path + The path to the Zarr storage. + element_type + type of the element; must be in ["images", "labels", "points", "polygons", "shapes", "tables"]. + element_name + name of the element + use_consolidated + whether to open zarr groups using consolidated metadata. This should be false when writing as we open + zarr groups multiple times when writing an element. If the consolidated metadata store is out of sync with + what is written on disk this leads to errors. + + Returns + ------- + either the existing Zarr subgroup or a new one. + """ + if not isinstance(zarr_path, Path): + raise ValueError("zarr_path should be a Path object") + + if element_type not in [ + "images", + "labels", + "points", + "polygons", + "shapes", + "tables", + ]: + raise ValueError(f"Unknown element type {element_type}") + # TODO: remove local import after remote PR + from spatialdata._io._utils import _open_zarr_store + + store = _open_zarr_store(zarr_path, mode="r+") + + # When writing, use_consolidated must be set to False. Otherwise, the metadata store + # can get out of sync with newly added elements (e.g., labels), leading to errors. + root = zarr.open_group(store=store, mode="a", use_consolidated=use_consolidated) + element_type_group = root.require_group(element_type) + element_type_group = zarr.open_group(element_type_group.store_path, mode="a", use_consolidated=use_consolidated) + + element_name_group = element_type_group.require_group(element_name) + return root, element_type_group, element_name_group + + +def _group_for_element_exists(zarr_path: Path, element_type: str, element_name: str) -> bool: + """ + Check if the group for an element exists. + + Parameters + ---------- + element_type + type of the element; must be in ["images", "labels", "points", "polygons", "shapes", "tables"]. + element_name + name of the element + + Returns + ------- + True if the group exists, False otherwise. + """ + # TODO: remove local import after remote PR + from spatialdata._io._utils import _open_zarr_store + + store = _open_zarr_store(zarr_path, mode="r") + root = zarr.open_group(store=store, mode="r") + assert element_type in [ + "images", + "labels", + "points", + "polygons", + "shapes", + "tables", + ] + exists = element_type in root and element_name in root[element_type] + store.close() + return exists diff --git a/tests/core/operations/test_spatialdata_operations.py b/tests/core/operations/test_spatialdata_operations.py index db413af3..a436d343 100644 --- a/tests/core/operations/test_spatialdata_operations.py +++ b/tests/core/operations/test_spatialdata_operations.py @@ -148,7 +148,7 @@ def test_element_type_from_element_name(points: SpatialData) -> None: def test_filter_by_coordinate_system(full_sdata: SpatialData) -> None: - sdata = full_sdata.filter_by_coordinate_system(coordinate_system="global", filter_table=False) + sdata = full_sdata.filter_by_coordinate_system(coordinate_system="global", filter_tables=False) assert_spatial_data_objects_are_identical(sdata, full_sdata) scale = Scale([2.0], axes=("x",)) @@ -156,12 +156,12 @@ def test_filter_by_coordinate_system(full_sdata: SpatialData) -> None: set_transformation(full_sdata.shapes["circles"], Identity(), "my_space0") set_transformation(full_sdata.shapes["poly"], Identity(), "my_space1") - sdata_my_space = full_sdata.filter_by_coordinate_system(coordinate_system="my_space0", filter_table=False) + sdata_my_space = full_sdata.filter_by_coordinate_system(coordinate_system="my_space0", filter_tables=False) assert len(list(sdata_my_space.gen_elements())) == 3 assert_elements_dict_are_identical(sdata_my_space.tables, full_sdata.tables) sdata_my_space1 = full_sdata.filter_by_coordinate_system( - coordinate_system=["my_space0", "my_space1", "my_space2"], filter_table=False + coordinate_system=["my_space0", "my_space1", "my_space2"], filter_tables=False ) assert len(list(sdata_my_space1.gen_elements())) == 4 @@ -187,7 +187,7 @@ def test_filter_by_coordinate_system_also_table(full_sdata: SpatialData) -> None filtered_sdata0 = full_sdata.filter_by_coordinate_system(coordinate_system="my_space0") filtered_sdata1 = full_sdata.filter_by_coordinate_system(coordinate_system="my_space1") - filtered_sdata2 = full_sdata.filter_by_coordinate_system(coordinate_system="my_space0", filter_table=False) + filtered_sdata2 = full_sdata.filter_by_coordinate_system(coordinate_system="my_space0", filter_tables=False) assert len(filtered_sdata0["table"]) + len(filtered_sdata1["table"]) == len(full_sdata["table"]) assert len(filtered_sdata2["table"]) == len(full_sdata["table"]) @@ -363,10 +363,10 @@ def test_concatenate_sdatas(full_sdata: SpatialData) -> None: set_transformation(full_sdata.shapes["circles"], Identity(), "my_space0") set_transformation(full_sdata.shapes["poly"], Identity(), "my_space1") - filtered = full_sdata.filter_by_coordinate_system(coordinate_system=["my_space0", "my_space1"], filter_table=False) + filtered = full_sdata.filter_by_coordinate_system(coordinate_system=["my_space0", "my_space1"], filter_tables=False) assert len(list(filtered.gen_elements())) == 3 - filtered0 = filtered.filter_by_coordinate_system(coordinate_system="my_space0", filter_table=False) - filtered1 = filtered.filter_by_coordinate_system(coordinate_system="my_space1", filter_table=False) + filtered0 = filtered.filter_by_coordinate_system(coordinate_system="my_space0", filter_tables=False) + filtered1 = filtered.filter_by_coordinate_system(coordinate_system="my_space1", filter_tables=False) # this is needed cause we can't handle regions with same name. # TODO: fix this new_region = "sample2" From 6c5ee98b7c0ccda198a61128e92ab9608aaf450f Mon Sep 17 00:00:00 2001 From: Wouter-Michiel Vierdag Date: Thu, 4 Sep 2025 11:01:40 +0200 Subject: [PATCH 043/126] further refactor, add docstrings --- src/spatialdata/_io/_utils.py | 19 ++++++++++++++++++- src/spatialdata/_io/io_points.py | 5 ++--- tests/core/operations/test_transform.py | 2 +- 3 files changed, 21 insertions(+), 5 deletions(-) diff --git a/src/spatialdata/_io/_utils.py b/src/spatialdata/_io/_utils.py index 92caf049..b3f4dd3a 100644 --- a/src/spatialdata/_io/_utils.py +++ b/src/spatialdata/_io/_utils.py @@ -63,6 +63,18 @@ def _get_transformations_from_ngff_dict( def overwrite_coordinate_transformations_non_raster( group: zarr.Group, axes: tuple[ValidAxis_t, ...], transformations: MappingToCoordinateSystem_t ) -> None: + """Write coordinate transformations of non-raster element to disk. + + Parameters + ---------- + group: zarr.Group + The zarr group containing the non-raster element for which to write the transformations, e.g. the zarr group + containing sdata['points']. + axes: tuple[ValidAxis_t, ...] + The list with axes names in the same order as the coordinates of the non-raster element. + transformations: MappingToCoordinateSystem_t + Mapping between names of the coordinate system and the transformations. + """ _validate_mapping_to_coordinate_system_type(transformations) ngff_transformations = [] for target_coordinate_system, t in transformations.items(): @@ -95,7 +107,7 @@ def overwrite_coordinate_transformations_raster( containing sdata['image2d']. axes: tuple[ValidAxis_t, ...] The list with axes names in the same order as the dimensions of the raster element. - transformations + transformations: MappingToCoordinateSystem_t Mapping between names of the coordinate system and the transformations. """ _validate_mapping_to_coordinate_system_type(transformations) @@ -147,6 +159,11 @@ def _overwrite_coordinate_transformations_raster_zarrv2( ) -> None: """Overwrite transformations of raster elements on disk in zarr v2. + The transformation present in multiscale["datasets"] are the ones for the multiscale, so and we leave them intact + we update multiscale["coordinateTransformations"] and multiscale["coordinateSystems"] + see the first post of https://github.com/scverse/spatialdata/issues/39 for an overview + fix the io to follow the NGFF specs, see https://github.com/scverse/spatialdata/issues/114 + Parameters ---------- group: zarr.Group diff --git a/src/spatialdata/_io/io_points.py b/src/spatialdata/_io/io_points.py index d2c22aef..6d1d44c5 100644 --- a/src/spatialdata/_io/io_points.py +++ b/src/spatialdata/_io/io_points.py @@ -1,4 +1,3 @@ -import os from collections.abc import MutableMapping from pathlib import Path @@ -32,10 +31,10 @@ def _read_points( format = PointsFormats[version] store_root = f.store_path.store.root - path = os.path.join(store_root, f.path, "points.parquet") + path = store_root / f.path / "points.parquet" # cache on remote file needed for parquet reader to work # TODO: allow reading in the metadata without caching all the data - points = read_parquet("simplecache::" + path if path.startswith("http") else path) + points = read_parquet("simplecache::" + str(path) if str(path).startswith("http") else path) assert isinstance(points, DaskDataFrame) transformations = _get_transformations_from_ngff_dict(f.attrs.asdict()["coordinateTransformations"]) diff --git a/tests/core/operations/test_transform.py b/tests/core/operations/test_transform.py index e8031820..81959e2f 100644 --- a/tests/core/operations/test_transform.py +++ b/tests/core/operations/test_transform.py @@ -587,7 +587,7 @@ def test_transform_elements_and_entire_spatial_data_object_multi_hop( labels=dict(full_sdata.labels), points=dict(full_sdata.points), shapes=dict(full_sdata.shapes), - table=full_sdata["table"], + tables=full_sdata["table"], ) temp["transformed_element"] = transformed_element transformation = get_transformation_between_coordinate_systems( From 25aca868a04ce54b8e41e5e53a6d7e3dfb564f88 Mon Sep 17 00:00:00 2001 From: Wouter-Michiel Vierdag Date: Thu, 4 Sep 2025 14:25:08 +0200 Subject: [PATCH 044/126] refactor io_raster --- src/spatialdata/_core/spatialdata.py | 1 - src/spatialdata/_io/io_points.py | 24 +++- src/spatialdata/_io/io_raster.py | 195 ++++++++++++++------------- src/spatialdata/models/models.py | 10 ++ 4 files changed, 127 insertions(+), 103 deletions(-) diff --git a/src/spatialdata/_core/spatialdata.py b/src/spatialdata/_core/spatialdata.py index 323988cc..d10ec455 100644 --- a/src/spatialdata/_core/spatialdata.py +++ b/src/spatialdata/_core/spatialdata.py @@ -1269,7 +1269,6 @@ def _write_element( write_points( points=element, group=element_group, - name=element_name, format=parsed_formats["points"], ) elif element_type == "shapes": diff --git a/src/spatialdata/_io/io_points.py b/src/spatialdata/_io/io_points.py index 6d1d44c5..6c5a2d3a 100644 --- a/src/spatialdata/_io/io_points.py +++ b/src/spatialdata/_io/io_points.py @@ -49,16 +49,27 @@ def _read_points( def write_points( points: DaskDataFrame, group: zarr.Group, - name: str, group_type: str = "ngff:points", format: Format = CurrentPointsFormat(), ) -> None: + """Write a points element to a zarr store. + + Parameters + ---------- + points: DaskDataFrame + The dataframe of the points element. + group: zarr.Group + The zarr group to in the 'points' zarr group to write the points element to. + group_type: str + The type of the element + format: + The format of the points element used to store it. + """ axes = get_axes_names(points) - t = _get_transformations(points) + transformations = _get_transformations(points) store_root = group.store_path.store.root - group_path = group.path - path = store_root / group_path / "points.parquet" + path = store_root / group.path / "points.parquet" # The following code iterates through all columns in the 'points' DataFrame. If the column's datatype is # 'category', it checks whether the categories of this column are known. If not, it explicitly converts the @@ -82,5 +93,6 @@ def write_points( axes=list(axes), attrs=attrs, ) - assert t is not None - overwrite_coordinate_transformations_non_raster(group=group, axes=axes, transformations=t) + if transformations is None: + raise ValueError(f"No transformations specified for element '{group.basename}'. Cannot write.") + overwrite_coordinate_transformations_non_raster(group=group, axes=axes, transformations=transformations) diff --git a/src/spatialdata/_io/io_raster.py b/src/spatialdata/_io/io_raster.py index 80281344..feffea90 100644 --- a/src/spatialdata/_io/io_raster.py +++ b/src/spatialdata/_io/io_raster.py @@ -44,13 +44,7 @@ def _read_multiscale(store: str | Path, raster_type: Literal["image", "labels"]) image_nodes = list(image_reader) if len(image_nodes): for node in image_nodes: - # if np.any([isinstance(spec, Multiscales) for spec in node.specs]) and ( - # raster_type == "image" - # and np.all([not isinstance(spec, Label) for spec in node.specs]) - # or raster_type == "labels" - # and np.any([isinstance(spec, Label) for spec in node.specs]) - # ): - # Labels are not also Multiscales + # Labels are now also Multiscales in newer version of ome-zarr-py if np.any([isinstance(spec, Multiscales) for spec in node.specs]) and raster_type in [ "image", "labels", @@ -84,8 +78,6 @@ def _read_multiscale(store: str | Path, raster_type: Literal["image", "labels"]) # and for instance in the xenium example encoded_ngff_transformations = multiscales[0]["coordinateTransformations"] transformations = _get_transformations_from_ngff_dict(encoded_ngff_transformations) - # TODO: what to do with name? For now remove? - # name = os.path.basename(node.metadata["name"]) # if image, read channels metadata channels: list[Any] | None = None if raster_type == "image": @@ -127,28 +119,20 @@ def _write_raster( raster_data: DataArray | DataTree, group: zarr.Group, name: str, - format: Format = CurrentRasterFormat(), + raster_format: Format, storage_options: JSONDict | list[JSONDict] | None = None, label_metadata: JSONDict | None = None, **metadata: str | JSONDict | list[JSONDict], ) -> None: - assert raster_type in ["image", "labels"] - # the argument "name" and "label_metadata" are only used for labels (to be precise, name is used in - # write_multiscale_ngff() when writing metadata, but name is not used in write_image_ngff(). Maybe this is bug of - # ome-zarr-py. In any case, we don't need that metadata and we use the argument name so that when we write labels - # the correct group is created by the ome-zarr-py APIs. For images we do it manually in the function - # _get_group_for_writing_data() - if raster_type == "image": - assert label_metadata is None - else: + if raster_type not in ["image", "labels"]: + raise ValueError(f"{raster_type} is not a valid raster type. Must be 'image' or 'labels'.") + # "name" and "label_metadata" are only used for labels. "name" is written in write_multiscale_ngff() but ignored in + # write_image_ngff() (possibly an ome-zarr-py bug). We only use "name" to ensure correct group access in the + # ome-zarr API. + if raster_type == "labels": metadata["name"] = name metadata["label_metadata"] = label_metadata - write_single_scale_ngff = write_image_ngff if raster_type == "image" else write_labels_ngff - write_multi_scale_ngff = write_multiscale_ngff if raster_type == "image" else write_multiscale_labels_ngff - - group_data = group # (group[name] if name in group else group.require_group(name)) if raster_type == "image" else - # convert channel names to channel metadata in omero if raster_type == "image": metadata["metadata"] = {"omero": {"channels": []}} @@ -156,88 +140,107 @@ def _write_raster( for c in channels: metadata["metadata"]["omero"]["channels"].append({"label": c}) # type: ignore[union-attr, index, call-overload] - # TODO refactor as function is way too big if isinstance(raster_data, DataArray): - data = raster_data.data - transformations = _get_transformations(raster_data) - input_axes: tuple[str, ...] = tuple(raster_data.dims) - chunks = raster_data.chunks - parsed_axes = _get_valid_axes(axes=list(input_axes), fmt=format) - if storage_options is not None: - if "chunks" not in storage_options and isinstance(storage_options, dict): - storage_options["chunks"] = chunks - else: - storage_options = {"chunks": chunks} - # Scaler needs to be None since we are passing the data already downscaled for the multiscale case. - # We need this because the argument of write_image_ngff is called image while the argument of - # write_labels_ngff is called label. - metadata[raster_type] = data - # TODO: check purpose of _get_group_for_writing_transformations here as it seems to return same as group_data - write_single_scale_ngff( - group=group_data, - scaler=None, - fmt=format, - axes=parsed_axes, - coordinate_transformations=None, - storage_options=storage_options, - **metadata, - ) - if not transformations: - raise ValueError(f"No transformations specified to be written for element {name}.") - - # Cannot move before conditional as group_data is updated when writing ngff scales - trans_group = group["labels"][name] if raster_type == "labels" else group_data - overwrite_coordinate_transformations_raster(group=trans_group, transformations=transformations, axes=input_axes) + _write_raster_dataarray(raster_type, group, name, raster_data, raster_format, storage_options, **metadata) elif isinstance(raster_data, DataTree): - data = get_pyramid_levels(raster_data, attr="data") - list_of_input_axes: list[Any] = get_pyramid_levels(raster_data, attr="dims") - assert len(set(list_of_input_axes)) == 1 - input_axes = list_of_input_axes[0] - # saving only the transformations of the first scale - d = dict(raster_data["scale0"]) - assert len(d) == 1 - xdata = d.values().__iter__().__next__() - transformations = _get_transformations_xarray(xdata) - if not transformations: - raise ValueError(f"No transformations specified to be written for element {name}.") - chunks = get_pyramid_levels(raster_data, "chunks") - - parsed_axes = _get_valid_axes(axes=list(input_axes), fmt=format) - storage_options = [{"chunks": chunk} for chunk in chunks] - dask_delayed = write_multi_scale_ngff( - pyramid=data, - group=group_data, - fmt=format, - axes=parsed_axes, - coordinate_transformations=None, - storage_options=storage_options, - **metadata, - compute=False, - ) - # Compute all pyramid levels at once to allow Dask to optimize the computational graph. - da.compute(*dask_delayed) - - # Cannot move before conditional as group_data is updated when writing ngff scales - trans_group = group["labels"][name] if raster_type == "labels" else group_data - overwrite_coordinate_transformations_raster( - group=trans_group, transformations=transformations, axes=tuple(input_axes) - ) + _write_raster_datatree(raster_type, group, name, raster_data, raster_format, storage_options, **metadata) else: raise ValueError("Not a valid labels object") - # as explained in a comment in format.py, since coordinate transformations are not part of NGFF yet, we need to have - # our spatialdata extension also for raster type (eventually it will be dropped in favor of pure NGFF). Until then, - # saving the NGFF version (i.e. 0.4) is not enough, and we need to also record which version of the spatialdata - # format we are using for raster types - group = group_data + # Since NGFF does not yet support coordinate transformations, we need a SpatialData extension for rasters. This will + # be dropped once NGFF supports it. For now, saving the NGFF version (0.4) is not enough—we must also record the + # SpatialData format version. if ATTRS_KEY not in group.attrs: group.attrs[ATTRS_KEY] = {} attrs = group.attrs[ATTRS_KEY] - attrs["version"] = format.spatialdata_format_version + attrs["version"] = raster_format.spatialdata_format_version # triggers the write operation group.attrs[ATTRS_KEY] = attrs +def _write_raster_dataarray( + raster_type: Literal["image", "labels"], + group: zarr.Group, + name: str, + raster_data: DataArray, + raster_format: Format, + storage_options: JSONDict | list[JSONDict] | None = None, + **metadata: str | JSONDict | list[JSONDict], +) -> None: + write_single_scale_ngff = write_image_ngff if raster_type == "image" else write_labels_ngff + + data = raster_data.data + transformations = _get_transformations(raster_data) + if transformations is None: + raise ValueError(f"{name} does not have any transformations and can therefore not be written.") + input_axes: tuple[str, ...] = tuple(raster_data.dims) + chunks = raster_data.chunks + parsed_axes = _get_valid_axes(axes=list(input_axes), fmt=raster_format) + if storage_options is not None: + if "chunks" not in storage_options and isinstance(storage_options, dict): + storage_options["chunks"] = chunks + else: + storage_options = {"chunks": chunks} + # Scaler needs to be None since we are passing the data already downscaled for the multiscale case. + metadata[raster_type] = data + write_single_scale_ngff( + group=group, + scaler=None, + fmt=raster_format, + axes=parsed_axes, + coordinate_transformations=None, + storage_options=storage_options, + **metadata, + ) + + trans_group = group["labels"][name] if raster_type == "labels" else group + overwrite_coordinate_transformations_raster(group=trans_group, transformations=transformations, axes=input_axes) + + +def _write_raster_datatree( + raster_type: Literal["image", "labels"], + group: zarr.Group, + name: str, + raster_data: DataArray, + raster_format: Format, + storage_options: JSONDict | list[JSONDict] | None = None, + **metadata: str | JSONDict | list[JSONDict], +) -> None: + write_multi_scale_ngff = write_multiscale_ngff if raster_type == "image" else write_multiscale_labels_ngff + data = get_pyramid_levels(raster_data, attr="data") + list_of_input_axes: list[Any] = get_pyramid_levels(raster_data, attr="dims") + assert len(set(list_of_input_axes)) == 1 + input_axes = list_of_input_axes[0] + # saving only the transformations of the first scale + d = dict(raster_data["scale0"]) + assert len(d) == 1 + xdata = d.values().__iter__().__next__() + transformations = _get_transformations_xarray(xdata) + if transformations is None: + raise ValueError(f"{name} does not have any transformations and can therefore not be written.") + chunks = get_pyramid_levels(raster_data, "chunks") + + parsed_axes = _get_valid_axes(axes=list(input_axes), fmt=raster_format) + storage_options = [{"chunks": chunk} for chunk in chunks] + dask_delayed = write_multi_scale_ngff( + pyramid=data, + group=group, + fmt=raster_format, + axes=parsed_axes, + coordinate_transformations=None, + storage_options=storage_options, + **metadata, + compute=False, + ) + # Compute all pyramid levels at once to allow Dask to optimize the computational graph. + da.compute(*dask_delayed) + + trans_group = group["labels"][name] if raster_type == "labels" else group + overwrite_coordinate_transformations_raster( + group=trans_group, transformations=transformations, axes=tuple(input_axes) + ) + + def write_image( image: DataArray | DataTree, group: zarr.Group, @@ -251,7 +254,7 @@ def write_image( raster_data=image, group=group, name=name, - format=format, + raster_format=format, storage_options=storage_options, **metadata, ) @@ -271,7 +274,7 @@ def write_labels( raster_data=labels, group=group, name=name, - format=format, + raster_format=format, storage_options=storage_options, label_metadata=label_metadata, **metadata, diff --git a/src/spatialdata/models/models.py b/src/spatialdata/models/models.py index 7aeb0b2c..b9a43992 100644 --- a/src/spatialdata/models/models.py +++ b/src/spatialdata/models/models.py @@ -260,6 +260,7 @@ def validate(self, data: Any) -> None: def _(self, data: DataArray) -> None: super().validate(data) self._check_chunk_size_not_too_large(data) + self._check_transforms_present(data) @validate.register(DataTree) def _(self, data: DataTree) -> None: @@ -273,6 +274,15 @@ def _(self, data: DataTree) -> None: for d in data: super().validate(data[d][name]) self._check_chunk_size_not_too_large(data) + self._check_transforms_present(data) + + def _check_transforms_present(self, data: DataArray | DataTree) -> None: + parsed_transform = _get_transformations(data) + if parsed_transform is None: + raise ValueError( + f"No transformation found for `{data}`. At least one transformation is required for " + f"raster elements, e.g. images, labels." + ) def _check_chunk_size_not_too_large(self, data: DataArray | DataTree) -> None: if isinstance(data, DataArray): From 934c2bfed1e1b50ecad88bac914975031c675c22 Mon Sep 17 00:00:00 2001 From: Wouter-Michiel Vierdag Date: Thu, 4 Sep 2025 17:34:13 +0200 Subject: [PATCH 045/126] several refactors io --- src/spatialdata/_core/spatialdata.py | 9 ++- src/spatialdata/_io/io_points.py | 12 ++-- src/spatialdata/_io/io_raster.py | 81 ++++++++++++++++++--- src/spatialdata/_io/io_shapes.py | 104 +++++++++++++++++++-------- 4 files changed, 155 insertions(+), 51 deletions(-) diff --git a/src/spatialdata/_core/spatialdata.py b/src/spatialdata/_core/spatialdata.py index d10ec455..d8d14ca6 100644 --- a/src/spatialdata/_core/spatialdata.py +++ b/src/spatialdata/_core/spatialdata.py @@ -1256,27 +1256,26 @@ def _write_element( image=element, group=element_group, name=element_name, - format=parsed_formats["raster"], + element_format=parsed_formats["raster"], ) elif element_type == "labels": write_labels( labels=element, group=root_group, name=element_name, - format=parsed_formats["raster"], + element_format=parsed_formats["raster"], ) elif element_type == "points": write_points( points=element, group=element_group, - format=parsed_formats["points"], + element_format=parsed_formats["points"], ) elif element_type == "shapes": write_shapes( shapes=element, group=element_group, - name=element_name, - format=parsed_formats["shapes"], + element_format=parsed_formats["shapes"], ) elif element_type == "tables": write_table( diff --git a/src/spatialdata/_io/io_points.py b/src/spatialdata/_io/io_points.py index 6c5a2d3a..cbf02384 100644 --- a/src/spatialdata/_io/io_points.py +++ b/src/spatialdata/_io/io_points.py @@ -50,7 +50,7 @@ def write_points( points: DaskDataFrame, group: zarr.Group, group_type: str = "ngff:points", - format: Format = CurrentPointsFormat(), + element_format: Format = CurrentPointsFormat(), ) -> None: """Write a points element to a zarr store. @@ -59,10 +59,10 @@ def write_points( points: DaskDataFrame The dataframe of the points element. group: zarr.Group - The zarr group to in the 'points' zarr group to write the points element to. + The zarr group in the 'points' zarr group to write the points element to. group_type: str - The type of the element - format: + The type of the element. + element_format: Format The format of the points element used to store it. """ axes = get_axes_names(points) @@ -84,8 +84,8 @@ def write_points( points.to_parquet(path) - attrs = format.attrs_to_dict(points.attrs) - attrs["version"] = format.spatialdata_format_version + attrs = element_format.attrs_to_dict(points.attrs) + attrs["version"] = element_format.spatialdata_format_version _write_metadata( group, diff --git a/src/spatialdata/_io/io_raster.py b/src/spatialdata/_io/io_raster.py index feffea90..a006f3bc 100644 --- a/src/spatialdata/_io/io_raster.py +++ b/src/spatialdata/_io/io_raster.py @@ -124,6 +124,27 @@ def _write_raster( label_metadata: JSONDict | None = None, **metadata: str | JSONDict | list[JSONDict], ) -> None: + """Write raster data to disk. + + Parameters + ---------- + raster_type: Literal["image", "labels"] + Whether the raster data pertains to a image or labels 'SpatialElement`. + raster_data: DataArray | DataTree + The raster data to write. + group: zarr.Group + The zarr group in the 'image' or 'labels' zarr group to write the raster data to. + name: str + The name of the raster element. + raster_format: Format + The format used to write the raster data. + storage_options: JSONDict | list[JSONDict] | None + Additional options for writing the raster data, like chunks and compression. + label_metadata: JSONDict | None + Label metadata which can only be defined when writing 'labels'. + metadata: str | JSONDict | list[JSONDict] + Additional metadata for the raster element + """ if raster_type not in ["image", "labels"]: raise ValueError(f"{raster_type} is not a valid raster type. Must be 'image' or 'labels'.") # "name" and "label_metadata" are only used for labels. "name" is written in write_multiscale_ngff() but ignored in @@ -161,18 +182,37 @@ def _write_raster( def _write_raster_dataarray( raster_type: Literal["image", "labels"], group: zarr.Group, - name: str, + element_name: str, raster_data: DataArray, raster_format: Format, storage_options: JSONDict | list[JSONDict] | None = None, **metadata: str | JSONDict | list[JSONDict], ) -> None: + """Write raster data of type DataArray to disk. + + Parameters + ---------- + raster_type: Literal["image", "labels"] + Whether the raster data pertains to a image or labels 'SpatialElement`. + group: zarr.Group + The zarr group in the 'image' or 'labels' zarr group to write the raster data to. + element_name: str + The name of the raster element. + raster_data: DataArray + The raster data to write. + raster_format: Format + The format used to write the raster data. + storage_options: JSONDict | list[JSONDict] | None + Additional options for writing the raster data, like chunks and compression. + metadata: str | JSONDict | list[JSONDict] + Additional metadata for the raster element + """ write_single_scale_ngff = write_image_ngff if raster_type == "image" else write_labels_ngff data = raster_data.data transformations = _get_transformations(raster_data) if transformations is None: - raise ValueError(f"{name} does not have any transformations and can therefore not be written.") + raise ValueError(f"{element_name} does not have any transformations and can therefore not be written.") input_axes: tuple[str, ...] = tuple(raster_data.dims) chunks = raster_data.chunks parsed_axes = _get_valid_axes(axes=list(input_axes), fmt=raster_format) @@ -193,19 +233,38 @@ def _write_raster_dataarray( **metadata, ) - trans_group = group["labels"][name] if raster_type == "labels" else group + trans_group = group["labels"][element_name] if raster_type == "labels" else group overwrite_coordinate_transformations_raster(group=trans_group, transformations=transformations, axes=input_axes) def _write_raster_datatree( raster_type: Literal["image", "labels"], group: zarr.Group, - name: str, - raster_data: DataArray, + element_name: str, + raster_data: DataTree, raster_format: Format, storage_options: JSONDict | list[JSONDict] | None = None, **metadata: str | JSONDict | list[JSONDict], ) -> None: + """Write raster data of type DataTree to disk. + + Parameters + ---------- + raster_type: Literal["image", "labels"] + Whether the raster data pertains to a image or labels 'SpatialElement`. + group: zarr.Group + The zarr group in the 'image' or 'labels' zarr group to write the raster data to. + element_name: str + The name of the raster element. + raster_data: DataTree + The raster data to write. + raster_format: Format + The format used to write the raster data. + storage_options: JSONDict | list[JSONDict] | None + Additional options for writing the raster data, like chunks and compression. + metadata: str | JSONDict | list[JSONDict] + Additional metadata for the raster element + """ write_multi_scale_ngff = write_multiscale_ngff if raster_type == "image" else write_multiscale_labels_ngff data = get_pyramid_levels(raster_data, attr="data") list_of_input_axes: list[Any] = get_pyramid_levels(raster_data, attr="dims") @@ -217,7 +276,7 @@ def _write_raster_datatree( xdata = d.values().__iter__().__next__() transformations = _get_transformations_xarray(xdata) if transformations is None: - raise ValueError(f"{name} does not have any transformations and can therefore not be written.") + raise ValueError(f"{element_name} does not have any transformations and can therefore not be written.") chunks = get_pyramid_levels(raster_data, "chunks") parsed_axes = _get_valid_axes(axes=list(input_axes), fmt=raster_format) @@ -235,7 +294,7 @@ def _write_raster_datatree( # Compute all pyramid levels at once to allow Dask to optimize the computational graph. da.compute(*dask_delayed) - trans_group = group["labels"][name] if raster_type == "labels" else group + trans_group = group["labels"][element_name] if raster_type == "labels" else group overwrite_coordinate_transformations_raster( group=trans_group, transformations=transformations, axes=tuple(input_axes) ) @@ -245,7 +304,7 @@ def write_image( image: DataArray | DataTree, group: zarr.Group, name: str, - format: Format = CurrentRasterFormat(), + element_format: Format = CurrentRasterFormat(), storage_options: JSONDict | list[JSONDict] | None = None, **metadata: str | JSONDict | list[JSONDict], ) -> None: @@ -254,7 +313,7 @@ def write_image( raster_data=image, group=group, name=name, - raster_format=format, + raster_format=element_format, storage_options=storage_options, **metadata, ) @@ -264,7 +323,7 @@ def write_labels( labels: DataArray | DataTree, group: zarr.Group, name: str, - format: Format = CurrentRasterFormat(), + element_format: Format = CurrentRasterFormat(), storage_options: JSONDict | list[JSONDict] | None = None, label_metadata: JSONDict | None = None, **metadata: JSONDict, @@ -274,7 +333,7 @@ def write_labels( raster_data=labels, group=group, name=name, - raster_format=format, + raster_format=element_format, storage_options=storage_options, label_metadata=label_metadata, **metadata, diff --git a/src/spatialdata/_io/io_shapes.py b/src/spatialdata/_io/io_shapes.py index 43d294eb..50527e77 100644 --- a/src/spatialdata/_io/io_shapes.py +++ b/src/spatialdata/_io/io_shapes.py @@ -1,5 +1,6 @@ from collections.abc import MutableMapping from pathlib import Path +from typing import Any import numpy as np import zarr @@ -71,38 +72,32 @@ def _read_shapes( def write_shapes( shapes: GeoDataFrame, group: zarr.Group, - name: str, group_type: str = "ngff:shapes", - format: Format = CurrentShapesFormat(), + element_format: Format = CurrentShapesFormat(), ) -> None: - import numcodecs + """Write shapes to spatialdata zarr store. + Parameters + ---------- + shapes: GeoDataFrame + The shapes dataframe + group: zarr.Group + The zarr group in the 'shapes' zarr group to write the shapes element to. + group_type: str + The type of the element. + element_format: Format + The format of the shapes element used to store it. + """ axes = get_axes_names(shapes) - t = _get_transformations(shapes) - - if isinstance(format, ShapesFormatV01): - geometry, coords, offsets = to_ragged_array(shapes.geometry) - group.create_array(name="coords", data=coords) - for i, o in enumerate(offsets): - group.create_array(name=f"offset{i}", data=o) - if shapes.index.dtype.kind == "U" or shapes.index.dtype.kind == "O": - group.create_array(name="Index", data=shapes.index.values, dtype=object, object_codec=numcodecs.VLenUTF8()) - else: - group.create_array(name="Index", data=shapes.index.values) - if geometry.name == "POINT": - group.create_array(name=ShapesModel.RADIUS_KEY, data=shapes[ShapesModel.RADIUS_KEY].values) - - attrs = format.attrs_to_dict(geometry) - attrs["version"] = format.spatialdata_format_version - elif isinstance(format, ShapesFormatV02 | ShapesFormatV03): - store_root = group.store_path.store.root - path = store_root / group.path / "shapes.parquet" - shapes.to_parquet(path) - - attrs = format.attrs_to_dict(shapes.attrs) - attrs["version"] = format.spatialdata_format_version + transformations = _get_transformations(shapes) + if transformations is None: + raise ValueError(f"{group.basename} does not have any transformations and can therefore not be written.") + if isinstance(element_format, ShapesFormatV01): + attrs = _write_shapes_v01(shapes, group, element_format) + elif isinstance(element_format, ShapesFormatV02 | ShapesFormatV03): + attrs = _write_shapes_v02_v03(shapes, group, element_format) else: - raise ValueError(f"Unsupported format version {format.version}. Please update the spatialdata library.") + raise ValueError(f"Unsupported format version {element_format.version}. Please update the spatialdata library.") _write_metadata( group, @@ -110,5 +105,56 @@ def write_shapes( axes=list(axes), attrs=attrs, ) - assert t is not None - overwrite_coordinate_transformations_non_raster(group=group, axes=axes, transformations=t) + + overwrite_coordinate_transformations_non_raster(group=group, axes=axes, transformations=transformations) + + +def _write_shapes_v01(shapes: GeoDataFrame, group: zarr.Group, element_format: Format) -> Any: + """Write shapes to spatialdata zarr store using format ShapesFormatV01. + + Parameters + ---------- + shapes: GeoDataFrame + The shapes dataframe + group: zarr.Group + The zarr group in the 'shapes' zarr group to write the shapes element to. + element_format: Format + The format of the shapes element used to store it. + """ + import numcodecs + + geometry, coords, offsets = to_ragged_array(shapes.geometry) + group.create_array(name="coords", data=coords) + for i, o in enumerate(offsets): + group.create_array(name=f"offset{i}", data=o) + if shapes.index.dtype.kind == "U" or shapes.index.dtype.kind == "O": + group.create_array(name="Index", data=shapes.index.values, dtype=object, object_codec=numcodecs.VLenUTF8()) + else: + group.create_array(name="Index", data=shapes.index.values) + if geometry.name == "POINT": + group.create_array(name=ShapesModel.RADIUS_KEY, data=shapes[ShapesModel.RADIUS_KEY].values) + + attrs = element_format.attrs_to_dict(geometry) + attrs["version"] = element_format.spatialdata_format_version + return attrs + + +def _write_shapes_v02_v03(shapes: GeoDataFrame, group: zarr.Group, element_format: Format) -> Any: + """Write shapes to spatialdata zarr store using format ShapesFormatV02 or ShapesFormatV02. + + Parameters + ---------- + shapes: GeoDataFrame + The shapes dataframe + group: zarr.Group + The zarr group in the 'shapes' zarr group to write the shapes element to. + element_format: Format + The format of the shapes element used to store it. + """ + store_root = group.store_path.store.root + path = store_root / group.path / "shapes.parquet" + shapes.to_parquet(path) + + attrs = element_format.attrs_to_dict(shapes.attrs) + attrs["version"] = element_format.spatialdata_format_version + return attrs From ee78d3060d301a9ede0cda650061e40d0386cd16 Mon Sep 17 00:00:00 2001 From: Wouter-Michiel Vierdag Date: Fri, 5 Sep 2025 10:02:33 +0200 Subject: [PATCH 046/126] war on warnings --- src/spatialdata/_core/_deepcopy.py | 2 +- src/spatialdata/_core/concatenate.py | 4 +- src/spatialdata/_core/operations/_utils.py | 2 +- src/spatialdata/_core/operations/aggregate.py | 2 +- .../_core/query/relational_query.py | 2 +- src/spatialdata/_core/query/spatial_query.py | 25 ----- src/spatialdata/_core/spatialdata.py | 91 ++----------------- src/spatialdata/_io/io_points.py | 4 +- src/spatialdata/_io/io_table.py | 13 +-- src/spatialdata/dataloader/datasets.py | 4 +- src/spatialdata/models/__init__.py | 3 - src/spatialdata/models/_utils.py | 28 ------ tests/conftest.py | 2 +- tests/core/operations/test_rasterize.py | 2 +- tests/core/operations/test_rasterize_bins.py | 4 +- .../operations/test_spatialdata_operations.py | 4 +- tests/core/operations/test_transform.py | 4 +- tests/core/query/test_spatial_query.py | 6 +- tests/core/test_centroids.py | 2 +- tests/io/test_format.py | 10 +- tests/io/test_multi_table.py | 24 ----- tests/io/test_pyramids_performance.py | 2 +- tests/io/test_readwrite.py | 21 ----- tests/models/test_models.py | 4 +- 24 files changed, 44 insertions(+), 221 deletions(-) diff --git a/src/spatialdata/_core/_deepcopy.py b/src/spatialdata/_core/_deepcopy.py index d55db756..6a5b4336 100644 --- a/src/spatialdata/_core/_deepcopy.py +++ b/src/spatialdata/_core/_deepcopy.py @@ -46,7 +46,7 @@ def _(sdata: SpatialData) -> SpatialData: for _, element_name, element in sdata.gen_elements(): elements_dict[element_name] = deepcopy(element) deepcopied_attrs = _deepcopy(sdata.attrs) - return SpatialData.from_elements_dict(elements_dict, attrs=deepcopied_attrs) + return SpatialData.init_from_elements(elements_dict, attrs=deepcopied_attrs) @deepcopy.register(DataArray) diff --git a/src/spatialdata/_core/concatenate.py b/src/spatialdata/_core/concatenate.py index 953e4f2f..68e3d17e 100644 --- a/src/spatialdata/_core/concatenate.py +++ b/src/spatialdata/_core/concatenate.py @@ -252,6 +252,8 @@ def _fix_ensure_unique_element_names( tables_by_sdata.append(tables) sdatas_fixed = [] for elements, tables in zip(elements_by_sdata, tables_by_sdata, strict=True): - sdata = SpatialData.init_from_elements(elements, tables=tables) + if tables is not None: + elements.update(tables) + sdata = SpatialData.init_from_elements(elements) sdatas_fixed.append(sdata) return sdatas_fixed diff --git a/src/spatialdata/_core/operations/_utils.py b/src/spatialdata/_core/operations/_utils.py index 2879af45..d3c438ab 100644 --- a/src/spatialdata/_core/operations/_utils.py +++ b/src/spatialdata/_core/operations/_utils.py @@ -136,7 +136,7 @@ def transform_to_data_extent( set_transformation(el, transformation={coordinate_system: Identity()}, set_all=True) for k, v in sdata.tables.items(): sdata_to_return_elements[k] = v.copy() - return SpatialData.from_elements_dict(sdata_to_return_elements, attrs=sdata.attrs) + return SpatialData.init_from_elements(sdata_to_return_elements, attrs=sdata.attrs) def _parse_element( diff --git a/src/spatialdata/_core/operations/aggregate.py b/src/spatialdata/_core/operations/aggregate.py index 83a3d6f9..809b9e52 100644 --- a/src/spatialdata/_core/operations/aggregate.py +++ b/src/spatialdata/_core/operations/aggregate.py @@ -242,7 +242,7 @@ def _create_sdata_from_table_and_shapes( if deepcopy: shapes = _deepcopy(shapes) - return SpatialData.from_elements_dict({shapes_name: shapes, table_name: table}) + return SpatialData.init_from_elements({shapes_name: shapes, table_name: table}) def _aggregate_image_by_labels( diff --git a/src/spatialdata/_core/query/relational_query.py b/src/spatialdata/_core/query/relational_query.py index 0803158c..b0f9978d 100644 --- a/src/spatialdata/_core/query/relational_query.py +++ b/src/spatialdata/_core/query/relational_query.py @@ -679,7 +679,7 @@ def join_spatialelement_table( if sdata is not None: elements_dict = _create_sdata_elements_dict_for_join(sdata, spatial_element_names) else: - derived_sdata = SpatialData.from_elements_dict(dict(zip(spatial_element_names, spatial_elements, strict=True))) + derived_sdata = SpatialData.init_from_elements(dict(zip(spatial_element_names, spatial_elements, strict=True))) element_types = ["labels", "shapes", "points"] elements_dict = defaultdict(lambda: defaultdict(dict)) for element_type in element_types: diff --git a/src/spatialdata/_core/query/spatial_query.py b/src/spatialdata/_core/query/spatial_query.py index 0bb639d2..72e9377b 100644 --- a/src/spatialdata/_core/query/spatial_query.py +++ b/src/spatialdata/_core/query/spatial_query.py @@ -797,19 +797,6 @@ def _( return queried_polygons -# TODO: we can replace the manually triggered deprecation warning heres with the decorator from Wouter -def _check_deprecated_kwargs(kwargs: dict[str, Any]) -> None: - deprecated_args = ["shapes", "points", "images", "labels"] - for arg in deprecated_args: - if arg in kwargs and kwargs[arg] is False: - warnings.warn( - f"The '{arg}' argument is deprecated and will be removed in one of the next following releases. Please " - f"filter the SpatialData object before calling this function.", - DeprecationWarning, - stacklevel=2, - ) - - @singledispatch def polygon_query( element: SpatialElement | SpatialData, @@ -817,10 +804,6 @@ def polygon_query( target_coordinate_system: str, filter_table: bool = True, clip: bool = False, - shapes: bool = True, - points: bool = True, - images: bool = True, - labels: bool = True, ) -> SpatialElement | SpatialData | None: """ Query a SpatialData object or a SpatialElement by a polygon or multipolygon. @@ -879,12 +862,7 @@ def _( target_coordinate_system: str, filter_table: bool = True, clip: bool = False, - shapes: bool = True, - points: bool = True, - images: bool = True, - labels: bool = True, ) -> SpatialData: - _check_deprecated_kwargs({"shapes": shapes, "points": points, "images": images, "labels": labels}) new_elements = {} for element_type in ["points", "images", "labels", "shapes"]: elements = getattr(sdata, element_type) @@ -911,7 +889,6 @@ def _( return_request_only: bool = False, **kwargs: Any, ) -> DataArray | DataTree | None: - _check_deprecated_kwargs(kwargs) gdf = GeoDataFrame(geometry=[polygon]) min_x, min_y, max_x, max_y = gdf.bounds.values.flatten().tolist() return bounding_box_query( @@ -933,7 +910,6 @@ def _( ) -> DaskDataFrame | None: from spatialdata.transformations import get_transformation, set_transformation - _check_deprecated_kwargs(kwargs) polygon_gdf = _get_polygon_in_intrinsic_coordinates(points, target_coordinate_system, polygon) points_gdf = points_dask_dataframe_to_geopandas(points, suppress_z_warning=True) @@ -966,7 +942,6 @@ def _( ) -> GeoDataFrame | None: from spatialdata.transformations import get_transformation, set_transformation - _check_deprecated_kwargs(kwargs) polygon_gdf = _get_polygon_in_intrinsic_coordinates(element, target_coordinate_system, polygon) polygon = polygon_gdf["geometry"].iloc[0] diff --git a/src/spatialdata/_core/spatialdata.py b/src/spatialdata/_core/spatialdata.py index d8d14ca6..e511183f 100644 --- a/src/spatialdata/_core/spatialdata.py +++ b/src/spatialdata/_core/spatialdata.py @@ -33,7 +33,6 @@ from spatialdata._logging import logger from spatialdata._types import ArrayLike, Raster_T from spatialdata._utils import ( - _deprecation_alias, _error_message_add_element, ) from spatialdata.models import ( @@ -234,34 +233,6 @@ def validate_table_in_spatialdata(self, table: AnnData) -> None: f"the annotated element ({dtype})." ) - @staticmethod - def from_elements_dict( - elements_dict: dict[str, SpatialElement | AnnData], - attrs: Mapping[Any, Any] | None = None, - ) -> SpatialData: - """ - Create a SpatialData object from a dict of elements. - - Parameters - ---------- - elements_dict - Dict of elements. The keys are the names of the elements and the values are the elements. - A table can be present in the dict, but only at most one; its name is not used and can be anything. - attrs - Additional attributes to store in the SpatialData object. - - Returns - ------- - The SpatialData object. - """ - warnings.warn( - 'This method is deprecated and will be removed in a future release. Use "SpatialData.init_from_elements(' - ')" instead. For the moment, such methods will be automatically called.', - DeprecationWarning, - stacklevel=2, - ) - return SpatialData.init_from_elements(elements=elements_dict, attrs=attrs) - @staticmethod def get_annotated_regions(table: AnnData) -> list[str]: """ @@ -816,7 +787,6 @@ def rename_coordinate_systems(self, rename_dict: dict[str, str]) -> None: # set the new transformations set_transformation(element=element, transformation=new_transformations, set_all=True) - @_deprecation_alias(element="element_name", version="0.3.0") def transform_element_to_coordinate_system( self, element_name: str, @@ -1282,7 +1252,7 @@ def _write_element( table=element, group=element_type_group, name=element_name, - format=parsed_formats["tables"], + element_format=parsed_formats["tables"], ) else: raise ValueError(f"Unknown element type: {element_type}") @@ -1304,7 +1274,7 @@ def write_element( The name(s) of the element(s) to write. overwrite If True, overwrite the element if it already exists. - format + sdata_formats It is recommended to leave this parameter equal to `None`. See more details in the documentation of `SpatialData.write()`. @@ -1823,51 +1793,6 @@ def tables(self, tables: dict[str, AnnData]) -> None: TableModel().validate(v) self._tables[k] = v - @property - def table(self) -> None | AnnData: - """ - Return table with name table from tables if it exists. - - Returns - ------- - The table. - """ - warnings.warn( - "Table accessor will be deprecated with SpatialData version 0.1, use sdata.tables instead.", - DeprecationWarning, - stacklevel=2, - ) - # Isinstance will still return table if anndata has 0 rows. - if isinstance(self.tables.get("table"), AnnData): - return self.tables["table"] - return None - - @table.setter - def table(self, table: AnnData) -> None: - warnings.warn( - "Table setter will be deprecated with SpatialData version 0.1, use tables instead.", - DeprecationWarning, - stacklevel=2, - ) - TableModel().validate(table) - if self.tables.get("table") is not None: - raise ValueError("The table already exists. Use del sdata.tables['table'] to remove it first.") - self.tables["table"] = table - - @table.deleter - def table(self) -> None: - """Delete the table.""" - warnings.warn( - "del sdata.table will be deprecated with SpatialData version 0.1, use del sdata.tables['table'] instead.", - DeprecationWarning, - stacklevel=2, - ) - if self.tables.get("table"): - del self.tables["table"] - else: - # More informative than the error in the zarr library. - raise KeyError("table with name 'table' not present in the SpatialData object.") - @staticmethod def read(file_path: Path | str, selection: tuple[str] | None = None) -> SpatialData: """ @@ -2185,14 +2110,14 @@ def _gen_spatial_element_values(self) -> Generator[SpatialElement, None, None]: yield from d.values() def _gen_elements( - self, include_table: bool = False + self, include_tables: bool = False ) -> Generator[tuple[str, str, SpatialElement | AnnData], None, None]: """ Generate elements contained in the SpatialData instance. Parameters ---------- - include_table + include_tables Whether to also generate table elements. Returns @@ -2201,7 +2126,7 @@ def _gen_elements( itself. """ element_types = ["images", "labels", "points", "shapes"] - if include_table: + if include_tables: element_types.append("tables") for element_type in element_types: d = getattr(SpatialData, element_type).fget(self) @@ -2235,7 +2160,7 @@ def gen_elements( ------- A generator that yields tuples containing the name, description, and element objects themselves. """ - return self._gen_elements(include_table=True) + return self._gen_elements(include_tables=True) def _validate_element_names_are_unique(self) -> None: """ @@ -2370,7 +2295,7 @@ def subset( """ elements_dict: dict[str, SpatialElement] = {} names_tables_to_keep: set[str] = set() - for element_type, element_name, element in self._gen_elements(include_table=True): + for element_type, element_name, element in self._gen_elements(include_tables=True): if element_name in element_names: if element_type != "tables": elements_dict.setdefault(element_type, {})[element_name] = element @@ -2403,7 +2328,7 @@ def __getitem__(self, item: str) -> SpatialElement | AnnData: def __contains__(self, key: str) -> bool: element_dict = { - element_name: element_value for _, element_name, element_value in self._gen_elements(include_table=True) + element_name: element_value for _, element_name, element_value in self._gen_elements(include_tables=True) } return key in element_dict diff --git a/src/spatialdata/_io/io_points.py b/src/spatialdata/_io/io_points.py index cbf02384..1d6c9d2d 100644 --- a/src/spatialdata/_io/io_points.py +++ b/src/spatialdata/_io/io_points.py @@ -28,7 +28,7 @@ def _read_points( version = _parse_version(f, expect_attrs_key=True) assert version is not None - format = PointsFormats[version] + points_format = PointsFormats[version] store_root = f.store_path.store.root path = store_root / f.path / "points.parquet" @@ -40,7 +40,7 @@ def _read_points( transformations = _get_transformations_from_ngff_dict(f.attrs.asdict()["coordinateTransformations"]) _set_transformations(points, transformations) - attrs = format.attrs_from_dict(f.attrs.asdict()) + attrs = points_format.attrs_from_dict(f.attrs.asdict()) if len(attrs): points.attrs["spatialdata_attrs"] = attrs return points diff --git a/src/spatialdata/_io/io_table.py b/src/spatialdata/_io/io_table.py index 8776d0d6..24c19271 100644 --- a/src/spatialdata/_io/io_table.py +++ b/src/spatialdata/_io/io_table.py @@ -8,6 +8,7 @@ from anndata import read_zarr as read_anndata_zarr from anndata._io.specs import write_elem as write_adata from ome_zarr.format import Format +from zarr.errors import ArrayNotFoundError # from zarr.errors import ArrayNotFoundError # removed in zarr 3.0 from spatialdata._io._utils import BadFileHandleMethod, handle_read_errors @@ -52,7 +53,7 @@ def _read_table( JSONDecodeError, KeyError, ValueError, - # ArrayNotFoundError, # removed in zarr 3.0 + ArrayNotFoundError, ), ): tables[table_name] = read_anndata_zarr(f_elem_store) @@ -65,6 +66,7 @@ def _read_table( _ = TablesFormats[version] f.store.close() + # TODO: implement per read logic of format # # replace with format from above # version = "0.1" # format = TablesFormats[version] @@ -92,21 +94,20 @@ def write_table( group: zarr.Group, name: str, group_type: str = "ngff:regions_table", - format: Format = CurrentTablesFormat(), + element_format: Format = CurrentTablesFormat(), ) -> None: if TableModel.ATTRS_KEY in table.uns: region = table.uns["spatialdata_attrs"]["region"] region_key = table.uns["spatialdata_attrs"].get("region_key", None) instance_key = table.uns["spatialdata_attrs"].get("instance_key", None) - format.validate_table(table, region_key, instance_key) + element_format.validate_table(table, region_key, instance_key) else: region, region_key, instance_key = (None, None, None) - # TODO: Problem is that tables group already exists and thus has consolidated_metadata, cannot write repeatedly. - write_adata(group, name, table) # creates group[name] + write_adata(group, name, table) tables_group = group[name] tables_group.attrs["spatialdata-encoding-type"] = group_type tables_group.attrs["region"] = region tables_group.attrs["region_key"] = region_key tables_group.attrs["instance_key"] = instance_key - tables_group.attrs["version"] = format.spatialdata_format_version + tables_group.attrs["version"] = element_format.spatialdata_format_version diff --git a/src/spatialdata/dataloader/datasets.py b/src/spatialdata/dataloader/datasets.py index 4dc43323..4bf14a91 100644 --- a/src/spatialdata/dataloader/datasets.py +++ b/src/spatialdata/dataloader/datasets.py @@ -258,9 +258,7 @@ def _preprocess( if table_name is not None: table_subset = filtered_table[filtered_table.obs[region_key] == region_name] - circles_sdata = SpatialData.init_from_elements( - {region_name: circles}, tables={"table": table_subset.copy()} - ) + circles_sdata = SpatialData.init_from_elements({region_name: circles, "table": table_subset.copy()}) _, table = join_spatialelement_table( sdata=circles_sdata, spatial_element_names=region_name, diff --git a/src/spatialdata/models/__init__.py b/src/spatialdata/models/__init__.py index 3c86fa0e..ba064e0a 100644 --- a/src/spatialdata/models/__init__.py +++ b/src/spatialdata/models/__init__.py @@ -8,7 +8,6 @@ force_2d, get_axes_names, get_channel_names, - get_channels, get_spatial_axes, points_dask_dataframe_to_geopandas, points_geopandas_to_dask_dataframe, @@ -50,9 +49,7 @@ "points_dask_dataframe_to_geopandas", "check_target_region_column_symmetry", "get_table_keys", - "get_channels", "get_channel_names", "set_channel_names", "force_2d", - "RasterSchema", ] diff --git a/src/spatialdata/models/_utils.py b/src/spatialdata/models/_utils.py index b6c31821..db5ebac9 100644 --- a/src/spatialdata/models/_utils.py +++ b/src/spatialdata/models/_utils.py @@ -290,34 +290,6 @@ def get_channel_names(data: Any) -> list[Any]: raise ValueError(f"Cannot get channels from {type(data)}") -def get_channels(data: Any) -> list[Any]: - """Get channels from data for an image element (both single and multiscale). - - [Deprecation] This function will be deprecated in version 0.3.0. Please use - `get_channel_names`. - - Parameters - ---------- - data - data to get channels from - - Returns - ------- - List of channels - - Notes - ----- - For multiscale images, the channels are validated to be consistent across scales. - """ - warnings.warn( - "The function 'get_channels' is deprecated and will be removed in version 0.3.0. " - "Please use 'get_channel_names' instead.", - DeprecationWarning, - stacklevel=2, # Adjust the stack level to point to the caller - ) - return get_channel_names(data) - - @get_channel_names.register def _(data: DataArray) -> list[Any]: return data.coords["c"].values.tolist() # type: ignore[no-any-return] diff --git a/tests/conftest.py b/tests/conftest.py index e0599128..df69dc0e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -445,7 +445,7 @@ def _make_sdata_for_testing_querying_and_aggretation() -> SpatialData: table = TableModel.parse( table, region=["values_circles", "values_polygons"], region_key="region", instance_key="instance_id" ) - sdata.table = table + sdata["table"] = table return sdata diff --git a/tests/core/operations/test_rasterize.py b/tests/core/operations/test_rasterize.py index 261768e1..50be3f45 100644 --- a/tests/core/operations/test_rasterize.py +++ b/tests/core/operations/test_rasterize.py @@ -134,7 +134,7 @@ def test_rasterize_labels_value_key_specified(): region_key="region", instance_key="instance_id", ) - sdata = SpatialData.init_from_elements({element_name: raster}, tables={table_name: table}) + sdata = SpatialData.init_from_elements({element_name: raster, table_name: table}) result = rasterize( data=element_name, sdata=sdata, diff --git a/tests/core/operations/test_rasterize_bins.py b/tests/core/operations/test_rasterize_bins.py index 6596325c..d7e6172b 100644 --- a/tests/core/operations/test_rasterize_bins.py +++ b/tests/core/operations/test_rasterize_bins.py @@ -63,7 +63,7 @@ def test_rasterize_bins(geometry: str, value_key: str | list[str] | None, return table = TableModel.parse( AnnData(X=X, var=var, obs=obs), region="points", region_key="region", instance_key="instance_id" ) - sdata = SpatialData.init_from_elements({"points": points}, tables={"table": table}) + sdata = SpatialData.init_from_elements({"points": points, "table": table}) rasterized = rasterize_bins( sdata=sdata, bins="points", @@ -130,7 +130,7 @@ def _get_sdata(n: int): region_key="region", instance_key="instance_id", ) - return SpatialData.init_from_elements({"points": points}, tables={"table": table}) + return SpatialData.init_from_elements({"points": points, "table": table}) # sdata with not enough bins (2*2) to estimate transformation sdata = _get_sdata(n=2) diff --git a/tests/core/operations/test_spatialdata_operations.py b/tests/core/operations/test_spatialdata_operations.py index a436d343..0a8b965b 100644 --- a/tests/core/operations/test_spatialdata_operations.py +++ b/tests/core/operations/test_spatialdata_operations.py @@ -174,7 +174,7 @@ def test_filter_by_coordinate_system_also_table(full_sdata: SpatialData) -> None adata = full_sdata["table"] del adata.uns[TableModel.ATTRS_KEY] del full_sdata.tables["table"] - full_sdata.table = TableModel.parse( + full_sdata["table"] = TableModel.parse( adata, region=["circles", "poly"], region_key="annotated_shapes", @@ -520,7 +520,7 @@ def test_init_from_elements(full_sdata: SpatialData) -> None: for element_type in ["images", "labels", "points", "shapes", "tables"]: assert set(getattr(sdata, element_type).keys()) == set(getattr(full_sdata, element_type).keys()) - all_elements = {name: el for _, name, el in full_sdata._gen_elements(include_table=True)} + all_elements = {name: el for _, name, el in full_sdata._gen_elements(include_tables=True)} sdata = SpatialData.init_from_elements(all_elements) for element_type in ["images", "labels", "points", "shapes", "tables"]: assert set(getattr(sdata, element_type).keys()) == set(getattr(full_sdata, element_type).keys()) diff --git a/tests/core/operations/test_transform.py b/tests/core/operations/test_transform.py index 81959e2f..13f4bbc3 100644 --- a/tests/core/operations/test_transform.py +++ b/tests/core/operations/test_transform.py @@ -578,9 +578,9 @@ def test_transform_elements_and_entire_spatial_data_object_multi_hop( if "global" in d: remove_transformation(element, "global") - for element in full_sdata._gen_spatial_element_values(): + for _, name, element in full_sdata._gen_elements(include_tables=False): transformed_element = full_sdata.transform_element_to_coordinate_system( - element, "multi_hop_space", maintain_positioning=maintain_positioning + name, "multi_hop_space", maintain_positioning=maintain_positioning ) temp = SpatialData( images=dict(full_sdata.images), diff --git a/tests/core/query/test_spatial_query.py b/tests/core/query/test_spatial_query.py index 65905a97..d4fcd6da 100644 --- a/tests/core/query/test_spatial_query.py +++ b/tests/core/query/test_spatial_query.py @@ -555,20 +555,18 @@ def test_polygon_query_with_multipolygon(sdata_query_aggregation): values_sdata, polygon=polygon, target_coordinate_system="global", - shapes=True, - points=False, ) assert len(queried["values_polygons"]) == 4 assert len(queried["values_circles"]) == 4 assert len(queried["table"]) == 8 - multipolygon = GeoDataFrame(geometry=[polygon, circle_pol]).unary_union + multipolygon = GeoDataFrame(geometry=[polygon, circle_pol]).union_all() queried = polygon_query(values_sdata, polygon=multipolygon, target_coordinate_system="global") assert len(queried["values_polygons"]) == 8 assert len(queried["values_circles"]) == 8 assert len(queried["table"]) == 16 - multipolygon = GeoDataFrame(geometry=[polygon, polygon]).unary_union + multipolygon = GeoDataFrame(geometry=[polygon, polygon]).union_all() queried = polygon_query(values_sdata, polygon=multipolygon, target_coordinate_system="global") assert len(queried["values_polygons"]) == 4 assert len(queried["values_circles"]) == 4 diff --git a/tests/core/test_centroids.py b/tests/core/test_centroids.py index b8ecc441..5011872b 100644 --- a/tests/core/test_centroids.py +++ b/tests/core/test_centroids.py @@ -157,7 +157,7 @@ def test_get_centroids_labels( assert np.array_equal(centroids.index.values, labels_indices) if not return_background: - assert 0 not in centroids.index + assert not (centroids.index == 0).any() if coordinate_system == "global": assert np.array_equal(centroids.compute().values, expected_centroids.values) diff --git a/tests/io/test_format.py b/tests/io/test_format.py index cf373119..aeb8762e 100644 --- a/tests/io/test_format.py +++ b/tests/io/test_format.py @@ -27,26 +27,26 @@ class TestFormat: """Test format.""" - @pytest.mark.parametrize("format", [PointsFormatV01()]) + @pytest.mark.parametrize("element_format", [PointsFormatV01()]) @pytest.mark.parametrize("attrs_key", [PointsModel.ATTRS_KEY]) @pytest.mark.parametrize("feature_key", [None, PointsModel.FEATURE_KEY]) @pytest.mark.parametrize("instance_key", [None, PointsModel.INSTANCE_KEY]) def test_format_points_v1( self, - format: PointsFormatType, + element_format: PointsFormatType, attrs_key: str | None, feature_key: str | None, instance_key: str | None, ) -> None: - metadata: dict[str, Any] = {attrs_key: {"version": format.spatialdata_format_version}} + metadata: dict[str, Any] = {attrs_key: {"version": element_format.spatialdata_format_version}} format_metadata: dict[str, Any] = {attrs_key: {}} if feature_key is not None: metadata[attrs_key][feature_key] = "target" if instance_key is not None: metadata[attrs_key][instance_key] = "cell_id" - format_metadata[attrs_key] = format.attrs_from_dict(metadata) + format_metadata[attrs_key] = element_format.attrs_from_dict(metadata) metadata[attrs_key].pop("version") - assert metadata[attrs_key] == format.attrs_to_dict(format_metadata) + assert metadata[attrs_key] == element_format.attrs_to_dict(format_metadata) if feature_key is None and instance_key is None: assert len(format_metadata[attrs_key]) == len(metadata[attrs_key]) == 0 diff --git a/tests/io/test_multi_table.py b/tests/io/test_multi_table.py index dd43cfa8..7191f2f8 100644 --- a/tests/io/test_multi_table.py +++ b/tests/io/test_multi_table.py @@ -127,30 +127,6 @@ def test_set_table_annotates_spatialelement(self, full_sdata, tmp_path): ) full_sdata.write(tmpdir) - def test_old_accessor_deprecation(self, full_sdata, tmp_path): - # To test self._backed - tmpdir = Path(tmp_path) / "tmp.zarr" - full_sdata.write(tmpdir) - adata0 = _get_table(region="polygon") - - with pytest.warns(DeprecationWarning): - _ = full_sdata.table - with pytest.raises(ValueError): - full_sdata.table = adata0 - with pytest.warns(DeprecationWarning): - del full_sdata.table - with pytest.raises(KeyError): - del full_sdata["table"] - with pytest.warns(DeprecationWarning): - full_sdata.table = adata0 # this gets placed in sdata['table'] - - assert_equal(adata0, full_sdata["table"]) - - del full_sdata["table"] - - full_sdata.tables["my_new_table0"] = adata0 - assert full_sdata.get("table") is None - @pytest.mark.parametrize("region", ["test_shapes", "non_existing"]) def test_single_table(self, tmp_path: str, region: str): tmpdir = Path(tmp_path) / "tmp.zarr" diff --git a/tests/io/test_pyramids_performance.py b/tests/io/test_pyramids_performance.py index 33a62924..0bfcc2c4 100644 --- a/tests/io/test_pyramids_performance.py +++ b/tests/io/test_pyramids_performance.py @@ -68,7 +68,7 @@ def test_write_image_multiscale_performance(sdata_with_image: SpatialData, tmp_p image=image, group=element_type_group, name=image_name, - format=CurrentRasterFormat(), + element_format=CurrentRasterFormat(), ) # The number of chunks of scale level 0 diff --git a/tests/io/test_readwrite.py b/tests/io/test_readwrite.py index 93de3391..b20d8c66 100644 --- a/tests/io/test_readwrite.py +++ b/tests/io/test_readwrite.py @@ -252,27 +252,6 @@ def test_incremental_io_on_disk( sdata.delete_element_from_disk(name) sdata.write_element(name) - def test_incremental_io_table_legacy(self, table_single_annotation: SpatialData) -> None: - s = table_single_annotation - t = s["table"][:10, :].copy() - with pytest.raises(ValueError): - s.table = t - del s["table"] - s.table = t - - with tempfile.TemporaryDirectory() as td: - f = os.path.join(td, "data.zarr") - s.write(f) - s2 = SpatialData.read(f) - assert len(s2["table"]) == len(t) - del s2["table"] - s2.table = s["table"] - assert len(s2["table"]) == len(s["table"]) - f2 = os.path.join(td, "data2.zarr") - s2.write(f2) - s3 = SpatialData.read(f2) - assert len(s3["table"]) == len(s2["table"]) - def test_io_and_lazy_loading_points(self, points): with tempfile.TemporaryDirectory() as td: f = os.path.join(td, "data.zarr") diff --git a/tests/models/test_models.py b/tests/models/test_models.py index 5b27f3ab..fddb08ac 100644 --- a/tests/models/test_models.py +++ b/tests/models/test_models.py @@ -238,10 +238,10 @@ def test_shapes_model(self, model: ShapesModel, path: Path) -> None: assert poly.equals(other_poly) if ShapesModel.RADIUS_KEY in poly.columns: - poly[ShapesModel.RADIUS_KEY].iloc[0] = -1 + poly.loc[0, ShapesModel.RADIUS_KEY] = -1 with pytest.raises(ValueError, match="Radii of circles must be positive."): ShapesModel.validate(poly) - poly[ShapesModel.RADIUS_KEY].iloc[0] = 0 + poly.loc[0, ShapesModel.RADIUS_KEY] = 0 with pytest.raises(ValueError, match="Radii of circles must be positive."): ShapesModel.validate(poly) From 42923c4570b1617b71ea3c57b19509019a8ecb29 Mon Sep 17 00:00:00 2001 From: Wouter-Michiel Vierdag Date: Sat, 6 Sep 2025 14:10:01 +0200 Subject: [PATCH 047/126] checks backward compatibility --- src/spatialdata/_core/spatialdata.py | 47 +++++------------- src/spatialdata/_io/format.py | 8 --- src/spatialdata/_io/io_raster.py | 41 +++++++++++----- src/spatialdata/_io/io_zarr.py | 23 ++++----- tests/io/test_format.py | 73 +++++++++++++++++++++++++++- tests/io/test_readwrite.py | 31 ++++++++++++ 6 files changed, 156 insertions(+), 67 deletions(-) diff --git a/src/spatialdata/_core/spatialdata.py b/src/spatialdata/_core/spatialdata.py index e511183f..40b399c8 100644 --- a/src/spatialdata/_core/spatialdata.py +++ b/src/spatialdata/_core/spatialdata.py @@ -1429,11 +1429,9 @@ def _check_element_not_on_disk_with_different_type(self, element_type: str, elem ) def write_consolidated_metadata(self) -> None: - store = parse_url(self.path, mode="r+", fmt=FormatV05()).store - # consolidate metadata to more easily support remote reading bug in zarr. In reality, 'zmetadata' is written - # instead of '.zmetadata' see discussion https://github.com/zarr-developers/zarr-python/issues/1121 - zarr.consolidate_metadata(store) - store.close() + from spatialdata._io.io_zarr import _write_consolidated_metadata + + _write_consolidated_metadata(self.path) def has_consolidated_metadata(self) -> bool: return_value = False @@ -1794,7 +1792,9 @@ def tables(self, tables: dict[str, AnnData]) -> None: self._tables[k] = v @staticmethod - def read(file_path: Path | str, selection: tuple[str] | None = None) -> SpatialData: + def read( + file_path: Path | str, selection: tuple[str] | None = None, reconsolidate_metadata: bool = False + ) -> SpatialData: """ Read a SpatialData object from a Zarr storage (on-disk or remote). @@ -1804,6 +1804,8 @@ def read(file_path: Path | str, selection: tuple[str] | None = None) -> SpatialD The path or URL to the Zarr storage. selection The elements to read (images, labels, points, shapes, table). If None, all elements are read. + reconsolidate_metadata + If the consolidated metadata store got corrupted this can lead to errors when trying to read the data. Returns ------- @@ -1811,6 +1813,11 @@ def read(file_path: Path | str, selection: tuple[str] | None = None) -> SpatialD """ from spatialdata import read_zarr + if reconsolidate_metadata: + from spatialdata._io.io_zarr import _write_consolidated_metadata + + _write_consolidated_metadata(file_path) + return read_zarr(file_path, selection=selection) def add_image( @@ -1823,34 +1830,6 @@ def add_image( """Deprecated. Use `sdata[name] = image` instead.""" # noqa: D401 _error_message_add_element() - def add_labels( - self, - name: str, - labels: DataArray | DataTree, - storage_options: JSONDict | list[JSONDict] | None = None, - overwrite: bool = False, - ) -> None: - """Deprecated. Use `sdata[name] = labels` instead.""" # noqa: D401 - _error_message_add_element() - - def add_points( - self, - name: str, - points: DaskDataFrame, - overwrite: bool = False, - ) -> None: - """Deprecated. Use `sdata[name] = points` instead.""" # noqa: D401 - _error_message_add_element() - - def add_shapes( - self, - name: str, - shapes: GeoDataFrame, - overwrite: bool = False, - ) -> None: - """Deprecated. Use `sdata[name] = shapes` instead.""" # noqa: D401 - _error_message_add_element() - @property def images(self) -> Images: """Return images as a Dict of name to image data.""" diff --git a/src/spatialdata/_io/format.py b/src/spatialdata/_io/format.py index a10755ae..2cb1867d 100644 --- a/src/spatialdata/_io/format.py +++ b/src/spatialdata/_io/format.py @@ -182,10 +182,6 @@ def spatialdata_format_version(self) -> str: def version(self) -> str: return "0.4" - # @property - # def zarr_format(self): - # return 2 - class RasterFormatV02(RasterFormatV01): @property @@ -198,10 +194,6 @@ def version(self) -> str: # https://github.com/scverse/spatialdata/pull/849 return "0.4-dev-spatialdata" - # @property - # def zarr_format(self): - # return 2 - class RasterFormatV03(FormatV05, CoordinateMixinV01): @property diff --git a/src/spatialdata/_io/io_raster.py b/src/spatialdata/_io/io_raster.py index a006f3bc..5a1c69f4 100644 --- a/src/spatialdata/_io/io_raster.py +++ b/src/spatialdata/_io/io_raster.py @@ -6,7 +6,7 @@ import zarr from ome_zarr.format import Format from ome_zarr.io import ZarrLocation -from ome_zarr.reader import Multiscales, Node, Reader +from ome_zarr.reader import Label, Multiscales, Node, Reader from ome_zarr.types import JSONDict from ome_zarr.writer import _get_valid_axes from ome_zarr.writer import write_image as write_image_ngff @@ -33,6 +33,22 @@ ) +def _get_nodes_zarr_v3(image_nodes: list[Node], nodes: list[Node]) -> list[Node]: + if len(image_nodes): + for node in image_nodes: + # Labels are now also Multiscales in newer version of ome-zarr-py + if np.any([isinstance(spec, Multiscales) for spec in node.specs]): + nodes.append(node) + return nodes + + +def _get_label_nodes_zarr_v2(image_nodes: list[Node], nodes: list[Node]) -> list[Node]: + for node in image_nodes: + if np.any([isinstance(spec, Label) for spec in node.specs]): + nodes.append(node) + return nodes + + def _read_multiscale(store: str | Path, raster_type: Literal["image", "labels"]) -> DataArray | DataTree: assert isinstance(store, str | Path) assert raster_type in ["image", "labels"] @@ -42,14 +58,7 @@ def _read_multiscale(store: str | Path, raster_type: Literal["image", "labels"]) if exists := image_loc.exists(): image_reader = Reader(image_loc)() image_nodes = list(image_reader) - if len(image_nodes): - for node in image_nodes: - # Labels are now also Multiscales in newer version of ome-zarr-py - if np.any([isinstance(spec, Multiscales) for spec in node.specs]) and raster_type in [ - "image", - "labels", - ]: - nodes.append(node) + nodes = _get_nodes_zarr_v3(image_nodes, nodes) else: raise OSError( f"Image location {image_loc} does not seem to exist. If it does, potentially the zarr.json file " @@ -57,14 +66,18 @@ def _read_multiscale(store: str | Path, raster_type: Literal["image", "labels"]) ) if len(nodes) != 1: if exists: + nodes = _get_label_nodes_zarr_v2(image_nodes, nodes) + else: + raise ValueError( + f"len(nodes) = {len(nodes)}, expected 1 and image location {image_loc} does not exist. Unable to read " + f"the NGFF file. Please report this bug and attach a minimal data example." + ) + if len(nodes) != 1: raise OSError( f"Image location {image_loc} exists, but len(nodes) = {len(nodes)}, expected 1. Element " f"{image_loc.basename()} is potentially corrupted." ) - raise ValueError( - f"len(nodes) = {len(nodes)}, expected 1 and image location {image_loc} does not exist. Unable to read " - f"the NGFF file. Please report this bug and attach a minimal data example." - ) + node = nodes[0] datasets = node.load(Multiscales).datasets multiscales = node.load(Multiscales).zarr.root_attrs["multiscales"] @@ -171,6 +184,8 @@ def _write_raster( # Since NGFF does not yet support coordinate transformations, we need a SpatialData extension for rasters. This will # be dropped once NGFF supports it. For now, saving the NGFF version (0.4) is not enough—we must also record the # SpatialData format version. + + group = group["labels"][name] if raster_type == "labels" else group if ATTRS_KEY not in group.attrs: group.attrs[ATTRS_KEY] = {} attrs = group.attrs[ATTRS_KEY] diff --git a/src/spatialdata/_io/io_zarr.py b/src/spatialdata/_io/io_zarr.py index 85304a1f..d548c2f0 100644 --- a/src/spatialdata/_io/io_zarr.py +++ b/src/spatialdata/_io/io_zarr.py @@ -24,7 +24,9 @@ # TODO: remove with incoming remote read / write PR # Not removing this now as it requires substantial extra refactor beyond scope of zarrv3 PR. -def _open_zarr_store(store: str | Path | zarr.Group) -> tuple[zarr.Group, str]: +def _open_zarr_store( + store: str | Path | zarr.Group, mode: Literal["r", "r+", "a", "w", "w-"] = "r", use_consolidated: bool | None = None +) -> tuple[zarr.Group, str]: """ Open a zarr store (on-disk or remote) and return the zarr.Group object and the path to the store. @@ -37,16 +39,8 @@ def _open_zarr_store(store: str | Path | zarr.Group) -> tuple[zarr.Group, str]: ------- A tuple of the zarr.Group object and the path to the store. """ - f = store if isinstance(store, zarr.Group) else zarr.open_group(store, mode="r") - # workaround: .zmetadata is being written as zmetadata (https://github.com/zarr-developers/zarr-python/issues/1121) - # not needed, consolidated metadata is always used if present - # if isinstance(store, str | Path) and str(store).startswith("http") and len(f) == 0: - # f = zarr.open_consolidated(store, mode="r", metadata_key="zmetadata") - # the metadata is accessible here: - # f.metadata.consolidated_metadata.metadata - f_store_path = f.store.root - # f_store_path = f.store.store.path if isinstance(f.store, zarr.storage.ConsolidatedMetadataStore) else f.store.path - return f, f_store_path + f = store if isinstance(store, zarr.Group) else zarr.open_group(store, mode=mode, use_consolidated=use_consolidated) + return f, f.store.root def read_zarr( @@ -309,3 +303,10 @@ def _group_for_element_exists(zarr_path: Path, element_type: str, element_name: exists = element_type in root and element_name in root[element_type] store.close() return exists + + +def _write_consolidated_metadata(path: Path | str | None) -> None: + if path is not None: + f, f_store_path = _open_zarr_store(path, mode="r+", use_consolidated=False) + zarr.consolidate_metadata(f.store) + f.store.close() diff --git a/tests/io/test_format.py b/tests/io/test_format.py index aeb8762e..4b3525e4 100644 --- a/tests/io/test_format.py +++ b/tests/io/test_format.py @@ -10,6 +10,7 @@ from spatialdata._io.format import ( PointsFormatType, PointsFormatV01, + PointsFormatV02, RasterFormatV01, RasterFormatV02, RasterFormatV03, @@ -19,6 +20,8 @@ SpatialDataContainerFormatV01, SpatialDataContainerFormatV02, SpatialDataFormatType, + TablesFormatV01, + TablesFormatV02, ) from spatialdata.models import PointsModel, ShapesModel from spatialdata.testing import assert_spatial_data_objects_are_identical @@ -115,16 +118,19 @@ def test_shapes_v1_to_v2_to_v3(self, shapes): shapes.write(f1, sdata_formats=[ShapesFormatV01(), SpatialDataContainerFormatV01()]) shapes_read_v1 = read_zarr(f1) assert_spatial_data_objects_are_identical(shapes, shapes_read_v1) + assert shapes_read_v1.is_self_contained() shapes_read_v1.write(f2, sdata_formats=[ShapesFormatV02(), SpatialDataContainerFormatV01()]) shapes_read_v2 = read_zarr(f2) assert_spatial_data_objects_are_identical(shapes, shapes_read_v2) + assert shapes_read_v2.is_self_contained() shapes_read_v2.write(f3, sdata_formats=[ShapesFormatV03(), SpatialDataContainerFormatV02()]) shapes_read_v3 = read_zarr(f3) assert_spatial_data_objects_are_identical(shapes, shapes_read_v3) + assert shapes_read_v3.is_self_contained() - def test_raster_v1_to_v2_to_v3(self, images): + def test_raster_images_v1_to_v2_to_v3(self, images): with tempfile.TemporaryDirectory() as tmpdir: f1 = Path(tmpdir) / "data1.zarr" f2 = Path(tmpdir) / "data2.zarr" @@ -136,11 +142,76 @@ def test_raster_v1_to_v2_to_v3(self, images): images.write(f1, sdata_formats=[RasterFormatV01(), SpatialDataContainerFormatV01()]) images_read_v1 = read_zarr(f1) assert_spatial_data_objects_are_identical(images, images_read_v1) + assert images_read_v1.is_self_contained() images_read_v1.write(f2, sdata_formats=[RasterFormatV02(), SpatialDataContainerFormatV01()]) images_read_v2 = read_zarr(f2) assert_spatial_data_objects_are_identical(images, images_read_v2) + assert images_read_v2.is_self_contained() images_read_v2.write(f3, sdata_formats=[RasterFormatV03(), SpatialDataContainerFormatV02()]) images_read_v3 = read_zarr(f3) assert_spatial_data_objects_are_identical(images, images_read_v3) + assert images_read_v3.is_self_contained() + + def test_raster_labels_v1_to_v2_to_v3(self, labels): + with tempfile.TemporaryDirectory() as tmpdir: + f1 = Path(tmpdir) / "data1.zarr" + f2 = Path(tmpdir) / "data2.zarr" + f3 = Path(tmpdir) / "data3.zarr" + + labels.write(f1, sdata_formats=[RasterFormatV01(), SpatialDataContainerFormatV01()]) + labels_read_v1 = read_zarr(f1) + assert_spatial_data_objects_are_identical(labels, labels_read_v1) + assert labels_read_v1.is_self_contained() + + labels_read_v1.write(f2, sdata_formats=[RasterFormatV02(), SpatialDataContainerFormatV01()]) + labels_read_v2 = read_zarr(f2) + assert_spatial_data_objects_are_identical(labels, labels_read_v2) + assert labels_read_v2.is_self_contained() + + labels_read_v2.write(f3, sdata_formats=[RasterFormatV03(), SpatialDataContainerFormatV02()]) + labels_read_v3 = read_zarr(f3) + assert_spatial_data_objects_are_identical(labels, labels_read_v3) + assert labels_read_v3.is_self_contained() + + def test_points_v1_to_v2(self, points): + with tempfile.TemporaryDirectory() as tmpdir: + f1 = Path(tmpdir) / "data1.zarr" + f2 = Path(tmpdir) / "data2.zarr" + + points.write(f1, sdata_formats=[PointsFormatV01(), SpatialDataContainerFormatV01()]) + points_read_v1 = read_zarr(f1) + assert_spatial_data_objects_are_identical(points, points_read_v1) + + points_read_v1.write(f2, sdata_formats=[PointsFormatV02(), SpatialDataContainerFormatV02()]) + points_read_v2 = read_zarr(f2) + assert_spatial_data_objects_are_identical(points, points_read_v2) + + def test_tables_v1_to_v2(self, table_multiple_annotations): + with tempfile.TemporaryDirectory() as tmpdir: + f1 = Path(tmpdir) / "data1.zarr" + f2 = Path(tmpdir) / "data2.zarr" + + table_multiple_annotations.write(f1, sdata_formats=[TablesFormatV01(), SpatialDataContainerFormatV01()]) + table_read_v1 = read_zarr(f1) + assert_spatial_data_objects_are_identical(table_multiple_annotations, table_read_v1) + + table_read_v1.write(f2, sdata_formats=[TablesFormatV02(), SpatialDataContainerFormatV02()]) + table_read_v2 = read_zarr(f2) + assert_spatial_data_objects_are_identical(table_multiple_annotations, table_read_v2) + + def test_container_v1_to_v2(self, full_sdata): + with tempfile.TemporaryDirectory() as tmpdir: + f1 = Path(tmpdir) / "data1.zarr" + f2 = Path(tmpdir) / "data2.zarr" + + full_sdata.write(f1, sdata_formats=[SpatialDataContainerFormatV01()]) + sdata_read_v1 = read_zarr(f1) + assert_spatial_data_objects_are_identical(full_sdata, sdata_read_v1) + assert sdata_read_v1.is_self_contained() + + sdata_read_v1.write(f2, sdata_formats=[SpatialDataContainerFormatV02()]) + sdata_read_v2 = read_zarr(f2) + assert_spatial_data_objects_are_identical(full_sdata, sdata_read_v2) + assert sdata_read_v2.is_self_contained() diff --git a/tests/io/test_readwrite.py b/tests/io/test_readwrite.py index b20d8c66..fed63382 100644 --- a/tests/io/test_readwrite.py +++ b/tests/io/test_readwrite.py @@ -1,3 +1,4 @@ +import json import os import tempfile from collections.abc import Callable @@ -7,8 +8,10 @@ import dask.dataframe as dd import numpy as np import pytest +import zarr from anndata import AnnData from numpy.random import default_rng +from zarr.errors import GroupNotFoundError from spatialdata import SpatialData, deepcopy, read_zarr from spatialdata._core.validation import ValidationError @@ -787,3 +790,31 @@ def test_reading_invalid_name(tmp_path: Path): "For renaming, please see the discussion here https://github.com/scverse/spatialdata/discussions/707" in actual_message ) + + +def test_write_store_unconsolidated_and_read(full_sdata): + with tempfile.TemporaryDirectory() as tmpdir: + path = Path(tmpdir) / "data.zarr" + full_sdata.write(path, consolidate_metadata=False) + + group = zarr.open_group(path, mode="r") + assert group.metadata.consolidated_metadata is None + second_read = SpatialData.read(path) + assert_spatial_data_objects_are_identical(full_sdata, second_read) + + +def test_can_read_sdata_with_reconsolidation(full_sdata): + with tempfile.TemporaryDirectory() as tmpdir: + path = Path(tmpdir) / "data.zarr" + full_sdata.write(path) + + json_path = path / "zarr.json" + json_dict = json.loads(json_path.read_text()) + del json_dict["consolidated_metadata"]["metadata"]["images/image2d"] + json_path.write_text(json.dumps(json_dict, indent=4)) + + with pytest.raises(GroupNotFoundError): + SpatialData.read(path) + + new_sdata = SpatialData.read(path, reconsolidate_metadata=True) + assert_spatial_data_objects_are_identical(full_sdata, new_sdata) From f11d683ce0e68eb8f297aaaa3d2cb85408db7daa Mon Sep 17 00:00:00 2001 From: Wouter-Michiel Vierdag Date: Sat, 6 Sep 2025 14:48:20 +0200 Subject: [PATCH 048/126] correct test --- tests/io/test_format.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/tests/io/test_format.py b/tests/io/test_format.py index 4b3525e4..334cda16 100644 --- a/tests/io/test_format.py +++ b/tests/io/test_format.py @@ -97,13 +97,12 @@ def test_format_raster_v1_v2(self, images, rformat: type[SpatialDataFormatType]) zattrs_file = Path(tmpdir) / "images.zarr/images/image2d/.zattrs" with open(zattrs_file) as infile: zattrs = json.load(infile) + ngff_version = zattrs["multiscales"][0]["version"] if rformat == RasterFormatV01: - ngff_version = zattrs["multiscales"][0]["version"] assert ngff_version == "0.4" else: assert rformat == RasterFormatV02 - # TODO: check whether this required change is due to bug in ome-zarr - assert zattrs["version"] == "0.4-dev-spatialdata" + assert ngff_version == "0.4-dev-spatialdata" class TestFormatConversions: From 2b71939e680f7eb3f219c68a0bb7e48de794e1cf Mon Sep 17 00:00:00 2001 From: Wouter-Michiel Vierdag Date: Sat, 6 Sep 2025 15:59:37 +0200 Subject: [PATCH 049/126] further reduce warnings --- src/spatialdata/_core/operations/aggregate.py | 2 +- src/spatialdata/_core/operations/vectorize.py | 2 +- tests/core/operations/test_rasterize_bins.py | 2 +- tests/core/operations/test_spatialdata_operations.py | 2 +- tests/core/test_centroids.py | 8 ++++---- 5 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/spatialdata/_core/operations/aggregate.py b/src/spatialdata/_core/operations/aggregate.py index 809b9e52..f4e5f591 100644 --- a/src/spatialdata/_core/operations/aggregate.py +++ b/src/spatialdata/_core/operations/aggregate.py @@ -440,7 +440,7 @@ def _aggregate_shapes( vk = value_key[0] if fractions_of_values is not None: joined[ONES_COLUMN] = fractions_of_values - aggregated = joined.groupby([INDEX, vk])[ONES_COLUMN].agg(agg_func).reset_index() + aggregated = joined.groupby([INDEX, vk], observed=False)[ONES_COLUMN].agg(agg_func).reset_index() aggregated_values = aggregated[ONES_COLUMN].values else: if fractions_of_values is not None: diff --git a/src/spatialdata/_core/operations/vectorize.py b/src/spatialdata/_core/operations/vectorize.py index a4f904dc..40d3e31f 100644 --- a/src/spatialdata/_core/operations/vectorize.py +++ b/src/spatialdata/_core/operations/vectorize.py @@ -272,7 +272,7 @@ def _(gdf: GeoDataFrame, buffer_resolution: int = 16) -> GeoDataFrame: if isinstance(gdf.geometry.iloc[0], Point): buffered_df = gdf.copy() buffered_df["geometry"] = buffered_df.apply( - lambda row: row.geometry.buffer(row[ShapesModel.RADIUS_KEY], resolution=buffer_resolution), axis=1 + lambda row: row.geometry.buffer(row[ShapesModel.RADIUS_KEY], quad_segs=buffer_resolution), axis=1 ) # Ensure the GeoDataFrame recognizes the updated geometry column diff --git a/tests/core/operations/test_rasterize_bins.py b/tests/core/operations/test_rasterize_bins.py index d7e6172b..4488b546 100644 --- a/tests/core/operations/test_rasterize_bins.py +++ b/tests/core/operations/test_rasterize_bins.py @@ -165,7 +165,7 @@ def _get_sdata(n: int): # table annotating multiple elements regions = table.obs["region"].copy() regions = regions.cat.add_categories(["shapes"]) - regions[0] = "shapes" + regions.iloc[0] = "shapes" sdata["shapes"] = sdata["points"] table.obs["region"] = regions with pytest.raises( diff --git a/tests/core/operations/test_spatialdata_operations.py b/tests/core/operations/test_spatialdata_operations.py index 0a8b965b..fa8eff79 100644 --- a/tests/core/operations/test_spatialdata_operations.py +++ b/tests/core/operations/test_spatialdata_operations.py @@ -516,7 +516,7 @@ def test_no_shared_transformations() -> None: def test_init_from_elements(full_sdata: SpatialData) -> None: # this first code block needs to be removed when the tables argument is removed from init_from_elements() all_elements = {name: el for _, name, el in full_sdata._gen_elements()} - sdata = SpatialData.init_from_elements(all_elements, tables=full_sdata["table"]) + sdata = SpatialData.init_from_elements(all_elements | {"table": full_sdata["table"]}) for element_type in ["images", "labels", "points", "shapes", "tables"]: assert set(getattr(sdata, element_type).keys()) == set(getattr(full_sdata, element_type).keys()) diff --git a/tests/core/test_centroids.py b/tests/core/test_centroids.py index 5011872b..54fead60 100644 --- a/tests/core/test_centroids.py +++ b/tests/core/test_centroids.py @@ -52,7 +52,7 @@ def test_get_centroids_points(points, coordinate_system: str, is_3d: bool): assert centroids.columns.tolist() == list(axes) # check the index is preserved - assert np.array_equal(centroids.index.values, element.index.values) + assert np.array_equal(centroids.index.compute().values, element.index.values) # the centroids should not contain extra columns assert "genes" in element.columns and "genes" not in centroids.columns @@ -63,7 +63,7 @@ def test_get_centroids_points(points, coordinate_system: str, is_3d: bool): # let's check the values if coordinate_system == "global": - assert np.array_equal(centroids.compute().values, element[list(axes)].compute().values) + assert np.array_equal(centroids.compute().values, element[list(axes)].values) else: matrix = affine.to_affine_matrix(input_axes=axes, output_axes=axes) centroids_untransformed = element[list(axes)].compute().values @@ -82,7 +82,7 @@ def test_get_centroids_shapes(shapes, coordinate_system: str, shapes_name: str): set_transformation(element, transformation=affine, to_coordinate_system=coordinate_system) centroids = get_centroids(element, coordinate_system=coordinate_system) - assert np.array_equal(centroids.index.values, element.index.values) + assert np.array_equal(centroids.index.compute().values, element.index.values) if shapes_name == "circles": xy = element.geometry.get_coordinates().values @@ -154,7 +154,7 @@ def test_get_centroids_labels( centroids = get_centroids(element, coordinate_system=coordinate_system, return_background=return_background) labels_indices = get_element_instances(element, return_background=return_background) - assert np.array_equal(centroids.index.values, labels_indices) + assert np.array_equal(centroids.index.compute().values, labels_indices) if not return_background: assert not (centroids.index == 0).any() From 94971cba7f4c0a70e348c97f8f3ec99e519a9073 Mon Sep 17 00:00:00 2001 From: Wouter-Michiel Vierdag Date: Sat, 6 Sep 2025 16:11:21 +0200 Subject: [PATCH 050/126] remove log with no useful info --- src/spatialdata/models/models.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/spatialdata/models/models.py b/src/spatialdata/models/models.py index b9a43992..9aa5dee2 100644 --- a/src/spatialdata/models/models.py +++ b/src/spatialdata/models/models.py @@ -170,7 +170,6 @@ def parse( raise ValueError( f"`dims`: {dims} does not match `data.dims`: {data.dims}, please specify the dims only once." ) - logger.info("`dims` is specified redundantly: found also inside `data`.") else: dims = data.dims # but if dims don't match the model's dims, throw error From 5672e19c9c03982f6ff9c0b31b82d69b467293f8 Mon Sep 17 00:00:00 2001 From: Wouter-Michiel Vierdag Date: Sat, 6 Sep 2025 16:22:23 +0200 Subject: [PATCH 051/126] remove log as it is stated in doc string --- src/spatialdata/models/models.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/spatialdata/models/models.py b/src/spatialdata/models/models.py index 9aa5dee2..e98dd0df 100644 --- a/src/spatialdata/models/models.py +++ b/src/spatialdata/models/models.py @@ -199,7 +199,6 @@ def parse( data = data.transpose(*[_reindex(d) for d in cls.dims.dims]) else: raise ValueError(f"Unsupported data type: {type(data)}.") - logger.info(f"Transposing `data` of type: {type(data)} to {cls.dims.dims}.") except ValueError as e: raise ValueError( f"Cannot transpose arrays to match `dims`: {dims}.", From 6a670344bc63ec8e6fd8b842b61c8630131d594a Mon Sep 17 00:00:00 2001 From: Wouter-Michiel Vierdag Date: Sat, 6 Sep 2025 16:47:24 +0200 Subject: [PATCH 052/126] move log to info in docstring --- src/spatialdata/models/models.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/spatialdata/models/models.py b/src/spatialdata/models/models.py index e98dd0df..52ad27d8 100644 --- a/src/spatialdata/models/models.py +++ b/src/spatialdata/models/models.py @@ -156,6 +156,12 @@ def parse( You can also pass the `rgb` argument to `kwargs` to automatically set the `c_coords` to `["r", "g", "b"]`. Please refer to :func:`to_spatial_image` for more information. Note: if you set `rgb=None` in `kwargs`, 3-4 channel images will be interpreted automatically as RGB(A) images. + + **Setting axes / dims** + In case of the data being a numpy or dask array, there are no named axes yet. In this case, we first try to + use the dimensions specified by the user in the `dims` argument of `.parse`. These dimensions are potentially + transposed. See the description of the `dims` argument above. If `dims` is not specified, the dims are set + to (c)(z)yx, dependent on the number of dimensions of the data. """ if transformations: transformations = transformations.copy() @@ -182,7 +188,6 @@ def parse( data = from_array(data) if dims is None: dims = cls.dims.dims - logger.info(f"no axes information specified in the object, setting `dims` to: {dims}") else: if len(set(dims).symmetric_difference(cls.dims.dims)) > 0: raise ValueError(f"Wrong `dims`: {dims}. Expected {cls.dims.dims}.") From 68ce29fa49fc93195af5e0466c70ad60b1b743b8 Mon Sep 17 00:00:00 2001 From: Wouter-Michiel Vierdag Date: Sun, 7 Sep 2025 09:40:01 +0200 Subject: [PATCH 053/126] remove deprecated warning --- src/spatialdata/_core/spatialdata.py | 17 +++-------------- 1 file changed, 3 insertions(+), 14 deletions(-) diff --git a/src/spatialdata/_core/spatialdata.py b/src/spatialdata/_core/spatialdata.py index 40b399c8..44d00522 100644 --- a/src/spatialdata/_core/spatialdata.py +++ b/src/spatialdata/_core/spatialdata.py @@ -820,18 +820,7 @@ def transform_element_to_coordinate_system( set_transformation, ) - # TODO remove after deprecation - if not isinstance(element_name, str): - warnings.warn( - "Passing a SpatialElement is as element will be deprecated in SpatialData v0.3.0. Pass" - "element_name as string to silence this warning.", - DeprecationWarning, - stacklevel=2, - ) - element = element_name - else: - element = self.get(element_name) - + element = self.get(element_name) t = get_transformation_between_coordinate_systems(self, element, target_coordinate_system) if maintain_positioning: transformed = transform(element, transformation=t, maintain_positioning=maintain_positioning) @@ -901,10 +890,10 @@ def transform_to_coordinate_system( """ sdata = self.filter_by_coordinate_system(target_coordinate_system, filter_tables=False) elements: dict[str, dict[str, SpatialElement]] = {} - for element_type, element_name, element in sdata.gen_elements(): + for element_type, element_name, _ in sdata.gen_elements(): if element_type != "tables": transformed = sdata.transform_element_to_coordinate_system( - element, + element_name, target_coordinate_system, maintain_positioning=maintain_positioning, ) From 82f12b025d165b75794ceba7cdf98b7bdf4d2e35 Mon Sep 17 00:00:00 2001 From: Wouter-Michiel Vierdag Date: Sun, 7 Sep 2025 15:01:32 +0200 Subject: [PATCH 054/126] get rid of categorical and str casting warnings --- src/spatialdata/_core/operations/aggregate.py | 2 +- tests/conftest.py | 7 ++++++- .../operations/test_spatialdata_operations.py | 20 +++++++++++++------ tests/core/test_centroids.py | 6 +++--- 4 files changed, 24 insertions(+), 11 deletions(-) diff --git a/src/spatialdata/_core/operations/aggregate.py b/src/spatialdata/_core/operations/aggregate.py index f4e5f591..0a107483 100644 --- a/src/spatialdata/_core/operations/aggregate.py +++ b/src/spatialdata/_core/operations/aggregate.py @@ -232,7 +232,7 @@ def _create_sdata_from_table_and_shapes( f"Instance key column dtype in table resulting from aggregation cannot be cast to the dtype of" f"element {shapes_name}.index" ) from err - table.obs[region_key] = shapes_name + table.obs[region_key] = pd.Categorical([shapes_name] * len(table)) table = TableModel.parse(table, region=shapes_name, region_key=region_key, instance_key=instance_key) # labels case, needs conversion from str to int diff --git a/tests/conftest.py b/tests/conftest.py index df69dc0e..01e46e68 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -293,7 +293,10 @@ def _get_table( region_key: None | str = "region", instance_key: None | str = "instance_id", ) -> AnnData: - adata = AnnData(RNG.normal(size=(100, 10)), obs=pd.DataFrame(RNG.normal(size=(100, 3)), columns=["a", "b", "c"])) + adata = AnnData( + RNG.normal(size=(100, 10)), + obs=pd.DataFrame(RNG.normal(size=(100, 3)), columns=["a", "b", "c"], index=[f"{i}" for i in range(100)]), + ) if not all(var for var in (region, region_key, instance_key)): return TableModel.parse(adata=adata) adata.obs[instance_key] = np.arange(adata.n_obs) @@ -301,6 +304,7 @@ def _get_table( adata.obs[region_key] = region elif isinstance(region, list): adata.obs[region_key] = RNG.choice(region, size=adata.n_obs) + adata.obs[region_key] = adata.obs[region_key].astype("category") return TableModel.parse(adata=adata, region=region, region_key=region_key, instance_key=instance_key) @@ -442,6 +446,7 @@ def _make_sdata_for_testing_querying_and_aggretation() -> SpatialData: ), var=pd.DataFrame(index=["numerical_in_var"]), ) + table.obs["region"] = table.obs["region"].astype("category") table = TableModel.parse( table, region=["values_circles", "values_polygons"], region_key="region", instance_key="instance_id" ) diff --git a/tests/core/operations/test_spatialdata_operations.py b/tests/core/operations/test_spatialdata_operations.py index fa8eff79..7b3f4932 100644 --- a/tests/core/operations/test_spatialdata_operations.py +++ b/tests/core/operations/test_spatialdata_operations.py @@ -1,6 +1,7 @@ import math import numpy as np +import pandas as pd import pytest from anndata import AnnData from geopandas import GeoDataFrame @@ -170,7 +171,9 @@ def test_filter_by_coordinate_system_also_table(full_sdata: SpatialData) -> None from spatialdata.models import TableModel rng = np.random.default_rng(seed=0) - full_sdata["table"].obs["annotated_shapes"] = rng.choice(["circles", "poly"], size=full_sdata["table"].shape[0]) + full_sdata["table"].obs["annotated_shapes"] = pd.Categorical( + rng.choice(["circles", "poly"], size=full_sdata["table"].shape[0]) + ) adata = full_sdata["table"] del adata.uns[TableModel.ATTRS_KEY] del full_sdata.tables["table"] @@ -322,13 +325,13 @@ def test_concatenate_custom_table_metadata() -> None: shapes1 = _get_shapes() n = len(shapes0["poly"]) table0 = TableModel.parse( - AnnData(obs={"my_region": ["poly0"] * n, "my_instance_id": list(range(n))}), + AnnData(obs={"my_region": pd.Categorical(["poly0"] * n), "my_instance_id": list(range(n))}), region="poly0", region_key="my_region", instance_key="my_instance_id", ) table1 = TableModel.parse( - AnnData(obs={"my_region": ["poly1"] * n, "my_instance_id": list(range(n))}), + AnnData(obs={"my_region": pd.Categorical(["poly1"] * n), "my_instance_id": list(range(n))}), region="poly1", region_key="my_region", instance_key="my_instance_id", @@ -419,7 +422,12 @@ def _get_table_and_poly(i: int) -> tuple[AnnData, GeoDataFrame]: n = len(poly) region = f"poly{i}" table = TableModel.parse( - AnnData(obs={"region": [region] * n, "instance_id": list(range(n))}), + AnnData( + obs=pd.DataFrame( + {"region": pd.Categorical([region] * n), "instance_id": list(range(n))}, + index=[f"{i}" for i in range(n)], + ) + ), region=region, region_key="region", instance_key="instance_id", @@ -540,7 +548,7 @@ def test_subset(full_sdata: SpatialData) -> None: adata = AnnData( shape=(10, 0), obs={ - "region": ["circles"] * 5 + ["poly"] * 5, + "region": pd.Categorical(["circles"] * 5 + ["poly"] * 5), "instance_id": [0, 1, 2, 3, 4, 0, 1, 2, 3, 4], }, ) @@ -653,7 +661,7 @@ def test_validate_table_in_spatialdata(full_sdata): with pytest.warns(UserWarning, match="in the SpatialData object"): full_sdata.validate_table_in_spatialdata(table) - table.obs[region_key] = "points_0" + table.obs[region_key] = pd.Categorical(["points_0"] * table.n_obs) full_sdata.set_table_annotates_spatialelement("table", region="points_0") full_sdata.validate_table_in_spatialdata(table) diff --git a/tests/core/test_centroids.py b/tests/core/test_centroids.py index 54fead60..aa332f9d 100644 --- a/tests/core/test_centroids.py +++ b/tests/core/test_centroids.py @@ -52,7 +52,7 @@ def test_get_centroids_points(points, coordinate_system: str, is_3d: bool): assert centroids.columns.tolist() == list(axes) # check the index is preserved - assert np.array_equal(centroids.index.compute().values, element.index.values) + assert np.array_equal(centroids.index.compute().values, element.index.compute().values) # the centroids should not contain extra columns assert "genes" in element.columns and "genes" not in centroids.columns @@ -63,7 +63,7 @@ def test_get_centroids_points(points, coordinate_system: str, is_3d: bool): # let's check the values if coordinate_system == "global": - assert np.array_equal(centroids.compute().values, element[list(axes)].values) + assert np.array_equal(centroids.compute().values, element[list(axes)].compute().values) else: matrix = affine.to_affine_matrix(input_axes=axes, output_axes=axes) centroids_untransformed = element[list(axes)].compute().values @@ -178,7 +178,7 @@ def test_get_centroids_invalid_element(images): # cannot compute centroids for tables N = 10 adata = TableModel.parse( - AnnData(X=RNG.random((N, N)), obs={"region": ["dummy" for _ in range(N)], "instance_id": np.arange(N)}), + AnnData(X=RNG.random((N, N)), obs={"region": pd.Categorical(["dummy"] * N), "instance_id": np.arange(N)}), region="dummy", region_key="region", instance_key="instance_id", From bf607b37d97c3ac8dd436e06b79c4b166c38b935 Mon Sep 17 00:00:00 2001 From: Wouter-Michiel Vierdag Date: Sun, 7 Sep 2025 16:49:30 +0200 Subject: [PATCH 055/126] below 1000 warnings --- src/spatialdata/_core/operations/aggregate.py | 2 +- .../_core/operations/rasterize_bins.py | 2 +- src/spatialdata/datasets.py | 2 +- tests/conftest.py | 12 +++++------ tests/core/operations/test_rasterize.py | 19 ++++++++++-------- tests/core/operations/test_rasterize_bins.py | 20 ++++++++++++++++--- tests/core/query/test_relational_query.py | 10 ++++++---- tests/core/query/test_spatial_query.py | 2 +- 8 files changed, 44 insertions(+), 25 deletions(-) diff --git a/src/spatialdata/_core/operations/aggregate.py b/src/spatialdata/_core/operations/aggregate.py index 0a107483..131b8d7f 100644 --- a/src/spatialdata/_core/operations/aggregate.py +++ b/src/spatialdata/_core/operations/aggregate.py @@ -478,7 +478,7 @@ def _aggregate_shapes( anndata = ad.AnnData( X, - obs=pd.DataFrame(index=rows_categories), + obs=pd.DataFrame(index=list(map(str, rows_categories))), var=pd.DataFrame(index=columns_categories), ) diff --git a/src/spatialdata/_core/operations/rasterize_bins.py b/src/spatialdata/_core/operations/rasterize_bins.py index 37558e1b..cff846ac 100644 --- a/src/spatialdata/_core/operations/rasterize_bins.py +++ b/src/spatialdata/_core/operations/rasterize_bins.py @@ -281,7 +281,7 @@ def rasterize_bins_link_table_to_labels(sdata: SpatialData, table_name: str, ras The name of the rasterized labels in the spatial data object. """ _, region_key, instance_key = get_table_keys(sdata[table_name]) - sdata[table_name].obs[region_key] = rasterized_labels_name + sdata[table_name].obs[region_key] = pd.Categorical([rasterized_labels_name] * sdata[table_name].n_obs) relabled_instance_key = _get_relabeled_column_name(instance_key) sdata.set_table_annotates_spatialelement( table_name=table_name, region=rasterized_labels_name, region_key=region_key, instance_key=relabled_instance_key diff --git a/src/spatialdata/datasets.py b/src/spatialdata/datasets.py index 4b3d61f6..a1380752 100644 --- a/src/spatialdata/datasets.py +++ b/src/spatialdata/datasets.py @@ -365,7 +365,7 @@ def blobs_annotating_element(name: BlobsTypes) -> SpatialData: index = sdata[name].index instance_id = index.compute().tolist() if isinstance(index, dask.dataframe.core.Index) else index.tolist() n = len(instance_id) - new_table = AnnData(shape=(n, 0), obs={"region": [name for _ in range(n)], "instance_id": instance_id}) + new_table = AnnData(shape=(n, 0), obs={"region": pd.Categorical([name] * n), "instance_id": instance_id}) new_table = TableModel.parse(new_table, region=name, region_key="region", instance_key="instance_id") del sdata.tables["table"] sdata["table"] = new_table diff --git a/tests/conftest.py b/tests/conftest.py index 01e46e68..9aef4744 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -430,10 +430,10 @@ def _make_sdata_for_testing_querying_and_aggretation() -> SpatialData: # generate table x = RNG.random((21, 1)) - region = np.array(["values_circles"] * 9 + ["values_polygons"] * 12) + region = pd.Categorical(np.array(["values_circles"] * 9 + ["values_polygons"] * 12)) instance_id = np.array(list(range(9)) + list(range(12))) - categorical_obs = pd.Series(pd.Categorical(["a"] * 9 + ["b"] * 9 + ["c"] * 3)) - numerical_obs = pd.Series(RNG.random(21)) + categorical_obs = pd.Categorical(["a"] * 9 + ["b"] * 9 + ["c"] * 3) + numerical_obs = RNG.random(21) table = AnnData( x, obs=pd.DataFrame( @@ -442,11 +442,11 @@ def _make_sdata_for_testing_querying_and_aggretation() -> SpatialData: "instance_id": instance_id, "categorical_in_obs": categorical_obs, "numerical_in_obs": numerical_obs, - } + }, + index=list(map(str, range(21))), ), var=pd.DataFrame(index=["numerical_in_var"]), ) - table.obs["region"] = table.obs["region"].astype("category") table = TableModel.parse( table, region=["values_circles", "values_polygons"], region_key="region", instance_key="instance_id" ) @@ -492,7 +492,7 @@ def adata_labels() -> AnnData: "instance_id": range(n_obs_labels), "region": ["test"] * n_obs_labels, }, - index=np.arange(n_obs_labels), + index=np.arange(n_obs_labels).astype(str), ) uns_labels = { "spatialdata_attrs": {"region": "test", "region_key": "region", "instance_key": "instance_id"}, diff --git a/tests/core/operations/test_rasterize.py b/tests/core/operations/test_rasterize.py index 50be3f45..a109c852 100644 --- a/tests/core/operations/test_rasterize.py +++ b/tests/core/operations/test_rasterize.py @@ -123,10 +123,11 @@ def test_rasterize_labels_value_key_specified(): labels_indices = get_element_instances(raster) obs = pd.DataFrame( { - "region": [element_name] * len(labels_indices), + "region": pd.Categorical([element_name] * len(labels_indices)), "instance_id": labels_indices, value_key: [True] * 10 + [False] * (len(labels_indices) - 10), - } + }, + index=[f"{i}" for i in range(len(labels_indices))], ) table = TableModel.parse( AnnData(shape=(len(labels_indices), 0), obs=obs), @@ -194,11 +195,12 @@ def _rasterize_shapes_prepare_data() -> tuple[SpatialData, GeoDataFrame, str]: X=np.arange(len(gdf)).reshape(-1, 1), obs=pd.DataFrame( { - "region": [element_name] * len(gdf), + "region": pd.Categorical([element_name] * len(gdf)), "instance_id": gdf.index, - "values": gdf["values"], - "cat_values": gdf["cat_values"], - } + "values": gdf["values"].to_numpy(), + "cat_values": gdf["cat_values"].to_numpy(), + }, + index=[str(i) for i in range(len(gdf))], ), ) adata.obs["cat_values"] = adata.obs["cat_values"].astype("category") @@ -333,11 +335,12 @@ def test_rasterize_points(): X=np.arange(len(ddf)).reshape(-1, 1), obs=pd.DataFrame( { - "region": [element_name] * len(ddf), + "region": pd.Categorical([element_name] * len(ddf)), "instance_id": ddf.index, "gene": data["gene"], "value": data["value"], - } + }, + index=[f"{i}" for i in range(len(ddf))], ), ) adata.obs["gene"] = adata.obs["gene"].astype("category") diff --git a/tests/core/operations/test_rasterize_bins.py b/tests/core/operations/test_rasterize_bins.py index 4488b546..84346af3 100644 --- a/tests/core/operations/test_rasterize_bins.py +++ b/tests/core/operations/test_rasterize_bins.py @@ -2,6 +2,7 @@ import re import numpy as np +import pandas as pd import pytest from anndata import AnnData from geopandas import GeoDataFrame @@ -56,7 +57,13 @@ def test_rasterize_bins(geometry: str, value_key: str | list[str] | None, return points = ShapesModel.parse(gdf, transformations={"global": scale}) obs = DataFrame( - data={"region": ["points"] * n * n, "instance_id": np.arange(n * n), "col_index": x, "row_index": y} + data={ + "region": pd.Categorical(["points"] * n * n), + "instance_id": np.arange(n * n), + "col_index": x, + "row_index": y, + }, + index=[f"{i}" for i in range(n * n)], ) X = RNG.normal(size=(n * n, 2)) var = DataFrame(index=["gene0", "gene1"]) @@ -122,7 +129,13 @@ def _get_sdata(n: int): data, x, y = _get_bins_data(n) points = PointsModel.parse(data) obs = DataFrame( - data={"region": ["points"] * n * n, "instance_id": np.arange(n * n), "col_index": x, "row_index": y} + data={ + "region": pd.Categorical(["points"] * n * n), + "instance_id": np.arange(n * n), + "col_index": x, + "row_index": y, + }, + index=[f"{i}" for i in range(n * n)], ) table = TableModel.parse( AnnData(X=RNG.normal(size=(n * n, 2)), obs=obs), @@ -269,7 +282,8 @@ def test_relabel_labels(caplog): "instance_key1": np.arange(10), "instance_key2": [1, 2] + list(range(4, 12)), "instance_key3": [str(i) for i in range(1, 11)], - } + }, + index=[f"{i}" for i in range(10)], ) adata = AnnData(X=RNG.normal(size=(10, 2)), obs=obs) _relabel_labels(table=adata, instance_key="instance_key0") diff --git a/tests/core/query/test_relational_query.py b/tests/core/query/test_relational_query.py index f0b4da7e..e50d0109 100644 --- a/tests/core/query/test_relational_query.py +++ b/tests/core/query/test_relational_query.py @@ -816,7 +816,7 @@ def test_get_values_df_points(points): p = p.drop("instance_id", axis=1) p.index.compute() n = len(p) - obs = pd.DataFrame(index=p.index, data={"region": ["points_0"] * n, "instance_id": range(n)}) + obs = pd.DataFrame(index=p.index.astype(str), data={"region": ["points_0"] * n, "instance_id": range(n)}) obs["region"] = obs["region"].astype("category") table = TableModel.parse( AnnData(shape=(n, 0), obs=obs), @@ -891,8 +891,10 @@ def test_get_values_labels_bug(sdata_blobs): def test_filter_table_categorical_bug(shapes): # one bug that was triggered by: https://github.com/scverse/anndata/issues/1210 - adata = AnnData(obs={"categorical": pd.Categorical(["a", "a", "a", "b", "c"])}) - adata.obs["region"] = "circles" + adata = AnnData( + obs=pd.DataFrame({"categorical": pd.Categorical(["a", "a", "a", "b", "c"])}, index=list(map(str, range(5)))) + ) + adata.obs["region"] = pd.Categorical(["circles"] * adata.n_obs) adata.obs["cell_id"] = np.arange(len(adata)) adata = TableModel.parse(adata, region=["circles"], region_key="region", instance_key="cell_id") adata_subset = adata[adata.obs["categorical"] == "a"].copy() @@ -901,7 +903,7 @@ def test_filter_table_categorical_bug(shapes): def test_filter_table_non_annotating(full_sdata): - obs = pd.DataFrame({"test": ["a", "b", "c"]}) + obs = pd.DataFrame({"test": ["a", "b", "c"]}, index=list(map(str, range(3)))) adata = AnnData(obs=obs) table = TableModel.parse(adata) full_sdata["table"] = table diff --git a/tests/core/query/test_spatial_query.py b/tests/core/query/test_spatial_query.py index d4fcd6da..f5dccbd8 100644 --- a/tests/core/query/test_spatial_query.py +++ b/tests/core/query/test_spatial_query.py @@ -502,7 +502,7 @@ def test_query_filter_table(with_polygon_query: bool): circles0 = ShapesModel.parse(coords0, geometry=0, radius=1) circles1 = ShapesModel.parse(coords1, geometry=0, radius=1) table = AnnData(shape=(3, 0)) - table.obs["region"] = ["circles0", "circles0", "circles1"] + table.obs["region"] = pd.Categorical(["circles0", "circles0", "circles1"]) table.obs["instance"] = [0, 1, 0] table = TableModel.parse(table, region=["circles0", "circles1"], region_key="region", instance_key="instance") sdata = SpatialData(shapes={"circles0": circles0, "circles1": circles1}, tables={"table": table}) From 20469f841c023520518ad65256f31c92d210eda5 Mon Sep 17 00:00:00 2001 From: Wouter-Michiel Vierdag Date: Sun, 7 Sep 2025 21:52:36 +0200 Subject: [PATCH 056/126] further reduction of warnings --- pyproject.toml | 6 ++++++ src/spatialdata/_core/spatialdata.py | 4 ---- src/spatialdata/datasets.py | 2 +- tests/core/operations/test_map.py | 17 ++++++++--------- tests/core/operations/test_transform.py | 2 +- tests/core/query/test_spatial_query.py | 2 +- tests/io/test_multi_table.py | 6 +++--- tests/models/test_models.py | 13 +++++++------ tests/utils/test_element_utils.py | 2 -- tests/utils/test_sanitize.py | 23 +++++++++++++---------- 10 files changed, 40 insertions(+), 37 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 614fe6b2..274eeb18 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -100,6 +100,12 @@ addopts = [ "--import-mode=importlib", # allow using test files with same name "-s" # print output from tests ] +# These are all markers coming from xarray, dask or anndata. Added here to silence warnings. +markers = [ + "slow: marks tests as slow (deselect with '-m \"not slow\"')", + "gpu: run test on GPU using CuPY.", + "skip_with_pyarrow_strings: skipwhen pyarrow string conversion is turned on", +] # info on how to use this https://stackoverflow.com/questions/57925071/how-do-i-avoid-getting-deprecationwarning-from-inside-dependencies-with-pytest filterwarnings = [ # "ignore:.*U.*mode is deprecated:DeprecationWarning", diff --git a/src/spatialdata/_core/spatialdata.py b/src/spatialdata/_core/spatialdata.py index 44d00522..16f0db51 100644 --- a/src/spatialdata/_core/spatialdata.py +++ b/src/spatialdata/_core/spatialdata.py @@ -140,10 +140,6 @@ def __init__( self._tables: Tables = Tables(shared_keys=self._shared_keys) self.attrs = attrs if attrs else {} # type: ignore[assignment] - # Workaround to allow for backward compatibility - if isinstance(tables, AnnData): - tables = {"table": tables} - element_names = list(chain.from_iterable([e.keys() for e in [images, labels, points, shapes] if e is not None])) if len(element_names) != len(set(element_names)): diff --git a/src/spatialdata/datasets.py b/src/spatialdata/datasets.py index a1380752..cc828c91 100644 --- a/src/spatialdata/datasets.py +++ b/src/spatialdata/datasets.py @@ -160,7 +160,7 @@ def blobs( labels={"blobs_labels": labels, "blobs_multiscale_labels": multiscale_labels}, points={"blobs_points": points}, shapes={"blobs_circles": circles, "blobs_polygons": polygons, "blobs_multipolygons": multipolygons}, - tables=table, + tables={"table": table}, ) def _image_blobs( diff --git a/tests/core/operations/test_map.py b/tests/core/operations/test_map.py index b3fd3165..a032d381 100644 --- a/tests/core/operations/test_map.py +++ b/tests/core/operations/test_map.py @@ -109,7 +109,7 @@ def test_map_raster_output_chunks(sdata_blobs): func_kwargs = {"parameter": 20} output_channels = ["test"] se = map_raster( - sdata_blobs["blobs_image"].chunk((3, 100, 100)), + sdata_blobs["blobs_image"].chunk({"c": 3, "y": 100, "x": 100}), func=_multiply_alter_c, func_kwargs=func_kwargs, chunks=( @@ -160,9 +160,8 @@ def test_map_to_labels_(sdata_blobs, blockwise, chunks, drop_axis): func_kwargs = {"parameter": 20} se = sdata_blobs[img_layer] - se = map_raster( - se.chunk((3, 256, 256)), + se.chunk({"c": 3, "y": 256, "x": 256}), func=_multiply_to_labels, func_kwargs=func_kwargs, c_coords=None, @@ -183,7 +182,7 @@ def test_map_squeeze_z(full_sdata): func_kwargs = {"parameter": 20} se = map_raster( - full_sdata[img_layer].chunk((3, 2, 64, 64)), + full_sdata[img_layer].chunk({"c": 3, "z": 2, "y": 64, "x": 64}), func=_multiply_squeeze_z, func_kwargs=func_kwargs, chunks=((3,), (64,), (64,)), @@ -212,7 +211,7 @@ def test_map_squeeze_z_fails(full_sdata): ), ): map_raster( - full_sdata[img_layer].chunk((3, 2, 64, 64)), + full_sdata[img_layer].chunk({"c": 3, "z": 2, "y": 64, "x": 64}), func=_multiply_squeeze_z, func_kwargs=func_kwargs, chunks=((3,), (64,), (64,)), @@ -266,7 +265,7 @@ def test_map_raster_relabel(sdata_blobs): element_name = "blobs_labels" se = map_raster( - sdata_blobs[element_name].chunk((100, 100)), + sdata_blobs[element_name].chunk({"y": 100, "x": 100}), func=_to_constant, func_kwargs=func_kwargs, c_coords=None, @@ -301,7 +300,7 @@ def test_map_raster_relabel_fail(sdata_blobs): match=re.escape("Relabel was set to True, but"), ): se = map_raster( - sdata_blobs[element_name].chunk((100, 100)), + sdata_blobs[element_name].chunk({"y": 100, "x": 100}), func=_to_constant, func_kwargs=func_kwargs, c_coords=None, @@ -319,8 +318,8 @@ def test_map_raster_relabel_fail(sdata_blobs): ValueError, match=re.escape(f"Relabeling is only supported for arrays of type {np.integer}."), ): - se = map_raster( - sdata_blobs[element_name].astype(float).chunk((100, 100)), + map_raster( + sdata_blobs[element_name].astype(float).chunk({"y": 100, "x": 100}), func=_to_constant, func_kwargs=func_kwargs, c_coords=None, diff --git a/tests/core/operations/test_transform.py b/tests/core/operations/test_transform.py index 13f4bbc3..9c1c6823 100644 --- a/tests/core/operations/test_transform.py +++ b/tests/core/operations/test_transform.py @@ -587,7 +587,7 @@ def test_transform_elements_and_entire_spatial_data_object_multi_hop( labels=dict(full_sdata.labels), points=dict(full_sdata.points), shapes=dict(full_sdata.shapes), - tables=full_sdata["table"], + tables={"table": full_sdata["table"]}, ) temp["transformed_element"] = transformed_element transformation = get_transformation_between_coordinate_systems( diff --git a/tests/core/query/test_spatial_query.py b/tests/core/query/test_spatial_query.py index f5dccbd8..fc59d069 100644 --- a/tests/core/query/test_spatial_query.py +++ b/tests/core/query/test_spatial_query.py @@ -545,7 +545,7 @@ def test_polygon_query_with_multipolygon(sdata_query_aggregation): sdata = sdata_query_aggregation values_sdata = SpatialData( shapes={"values_polygons": sdata["values_polygons"], "values_circles": sdata["values_circles"]}, - tables=sdata["table"], + tables={"table": sdata["table"]}, ) polygon = sdata["by_polygons"].geometry.iloc[0] circle = sdata["by_circles"].geometry.iloc[0] diff --git a/tests/io/test_multi_table.py b/tests/io/test_multi_table.py index 7191f2f8..1b754370 100644 --- a/tests/io/test_multi_table.py +++ b/tests/io/test_multi_table.py @@ -74,7 +74,7 @@ def test_change_annotation_target(self, full_sdata, region_key, instance_key, er full_sdata.set_table_annotates_spatialelement("table", "poly") del full_sdata["table"].obs["instance_id"] - full_sdata["table"].obs["region"] = ["poly"] * n_obs + full_sdata["table"].obs["region"] = pd.Categorical(["poly"] * n_obs) with pytest.raises(ValueError, match=error_msg): full_sdata.set_table_annotates_spatialelement( "table", "poly", region_key=region_key, instance_key=instance_key @@ -114,14 +114,14 @@ def test_set_table_annotates_spatialelement(self, full_sdata, tmp_path): "table", "labels2d", region_key="region", instance_key="instance_id" ) - region = ["circles"] * 50 + ["poly"] * 50 + region = pd.Categorical(["circles"] * 50 + ["poly"] * 50) full_sdata["table"].obs["region"] = region full_sdata.set_table_annotates_spatialelement( "table", pd.Series(["circles", "poly"]), region_key="region", instance_key="instance_id" ) - full_sdata["table"].obs["region"] = "circles" + full_sdata["table"].obs["region"] = pd.Categorical(["circles"] * full_sdata["table"].n_obs) full_sdata.set_table_annotates_spatialelement( "table", "circles", region_key="region", instance_key="instance_id" ) diff --git a/tests/models/test_models.py b/tests/models/test_models.py index fddb08ac..e70cc997 100644 --- a/tests/models/test_models.py +++ b/tests/models/test_models.py @@ -71,7 +71,6 @@ def test_validate_axis_name(): validate_axis_name("invalid") -@pytest.mark.ci_only class TestModels: def _parse_transformation_from_multiple_places(self, model: Any, element: Any, **kwargs) -> None: # This function seems convoluted but the idea is simple: sometimes the parser creates a whole new object, @@ -338,7 +337,7 @@ def test_points_model( assert "cell_id" in points.attrs["spatialdata_attrs"]["instance_key"] @pytest.mark.parametrize("model", [TableModel]) - @pytest.mark.parametrize("region", ["sample", RNG.choice([1, 2], size=10).tolist()]) + @pytest.mark.parametrize("region", [["sample"] * 10, RNG.choice([1, 2], size=10).tolist()]) def test_table_model( self, model: TableModel, @@ -348,8 +347,9 @@ def test_table_model( obs = pd.DataFrame( RNG.choice(np.arange(0, 100, dtype=float), size=(10, 3), replace=False), columns=["A", "B", "C"], + index=list(map(str, range(10))), ) - obs[region_key] = region + obs[region_key] = pd.Categorical(region) adata = AnnData(RNG.normal(size=(10, 2)), obs=obs) with pytest.raises(TypeError, match="Only int"): model.parse(adata, region=region, region_key=region_key, instance_key="A") @@ -357,8 +357,9 @@ def test_table_model( obs = pd.DataFrame( RNG.choice(np.arange(0, 100), size=(10, 3), replace=False), columns=["A", "B", "C"], + index=list(map(str, range(10))), ) - obs[region_key] = region + obs[region_key] = pd.Categorical(region) adata = AnnData(RNG.normal(size=(10, 2)), obs=obs) table = model.parse(adata, region=region, region_key=region_key, instance_key="A") assert region_key in table.obs @@ -399,7 +400,7 @@ def test_table_model( assert instance_key_ == "A" # let's fix the region_key column - table.obs["B"] = ["element"] * len(table) + table.obs["B"] = pd.Categorical(["element"] * len(table)) _ = TableModel.parse(adata, region="element", region_key="B", instance_key="C", overwrite_metadata=True) region_, region_key_, instance_key_ = get_table_keys(table) @@ -445,7 +446,7 @@ def test_model_not_unique_names(self, full_sdata, element_type: str, names: list @pytest.mark.parametrize("region", [["sample_1"] * 5 + ["sample_2"] * 5]) def test_table_instance_key_values_not_unique(self, model: TableModel, region: str | np.ndarray): region_key = "region" - obs = pd.DataFrame(RNG.integers(0, 100, size=(10, 3)), columns=["A", "B", "C"]) + obs = pd.DataFrame(RNG.integers(0, 100, size=(10, 3)), columns=["A", "B", "C"], index=list(map(str, range(10)))) obs[region_key] = region obs["A"] = [1] * 5 + list(range(5)) adata = AnnData(RNG.normal(size=(10, 2)), obs=obs) diff --git a/tests/utils/test_element_utils.py b/tests/utils/test_element_utils.py index 86e75887..1bfd20aa 100644 --- a/tests/utils/test_element_utils.py +++ b/tests/utils/test_element_utils.py @@ -1,7 +1,6 @@ import itertools import dask_image.ndinterp -import pytest import xarray from xarray import DataArray, DataTree @@ -27,7 +26,6 @@ def _pad_raster(data: DataArray, axes: tuple[str, ...]) -> DataArray: return dask_image.ndinterp.affine_transform(data, matrix, output_shape=new_shape) -@pytest.mark.ci_only def test_unpad_raster(images, labels) -> None: for raster in itertools.chain(images.images.values(), labels.labels.values()): schema = get_model(raster) diff --git a/tests/utils/test_sanitize.py b/tests/utils/test_sanitize.py index 082e08b1..b61f1908 100644 --- a/tests/utils/test_sanitize.py +++ b/tests/utils/test_sanitize.py @@ -16,7 +16,8 @@ def invalid_table() -> AnnData: "@invalid#": [1, 2], "valid_name": [3, 4], "__private": [5, 6], - } + }, + index=["0", "1"], ) ) @@ -29,7 +30,8 @@ def invalid_table_with_index() -> AnnData: { "invalid name": [1, 2], "_index": [3, 4], - } + }, + index=["0", "1"], ) ) @@ -103,7 +105,8 @@ def test_sanitize_table_case_insensitive_collisions(): "Column1": [1, 2], "column1": [3, 4], "COLUMN1": [5, 6], - } + }, + index=["0", "1"], ) ad = AnnData(obs=obs) sanitized = sanitize_table(ad, inplace=False) @@ -113,7 +116,7 @@ def test_sanitize_table_case_insensitive_collisions(): def test_sanitize_table_whitespace_collision(): """Ensure 'a b' → 'a_b' doesn't collide silently with existing 'a_b'.""" - obs = pd.DataFrame({"a b": [1], "a_b": [2]}) + obs = pd.DataFrame({"a b": [1], "a_b": [2]}, index=["0"]) ad = AnnData(obs=obs) sanitized = sanitize_table(ad, inplace=False) cols = list(sanitized.obs.columns) @@ -127,13 +130,13 @@ def test_sanitize_table_whitespace_collision(): def test_sanitize_table_obs_and_obs_columns(): - ad = AnnData(obs=pd.DataFrame({"@col": [1, 2]})) + ad = AnnData(obs=pd.DataFrame({"@col": [1, 2]}, index=["0", "1"])) sanitized = sanitize_table(ad, inplace=False) assert list(sanitized.obs.columns) == ["_col"] def test_sanitize_table_obsm_and_obsp(): - ad = AnnData(obs=pd.DataFrame({"@col": [1, 2]})) + ad = AnnData(obs=pd.DataFrame({"@col": [1, 2]}, index=["0", "1"])) ad.obsm["@col"] = np.array([[1, 2], [3, 4]]) ad.obsp["bad name"] = np.array([[1, 2], [3, 4]]) sanitized = sanitize_table(ad, inplace=False) @@ -142,7 +145,7 @@ def test_sanitize_table_obsm_and_obsp(): def test_sanitize_table_varm_and_varp(): - ad = AnnData(obs=pd.DataFrame({"x": [1, 2]}), var=pd.DataFrame(index=["v1", "v2"])) + ad = AnnData(obs=pd.DataFrame({"x": [1, 2]}, index=["0", "1"]), var=pd.DataFrame(index=["v1", "v2"])) ad.varm["__priv"] = np.array([[1, 2], [3, 4]]) ad.varp["_index"] = np.array([[1, 2], [3, 4]]) sanitized = sanitize_table(ad, inplace=False) @@ -151,7 +154,7 @@ def test_sanitize_table_varm_and_varp(): def test_sanitize_table_uns_and_layers(): - ad = AnnData(obs=pd.DataFrame({"x": [1, 2]}), var=pd.DataFrame(index=["v1", "v2"])) + ad = AnnData(obs=pd.DataFrame({"x": [1, 2]}, index=["0", "1"]), var=pd.DataFrame(index=["v1", "v2"])) ad.uns["bad@key"] = "val" ad.layers["bad#layer"] = np.array([[0, 1], [1, 0]]) sanitized = sanitize_table(ad, inplace=False) @@ -168,7 +171,7 @@ def test_sanitize_table_empty_returns_empty(): def test_sanitize_table_preserves_underlying_data(): - ad = AnnData(obs=pd.DataFrame({"@invalid#": [1, 2], "valid": [3, 4]})) + ad = AnnData(obs=pd.DataFrame({"@invalid#": [1, 2], "valid": [3, 4]}, index=["0", "1"])) ad.obsm["@invalid#"] = np.array([[1, 2], [3, 4]]) ad.uns["invalid@key"] = "value" sanitized = sanitize_table(ad, inplace=False) @@ -198,7 +201,7 @@ def test_sanitize_table_in_spatialdata_sanitized_fixture(invalid_table, invalid_ def test_spatialdata_retains_other_elements(full_sdata): # Add another sanitized table into an existing full_sdata - tbl = AnnData(obs=pd.DataFrame({"@foo#": [1, 2], "bar": [3, 4]})) + tbl = AnnData(obs=pd.DataFrame({"@foo#": [1, 2], "bar": [3, 4]}, index=["0", "1"])) sanitize_table(tbl) full_sdata.tables["new_table"] = tbl From 429db460c1b0e2d5f0a71976b036bbd72b982e47 Mon Sep 17 00:00:00 2001 From: Wouter-Michiel Vierdag Date: Sun, 7 Sep 2025 22:35:07 +0200 Subject: [PATCH 057/126] remove deprecated code --- src/spatialdata/__init__.py | 3 +- .../_core/query/relational_query.py | 8 ----- src/spatialdata/_core/query/spatial_query.py | 12 ------- src/spatialdata/_core/spatialdata.py | 34 ------------------- src/spatialdata/_io/_utils.py | 17 ---------- 5 files changed, 1 insertion(+), 73 deletions(-) diff --git a/src/spatialdata/__init__.py b/src/spatialdata/__init__.py index d8f48fcc..5d84e172 100644 --- a/src/spatialdata/__init__.py +++ b/src/spatialdata/__init__.py @@ -46,7 +46,6 @@ "read_zarr", "unpad_raster", "get_pyramid_levels", - "save_transformations", "get_dask_backing_files", "are_extents_equal", "relabel_sequential", @@ -80,7 +79,7 @@ ) from spatialdata._core.query.spatial_query import bounding_box_query, polygon_query from spatialdata._core.spatialdata import SpatialData -from spatialdata._io._utils import get_dask_backing_files, save_transformations +from spatialdata._io._utils import get_dask_backing_files from spatialdata._io.format import SpatialDataFormatType from spatialdata._io.io_zarr import read_zarr from spatialdata._utils import get_pyramid_levels, unpad_raster diff --git a/src/spatialdata/_core/query/relational_query.py b/src/spatialdata/_core/query/relational_query.py index b0f9978d..247bffeb 100644 --- a/src/spatialdata/_core/query/relational_query.py +++ b/src/spatialdata/_core/query/relational_query.py @@ -738,14 +738,6 @@ def match_table_to_element(sdata: SpatialData, element_name: str, table_name: st match_element_to_table : Function to match a spatial element to a table. join_spatialelement_table : General function, to join spatial elements with a table with more control. """ - if table_name is None: - warnings.warn( - "Assumption of table with name `table` being present is being deprecated in SpatialData v0.1. " - "Please provide the name of the table as argument to table_name.", - DeprecationWarning, - stacklevel=2, - ) - table_name = "table" _, table = join_spatialelement_table( sdata=sdata, spatial_element_names=element_name, table_name=table_name, how="left", match_rows="left" ) diff --git a/src/spatialdata/_core/query/spatial_query.py b/src/spatialdata/_core/query/spatial_query.py index 72e9377b..e6dccb45 100644 --- a/src/spatialdata/_core/query/spatial_query.py +++ b/src/spatialdata/_core/query/spatial_query.py @@ -828,18 +828,6 @@ def polygon_query( Importantly, when clipping is enabled, the circles will be converted to polygons before the clipping. This may affect downstream operations that rely on the circle radius or on performance, so it is recommended to disable clipping when querying circles or when querying a `SpatialData` object that contains circles. - shapes [Deprecated] - This argument is now ignored and will be removed. Please filter the SpatialData object before calling this - function. - points [Deprecated] - This argument is now ignored and will be removed. Please filter the SpatialData object before calling this - function. - images [Deprecated] - This argument is now ignored and will be removed. Please filter the SpatialData object before calling this - function. - labels [Deprecated] - This argument is now ignored and will be removed. Please filter the SpatialData object before calling this - function. Returns ------- diff --git a/src/spatialdata/_core/spatialdata.py b/src/spatialdata/_core/spatialdata.py index 16f0db51..d47be901 100644 --- a/src/spatialdata/_core/spatialdata.py +++ b/src/spatialdata/_core/spatialdata.py @@ -18,7 +18,6 @@ from geopandas import GeoDataFrame from ome_zarr.format import FormatV05 from ome_zarr.io import parse_url -from ome_zarr.types import JSONDict from shapely import MultiPolygon, Polygon from xarray import DataArray, DataTree @@ -32,9 +31,6 @@ ) from spatialdata._logging import logger from spatialdata._types import ArrayLike, Raster_T -from spatialdata._utils import ( - _error_message_add_element, -) from spatialdata.models import ( Image2DModel, Image3DModel, @@ -1805,16 +1801,6 @@ def read( return read_zarr(file_path, selection=selection) - def add_image( - self, - name: str, - image: DataArray | DataTree, - storage_options: JSONDict | list[JSONDict] | None = None, - overwrite: bool = False, - ) -> None: - """Deprecated. Use `sdata[name] = image` instead.""" # noqa: D401 - _error_message_add_element() - @property def images(self) -> Images: """Return images as a Dict of name to image data.""" @@ -2210,26 +2196,6 @@ def init_from_elements( assert model == ShapesModel element_type = "shapes" elements_dict.setdefault(element_type, {})[name] = element - # when the "tables" argument is removed, we can remove all this if block - if tables is not None: - warnings.warn( - 'The "tables" argument is deprecated and will be removed in a future version. Please ' - "specifies the tables in the `elements` argument. Until the removal occurs, the `elements` " - "variable will be automatically populated with the tables if the `tables` argument is not None.", - DeprecationWarning, - stacklevel=2, - ) - if "tables" in elements_dict: - raise ValueError( - "The tables key is already present in the elements dictionary. Please do not specify " - "the `tables` argument." - ) - elements_dict["tables"] = {} - if isinstance(tables, AnnData): - elements_dict["tables"]["table"] = tables - else: - for name, table in tables.items(): - elements_dict["tables"][name] = table return cls(**elements_dict, attrs=attrs) def subset( diff --git a/src/spatialdata/_io/_utils.py b/src/spatialdata/_io/_utils.py index b3f4dd3a..d1e3ffbd 100644 --- a/src/spatialdata/_io/_utils.py +++ b/src/spatialdata/_io/_utils.py @@ -439,23 +439,6 @@ def _is_element_self_contained( return all(_backed_elements_contained_in_path(path=element_path, object=element)) -def save_transformations(sdata: SpatialData) -> None: - """ - Save all the transformations of a SpatialData object to disk. - - sdata - The SpatialData object - """ - warnings.warn( - "This function is deprecated and should be replaced by `SpatialData.write_transformations()` or " - "`SpatialData.write_metadata()`, which gives more control over which metadata to write. This function will call" - " `SpatialData.write_transformations()`; please call this function directly.", - DeprecationWarning, - stacklevel=2, - ) - sdata.write_transformations() - - def _open_zarr_store(path: StoreLike, **kwargs: Any) -> zarr.storage.StoreLike: # TODO: ensure kwargs like mode are enforced everywhere and passed correctly to the store if isinstance(path, str | Path): From 2a5ec5b3c4271fd6298573db1d536fe1bf42be01 Mon Sep 17 00:00:00 2001 From: Wouter-Michiel Vierdag Date: Sun, 7 Sep 2025 22:58:34 +0200 Subject: [PATCH 058/126] correct location for storing transforms --- src/spatialdata/_io/_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spatialdata/_io/_utils.py b/src/spatialdata/_io/_utils.py index d1e3ffbd..c96b76b5 100644 --- a/src/spatialdata/_io/_utils.py +++ b/src/spatialdata/_io/_utils.py @@ -151,7 +151,7 @@ def _write_coordinate_transformations_raster_zarrv3( # zarr v3 ome-zarr requires the coordinate transformations to be written this way, leaving one out won't work. multiscale["coordinateTransformations"] = coordinate_transformations - group.attrs["coordinateTransformations"] = coordinate_transformations + group.attrs["multiscales"] = multiscales def _overwrite_coordinate_transformations_raster_zarrv2( From 016da2e644f7c2ff12c70ea9914f846919bd88c9 Mon Sep 17 00:00:00 2001 From: Wouter-Michiel Vierdag Date: Mon, 8 Sep 2025 08:29:36 +0200 Subject: [PATCH 059/126] consistent naming --- src/spatialdata/_io/_utils.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/spatialdata/_io/_utils.py b/src/spatialdata/_io/_utils.py index c96b76b5..121f5ffa 100644 --- a/src/spatialdata/_io/_utils.py +++ b/src/spatialdata/_io/_utils.py @@ -125,13 +125,12 @@ def overwrite_coordinate_transformations_raster( coordinate_transformations = [t.to_dict() for t in ngff_transformations] # replace the metadata storage if group.metadata.zarr_format == 3: - _write_coordinate_transformations_raster_zarrv3(group, coordinate_transformations) + _overwrite_coordinate_transformations_raster_zarrv3(group, coordinate_transformations) elif group.metadata.zarr_format == 2: _overwrite_coordinate_transformations_raster_zarrv2(group, coordinate_transformations) -# TODO: check type coordinate_transformations here -def _write_coordinate_transformations_raster_zarrv3( +def _overwrite_coordinate_transformations_raster_zarrv3( group: zarr.Group, coordinate_transformations: list[dict[str, BaseTransformation]] ) -> None: """Write transformations of raster elements to disk in zarr v3. From 3e5490eb1e72d6dd90b86ac21a2a911c6e42b33e Mon Sep 17 00:00:00 2001 From: Wouter-Michiel Vierdag Date: Mon, 8 Sep 2025 08:43:25 +0200 Subject: [PATCH 060/126] update according to ome-zarr-py --- src/spatialdata/_io/io_raster.py | 45 +++++++++++++++++++------------- 1 file changed, 27 insertions(+), 18 deletions(-) diff --git a/src/spatialdata/_io/io_raster.py b/src/spatialdata/_io/io_raster.py index 5a1c69f4..ff3c38b5 100644 --- a/src/spatialdata/_io/io_raster.py +++ b/src/spatialdata/_io/io_raster.py @@ -6,7 +6,7 @@ import zarr from ome_zarr.format import Format from ome_zarr.io import ZarrLocation -from ome_zarr.reader import Label, Multiscales, Node, Reader +from ome_zarr.reader import Multiscales, Node, Reader from ome_zarr.types import JSONDict from ome_zarr.writer import _get_valid_axes from ome_zarr.writer import write_image as write_image_ngff @@ -33,7 +33,26 @@ ) -def _get_nodes_zarr_v3(image_nodes: list[Node], nodes: list[Node]) -> list[Node]: +def _get_multiscale_nodes(image_nodes: list[Node], nodes: list[Node]) -> list[Node]: + """Get nodes with Multiscales spec from a list of nodes. + + The nodes with the Multiscales spec are the nodes used for reading in image and label data. We only have to check + the multiscales now, while before we also had to check the label spec. In the new ome-zarr-py though labels can have + the Label spec, these do not contain the multiscales anymore used to read the data. They can contain label specific + metadata though. + + Parameters + ---------- + image_nodes: list[Node] + List of nodes returned from the ome-zarr-py Reader. + nodes: list[Node] + List to append the nodes with the multiscales spec to. + + Returns + ------- + list[Node] + List of nodes with the multiscales spec. + """ if len(image_nodes): for node in image_nodes: # Labels are now also Multiscales in newer version of ome-zarr-py @@ -42,13 +61,6 @@ def _get_nodes_zarr_v3(image_nodes: list[Node], nodes: list[Node]) -> list[Node] return nodes -def _get_label_nodes_zarr_v2(image_nodes: list[Node], nodes: list[Node]) -> list[Node]: - for node in image_nodes: - if np.any([isinstance(spec, Label) for spec in node.specs]): - nodes.append(node) - return nodes - - def _read_multiscale(store: str | Path, raster_type: Literal["image", "labels"]) -> DataArray | DataTree: assert isinstance(store, str | Path) assert raster_type in ["image", "labels"] @@ -58,25 +70,22 @@ def _read_multiscale(store: str | Path, raster_type: Literal["image", "labels"]) if exists := image_loc.exists(): image_reader = Reader(image_loc)() image_nodes = list(image_reader) - nodes = _get_nodes_zarr_v3(image_nodes, nodes) + nodes = _get_multiscale_nodes(image_nodes, nodes) else: raise OSError( f"Image location {image_loc} does not seem to exist. If it does, potentially the zarr.json file " f"inside is corrupted or not present or the image files themselves are corrupted." ) if len(nodes) != 1: - if exists: - nodes = _get_label_nodes_zarr_v2(image_nodes, nodes) - else: + if not exists: raise ValueError( f"len(nodes) = {len(nodes)}, expected 1 and image location {image_loc} does not exist. Unable to read " f"the NGFF file. Please report this bug and attach a minimal data example." ) - if len(nodes) != 1: - raise OSError( - f"Image location {image_loc} exists, but len(nodes) = {len(nodes)}, expected 1. Element " - f"{image_loc.basename()} is potentially corrupted." - ) + raise OSError( + f"Image location {image_loc} exists, but len(nodes) = {len(nodes)}, expected 1. Element " + f"{image_loc.basename()} is potentially corrupted." + ) node = nodes[0] datasets = node.load(Multiscales).datasets From be1ae2faf095c5ed160dcc552822a586deb9dc3b Mon Sep 17 00:00:00 2001 From: Wouter-Michiel Vierdag Date: Mon, 8 Sep 2025 08:46:48 +0200 Subject: [PATCH 061/126] correct docstring --- src/spatialdata/_io/io_shapes.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spatialdata/_io/io_shapes.py b/src/spatialdata/_io/io_shapes.py index 50527e77..066eb61c 100644 --- a/src/spatialdata/_io/io_shapes.py +++ b/src/spatialdata/_io/io_shapes.py @@ -140,7 +140,7 @@ def _write_shapes_v01(shapes: GeoDataFrame, group: zarr.Group, element_format: F def _write_shapes_v02_v03(shapes: GeoDataFrame, group: zarr.Group, element_format: Format) -> Any: - """Write shapes to spatialdata zarr store using format ShapesFormatV02 or ShapesFormatV02. + """Write shapes to spatialdata zarr store using format ShapesFormatV02 or ShapesFormatV03. Parameters ---------- From 8873fa06118551078d55a1e80acb5565e07dbd2f Mon Sep 17 00:00:00 2001 From: Wouter-Michiel Vierdag Date: Mon, 8 Sep 2025 08:51:40 +0200 Subject: [PATCH 062/126] update docstring --- src/spatialdata/_io/io_zarr.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/src/spatialdata/_io/io_zarr.py b/src/spatialdata/_io/io_zarr.py index d548c2f0..56bec050 100644 --- a/src/spatialdata/_io/io_zarr.py +++ b/src/spatialdata/_io/io_zarr.py @@ -226,7 +226,12 @@ def _get_groups_for_element( """ Get the Zarr groups for the root, element_type and element for a specific element. - The store must exist, but creates the element type group and the element group if they don't exist. + The store must exist, but creates the element type group and the element group if they don't exist. In all cases + the zarr group will also be opened. When writing data to disk this should always be done with 'use_consolidated' + being 'False'. If a user wrote the data previously with consolidation of the metadata and then they write new data + in the zarr store, it can give errors otherwise, due to partially written elements not yet being present in the + consolidated metadata store, e.g. when first writing the element and then opening the zarr group again for writing + transformations. Parameters ---------- @@ -243,7 +248,7 @@ def _get_groups_for_element( Returns ------- - either the existing Zarr subgroup or a new one. + The Zarr groups for the root, element_type and element for a specific element. """ if not isinstance(zarr_path, Path): raise ValueError("zarr_path should be a Path object") From b1ade7ecb616b810497a23c5ef4c023f9298a3f8 Mon Sep 17 00:00:00 2001 From: Wouter-Michiel Vierdag Date: Mon, 8 Sep 2025 09:57:14 +0200 Subject: [PATCH 063/126] remove todo --- src/spatialdata/_io/io_zarr.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/spatialdata/_io/io_zarr.py b/src/spatialdata/_io/io_zarr.py index 56bec050..607171cc 100644 --- a/src/spatialdata/_io/io_zarr.py +++ b/src/spatialdata/_io/io_zarr.py @@ -82,7 +82,6 @@ def read_zarr( tables: dict[str, AnnData] = {} shapes = {} - # TODO: remove table once deprecated. selector = {"images", "labels", "points", "shapes", "tables"} if not selection else set(selection or []) logger.debug(f"Reading selection {selector}") From 421b56f4901f580bb0cdc76d9964773fa87e6445 Mon Sep 17 00:00:00 2001 From: Wouter-Michiel Vierdag Date: Mon, 8 Sep 2025 10:41:25 +0200 Subject: [PATCH 064/126] remove unassigned function call --- src/spatialdata/_core/spatialdata.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/spatialdata/_core/spatialdata.py b/src/spatialdata/_core/spatialdata.py index d47be901..9de56e80 100644 --- a/src/spatialdata/_core/spatialdata.py +++ b/src/spatialdata/_core/spatialdata.py @@ -1043,7 +1043,7 @@ def _validate_can_safely_write_to_path( overwrite: bool = False, saving_an_element: bool = False, ) -> None: - from spatialdata._io._utils import _backed_elements_contained_in_path, _is_subfolder, _open_zarr_store + from spatialdata._io._utils import _backed_elements_contained_in_path, _is_subfolder if isinstance(file_path, str): file_path = Path(file_path) @@ -1052,7 +1052,6 @@ def _validate_can_safely_write_to_path( raise ValueError(f"file_path must be a string or a Path object, type(file_path) = {type(file_path)}.") if os.path.exists(file_path): - _open_zarr_store(file_path, mode="r") if parse_url(file_path, mode="r", fmt=FormatV05()) is None: raise ValueError( "The target file path specified already exists, and it has been detected to not be a Zarr store. " From 0bb8bc97c30d519d8846c440b668243439c69ddd Mon Sep 17 00:00:00 2001 From: Wouter-Michiel Vierdag Date: Mon, 8 Sep 2025 13:37:37 +0200 Subject: [PATCH 065/126] refactor new _open_zarr_store to _resolve_zarr_store --- src/spatialdata/_core/spatialdata.py | 4 +- src/spatialdata/_io/_utils.py | 32 ++++++++++- src/spatialdata/_io/io_zarr.py | 79 +++++++++++++--------------- 3 files changed, 69 insertions(+), 46 deletions(-) diff --git a/src/spatialdata/_core/spatialdata.py b/src/spatialdata/_core/spatialdata.py index 9de56e80..485b5cde 100644 --- a/src/spatialdata/_core/spatialdata.py +++ b/src/spatialdata/_core/spatialdata.py @@ -989,12 +989,12 @@ def elements_paths_on_disk(self) -> list[str]: ------- A list of paths of the elements saved in the Zarr store. """ - from spatialdata._io._utils import _open_zarr_store + from spatialdata._io._utils import _resolve_zarr_store if self.path is None: raise ValueError("The SpatialData object is not backed by a Zarr store.") - store = _open_zarr_store(self.path, mode="r") + store = _resolve_zarr_store(self.path) root = zarr.open_group(store=store, mode="r") elements_in_zarr = [] diff --git a/src/spatialdata/_io/_utils.py b/src/spatialdata/_io/_utils.py index 121f5ffa..1efb70a8 100644 --- a/src/spatialdata/_io/_utils.py +++ b/src/spatialdata/_io/_utils.py @@ -438,7 +438,37 @@ def _is_element_self_contained( return all(_backed_elements_contained_in_path(path=element_path, object=element)) -def _open_zarr_store(path: StoreLike, **kwargs: Any) -> zarr.storage.StoreLike: +def _resolve_zarr_store(path: StoreLike, **kwargs: Any) -> zarr.storage.StoreLike: + """ + Normalize different Zarr store inputs into a usable store instance. + + This function accepts various forms of input (e.g. filesystem paths, + UPath objects, existing Zarr stores, or `zarr.Group`s) and resolves + them into a `StoreLike` that can be passed to Zarr APIs. It handles + local files, fsspec-backed stores, consolidated metadata stores, and + groups with nested paths. + + Parameters + ---------- + path : StoreLike | str | Path | UPath | zarr.Group + The input representing a Zarr store or group. Can be a filesystem + path, remote path, existing store, or Zarr group. + **kwargs : Any + Additional keyword arguments forwarded to the underlying store + constructor (e.g. `mode`, `storage_options`). + + Returns + ------- + zarr.storage.StoreLike + A normalized store instance suitable for use with Zarr. + + Raises + ------ + TypeError + If the input type is unsupported. + ValueError + If a `zarr.Group` has an unsupported store type. + """ # TODO: ensure kwargs like mode are enforced everywhere and passed correctly to the store if isinstance(path, str | Path): # if the input is str or Path, map it to UPath diff --git a/src/spatialdata/_io/io_zarr.py b/src/spatialdata/_io/io_zarr.py index 607171cc..a62df235 100644 --- a/src/spatialdata/_io/io_zarr.py +++ b/src/spatialdata/_io/io_zarr.py @@ -10,11 +10,7 @@ from zarr.errors import ArrayNotFoundError, MetadataValidationError from spatialdata._core.spatialdata import SpatialData -from spatialdata._io._utils import ( - BadFileHandleMethod, - handle_read_errors, - ome_zarr_logger, -) +from spatialdata._io._utils import BadFileHandleMethod, _resolve_zarr_store, handle_read_errors, ome_zarr_logger from spatialdata._io.io_points import _read_points from spatialdata._io.io_raster import _read_multiscale from spatialdata._io.io_shapes import _read_shapes @@ -24,7 +20,7 @@ # TODO: remove with incoming remote read / write PR # Not removing this now as it requires substantial extra refactor beyond scope of zarrv3 PR. -def _open_zarr_store( +def _open_zarr( store: str | Path | zarr.Group, mode: Literal["r", "r+", "a", "w", "w-"] = "r", use_consolidated: bool | None = None ) -> tuple[zarr.Group, str]: """ @@ -74,7 +70,10 @@ def read_zarr( ------- A SpatialData object. """ - f, f_store_path = _open_zarr_store(store) + from spatialdata._io._utils import _resolve_zarr_store as rzs + + resolved_store = rzs(store) + root_group, root_store_path = _open_zarr(resolved_store) images = {} labels = {} @@ -87,15 +86,15 @@ def read_zarr( # We raise OS errors instead for some read errors now as in zarr v3 with some corruptions nothing will be read. # related to images / labels. - if "images" in selector and "images" in f: - group = f["images"] + if "images" in selector and "images" in root_group: + group = root_group["images"] count = 0 for subgroup_name in group: if Path(subgroup_name).name.startswith("."): # skip hidden files like .zgroup or .zmetadata continue - f_elem = group[subgroup_name] - f_elem_store = os.path.join(f_store_path, f_elem.path) + elem_group = group[subgroup_name] + elem_group_path = os.path.join(root_store_path, elem_group.path) with handle_read_errors( on_bad_files, location=f"{group.path}/{subgroup_name}", @@ -107,22 +106,22 @@ def read_zarr( TypeError, # instead of ArrayNotFoundError, with dask>=2024.10.0 zarr<=2.18.3 ), ): - element = _read_multiscale(f_elem_store, raster_type="image") + element = _read_multiscale(elem_group_path, raster_type="image") images[subgroup_name] = element count += 1 logger.debug(f"Found {count} elements in {group}") # read multiscale labels with ome_zarr_logger(logging.ERROR): - if "labels" in selector and "labels" in f: - group = f["labels"] + if "labels" in selector and "labels" in root_group: + group = root_group["labels"] count = 0 for subgroup_name in group: if Path(subgroup_name).name.startswith("."): # skip hidden files like .zgroup or .zmetadata continue - f_elem = group[subgroup_name] - f_elem_store = f_store_path / f_elem.path + elem_group = group[subgroup_name] + elem_group_path = root_store_path / elem_group.path with handle_read_errors( on_bad_files, location=f"{group.path}/{subgroup_name}", @@ -134,48 +133,48 @@ def read_zarr( TypeError, ), ): - labels[subgroup_name] = _read_multiscale(f_elem_store, raster_type="labels") + labels[subgroup_name] = _read_multiscale(elem_group_path, raster_type="labels") count += 1 logger.debug(f"Found {count} elements in {group}") # now read rest of the data - if "points" in selector and "points" in f: + if "points" in selector and "points" in root_group: with handle_read_errors( on_bad_files, location="points", exc_types=(JSONDecodeError, MetadataValidationError), ): - group = f["points"] + group = root_group["points"] count = 0 for subgroup_name in group: - f_elem = group[subgroup_name] + elem_group = group[subgroup_name] if Path(subgroup_name).name.startswith("."): # skip hidden files like .zgroup or .zmetadata continue - f_elem_store = os.path.join(f_store_path, f_elem.path) + elem_group_path = os.path.join(root_store_path, elem_group.path) with handle_read_errors( on_bad_files, location=f"{group.path}/{subgroup_name}", exc_types=(JSONDecodeError, KeyError, ArrowInvalid), ): - points[subgroup_name] = _read_points(f_elem_store) + points[subgroup_name] = _read_points(elem_group_path) count += 1 logger.debug(f"Found {count} elements in {group}") - if "shapes" in selector and "shapes" in f: + if "shapes" in selector and "shapes" in root_group: with handle_read_errors( on_bad_files, location="shapes", exc_types=(JSONDecodeError, MetadataValidationError), ): - group = f["shapes"] + group = root_group["shapes"] count = 0 for subgroup_name in group: if Path(subgroup_name).name.startswith("."): # skip hidden files like .zgroup or .zmetadata continue - f_elem = group[subgroup_name] - f_elem_store = os.path.join(f_store_path, f_elem.path) + elem_group = group[subgroup_name] + elem_group_path = os.path.join(root_store_path, elem_group.path) with handle_read_errors( on_bad_files, location=f"{group.path}/{subgroup_name}", @@ -186,20 +185,20 @@ def read_zarr( ArrayNotFoundError, ), ): - shapes[subgroup_name] = _read_shapes(f_elem_store) + shapes[subgroup_name] = _read_shapes(elem_group_path) count += 1 logger.debug(f"Found {count} elements in {group}") - if "tables" in selector and "tables" in f: + if "tables" in selector and "tables" in root_group: with handle_read_errors( on_bad_files, location="tables", exc_types=(JSONDecodeError, MetadataValidationError), ): - group = f["tables"] - tables = _read_table(f_store_path, group, tables, on_bad_files=on_bad_files) + group = root_group["tables"] + tables = _read_table(root_store_path, group, tables, on_bad_files=on_bad_files) # read attrs metadata - attrs = f.attrs.asdict() + attrs = root_group.attrs.asdict() if "spatialdata_attrs" in attrs: # when refactoring the read_zarr function into reading componenets separately (and according to the version), # we can move the code below (.pop()) into attrs_from_dict() @@ -261,19 +260,16 @@ def _get_groups_for_element( "tables", ]: raise ValueError(f"Unknown element type {element_type}") - # TODO: remove local import after remote PR - from spatialdata._io._utils import _open_zarr_store - - store = _open_zarr_store(zarr_path, mode="r+") + resolved_store = _resolve_zarr_store(zarr_path) # When writing, use_consolidated must be set to False. Otherwise, the metadata store # can get out of sync with newly added elements (e.g., labels), leading to errors. - root = zarr.open_group(store=store, mode="a", use_consolidated=use_consolidated) - element_type_group = root.require_group(element_type) + root_group = zarr.open_group(store=resolved_store, mode="a", use_consolidated=use_consolidated) + element_type_group = root_group.require_group(element_type) element_type_group = zarr.open_group(element_type_group.store_path, mode="a", use_consolidated=use_consolidated) element_name_group = element_type_group.require_group(element_name) - return root, element_type_group, element_name_group + return root_group, element_type_group, element_name_group def _group_for_element_exists(zarr_path: Path, element_type: str, element_name: str) -> bool: @@ -291,10 +287,7 @@ def _group_for_element_exists(zarr_path: Path, element_type: str, element_name: ------- True if the group exists, False otherwise. """ - # TODO: remove local import after remote PR - from spatialdata._io._utils import _open_zarr_store - - store = _open_zarr_store(zarr_path, mode="r") + store = _resolve_zarr_store(zarr_path, mode="r") root = zarr.open_group(store=store, mode="r") assert element_type in [ "images", @@ -311,6 +304,6 @@ def _group_for_element_exists(zarr_path: Path, element_type: str, element_name: def _write_consolidated_metadata(path: Path | str | None) -> None: if path is not None: - f, f_store_path = _open_zarr_store(path, mode="r+", use_consolidated=False) + f, f_store_path = _open_zarr(path, mode="r+", use_consolidated=False) zarr.consolidate_metadata(f.store) f.store.close() From fd08907633ea22719ab86cbac432b426695ce105 Mon Sep 17 00:00:00 2001 From: Wouter-Michiel Vierdag Date: Mon, 8 Sep 2025 16:45:36 +0200 Subject: [PATCH 066/126] silence zarr parquet warnings --- src/spatialdata/_io/io_shapes.py | 4 +++- src/spatialdata/_io/io_zarr.py | 7 ++++++- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/src/spatialdata/_io/io_shapes.py b/src/spatialdata/_io/io_shapes.py index 066eb61c..bf6f454a 100644 --- a/src/spatialdata/_io/io_shapes.py +++ b/src/spatialdata/_io/io_shapes.py @@ -77,6 +77,9 @@ def write_shapes( ) -> None: """Write shapes to spatialdata zarr store. + Note that the parquet file is not recognized as part of the zarr hierarchy as it is not a valid component of a + zarr store, e.g. group, array or metadata file. + Parameters ---------- shapes: GeoDataFrame @@ -105,7 +108,6 @@ def write_shapes( axes=list(axes), attrs=attrs, ) - overwrite_coordinate_transformations_non_raster(group=group, axes=axes, transformations=transformations) diff --git a/src/spatialdata/_io/io_zarr.py b/src/spatialdata/_io/io_zarr.py index a62df235..7d40a12e 100644 --- a/src/spatialdata/_io/io_zarr.py +++ b/src/spatialdata/_io/io_zarr.py @@ -1,5 +1,6 @@ import logging import os +import warnings from json import JSONDecodeError from pathlib import Path from typing import Literal @@ -305,5 +306,9 @@ def _group_for_element_exists(zarr_path: Path, element_type: str, element_name: def _write_consolidated_metadata(path: Path | str | None) -> None: if path is not None: f, f_store_path = _open_zarr(path, mode="r+", use_consolidated=False) - zarr.consolidate_metadata(f.store) + # .parquet files are not recognized as proper zarr and thus throw a warning. This does not affect SpatialData. + # and therefore we silence it for our users as they can't do anything about this. + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=zarr.errors.ZarrUserWarning) + zarr.consolidate_metadata(f.store) f.store.close() From 74a81cd290482108d9a51b64236e381dfe270dd5 Mon Sep 17 00:00:00 2001 From: Wouter-Michiel Vierdag Date: Mon, 8 Sep 2025 17:44:35 +0200 Subject: [PATCH 067/126] change overwriting warning, silence in tests --- src/spatialdata/_core/_elements.py | 8 +++++++- src/spatialdata/_core/spatialdata.py | 13 ++++++++----- tests/core/operations/test_rasterize.py | 7 +++++-- tests/io/test_readwrite.py | 5 ++++- tests/utils/test_testing.py | 5 ++++- 5 files changed, 28 insertions(+), 10 deletions(-) diff --git a/src/spatialdata/_core/_elements.py b/src/spatialdata/_core/_elements.py index d1c2f4a5..a8f51340 100644 --- a/src/spatialdata/_core/_elements.py +++ b/src/spatialdata/_core/_elements.py @@ -42,7 +42,13 @@ def _remove_shared_key(self, key: str) -> None: def _check_key(key: str, element_keys: Iterable[str], shared_keys: set[str | None]) -> None: check_valid_name(key) if key in element_keys: - warn(f"Key `{key}` already exists. Overwriting it in-memory.", UserWarning, stacklevel=2) + warn( + f"Key `{key}` already exists. Overwriting it in-memory. If you want to silence this warning " + f"either delete it from memory 'del sdata[{key}]' (does not make changes on disk) or filter the " + f"warning.", + UserWarning, + stacklevel=2, + ) else: try: check_key_is_case_insensitively_unique(key, shared_keys) diff --git a/src/spatialdata/_core/spatialdata.py b/src/spatialdata/_core/spatialdata.py index 485b5cde..d438b841 100644 --- a/src/spatialdata/_core/spatialdata.py +++ b/src/spatialdata/_core/spatialdata.py @@ -296,9 +296,10 @@ def get_instance_key_column(table: AnnData) -> pd.Series: def set_channel_names(self, element_name: str, channel_names: str | list[str], write: bool = False) -> None: """Set the channel names for an image `SpatialElement` in the `SpatialData` object. - This method assumes that the `SpatialData` object and the element are already stored on disk as it will - also overwrite the channel names metadata on disk. In case either the `SpatialData` object or the - element are not stored on disk, please use `SpatialData.set_image_channel_names` instead. + This method will overwrite the element in memory with the same element, but with new channel names. + If 'write` is 'True', this method assumes that the `SpatialData` object and the element are already stored on + disk as it will also overwrite the channel names metadata on disk. In case either the `SpatialData` object or + the element are not stored on disk, please use `SpatialData.set_image_channel_names` instead. Parameters ---------- @@ -307,9 +308,11 @@ def set_channel_names(self, element_name: str, channel_names: str | list[str], w channel_names The channel names to be assigned to the c dimension of the image `SpatialElement`. write - Whether to overwrite the channel metadata on disk. + Whether to overwrite the channel metadata on disk. This will not rewrite the pixel data itself. """ - self.images[element_name] = set_channel_names(self.images[element_name], channel_names) + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=UserWarning) + self.images[element_name] = set_channel_names(self.images[element_name], channel_names) if write: self.write_channel_names(element_name) diff --git a/tests/core/operations/test_rasterize.py b/tests/core/operations/test_rasterize.py index a109c852..4a5f2e2d 100644 --- a/tests/core/operations/test_rasterize.py +++ b/tests/core/operations/test_rasterize.py @@ -158,8 +158,11 @@ def test_rasterize_points_shapes_with_string_index(points, shapes): sdata = SpatialData.init_from_elements({"points_0": points["points_0"], "circles": shapes["circles"]}) # make the indices of the points_0 and circles dataframes strings - sdata["points_0"]["str_index"] = dd.from_pandas(pd.Series([str(i) for i in sdata["points_0"].index]), npartitions=1) - sdata["points_0"] = sdata["points_0"].set_index("str_index") + points = sdata["points_0"] + points["str_index"] = dd.from_pandas(pd.Series([str(i) for i in sdata["points_0"].index]), npartitions=1) + points = points.set_index("str_index") + del sdata["points_0"] + sdata["points_0"] = points sdata["circles"].index = [str(i) for i in sdata["circles"].index] data_extent = get_extent(sdata) diff --git a/tests/io/test_readwrite.py b/tests/io/test_readwrite.py index fed63382..e3d68bf8 100644 --- a/tests/io/test_readwrite.py +++ b/tests/io/test_readwrite.py @@ -63,7 +63,9 @@ def test_points(self, tmp_path: str, points: SpatialData) -> None: # check the index is correctly written and then read new_index = dd.from_array(np.arange(1, len(points["points_0"]) + 1)) - points["points_0"] = points["points_0"].set_index(new_index) + el_point = points["points_0"].set_index(new_index) + del points["points_0"] + points["points_0"] = el_point points.write(tmpdir) sdata = SpatialData.read(tmpdir) @@ -337,6 +339,7 @@ def test_overwrite_fails_when_no_zarr_store_bug_dask_backed_data(self, full_sdat # now we have a sdata with dask-backed elements sdata2 = SpatialData.read(f) p = sdata2[element] + del full_sdata[element] full_sdata[element] = p with pytest.raises( ValueError, diff --git a/tests/utils/test_testing.py b/tests/utils/test_testing.py index a181c87f..37287e89 100644 --- a/tests/utils/test_testing.py +++ b/tests/utils/test_testing.py @@ -46,11 +46,14 @@ def _change_metadata_tables(sdata: SpatialData, element_name: str) -> None: def _change_metadata_image(sdata: SpatialData, element_name: str, coords: bool, transformations: bool) -> None: if coords: if isinstance(sdata[element_name], DataArray): - sdata[element_name] = sdata[element_name].assign_coords({"c": np.array(["m", "l", "b"])}) + element = sdata[element_name].assign_coords({"c": np.array(["m", "l", "b"])}) + del sdata[element_name] + sdata[element_name] = element else: assert isinstance(sdata[element_name], DataTree) dt = sdata[element_name].assign_coords({"c": np.array(["m", "l", "b"])}) + del sdata[element_name] sdata[element_name] = dt if transformations: set_transformation(sdata[element_name], copy.deepcopy(scale)) From 2bc809c53539d2ef4c80a8a4fd3bf0fb492f26cb Mon Sep 17 00:00:00 2001 From: Wouter-Michiel Vierdag Date: Mon, 8 Sep 2025 20:53:24 +0200 Subject: [PATCH 068/126] silence overwriting warnings --- tests/core/query/test_relational_query.py | 33 +++++++++++++++++------ tests/core/query/test_spatial_query.py | 1 + 2 files changed, 26 insertions(+), 8 deletions(-) diff --git a/tests/core/query/test_relational_query.py b/tests/core/query/test_relational_query.py index e50d0109..3cd406bb 100644 --- a/tests/core/query/test_relational_query.py +++ b/tests/core/query/test_relational_query.py @@ -35,8 +35,11 @@ def test_join_using_string_instance_id_and_index(sdata_query_aggregation): [f"string_{i}" for i in sdata_query_aggregation["values_polygons"].index] ) - sdata_query_aggregation["values_polygons"] = sdata_query_aggregation["values_polygons"][:5] - sdata_query_aggregation["values_circles"] = sdata_query_aggregation["values_circles"][:5] + values_polygons = sdata_query_aggregation["values_polygons"][:5] + values_circles = sdata_query_aggregation["values_circles"][:5] + del sdata_query_aggregation["values_polygons"], sdata_query_aggregation["values_circles"] + sdata_query_aggregation["values_polygons"] = values_polygons + sdata_query_aggregation["values_circles"] = values_circles element_dict, table = join_spatialelement_table( sdata=sdata_query_aggregation, @@ -85,7 +88,9 @@ def test_left_inner_right_exclusive_join(sdata_query_aggregation): assert table is None assert all(element_dict[key] is None for key in element_dict) - sdata["values_polygons"] = sdata["values_polygons"].drop([10, 11]) + values_polygons = sdata["values_polygons"].drop([10, 11]) + del sdata["values_polygons"] + sdata["values_polygons"] = values_polygons with pytest.raises(ValueError, match="No table with"): join_spatialelement_table( sdata=sdata, @@ -147,7 +152,9 @@ def test_left_inner_right_exclusive_join(sdata_query_aggregation): assert "by_polygons" in element_dict # check multiple elements joined to table. - sdata["values_circles"] = sdata["values_circles"].drop([7, 8]) + values_circles = sdata["values_circles"].drop([7, 8]) + del sdata["values_circles"] + sdata["values_circles"] = values_circles element_dict, table = join_spatialelement_table( sdata=sdata, spatial_element_names=["values_circles", "values_polygons"], @@ -289,7 +296,9 @@ def test_left_exclusive_and_right_join(sdata_query_aggregation): assert table is None # Dropped indices correspond to instance ids 7, 8 for 'values_circles' and 10, 11 for 'values_polygons' - sdata["table"] = sdata["table"][sdata["table"].obs.index.drop(["7", "8", "19", "20"])] + table_update = sdata["table"][sdata["table"].obs.index.drop(["7", "8", "19", "20"])] + del sdata["table"] + sdata["table"] = table_update with pytest.warns(UserWarning, match="The element"): element_dict, table = join_spatialelement_table( sdata=sdata, @@ -372,7 +381,9 @@ def test_left_exclusive_and_right_join(sdata_query_aggregation): def test_match_rows_inner_join_non_matching_element(sdata_query_aggregation): sdata = sdata_query_aggregation - sdata["values_circles"] = sdata["values_circles"][4:] + circles = sdata["values_circles"][4:] + del sdata["values_circles"] + sdata["values_circles"] = circles original_index = sdata["values_circles"].index reversed_instance_id = [3, 5, 8, 7, 6, 4, 1, 2, 0] + list(reversed(range(12))) sdata["table"].obs["instance_id"] = reversed_instance_id @@ -403,6 +414,7 @@ def test_match_rows_inner_join_non_matching_table(sdata_query_aggregation): original_instance_id = table.obs["instance_id"] reversed_instance_id = [6, 7, 8, 3, 4, 5] + list(reversed(range(12))) table.obs["instance_id"] = reversed_instance_id + del sdata["table"] sdata["table"] = table element_dict, table = join_spatialelement_table( @@ -433,8 +445,11 @@ def test_match_rows_inner_join_non_matching_table(sdata_query_aggregation): def test_inner_join_match_rows_duplicate_obs_indices(sdata_query_aggregation: SpatialData, join_type: str) -> None: sdata = sdata_query_aggregation sdata["table"].obs.index = ["a"] * sdata["table"].n_obs - sdata["values_circles"] = sdata_query_aggregation["values_circles"][:4] - sdata["values_polygons"] = sdata_query_aggregation["values_polygons"][:5] + values_circles = sdata_query_aggregation["values_circles"][:4] + values_polygons = sdata_query_aggregation["values_polygons"][:5] + del sdata["values_circles"], sdata["values_polygons"] + sdata["values_circles"] = values_circles + sdata["values_polygons"] = values_polygons element_dict, table = join_spatialelement_table( sdata=sdata, @@ -824,6 +839,7 @@ def test_get_values_df_points(points): region_key="region", instance_key="instance_id", ) + del points["points_0"] points["points_0"] = p points["table"] = table @@ -906,6 +922,7 @@ def test_filter_table_non_annotating(full_sdata): obs = pd.DataFrame({"test": ["a", "b", "c"]}, index=list(map(str, range(3)))) adata = AnnData(obs=obs) table = TableModel.parse(adata) + del full_sdata["table"] full_sdata["table"] = table full_sdata.filter_by_coordinate_system("global") diff --git a/tests/core/query/test_spatial_query.py b/tests/core/query/test_spatial_query.py index fc59d069..a82f8526 100644 --- a/tests/core/query/test_spatial_query.py +++ b/tests/core/query/test_spatial_query.py @@ -751,6 +751,7 @@ def test_spatial_query_different_axes(full_sdata, name: str): if name == "multipoly": new_data = GeoDataFrame({"geometry": [MultiPolygon([Polygon([(3, 1), (4, 1), (3, 0)])])]}) gdf = pd.concat([gdf, new_data], ignore_index=True) + del full_sdata[name] full_sdata[name] = ShapesModel.parse(gdf) map_axis = MapAxis(map_axis={"x": "y", "y": "x"}) From 6e43b42e72601e0844bd4be7151745cd181a6be3 Mon Sep 17 00:00:00 2001 From: Wouter-Michiel Vierdag Date: Mon, 8 Sep 2025 23:52:09 +0200 Subject: [PATCH 069/126] silence chunk warning --- src/spatialdata/models/models.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/spatialdata/models/models.py b/src/spatialdata/models/models.py index 52ad27d8..bd5a077f 100644 --- a/src/spatialdata/models/models.py +++ b/src/spatialdata/models/models.py @@ -228,6 +228,10 @@ def parse( parsed_transform = _get_transformations(data) # delete transforms del data.attrs["transform"] + if isinstance(chunks, tuple): + chunks = {dim: chunks[index] for index, dim in enumerate(data.dims)} + if isinstance(chunks, float): + chunks = {dim: chunks for index, dim in data.dims} data = to_multiscale( data, scale_factors=scale_factors, From bd0468e44c31eced0bb4d93e695fc5518aae4b2b Mon Sep 17 00:00:00 2001 From: Wouter-Michiel Vierdag Date: Tue, 9 Sep 2025 14:46:24 +0200 Subject: [PATCH 070/126] remove argument from docstring, update typehint --- src/spatialdata/_core/spatialdata.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/src/spatialdata/_core/spatialdata.py b/src/spatialdata/_core/spatialdata.py index d438b841..83a1971b 100644 --- a/src/spatialdata/_core/spatialdata.py +++ b/src/spatialdata/_core/spatialdata.py @@ -2163,8 +2163,7 @@ def _find_element(self, element_name: str) -> tuple[str, str, SpatialElement | A @classmethod def init_from_elements( cls, - elements: dict[str, SpatialElement], - tables: AnnData | dict[str, AnnData] | None = None, + elements: dict[str, SpatialElement | AnnData], attrs: Mapping[Any, Any] | None = None, ) -> SpatialData: """ @@ -2173,9 +2172,7 @@ def init_from_elements( Parameters ---------- elements - A dict of named elements. - tables - An optional table or dictionary of tables + A dict of named elements, e.g. SpatialElements like images and labels and AnnData tables. attrs Additional attributes to store in the SpatialData object. From 54e0d871435deeb2da7e77e4afb30d4a16d1da29 Mon Sep 17 00:00:00 2001 From: Wouter-Michiel Vierdag Date: Tue, 9 Sep 2025 14:47:08 +0200 Subject: [PATCH 071/126] update typehint --- src/spatialdata/_core/spatialdata.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spatialdata/_core/spatialdata.py b/src/spatialdata/_core/spatialdata.py index 83a1971b..790641c9 100644 --- a/src/spatialdata/_core/spatialdata.py +++ b/src/spatialdata/_core/spatialdata.py @@ -2180,7 +2180,7 @@ def init_from_elements( ------- The SpatialData object. """ - elements_dict: dict[str, SpatialElement] = {} + elements_dict: dict[str, SpatialElement | AnnData] = {} for name, element in elements.items(): model = get_model(element) if model in [Image2DModel, Image3DModel]: From fc052d71b8c46e7a4cb15aa2cd0bc921579ce70f Mon Sep 17 00:00:00 2001 From: Luca Marconato Date: Tue, 9 Sep 2025 14:47:29 +0200 Subject: [PATCH 072/126] small fixes --- docs/api/data_formats.md | 4 ++-- src/spatialdata/_core/spatialdata.py | 3 --- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/docs/api/data_formats.md b/docs/api/data_formats.md index 0bb72bf1..81638250 100644 --- a/docs/api/data_formats.md +++ b/docs/api/data_formats.md @@ -1,7 +1,7 @@ # Data formats (advanced) -The SpatialData format is defined as a set of versioned subclasses of `spatialdata._io.format.SpatialDataFormatType`, one per type of element. -These classes are useful to ensure backward compatibility whenever a major version change is introduced. We also provide pointers to the latest format. +The SpatialData format is defined as a set of versioned subclasses of `ome_zarr.format.Format`, one per type of element. The `spatialdata.SpatialDataFormatType` is a union type encompassing the possible valid formats. +These format subclasses are useful to ensure backward compatibility whenever a major version change is introduced. We also provide pointers to the latest format. ```{eval-rst} .. currentmodule:: spatialdata._io.format diff --git a/src/spatialdata/_core/spatialdata.py b/src/spatialdata/_core/spatialdata.py index 485b5cde..9c44781b 100644 --- a/src/spatialdata/_core/spatialdata.py +++ b/src/spatialdata/_core/spatialdata.py @@ -2161,7 +2161,6 @@ def _find_element(self, element_name: str) -> tuple[str, str, SpatialElement | A def init_from_elements( cls, elements: dict[str, SpatialElement], - tables: AnnData | dict[str, AnnData] | None = None, attrs: Mapping[Any, Any] | None = None, ) -> SpatialData: """ @@ -2171,8 +2170,6 @@ def init_from_elements( ---------- elements A dict of named elements. - tables - An optional table or dictionary of tables attrs Additional attributes to store in the SpatialData object. From 7c1bd7e486c887791cdfa9a36f85d96e1b4a0a40 Mon Sep 17 00:00:00 2001 From: Wouter-Michiel Vierdag Date: Tue, 9 Sep 2025 15:04:56 +0200 Subject: [PATCH 073/126] fail if not root does not exist --- src/spatialdata/_io/io_zarr.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spatialdata/_io/io_zarr.py b/src/spatialdata/_io/io_zarr.py index 7d40a12e..8fcbc4c3 100644 --- a/src/spatialdata/_io/io_zarr.py +++ b/src/spatialdata/_io/io_zarr.py @@ -265,7 +265,7 @@ def _get_groups_for_element( # When writing, use_consolidated must be set to False. Otherwise, the metadata store # can get out of sync with newly added elements (e.g., labels), leading to errors. - root_group = zarr.open_group(store=resolved_store, mode="a", use_consolidated=use_consolidated) + root_group = zarr.open_group(store=resolved_store, mode="r+", use_consolidated=use_consolidated) element_type_group = root_group.require_group(element_type) element_type_group = zarr.open_group(element_type_group.store_path, mode="a", use_consolidated=use_consolidated) From b9e5d92f820662159617bd191bfb65257e59c883 Mon Sep 17 00:00:00 2001 From: Wouter-Michiel Vierdag Date: Tue, 9 Sep 2025 15:07:12 +0200 Subject: [PATCH 074/126] write out function name --- src/spatialdata/_io/io_zarr.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/spatialdata/_io/io_zarr.py b/src/spatialdata/_io/io_zarr.py index 8fcbc4c3..bc413479 100644 --- a/src/spatialdata/_io/io_zarr.py +++ b/src/spatialdata/_io/io_zarr.py @@ -71,9 +71,9 @@ def read_zarr( ------- A SpatialData object. """ - from spatialdata._io._utils import _resolve_zarr_store as rzs + from spatialdata._io._utils import _resolve_zarr_store - resolved_store = rzs(store) + resolved_store = _resolve_zarr_store(store) root_group, root_store_path = _open_zarr(resolved_store) images = {} From 854468476c8ee6ccfe2245c90e692510c48772dd Mon Sep 17 00:00:00 2001 From: Luca Marconato Date: Tue, 9 Sep 2025 17:19:40 +0200 Subject: [PATCH 075/126] small fixes --- src/spatialdata/_io/_utils.py | 1 - src/spatialdata/_io/format.py | 18 +++++++++--------- src/spatialdata/_io/io_points.py | 8 ++++---- src/spatialdata/_io/io_raster.py | 10 ++++++---- 4 files changed, 19 insertions(+), 18 deletions(-) diff --git a/src/spatialdata/_io/_utils.py b/src/spatialdata/_io/_utils.py index 1efb70a8..a9189ae5 100644 --- a/src/spatialdata/_io/_utils.py +++ b/src/spatialdata/_io/_utils.py @@ -161,7 +161,6 @@ def _overwrite_coordinate_transformations_raster_zarrv2( The transformation present in multiscale["datasets"] are the ones for the multiscale, so and we leave them intact we update multiscale["coordinateTransformations"] and multiscale["coordinateSystems"] see the first post of https://github.com/scverse/spatialdata/issues/39 for an overview - fix the io to follow the NGFF specs, see https://github.com/scverse/spatialdata/issues/114 Parameters ---------- diff --git a/src/spatialdata/_io/format.py b/src/spatialdata/_io/format.py index 2cb1867d..71690903 100644 --- a/src/spatialdata/_io/format.py +++ b/src/spatialdata/_io/format.py @@ -300,15 +300,20 @@ def spatialdata_format_version(self) -> str: CurrentTablesFormat = TablesFormatV02 CurrentSpatialDataContainerFormat = SpatialDataContainerFormatV02 +RasterFormatType = RasterFormatV01 | RasterFormatV02 | RasterFormatV03 ShapesFormatType = ShapesFormatV01 | ShapesFormatV02 | ShapesFormatV03 PointsFormatType = PointsFormatV01 | PointsFormatV02 TablesFormatType = TablesFormatV01 | TablesFormatV02 -RasterFormatType = RasterFormatV01 | RasterFormatV02 | RasterFormatV03 SpatialDataContainerFormatType = SpatialDataContainerFormatV01 | SpatialDataContainerFormatV02 SpatialDataFormatType = ( - ShapesFormatType | PointsFormatType | TablesFormatType | RasterFormatType | SpatialDataContainerFormatType + RasterFormatType | ShapesFormatType | PointsFormatType | TablesFormatType | SpatialDataContainerFormatType ) +RasterFormats: dict[str, RasterFormatType] = { + "0.1": RasterFormatV01(), + "0.2": RasterFormatV02(), + "0.3": RasterFormatV03(), +} ShapesFormats: dict[str, ShapesFormatType] = { "0.1": ShapesFormatV01(), "0.2": ShapesFormatV02(), @@ -322,11 +327,6 @@ def spatialdata_format_version(self) -> str: "0.1": TablesFormatV01(), "0.2": TablesFormatV02(), } -RasterFormats: dict[str, RasterFormatType] = { - "0.1": RasterFormatV01(), - "0.2": RasterFormatV02(), - "0.3": RasterFormatV03(), -} SpatialDataContainerFormats: dict[str, SpatialDataContainerFormatType] = { "0.1": SpatialDataContainerFormatV01(), "0.2": SpatialDataContainerFormatV02(), @@ -335,15 +335,15 @@ def spatialdata_format_version(self) -> str: SpatialDataContainerFormatV01().__str__(): [ RasterFormatV01().__str__(), RasterFormatV02().__str__(), - PointsFormatV01().__str__(), ShapesFormatV01().__str__(), ShapesFormatV02().__str__(), + PointsFormatV01().__str__(), TablesFormatV01().__str__(), ], SpatialDataContainerFormatV02().__str__(): [ RasterFormatV03().__str__(), - PointsFormatV02().__str__(), ShapesFormatV03().__str__(), + PointsFormatV02().__str__(), TablesFormatV02().__str__(), ], } diff --git a/src/spatialdata/_io/io_points.py b/src/spatialdata/_io/io_points.py index 1d6c9d2d..a251b042 100644 --- a/src/spatialdata/_io/io_points.py +++ b/src/spatialdata/_io/io_points.py @@ -56,13 +56,13 @@ def write_points( Parameters ---------- - points: DaskDataFrame + points The dataframe of the points element. - group: zarr.Group + group The zarr group in the 'points' zarr group to write the points element to. - group_type: str + group_type The type of the element. - element_format: Format + element_format The format of the points element used to store it. """ axes = get_axes_names(points) diff --git a/src/spatialdata/_io/io_raster.py b/src/spatialdata/_io/io_raster.py index ff3c38b5..b180a322 100644 --- a/src/spatialdata/_io/io_raster.py +++ b/src/spatialdata/_io/io_raster.py @@ -43,9 +43,9 @@ def _get_multiscale_nodes(image_nodes: list[Node], nodes: list[Node]) -> list[No Parameters ---------- - image_nodes: list[Node] + image_nodes List of nodes returned from the ome-zarr-py Reader. - nodes: list[Node] + nodes List to append the nodes with the multiscales spec to. Returns @@ -73,7 +73,7 @@ def _read_multiscale(store: str | Path, raster_type: Literal["image", "labels"]) nodes = _get_multiscale_nodes(image_nodes, nodes) else: raise OSError( - f"Image location {image_loc} does not seem to exist. If it does, potentially the zarr.json file " + f"Image location {image_loc} does not seem to exist. If it does, potentially the zarr.json (or .zattrs) file " f"inside is corrupted or not present or the image files themselves are corrupted." ) if len(nodes) != 1: @@ -84,7 +84,7 @@ def _read_multiscale(store: str | Path, raster_type: Literal["image", "labels"]) ) raise OSError( f"Image location {image_loc} exists, but len(nodes) = {len(nodes)}, expected 1. Element " - f"{image_loc.basename()} is potentially corrupted." + f"{image_loc.basename()} is potentially corrupted. Please report this bug and attach a minimal data example." ) node = nodes[0] @@ -246,6 +246,8 @@ def _write_raster_dataarray( else: storage_options = {"chunks": chunks} # Scaler needs to be None since we are passing the data already downscaled for the multiscale case. + # We need this because the argument of write_image_ngff is called image while the argument of + # write_labels_ngff is called label. metadata[raster_type] = data write_single_scale_ngff( group=group, From fa6203855e20d023b539aecb555196843222591b Mon Sep 17 00:00:00 2001 From: Wouter-Michiel Vierdag Date: Tue, 9 Sep 2025 17:29:50 +0200 Subject: [PATCH 076/126] initial replacement parse_url --- src/spatialdata/_core/spatialdata.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/spatialdata/_core/spatialdata.py b/src/spatialdata/_core/spatialdata.py index 790641c9..6fe72388 100644 --- a/src/spatialdata/_core/spatialdata.py +++ b/src/spatialdata/_core/spatialdata.py @@ -1146,6 +1146,7 @@ def write( self._validate_can_safely_write_to_path(file_path, overwrite=overwrite) self._validate_all_elements() + # parse_url cannot be replaced here as it actually also initialized an ome-zarr store. store = parse_url(file_path, mode="w", fmt=parsed["SpatialData"]).store zarr_group = zarr.open_group(store=store, mode="w" if overwrite else "a") self.write_attrs(zarr_group=zarr_group) @@ -1390,8 +1391,10 @@ def delete_element_from_disk(self, element_name: str | list[str]) -> None: "more elements in the SpatialData object. Deleting the data would corrupt the SpatialData object." ) + from spatialdata._io._utils import _resolve_zarr_store + # delete the element - store = parse_url(self.path, mode="r+", fmt=FormatV05()).store + store = _resolve_zarr_store(self.path) root = zarr.open_group(store=store, mode="r+", use_consolidated=False) del root[element_type][element_name] store.close() From 6a0c2d0f85f8956727b0fc9a0807d74c1c6b16d3 Mon Sep 17 00:00:00 2001 From: Wouter-Michiel Vierdag Date: Wed, 10 Sep 2025 08:21:10 +0200 Subject: [PATCH 077/126] alter docstring --- src/spatialdata/_core/spatialdata.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/spatialdata/_core/spatialdata.py b/src/spatialdata/_core/spatialdata.py index 6fe72388..67a54cd6 100644 --- a/src/spatialdata/_core/spatialdata.py +++ b/src/spatialdata/_core/spatialdata.py @@ -298,8 +298,8 @@ def set_channel_names(self, element_name: str, channel_names: str | list[str], w This method will overwrite the element in memory with the same element, but with new channel names. If 'write` is 'True', this method assumes that the `SpatialData` object and the element are already stored on - disk as it will also overwrite the channel names metadata on disk. In case either the `SpatialData` object or - the element are not stored on disk, please use `SpatialData.set_image_channel_names` instead. + disk as it will also overwrite the channel names metadata on disk. If you do not want to overwrite the element + on disk, or it is not stored, set `write` to False. Parameters ---------- From 0c44ddb7be5b8225da20d3cb4c750a9b80a51205 Mon Sep 17 00:00:00 2001 From: Wouter-Michiel Vierdag Date: Wed, 10 Sep 2025 08:22:14 +0200 Subject: [PATCH 078/126] adjust argument docstring --- src/spatialdata/_core/spatialdata.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/spatialdata/_core/spatialdata.py b/src/spatialdata/_core/spatialdata.py index 67a54cd6..2c8ad4d0 100644 --- a/src/spatialdata/_core/spatialdata.py +++ b/src/spatialdata/_core/spatialdata.py @@ -308,7 +308,8 @@ def set_channel_names(self, element_name: str, channel_names: str | list[str], w channel_names The channel names to be assigned to the c dimension of the image `SpatialElement`. write - Whether to overwrite the channel metadata on disk. This will not rewrite the pixel data itself. + Whether to overwrite the channel metadata on disk (lightweight operation). This will not rewrite the pixel + data itself (heavy operation). """ with warnings.catch_warnings(): warnings.filterwarnings("ignore", category=UserWarning) From a2a2a457d0e3b22772f34a79a63c255d159caabd Mon Sep 17 00:00:00 2001 From: Wouter-Michiel Vierdag Date: Wed, 10 Sep 2025 08:52:18 +0200 Subject: [PATCH 079/126] remove overwrite warnings --- src/spatialdata/_core/_elements.py | 11 +--------- src/spatialdata/_core/validation.py | 5 ++++- tests/core/operations/test_aggregations.py | 2 -- tests/core/operations/test_rasterize.py | 1 - .../operations/test_spatialdata_operations.py | 20 ----------------- tests/core/query/test_relational_query.py | 10 --------- tests/core/query/test_spatial_query.py | 1 - tests/io/test_readwrite.py | 22 ------------------- tests/utils/test_testing.py | 2 -- 9 files changed, 5 insertions(+), 69 deletions(-) diff --git a/src/spatialdata/_core/_elements.py b/src/spatialdata/_core/_elements.py index a8f51340..99ff9d33 100644 --- a/src/spatialdata/_core/_elements.py +++ b/src/spatialdata/_core/_elements.py @@ -3,7 +3,6 @@ from collections import UserDict from collections.abc import Iterable, KeysView, ValuesView from typing import TypeVar -from warnings import warn from anndata import AnnData from dask.dataframe import DataFrame as DaskDataFrame @@ -41,15 +40,7 @@ def _remove_shared_key(self, key: str) -> None: @staticmethod def _check_key(key: str, element_keys: Iterable[str], shared_keys: set[str | None]) -> None: check_valid_name(key) - if key in element_keys: - warn( - f"Key `{key}` already exists. Overwriting it in-memory. If you want to silence this warning " - f"either delete it from memory 'del sdata[{key}]' (does not make changes on disk) or filter the " - f"warning.", - UserWarning, - stacklevel=2, - ) - else: + if key not in element_keys: try: check_key_is_case_insensitively_unique(key, shared_keys) except ValueError as e: diff --git a/src/spatialdata/_core/validation.py b/src/spatialdata/_core/validation.py index 537c49f3..50e1b65d 100644 --- a/src/spatialdata/_core/validation.py +++ b/src/spatialdata/_core/validation.py @@ -185,7 +185,10 @@ def check_key_is_case_insensitively_unique(key: str, other_keys: set[str | None] """ normalized_key = key.lower() if normalized_key in other_keys: - raise ValueError(f"Key `{key}` is not unique, or another case-variant of it exists.") + raise ValueError( + f"Key `{key}` is not unique as it exists with a different element type, or another " + f"case-variant of it exists." + ) def check_valid_dataframe_column_name(name: str) -> None: diff --git a/tests/core/operations/test_aggregations.py b/tests/core/operations/test_aggregations.py index 540161c7..eb2ed089 100644 --- a/tests/core/operations/test_aggregations.py +++ b/tests/core/operations/test_aggregations.py @@ -287,7 +287,6 @@ def test_aggregate_shapes_by_shapes( new_var = pd.concat((sdata.tables["table"].var, pd.DataFrame(index=["another_numerical_in_var"]))) new_x = np.concatenate((sdata.tables["table"].X, np.ones_like(sdata.tables["table"].X[:, :1])), axis=1) new_table = AnnData(X=new_x, obs=sdata.tables["table"].obs, var=new_var, uns=sdata.tables["table"].uns) - del sdata.tables["table"] sdata.tables["table"] = new_table result_adata = aggregate( @@ -500,7 +499,6 @@ def test_aggregate_considering_fractions_multiple_values( new_var = pd.concat((sdata.tables["table"].var, pd.DataFrame(index=["another_numerical_in_var"]))) new_x = np.concatenate((sdata.tables["table"].X, np.ones_like(sdata.tables["table"].X[:, :1])), axis=1) new_table = AnnData(X=new_x, obs=sdata.tables["table"].obs, var=new_var, uns=sdata.tables["table"].uns) - del sdata.tables["table"] sdata.tables["table"] = new_table out = aggregate( values_sdata=sdata, diff --git a/tests/core/operations/test_rasterize.py b/tests/core/operations/test_rasterize.py index 4a5f2e2d..25f3c3d0 100644 --- a/tests/core/operations/test_rasterize.py +++ b/tests/core/operations/test_rasterize.py @@ -161,7 +161,6 @@ def test_rasterize_points_shapes_with_string_index(points, shapes): points = sdata["points_0"] points["str_index"] = dd.from_pandas(pd.Series([str(i) for i in sdata["points_0"].index]), npartitions=1) points = points.set_index("str_index") - del sdata["points_0"] sdata["points_0"] = points sdata["circles"].index = [str(i) for i in sdata["circles"].index] diff --git a/tests/core/operations/test_spatialdata_operations.py b/tests/core/operations/test_spatialdata_operations.py index 7b3f4932..bbf5eb10 100644 --- a/tests/core/operations/test_spatialdata_operations.py +++ b/tests/core/operations/test_spatialdata_operations.py @@ -50,19 +50,6 @@ def test_element_names_unique() -> None: tables={"table": table}, ) - # add elements with the same name - # of element of same type - with pytest.warns(UserWarning): - sdata.images["image"] = image - with pytest.warns(UserWarning): - sdata.points["points"] = points - with pytest.warns(UserWarning): - sdata.shapes["shapes"] = shapes - with pytest.warns(UserWarning): - sdata.labels["labels"] = labels - with pytest.warns(UserWarning): - sdata.tables["table"] = table - # add elements with the same name # of element of different type with pytest.raises(KeyError): @@ -100,8 +87,6 @@ def test_element_names_unique() -> None: # add elements with the same name, test only couples of elements with pytest.raises(KeyError): sdata["labels"] = image - with pytest.warns(UserWarning): - sdata["points"] = points # this should not raise warnings because it's a different (new) name sdata["image2"] = image @@ -176,7 +161,6 @@ def test_filter_by_coordinate_system_also_table(full_sdata: SpatialData) -> None ) adata = full_sdata["table"] del adata.uns[TableModel.ATTRS_KEY] - del full_sdata.tables["table"] full_sdata["table"] = TableModel.parse( adata, region=["circles", "poly"], @@ -374,7 +358,6 @@ def test_concatenate_sdatas(full_sdata: SpatialData) -> None: # TODO: fix this new_region = "sample2" table_new = filtered1["table"].copy() - del filtered1.tables["table"] filtered1["table"] = table_new filtered1["table"].uns[TableModel.ATTRS_KEY][TableModel.REGION_KEY] = new_region filtered1["table"].obs[filtered1["table"].uns[TableModel.ATTRS_KEY][TableModel.REGION_KEY_KEY]] = new_region @@ -492,8 +475,6 @@ def test_get_item(points: SpatialData) -> None: def test_set_item(full_sdata: SpatialData) -> None: for name in ["image2d", "labels2d", "points_0", "circles", "poly"]: full_sdata[name + "_again"] = full_sdata[name] - with pytest.warns(UserWarning): - full_sdata[name] = full_sdata[name] def test_del_item(full_sdata: SpatialData) -> None: @@ -552,7 +533,6 @@ def test_subset(full_sdata: SpatialData) -> None: "instance_id": [0, 1, 2, 3, 4, 0, 1, 2, 3, 4], }, ) - del full_sdata.tables["table"] sdata_table = TableModel.parse( adata, region=["circles", "poly"], diff --git a/tests/core/query/test_relational_query.py b/tests/core/query/test_relational_query.py index 3cd406bb..a821ba03 100644 --- a/tests/core/query/test_relational_query.py +++ b/tests/core/query/test_relational_query.py @@ -37,7 +37,6 @@ def test_join_using_string_instance_id_and_index(sdata_query_aggregation): values_polygons = sdata_query_aggregation["values_polygons"][:5] values_circles = sdata_query_aggregation["values_circles"][:5] - del sdata_query_aggregation["values_polygons"], sdata_query_aggregation["values_circles"] sdata_query_aggregation["values_polygons"] = values_polygons sdata_query_aggregation["values_circles"] = values_circles @@ -89,7 +88,6 @@ def test_left_inner_right_exclusive_join(sdata_query_aggregation): assert all(element_dict[key] is None for key in element_dict) values_polygons = sdata["values_polygons"].drop([10, 11]) - del sdata["values_polygons"] sdata["values_polygons"] = values_polygons with pytest.raises(ValueError, match="No table with"): join_spatialelement_table( @@ -153,7 +151,6 @@ def test_left_inner_right_exclusive_join(sdata_query_aggregation): # check multiple elements joined to table. values_circles = sdata["values_circles"].drop([7, 8]) - del sdata["values_circles"] sdata["values_circles"] = values_circles element_dict, table = join_spatialelement_table( sdata=sdata, @@ -297,7 +294,6 @@ def test_left_exclusive_and_right_join(sdata_query_aggregation): # Dropped indices correspond to instance ids 7, 8 for 'values_circles' and 10, 11 for 'values_polygons' table_update = sdata["table"][sdata["table"].obs.index.drop(["7", "8", "19", "20"])] - del sdata["table"] sdata["table"] = table_update with pytest.warns(UserWarning, match="The element"): element_dict, table = join_spatialelement_table( @@ -382,7 +378,6 @@ def test_left_exclusive_and_right_join(sdata_query_aggregation): def test_match_rows_inner_join_non_matching_element(sdata_query_aggregation): sdata = sdata_query_aggregation circles = sdata["values_circles"][4:] - del sdata["values_circles"] sdata["values_circles"] = circles original_index = sdata["values_circles"].index reversed_instance_id = [3, 5, 8, 7, 6, 4, 1, 2, 0] + list(reversed(range(12))) @@ -414,7 +409,6 @@ def test_match_rows_inner_join_non_matching_table(sdata_query_aggregation): original_instance_id = table.obs["instance_id"] reversed_instance_id = [6, 7, 8, 3, 4, 5] + list(reversed(range(12))) table.obs["instance_id"] = reversed_instance_id - del sdata["table"] sdata["table"] = table element_dict, table = join_spatialelement_table( @@ -447,7 +441,6 @@ def test_inner_join_match_rows_duplicate_obs_indices(sdata_query_aggregation: Sp sdata["table"].obs.index = ["a"] * sdata["table"].n_obs values_circles = sdata_query_aggregation["values_circles"][:4] values_polygons = sdata_query_aggregation["values_polygons"][:5] - del sdata["values_circles"], sdata["values_polygons"] sdata["values_circles"] = values_circles sdata["values_polygons"] = values_polygons @@ -765,7 +758,6 @@ def test_get_values_df_shapes(sdata_query_aggregation): var=pd.DataFrame(index=["numerical_in_var", "another_numerical_in_var"]), uns=adata.uns, ) - del sdata_query_aggregation.tables["table"] sdata_query_aggregation["table"] = new_adata # test v = get_values( @@ -839,7 +831,6 @@ def test_get_values_df_points(points): region_key="region", instance_key="instance_id", ) - del points["points_0"] points["points_0"] = p points["table"] = table @@ -922,7 +913,6 @@ def test_filter_table_non_annotating(full_sdata): obs = pd.DataFrame({"test": ["a", "b", "c"]}, index=list(map(str, range(3)))) adata = AnnData(obs=obs) table = TableModel.parse(adata) - del full_sdata["table"] full_sdata["table"] = table full_sdata.filter_by_coordinate_system("global") diff --git a/tests/core/query/test_spatial_query.py b/tests/core/query/test_spatial_query.py index a82f8526..fc59d069 100644 --- a/tests/core/query/test_spatial_query.py +++ b/tests/core/query/test_spatial_query.py @@ -751,7 +751,6 @@ def test_spatial_query_different_axes(full_sdata, name: str): if name == "multipoly": new_data = GeoDataFrame({"geometry": [MultiPolygon([Polygon([(3, 1), (4, 1), (3, 0)])])]}) gdf = pd.concat([gdf, new_data], ignore_index=True) - del full_sdata[name] full_sdata[name] = ShapesModel.parse(gdf) map_axis = MapAxis(map_axis={"x": "y", "y": "x"}) diff --git a/tests/io/test_readwrite.py b/tests/io/test_readwrite.py index e3d68bf8..feba8409 100644 --- a/tests/io/test_readwrite.py +++ b/tests/io/test_readwrite.py @@ -64,7 +64,6 @@ def test_points(self, tmp_path: str, points: SpatialData) -> None: # check the index is correctly written and then read new_index = dd.from_array(np.arange(1, len(points["points_0"]) + 1)) el_point = points["points_0"].set_index(new_index) - del points["points_0"] points["points_0"] = el_point points.write(tmpdir) @@ -108,46 +107,26 @@ def test_incremental_io_in_memory( for k, v in _get_images().items(): sdata.images[f"additional_{k}"] = v - with pytest.warns(UserWarning): - sdata.images[f"additional_{k}"] = v - with pytest.warns(UserWarning): - sdata[f"additional_{k}"] = v with pytest.raises(KeyError, match="Key `table` is not unique"): sdata["table"] = v for k, v in _get_labels().items(): sdata.labels[f"additional_{k}"] = v - with pytest.warns(UserWarning): - sdata.labels[f"additional_{k}"] = v - with pytest.warns(UserWarning): - sdata[f"additional_{k}"] = v with pytest.raises(KeyError, match="Key `table` is not unique"): sdata["table"] = v for k, v in _get_shapes().items(): sdata.shapes[f"additional_{k}"] = v - with pytest.warns(UserWarning): - sdata.shapes[f"additional_{k}"] = v - with pytest.warns(UserWarning): - sdata[f"additional_{k}"] = v with pytest.raises(KeyError, match="Key `table` is not unique"): sdata["table"] = v for k, v in _get_points().items(): sdata.points[f"additional_{k}"] = v - with pytest.warns(UserWarning): - sdata.points[f"additional_{k}"] = v - with pytest.warns(UserWarning): - sdata[f"additional_{k}"] = v with pytest.raises(KeyError, match="Key `table` is not unique"): sdata["table"] = v for k, v in _get_tables().items(): sdata.tables[f"additional_{k}"] = v - with pytest.warns(UserWarning): - sdata.tables[f"additional_{k}"] = v - with pytest.warns(UserWarning): - sdata[f"additional_{k}"] = v with pytest.raises(KeyError, match="Key `poly` is not unique"): sdata["poly"] = v @@ -339,7 +318,6 @@ def test_overwrite_fails_when_no_zarr_store_bug_dask_backed_data(self, full_sdat # now we have a sdata with dask-backed elements sdata2 = SpatialData.read(f) p = sdata2[element] - del full_sdata[element] full_sdata[element] = p with pytest.raises( ValueError, diff --git a/tests/utils/test_testing.py b/tests/utils/test_testing.py index 37287e89..6125d73c 100644 --- a/tests/utils/test_testing.py +++ b/tests/utils/test_testing.py @@ -47,13 +47,11 @@ def _change_metadata_image(sdata: SpatialData, element_name: str, coords: bool, if coords: if isinstance(sdata[element_name], DataArray): element = sdata[element_name].assign_coords({"c": np.array(["m", "l", "b"])}) - del sdata[element_name] sdata[element_name] = element else: assert isinstance(sdata[element_name], DataTree) dt = sdata[element_name].assign_coords({"c": np.array(["m", "l", "b"])}) - del sdata[element_name] sdata[element_name] = dt if transformations: set_transformation(sdata[element_name], copy.deepcopy(scale)) From ac563722bb3d4097ea567d1826d83f96681d39d9 Mon Sep 17 00:00:00 2001 From: Wouter-Michiel Vierdag Date: Wed, 10 Sep 2025 09:59:28 +0200 Subject: [PATCH 080/126] fix test --- src/spatialdata/_core/spatialdata.py | 7 ++++++- tests/models/test_models.py | 8 +------- 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/src/spatialdata/_core/spatialdata.py b/src/spatialdata/_core/spatialdata.py index 2c8ad4d0..033e2eb4 100644 --- a/src/spatialdata/_core/spatialdata.py +++ b/src/spatialdata/_core/spatialdata.py @@ -31,6 +31,7 @@ ) from spatialdata._logging import logger from spatialdata._types import ArrayLike, Raster_T +from spatialdata._utils import _deprecation_alias from spatialdata.models import ( Image2DModel, Image3DModel, @@ -1055,6 +1056,7 @@ def _validate_can_safely_write_to_path( if not isinstance(file_path, Path): raise ValueError(f"file_path must be a string or a Path object, type(file_path) = {type(file_path)}.") + # TODO: add test for this if os.path.exists(file_path): if parse_url(file_path, mode="r", fmt=FormatV05()) is None: raise ValueError( @@ -1106,6 +1108,7 @@ def _validate_all_elements(self) -> None: with collect_error(location=element_path): validate_table_attr_keys(element, location=element_path) + @_deprecation_alias(format="sdata_formats", version="0.7.0") def write( self, file_path: str | Path, @@ -1421,8 +1424,10 @@ def write_consolidated_metadata(self) -> None: _write_consolidated_metadata(self.path) def has_consolidated_metadata(self) -> bool: + from spatialdata._io._utils import _resolve_zarr_store + return_value = False - store = parse_url(self.path, mode="r", fmt=FormatV05()).store + store = _resolve_zarr_store(self.path) group = zarr.open_group(store, mode="r") if getattr(group.metadata, "consolidated_metadata", None): return_value = True diff --git a/tests/models/test_models.py b/tests/models/test_models.py index e70cc997..b62eb53b 100644 --- a/tests/models/test_models.py +++ b/tests/models/test_models.py @@ -520,15 +520,9 @@ def test_table_model_invalid_names(self, key: str, attr: str, parse: bool): @pytest.mark.parametrize("attr", ["obs", "var"]) @pytest.mark.parametrize("parse", [True, False]) def test_table_model_not_unique_columns(self, keys: list[str], attr: str, parse: bool): - invalid_key = keys[1] - key_regex = re.escape(invalid_key) df = pd.DataFrame([[None] * len(keys)], columns=keys, index=["1"]) adata = AnnData(np.array([[0]]), **{attr: df}) - with pytest.raises( - ValueError, - match=f"Table contains invalid names(.|\n)*\n {attr}/{invalid_key}: " - + f"Key `{key_regex}` is not unique, or another case-variant of it exists.", - ): + with pytest.raises(ValueError, match="Table contains invalid names.\nFor renaming, please"): if parse: TableModel.parse(adata) else: From 22fce419dda3b0065a6de4e418cf9540d6daf2ec Mon Sep 17 00:00:00 2001 From: Wouter-Michiel Vierdag Date: Wed, 10 Sep 2025 10:19:43 +0200 Subject: [PATCH 081/126] replace parse_url --- src/spatialdata/_core/spatialdata.py | 3 ++- tests/io/test_format.py | 2 ++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/src/spatialdata/_core/spatialdata.py b/src/spatialdata/_core/spatialdata.py index 033e2eb4..a2143172 100644 --- a/src/spatialdata/_core/spatialdata.py +++ b/src/spatialdata/_core/spatialdata.py @@ -1610,6 +1610,7 @@ def write_attrs( format: SpatialDataContainerFormatType | None = None, zarr_group: zarr.Group | None = None, ) -> None: + from spatialdata._io._utils import _resolve_zarr_store from spatialdata._io.format import SpatialDataContainerFormatType, _parse_formats parsed = _parse_formats(formats=format) @@ -1620,7 +1621,7 @@ def write_attrs( if zarr_group is None: assert self.is_backed(), "The SpatialData object must be backed by a Zarr store to write attrs." - store = parse_url(self.path, mode="r+", fmt=FormatV05()).store + store = _resolve_zarr_store(self.path) zarr_group = zarr.open_group(store=store, mode="r+") version = spatialdata_container_format.spatialdata_format_version diff --git a/tests/io/test_format.py b/tests/io/test_format.py index 334cda16..0273fe15 100644 --- a/tests/io/test_format.py +++ b/tests/io/test_format.py @@ -209,8 +209,10 @@ def test_container_v1_to_v2(self, full_sdata): sdata_read_v1 = read_zarr(f1) assert_spatial_data_objects_are_identical(full_sdata, sdata_read_v1) assert sdata_read_v1.is_self_contained() + assert sdata_read_v1.has_consolidated_metadata() sdata_read_v1.write(f2, sdata_formats=[SpatialDataContainerFormatV02()]) sdata_read_v2 = read_zarr(f2) assert_spatial_data_objects_are_identical(full_sdata, sdata_read_v2) assert sdata_read_v2.is_self_contained() + assert sdata_read_v2.has_consolidated_metadata() From 92390d209d225662b34bd2dd6f042ae843e293fd Mon Sep 17 00:00:00 2001 From: Wouter-Michiel Vierdag Date: Wed, 10 Sep 2025 10:46:03 +0200 Subject: [PATCH 082/126] change version --- src/spatialdata/_io/format.py | 2 +- src/spatialdata/_io/io_zarr.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/spatialdata/_io/format.py b/src/spatialdata/_io/format.py index 71690903..a9175e38 100644 --- a/src/spatialdata/_io/format.py +++ b/src/spatialdata/_io/format.py @@ -204,7 +204,7 @@ def spatialdata_format_version(self) -> str: def version(self) -> str: # 0.1 -> 0.2 changed the version string for the NGFF format, from 0.4 to 0.6-dev-spatialdata as discussed here # https://github.com/scverse/spatialdata/pull/849 - return "0.4-dev-spatialdata" + return "0.5-dev-spatialdata" class ShapesFormatV01(FormatV04): diff --git a/src/spatialdata/_io/io_zarr.py b/src/spatialdata/_io/io_zarr.py index bc413479..7774a04b 100644 --- a/src/spatialdata/_io/io_zarr.py +++ b/src/spatialdata/_io/io_zarr.py @@ -288,7 +288,7 @@ def _group_for_element_exists(zarr_path: Path, element_type: str, element_name: ------- True if the group exists, False otherwise. """ - store = _resolve_zarr_store(zarr_path, mode="r") + store = _resolve_zarr_store(zarr_path) root = zarr.open_group(store=store, mode="r") assert element_type in [ "images", From ffc7ab0392fbd6fe0fd3d9c422789c2fbf7f882a Mon Sep 17 00:00:00 2001 From: Wouter-Michiel Vierdag Date: Wed, 10 Sep 2025 10:55:05 +0200 Subject: [PATCH 083/126] remove type hints from docstrings --- src/spatialdata/_io/_utils.py | 21 ++++++++++----------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/src/spatialdata/_io/_utils.py b/src/spatialdata/_io/_utils.py index a9189ae5..86686ea9 100644 --- a/src/spatialdata/_io/_utils.py +++ b/src/spatialdata/_io/_utils.py @@ -102,12 +102,12 @@ def overwrite_coordinate_transformations_raster( Parameters ---------- - group: zarr.Group + group The zarr group containing the raster element for which to write the transformations, e.g. the zarr group containing sdata['image2d']. - axes: tuple[ValidAxis_t, ...] + axes The list with axes names in the same order as the dimensions of the raster element. - transformations: MappingToCoordinateSystem_t + transformations Mapping between names of the coordinate system and the transformations. """ _validate_mapping_to_coordinate_system_type(transformations) @@ -137,10 +137,10 @@ def _overwrite_coordinate_transformations_raster_zarrv3( Parameters ---------- - group: zarr.Group + group The zarr group containing the raster element for which to write the transformations, e.g. the zarr group containing sdata['image2d']. - coordinate_transformations: list[dict[str, BaseTransformation]] + coordinate_transformations List of NGFF transformation representations as dictionaries. """ if len(multiscales := group.metadata.attributes["ome"]["multiscales"]) != 1: @@ -164,10 +164,10 @@ def _overwrite_coordinate_transformations_raster_zarrv2( Parameters ---------- - group: zarr.Group + group The zarr group containing the raster element for which to write the transformations, e.g. the zarr group containing sdata['image2d']. - coordinate_transformations: list[dict[str, BaseTransformation]] + coordinate_transformations List of NGFF transformation representations as dictionaries. """ multiscales = group.attrs["multiscales"] @@ -449,17 +449,16 @@ def _resolve_zarr_store(path: StoreLike, **kwargs: Any) -> zarr.storage.StoreLik Parameters ---------- - path : StoreLike | str | Path | UPath | zarr.Group + path The input representing a Zarr store or group. Can be a filesystem path, remote path, existing store, or Zarr group. - **kwargs : Any + **kwargs Additional keyword arguments forwarded to the underlying store constructor (e.g. `mode`, `storage_options`). Returns ------- - zarr.storage.StoreLike - A normalized store instance suitable for use with Zarr. + A normalized store instance suitable for use with Zarr. Raises ------ From 1decd22cc973cfa419f2495ae086aaee79638e0b Mon Sep 17 00:00:00 2001 From: Wouter-Michiel Vierdag Date: Wed, 10 Sep 2025 11:08:30 +0200 Subject: [PATCH 084/126] refactor to one function --- src/spatialdata/_io/_utils.py | 51 ++++------------------------------- 1 file changed, 5 insertions(+), 46 deletions(-) diff --git a/src/spatialdata/_io/_utils.py b/src/spatialdata/_io/_utils.py index 86686ea9..cc3967a0 100644 --- a/src/spatialdata/_io/_utils.py +++ b/src/spatialdata/_io/_utils.py @@ -124,56 +124,15 @@ def overwrite_coordinate_transformations_raster( ) coordinate_transformations = [t.to_dict() for t in ngff_transformations] # replace the metadata storage - if group.metadata.zarr_format == 3: - _overwrite_coordinate_transformations_raster_zarrv3(group, coordinate_transformations) - elif group.metadata.zarr_format == 2: - _overwrite_coordinate_transformations_raster_zarrv2(group, coordinate_transformations) - - -def _overwrite_coordinate_transformations_raster_zarrv3( - group: zarr.Group, coordinate_transformations: list[dict[str, BaseTransformation]] -) -> None: - """Write transformations of raster elements to disk in zarr v3. - - Parameters - ---------- - group - The zarr group containing the raster element for which to write the transformations, e.g. the zarr group - containing sdata['image2d']. - coordinate_transformations - List of NGFF transformation representations as dictionaries. - """ - if len(multiscales := group.metadata.attributes["ome"]["multiscales"]) != 1: + if group.metadata.zarr_format == 3 and len(multiscales := group.metadata.attributes["ome"]["multiscales"]) != 1: len_scales = len(multiscales) raise ValueError(f"The length of multiscales metadata should be 1, found the length to be {len_scales}") + if group.metadata.zarr_format == 2: + multiscales = group.attrs["multiscales"] + if (len_scales := len(multiscales)) != 1: + raise ValueError(f"The length of multiscales metadata should be 1, found length of {len_scales}") multiscale = multiscales[0] - # zarr v3 ome-zarr requires the coordinate transformations to be written this way, leaving one out won't work. - multiscale["coordinateTransformations"] = coordinate_transformations - group.attrs["multiscales"] = multiscales - - -def _overwrite_coordinate_transformations_raster_zarrv2( - group: zarr.Group, coordinate_transformations: list[dict[str, BaseTransformation]] -) -> None: - """Overwrite transformations of raster elements on disk in zarr v2. - - The transformation present in multiscale["datasets"] are the ones for the multiscale, so and we leave them intact - we update multiscale["coordinateTransformations"] and multiscale["coordinateSystems"] - see the first post of https://github.com/scverse/spatialdata/issues/39 for an overview - - Parameters - ---------- - group - The zarr group containing the raster element for which to write the transformations, e.g. the zarr group - containing sdata['image2d']. - coordinate_transformations - List of NGFF transformation representations as dictionaries. - """ - multiscales = group.attrs["multiscales"] - if (len_scales := len(multiscales)) != 1: - raise ValueError(f"The length of multiscales metadata should be 1, found length of {len_scales}") - multiscale = multiscales[0] multiscale["coordinateTransformations"] = coordinate_transformations group.attrs["multiscales"] = multiscales From daa804b6b96117ddd996eab1ba123c33c5baec0b Mon Sep 17 00:00:00 2001 From: Wouter-Michiel Vierdag Date: Wed, 10 Sep 2025 11:15:23 +0200 Subject: [PATCH 085/126] change typehint --- src/spatialdata/_io/io_raster.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/src/spatialdata/_io/io_raster.py b/src/spatialdata/_io/io_raster.py index b180a322..243062d2 100644 --- a/src/spatialdata/_io/io_raster.py +++ b/src/spatialdata/_io/io_raster.py @@ -4,7 +4,6 @@ import dask.array as da import numpy as np import zarr -from ome_zarr.format import Format from ome_zarr.io import ZarrLocation from ome_zarr.reader import Multiscales, Node, Reader from ome_zarr.types import JSONDict @@ -21,6 +20,7 @@ ) from spatialdata._io.format import ( CurrentRasterFormat, + RasterFormatType, ) from spatialdata._utils import get_pyramid_levels from spatialdata.models._utils import get_channel_names @@ -73,8 +73,8 @@ def _read_multiscale(store: str | Path, raster_type: Literal["image", "labels"]) nodes = _get_multiscale_nodes(image_nodes, nodes) else: raise OSError( - f"Image location {image_loc} does not seem to exist. If it does, potentially the zarr.json (or .zattrs) file " - f"inside is corrupted or not present or the image files themselves are corrupted." + f"Image location {image_loc} does not seem to exist. If it does, potentially the zarr.json (or .zattrs) " + f"file inside is corrupted or not present or the image files themselves are corrupted." ) if len(nodes) != 1: if not exists: @@ -84,7 +84,8 @@ def _read_multiscale(store: str | Path, raster_type: Literal["image", "labels"]) ) raise OSError( f"Image location {image_loc} exists, but len(nodes) = {len(nodes)}, expected 1. Element " - f"{image_loc.basename()} is potentially corrupted. Please report this bug and attach a minimal data example." + f"{image_loc.basename()} is potentially corrupted. Please report this bug and attach a minimal data " + f"example." ) node = nodes[0] @@ -141,7 +142,7 @@ def _write_raster( raster_data: DataArray | DataTree, group: zarr.Group, name: str, - raster_format: Format, + raster_format: RasterFormatType, storage_options: JSONDict | list[JSONDict] | None = None, label_metadata: JSONDict | None = None, **metadata: str | JSONDict | list[JSONDict], @@ -208,7 +209,7 @@ def _write_raster_dataarray( group: zarr.Group, element_name: str, raster_data: DataArray, - raster_format: Format, + raster_format: RasterFormatType, storage_options: JSONDict | list[JSONDict] | None = None, **metadata: str | JSONDict | list[JSONDict], ) -> None: @@ -268,7 +269,7 @@ def _write_raster_datatree( group: zarr.Group, element_name: str, raster_data: DataTree, - raster_format: Format, + raster_format: RasterFormatType, storage_options: JSONDict | list[JSONDict] | None = None, **metadata: str | JSONDict | list[JSONDict], ) -> None: @@ -330,7 +331,7 @@ def write_image( image: DataArray | DataTree, group: zarr.Group, name: str, - element_format: Format = CurrentRasterFormat(), + element_format: RasterFormatType = CurrentRasterFormat(), storage_options: JSONDict | list[JSONDict] | None = None, **metadata: str | JSONDict | list[JSONDict], ) -> None: @@ -349,7 +350,7 @@ def write_labels( labels: DataArray | DataTree, group: zarr.Group, name: str, - element_format: Format = CurrentRasterFormat(), + element_format: RasterFormatType = CurrentRasterFormat(), storage_options: JSONDict | list[JSONDict] | None = None, label_metadata: JSONDict | None = None, **metadata: JSONDict, From 564abae4a174b93bbbad0074e6af0a2bc92e8c39 Mon Sep 17 00:00:00 2001 From: Wouter-Michiel Vierdag Date: Wed, 10 Sep 2025 11:20:08 +0200 Subject: [PATCH 086/126] remove typehints from docstring --- src/spatialdata/_io/io_raster.py | 42 ++++++++++++++++---------------- 1 file changed, 21 insertions(+), 21 deletions(-) diff --git a/src/spatialdata/_io/io_raster.py b/src/spatialdata/_io/io_raster.py index 243062d2..7f1db705 100644 --- a/src/spatialdata/_io/io_raster.py +++ b/src/spatialdata/_io/io_raster.py @@ -151,21 +151,21 @@ def _write_raster( Parameters ---------- - raster_type: Literal["image", "labels"] + raster_type Whether the raster data pertains to a image or labels 'SpatialElement`. - raster_data: DataArray | DataTree + raster_data The raster data to write. - group: zarr.Group + group The zarr group in the 'image' or 'labels' zarr group to write the raster data to. name: str The name of the raster element. - raster_format: Format + raster_format The format used to write the raster data. - storage_options: JSONDict | list[JSONDict] | None + storage_options Additional options for writing the raster data, like chunks and compression. - label_metadata: JSONDict | None + label_metadata Label metadata which can only be defined when writing 'labels'. - metadata: str | JSONDict | list[JSONDict] + metadata Additional metadata for the raster element """ if raster_type not in ["image", "labels"]: @@ -217,19 +217,19 @@ def _write_raster_dataarray( Parameters ---------- - raster_type: Literal["image", "labels"] + raster_type Whether the raster data pertains to a image or labels 'SpatialElement`. - group: zarr.Group + group The zarr group in the 'image' or 'labels' zarr group to write the raster data to. - element_name: str + element_name The name of the raster element. - raster_data: DataArray + raster_data The raster data to write. - raster_format: Format + raster_format The format used to write the raster data. - storage_options: JSONDict | list[JSONDict] | None + storage_options Additional options for writing the raster data, like chunks and compression. - metadata: str | JSONDict | list[JSONDict] + metadata Additional metadata for the raster element """ write_single_scale_ngff = write_image_ngff if raster_type == "image" else write_labels_ngff @@ -277,19 +277,19 @@ def _write_raster_datatree( Parameters ---------- - raster_type: Literal["image", "labels"] + raster_type Whether the raster data pertains to a image or labels 'SpatialElement`. - group: zarr.Group + group The zarr group in the 'image' or 'labels' zarr group to write the raster data to. - element_name: str + element_name The name of the raster element. - raster_data: DataTree + raster_data The raster data to write. - raster_format: Format + raster_format The format used to write the raster data. - storage_options: JSONDict | list[JSONDict] | None + storage_options Additional options for writing the raster data, like chunks and compression. - metadata: str | JSONDict | list[JSONDict] + metadata Additional metadata for the raster element """ write_multi_scale_ngff = write_multiscale_ngff if raster_type == "image" else write_multiscale_labels_ngff From aa8e68696b37dd5d4af4be426f58b9225dabae5d Mon Sep 17 00:00:00 2001 From: Wouter-Michiel Vierdag Date: Wed, 10 Sep 2025 11:33:18 +0200 Subject: [PATCH 087/126] remove type hint return in docstring --- src/spatialdata/_io/io_raster.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/spatialdata/_io/io_raster.py b/src/spatialdata/_io/io_raster.py index 7f1db705..bd7b8a96 100644 --- a/src/spatialdata/_io/io_raster.py +++ b/src/spatialdata/_io/io_raster.py @@ -50,8 +50,7 @@ def _get_multiscale_nodes(image_nodes: list[Node], nodes: list[Node]) -> list[No Returns ------- - list[Node] - List of nodes with the multiscales spec. + List of nodes with the multiscales spec. """ if len(image_nodes): for node in image_nodes: From 36c59870dcfc2354c798a0830aababb10866b1d5 Mon Sep 17 00:00:00 2001 From: Wouter-Michiel Vierdag Date: Wed, 10 Sep 2025 11:37:18 +0200 Subject: [PATCH 088/126] remove comment --- src/spatialdata/_io/io_raster.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/spatialdata/_io/io_raster.py b/src/spatialdata/_io/io_raster.py index bd7b8a96..f0992995 100644 --- a/src/spatialdata/_io/io_raster.py +++ b/src/spatialdata/_io/io_raster.py @@ -190,10 +190,6 @@ def _write_raster( else: raise ValueError("Not a valid labels object") - # Since NGFF does not yet support coordinate transformations, we need a SpatialData extension for rasters. This will - # be dropped once NGFF supports it. For now, saving the NGFF version (0.4) is not enough—we must also record the - # SpatialData format version. - group = group["labels"][name] if raster_type == "labels" else group if ATTRS_KEY not in group.attrs: group.attrs[ATTRS_KEY] = {} From 48f0b81b071274da851a0678695d3dabc224c9a4 Mon Sep 17 00:00:00 2001 From: Wouter-Michiel Vierdag Date: Wed, 10 Sep 2025 11:41:58 +0200 Subject: [PATCH 089/126] ensure comment added back --- src/spatialdata/_io/io_raster.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spatialdata/_io/io_raster.py b/src/spatialdata/_io/io_raster.py index f0992995..9cdf5e94 100644 --- a/src/spatialdata/_io/io_raster.py +++ b/src/spatialdata/_io/io_raster.py @@ -242,7 +242,7 @@ def _write_raster_dataarray( else: storage_options = {"chunks": chunks} # Scaler needs to be None since we are passing the data already downscaled for the multiscale case. - # We need this because the argument of write_image_ngff is called image while the argument of + # We need this because the argument of write_image_ngff is called image while the argument of # write_labels_ngff is called label. metadata[raster_type] = data write_single_scale_ngff( From b6e23f7de89630a1f1b2ae5df05eff60dbdd5aa5 Mon Sep 17 00:00:00 2001 From: Wouter-Michiel Vierdag Date: Wed, 10 Sep 2025 13:52:49 +0200 Subject: [PATCH 090/126] fix channel metadata --- src/spatialdata/_io/_utils.py | 8 ++++++-- src/spatialdata/_io/io_raster.py | 6 +++--- tests/io/test_format.py | 16 ++++++++++++++++ 3 files changed, 25 insertions(+), 5 deletions(-) diff --git a/src/spatialdata/_io/_utils.py b/src/spatialdata/_io/_utils.py index cc3967a0..ceae5dbf 100644 --- a/src/spatialdata/_io/_utils.py +++ b/src/spatialdata/_io/_utils.py @@ -145,9 +145,13 @@ def overwrite_channel_names(group: zarr.Group, element: DataArray | DataTree) -> channel_names = element["scale0"]["image"].coords["c"].data.tolist() channel_metadata = [{"label": name} for name in channel_names] - omero_meta = group.attrs["ome"]["omero"] + omero_meta = group.attrs.get("omero", None) or group.attrs.get("ome", {}).get("omero") omero_meta["channels"] = channel_metadata - group.attrs["omero"] = omero_meta + if ome_meta := group.attrs.get("ome", None): + ome_meta["omero"] = omero_meta + group.attrs["ome"] = ome_meta + else: + group.attrs["omero"] = omero_meta def _write_metadata( diff --git a/src/spatialdata/_io/io_raster.py b/src/spatialdata/_io/io_raster.py index 9cdf5e94..3739df60 100644 --- a/src/spatialdata/_io/io_raster.py +++ b/src/spatialdata/_io/io_raster.py @@ -88,9 +88,9 @@ def _read_multiscale(store: str | Path, raster_type: Literal["image", "labels"]) ) node = nodes[0] - datasets = node.load(Multiscales).datasets - multiscales = node.load(Multiscales).zarr.root_attrs["multiscales"] - omero_metadata = node.load(Multiscales).zarr.root_attrs.get("omero", None) + loaded_node = node.load(Multiscales) + datasets, multiscales = loaded_node.datasets, loaded_node.zarr.root_attrs["multiscales"] + omero_metadata = loaded_node.zarr.root_attrs.get("omero") or loaded_node.zarr.root_attrs.get("ome", {}).get("omero") # TODO: check if below is still valid legacy_channels_metadata = node.load(Multiscales).zarr.root_attrs.get("channels_metadata", None) # legacy v0.1 assert len(multiscales) == 1 diff --git a/tests/io/test_format.py b/tests/io/test_format.py index 0273fe15..dcc5b938 100644 --- a/tests/io/test_format.py +++ b/tests/io/test_format.py @@ -204,6 +204,8 @@ def test_container_v1_to_v2(self, full_sdata): with tempfile.TemporaryDirectory() as tmpdir: f1 = Path(tmpdir) / "data1.zarr" f2 = Path(tmpdir) / "data2.zarr" + f3 = Path(tmpdir) / "data3.zarr" + f4 = Path(tmpdir) / "data4.zarr" full_sdata.write(f1, sdata_formats=[SpatialDataContainerFormatV01()]) sdata_read_v1 = read_zarr(f1) @@ -216,3 +218,17 @@ def test_container_v1_to_v2(self, full_sdata): assert_spatial_data_objects_are_identical(full_sdata, sdata_read_v2) assert sdata_read_v2.is_self_contained() assert sdata_read_v2.has_consolidated_metadata() + + new_channels = ["first", "second", "third"] + sdata_read_v1.set_channel_names("image2d", new_channels, write=True) + sdata_read_v1.set_channel_names("image2d_multiscale", new_channels, write=True) + assert sdata_read_v1["image2d"].coords["c"].data.tolist() == new_channels + assert sdata_read_v1["image2d_multiscale"]["scale0"]["image"].coords["c"].data.tolist() == new_channels + sdata_read_v1.write(f3, sdata_formats=[SpatialDataContainerFormatV01()]) + sdata_read_v1 = read_zarr(f3) + assert sdata_read_v1["image2d"].coords["c"].data.tolist() == new_channels + assert sdata_read_v1["image2d_multiscale"]["scale0"]["image"].coords["c"].data.tolist() == new_channels + sdata_read_v1.write(f4, sdata_formats=[SpatialDataContainerFormatV02()]) + sdata_read_v2 = read_zarr(f4) + assert sdata_read_v2["image2d"].coords["c"].data.tolist() == new_channels + assert sdata_read_v2["image2d_multiscale"]["scale0"]["image"].coords["c"].data.tolist() == new_channels From 7959e02fe7ce29b633a23af86c955d6970fbf3f9 Mon Sep 17 00:00:00 2001 From: Wouter-Michiel Vierdag Date: Wed, 10 Sep 2025 14:33:39 +0200 Subject: [PATCH 091/126] get rid of TableValidateMixin --- src/spatialdata/_core/concatenate.py | 3 +-- src/spatialdata/_io/format.py | 23 ++--------------------- src/spatialdata/_io/io_table.py | 8 +++----- src/spatialdata/models/models.py | 26 +++++++++++++++++++++++++- tests/conftest.py | 5 +++-- tests/io/test_multi_table.py | 11 ++++++++--- 6 files changed, 42 insertions(+), 34 deletions(-) diff --git a/src/spatialdata/_core/concatenate.py b/src/spatialdata/_core/concatenate.py index 68e3d17e..90b897c7 100644 --- a/src/spatialdata/_core/concatenate.py +++ b/src/spatialdata/_core/concatenate.py @@ -68,8 +68,7 @@ def _concatenate_tables( TableModel.INSTANCE_KEY: instance_key, } merged_table.uns[TableModel.ATTRS_KEY] = attrs - - return TableModel().validate(merged_table) + return TableModel().parse(merged_table) def concatenate( diff --git a/src/spatialdata/_io/format.py b/src/spatialdata/_io/format.py index a9175e38..df275cb5 100644 --- a/src/spatialdata/_io/format.py +++ b/src/spatialdata/_io/format.py @@ -4,7 +4,6 @@ import ome_zarr.format import zarr -from anndata import AnnData from ome_zarr.format import ( Format, FormatV01, @@ -13,7 +12,6 @@ FormatV04, FormatV05, ) -from pandas.api.types import CategoricalDtype from shapely import GeometryType from spatialdata.models.models import ATTRS_KEY, PointsModel, ShapesModel @@ -124,23 +122,6 @@ def attrs_to_dict(self, data: dict[str, Any]) -> dict[str, dict[str, Any]]: return d -class TableValidateMixinV01: - def validate_table( - self, - table: AnnData, - region_key: None | str = None, - instance_key: None | str = None, - ) -> None: - if not isinstance(table, AnnData): - raise TypeError(f"`table` must be `anndata.AnnData`, was {type(table)}.") - if region_key is not None and not isinstance(table.obs[region_key].dtype, CategoricalDtype): - raise ValueError( - f"`table.obs[region_key]` must be of type `categorical`, not `{type(table.obs[region_key])}`." - ) - if instance_key is not None and table.obs[instance_key].isnull().values.any(): - raise ValueError("`table.obs[instance_key]` must not contain null values, but it does.") - - class SpatialDataContainerFormatV01(FormatV04): @property def spatialdata_format_version(self) -> str: @@ -278,7 +259,7 @@ def spatialdata_format_version(self) -> str: return "0.2" -class TablesFormatV01(FormatV04, TableValidateMixinV01): +class TablesFormatV01(FormatV04): """Formatter for the table.""" @property @@ -286,7 +267,7 @@ def spatialdata_format_version(self) -> str: return "0.1" -class TablesFormatV02(FormatV05, TableValidateMixinV01): +class TablesFormatV02(FormatV05): """Formatter for the table.""" @property diff --git a/src/spatialdata/_io/io_table.py b/src/spatialdata/_io/io_table.py index 24c19271..937122d3 100644 --- a/src/spatialdata/_io/io_table.py +++ b/src/spatialdata/_io/io_table.py @@ -14,7 +14,7 @@ from spatialdata._io._utils import BadFileHandleMethod, handle_read_errors from spatialdata._io.format import CurrentTablesFormat, TablesFormats, _parse_version from spatialdata._logging import logger -from spatialdata.models import TableModel +from spatialdata.models import TableModel, get_table_keys def _read_table( @@ -97,10 +97,8 @@ def write_table( element_format: Format = CurrentTablesFormat(), ) -> None: if TableModel.ATTRS_KEY in table.uns: - region = table.uns["spatialdata_attrs"]["region"] - region_key = table.uns["spatialdata_attrs"].get("region_key", None) - instance_key = table.uns["spatialdata_attrs"].get("instance_key", None) - element_format.validate_table(table, region_key, instance_key) + region, region_key, instance_key = get_table_keys(table) + TableModel().validate(table) else: region, region_key, instance_key = (None, None, None) diff --git a/src/spatialdata/models/models.py b/src/spatialdata/models/models.py index bd5a077f..bd805e26 100644 --- a/src/spatialdata/models/models.py +++ b/src/spatialdata/models/models.py @@ -1066,10 +1066,33 @@ def validate( ------- The validated data. """ + if not isinstance(data, AnnData): + raise TypeError(f"`table` must be `anndata.AnnData`, was {type(data)}.") + validate_table_attr_keys(data) if ATTRS_KEY not in data.uns: return data + _, region_key, instance_key = get_table_keys(data) + if region_key is not None: + if region_key not in data.obs: + raise ValueError( + f"Region key `{region_key}` not in `adata.obs`. Please create the column and parse " + f"using TableModel.parse(adata)." + ) + if not isinstance(data.obs[region_key].dtype, CategoricalDtype): + raise ValueError( + f"`table.obs[{region_key}]` must be of type `categorical`, not `{type(data.obs[region_key])}`." + ) + if instance_key: + if instance_key not in data.obs: + raise ValueError( + f"Instance key `{instance_key}` not in `adata.obs`. Please create the column and parse" + f" using TableModel.parse(adata)." + ) + if data.obs[instance_key].isnull().values.any(): + raise ValueError("`table.obs[instance_key]` must not contain null values, but it does.") + self._validate_table_annotation_metadata(data) return data @@ -1157,8 +1180,9 @@ def parse( "instance_key": instance_key, } adata.uns[cls.ATTRS_KEY] = attr + convert_region_column_to_categorical(adata) cls().validate(adata) - return convert_region_column_to_categorical(adata) + return adata Schema_t: TypeAlias = ( diff --git a/tests/conftest.py b/tests/conftest.py index 9aef4744..deca11f3 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -461,12 +461,13 @@ def sdata_query_aggregation() -> SpatialData: def generate_adata(n_var: int, obs: pd.DataFrame, obsm: dict[Any, Any], uns: dict[Any, Any]) -> AnnData: rng = np.random.default_rng(SEED) - return AnnData( + adata = AnnData( rng.normal(size=(obs.shape[0], n_var)).astype(np.float64), obs=obs, obsm=obsm, uns=uns, ) + return TableModel().parse(adata) def _get_blobs_galaxy() -> tuple[ArrayLike, ArrayLike]: @@ -490,7 +491,7 @@ def adata_labels() -> AnnData: "categorical": pd.Categorical(rng.integers(0, 2, size=(n_obs_labels,))), "cell_id": pd.Categorical(seg), "instance_id": range(n_obs_labels), - "region": ["test"] * n_obs_labels, + "region": pd.Categorical(["test"] * n_obs_labels), }, index=np.arange(n_obs_labels).astype(str), ) diff --git a/tests/io/test_multi_table.py b/tests/io/test_multi_table.py index 1b754370..0e37e1b4 100644 --- a/tests/io/test_multi_table.py +++ b/tests/io/test_multi_table.py @@ -24,9 +24,14 @@ def test_set_get_tables_from_spatialdata(self, full_sdata: SpatialData, tmp_path adata2 = adata0.copy() del adata2.obs["region"] # fails because either none either all three 'region', 'region_key', 'instance_key' are required - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="Region key `region` not in `adata.obs`."): full_sdata["not_added_table"] = adata2 + adata3 = adata0.copy() + del adata3.obs["instance_id"] + with pytest.raises(ValueError, match="Instance key `instance_id` not in `adata.obs`."): + full_sdata["not_added_table"] = adata3 + assert len(full_sdata.tables) == 3 assert "adata0" in full_sdata.tables and "adata1" in full_sdata.tables full_sdata.write(tmpdir) @@ -247,13 +252,13 @@ def test_static_set_annotation_target(): ) table = _get_table(region="test_non_shapes") table_target = table.copy() - table_target.obs["region"] = "test_shapes" + table_target.obs["region"] = pd.Categorical(["test_shapes"] * table_target.n_obs) table_target = SpatialData.update_annotated_regions_metadata(table_target) assert table_target.uns[TableModel.ATTRS_KEY][TableModel.REGION_KEY] == ["test_shapes"] test_sdata["another_table"] = table_target - table.obs["diff_region"] = "test_shapes" + table.obs["diff_region"] = pd.Categorical(["test_shapes"] * table.n_obs) table = SpatialData.update_annotated_regions_metadata(table, region_key="diff_region") assert table.uns[TableModel.ATTRS_KEY][TableModel.REGION_KEY] == ["test_shapes"] From 8e16c8e109a7f47a443bf026bc969d0dc60ab1ef Mon Sep 17 00:00:00 2001 From: Luca Marconato Date: Wed, 10 Sep 2025 14:52:42 +0200 Subject: [PATCH 092/126] code fixes --- src/spatialdata/_io/_utils.py | 5 ++-- src/spatialdata/_io/io_raster.py | 15 ++++++---- src/spatialdata/_io/io_shapes.py | 26 ++++++++--------- src/spatialdata/_io/io_table.py | 1 - src/spatialdata/_io/io_zarr.py | 48 +++++++++++++++----------------- src/spatialdata/_types.py | 9 ++---- src/spatialdata/models/models.py | 10 +++---- 7 files changed, 53 insertions(+), 61 deletions(-) diff --git a/src/spatialdata/_io/_utils.py b/src/spatialdata/_io/_utils.py index a9189ae5..34ea5ca8 100644 --- a/src/spatialdata/_io/_utils.py +++ b/src/spatialdata/_io/_utils.py @@ -24,7 +24,6 @@ from zarr.storage import FsspecStore, LocalStore from spatialdata._core.spatialdata import SpatialData -from spatialdata._types import StoreLike from spatialdata._utils import get_pyramid_levels from spatialdata.models._utils import ( MappingToCoordinateSystem_t, @@ -437,7 +436,9 @@ def _is_element_self_contained( return all(_backed_elements_contained_in_path(path=element_path, object=element)) -def _resolve_zarr_store(path: StoreLike, **kwargs: Any) -> zarr.storage.StoreLike: +def _resolve_zarr_store( + path: str | Path | UPath | zarr.storage.StoreLike | zarr.Group, **kwargs: Any +) -> zarr.storage.StoreLike: """ Normalize different Zarr store inputs into a usable store instance. diff --git a/src/spatialdata/_io/io_raster.py b/src/spatialdata/_io/io_raster.py index b180a322..1540e076 100644 --- a/src/spatialdata/_io/io_raster.py +++ b/src/spatialdata/_io/io_raster.py @@ -73,18 +73,21 @@ def _read_multiscale(store: str | Path, raster_type: Literal["image", "labels"]) nodes = _get_multiscale_nodes(image_nodes, nodes) else: raise OSError( - f"Image location {image_loc} does not seem to exist. If it does, potentially the zarr.json (or .zattrs) file " - f"inside is corrupted or not present or the image files themselves are corrupted." + f"Image location {image_loc} does not seem to exist. If it does, " + "potentially the zarr.json (or .zattrs) file inside is corrupted or not " + "present or the image files themselves are corrupted." ) if len(nodes) != 1: if not exists: raise ValueError( - f"len(nodes) = {len(nodes)}, expected 1 and image location {image_loc} does not exist. Unable to read " - f"the NGFF file. Please report this bug and attach a minimal data example." + f"len(nodes) = {len(nodes)}, expected 1 and image location {image_loc} " + "does not exist. Unable to read the NGFF file. Please report this bug " + "and attach a minimal data example." ) raise OSError( - f"Image location {image_loc} exists, but len(nodes) = {len(nodes)}, expected 1. Element " - f"{image_loc.basename()} is potentially corrupted. Please report this bug and attach a minimal data example." + f"Image location {image_loc} exists, but len(nodes) = {len(nodes)}, " + f"expected 1. Element {image_loc.basename()} is potentially corrupted. " + "Please report this bug and attach a minimal data example." ) node = nodes[0] diff --git a/src/spatialdata/_io/io_shapes.py b/src/spatialdata/_io/io_shapes.py index bf6f454a..6e14baa4 100644 --- a/src/spatialdata/_io/io_shapes.py +++ b/src/spatialdata/_io/io_shapes.py @@ -5,6 +5,7 @@ import numpy as np import zarr from geopandas import GeoDataFrame, read_parquet +from natsort import natsorted from ome_zarr.format import Format from shapely import from_ragged_array, to_ragged_array @@ -48,10 +49,7 @@ def _read_shapes( geo_df = GeoDataFrame({"geometry": geometry, "radius": radius}, index=index) else: offsets_keys = [k for k in f if k.startswith("offset")] - - # We do this because of async reading not necessarily leading to ordered offset keys. - # We can't use sorted because if offsets are higher than 11 we get 1, 11, 2 - offsets_keys = [f"offset{i}" for i in range(len(offsets_keys))] + offsets_keys = natsorted(offsets_keys) offsets = tuple(np.array(f[k]).flatten() for k in offsets_keys) geometry = from_ragged_array(typ, coords, offsets) geo_df = GeoDataFrame({"geometry": geometry}, index=index) @@ -82,13 +80,13 @@ def write_shapes( Parameters ---------- - shapes: GeoDataFrame + shapes The shapes dataframe - group: zarr.Group + group The zarr group in the 'shapes' zarr group to write the shapes element to. - group_type: str + group_type The type of the element. - element_format: Format + element_format The format of the shapes element used to store it. """ axes = get_axes_names(shapes) @@ -116,11 +114,11 @@ def _write_shapes_v01(shapes: GeoDataFrame, group: zarr.Group, element_format: F Parameters ---------- - shapes: GeoDataFrame + shapes The shapes dataframe - group: zarr.Group + group The zarr group in the 'shapes' zarr group to write the shapes element to. - element_format: Format + element_format The format of the shapes element used to store it. """ import numcodecs @@ -146,11 +144,11 @@ def _write_shapes_v02_v03(shapes: GeoDataFrame, group: zarr.Group, element_forma Parameters ---------- - shapes: GeoDataFrame + shapes The shapes dataframe - group: zarr.Group + group The zarr group in the 'shapes' zarr group to write the shapes element to. - element_format: Format + element_format The format of the shapes element used to store it. """ store_root = group.store_path.store.root diff --git a/src/spatialdata/_io/io_table.py b/src/spatialdata/_io/io_table.py index 24c19271..5df1bda4 100644 --- a/src/spatialdata/_io/io_table.py +++ b/src/spatialdata/_io/io_table.py @@ -10,7 +10,6 @@ from ome_zarr.format import Format from zarr.errors import ArrayNotFoundError -# from zarr.errors import ArrayNotFoundError # removed in zarr 3.0 from spatialdata._io._utils import BadFileHandleMethod, handle_read_errors from spatialdata._io.format import CurrentTablesFormat, TablesFormats, _parse_version from spatialdata._logging import logger diff --git a/src/spatialdata/_io/io_zarr.py b/src/spatialdata/_io/io_zarr.py index bc413479..78f521c9 100644 --- a/src/spatialdata/_io/io_zarr.py +++ b/src/spatialdata/_io/io_zarr.py @@ -11,7 +11,12 @@ from zarr.errors import ArrayNotFoundError, MetadataValidationError from spatialdata._core.spatialdata import SpatialData -from spatialdata._io._utils import BadFileHandleMethod, _resolve_zarr_store, handle_read_errors, ome_zarr_logger +from spatialdata._io._utils import ( + BadFileHandleMethod, + _resolve_zarr_store, + handle_read_errors, + ome_zarr_logger, +) from spatialdata._io.io_points import _read_points from spatialdata._io.io_raster import _read_multiscale from spatialdata._io.io_shapes import _read_shapes @@ -19,29 +24,8 @@ from spatialdata._logging import logger -# TODO: remove with incoming remote read / write PR -# Not removing this now as it requires substantial extra refactor beyond scope of zarrv3 PR. -def _open_zarr( - store: str | Path | zarr.Group, mode: Literal["r", "r+", "a", "w", "w-"] = "r", use_consolidated: bool | None = None -) -> tuple[zarr.Group, str]: - """ - Open a zarr store (on-disk or remote) and return the zarr.Group object and the path to the store. - - Parameters - ---------- - store - Path to the zarr store (on-disk or remote) or a zarr.Group object. - - Returns - ------- - A tuple of the zarr.Group object and the path to the store. - """ - f = store if isinstance(store, zarr.Group) else zarr.open_group(store, mode=mode, use_consolidated=use_consolidated) - return f, f.store.root - - def read_zarr( - store: str | Path | zarr.Group, + store: str | Path, selection: None | tuple[str] = None, on_bad_files: Literal[BadFileHandleMethod.ERROR, BadFileHandleMethod.WARN] = BadFileHandleMethod.ERROR, ) -> SpatialData: @@ -51,7 +35,7 @@ def read_zarr( Parameters ---------- store - Path to the zarr store (on-disk or remote) or a zarr.Group object. + Path to the zarr store (on-disk or remote). selection List of elements to read from the zarr store (images, labels, points, shapes, table). If None, all elements are @@ -74,7 +58,8 @@ def read_zarr( from spatialdata._io._utils import _resolve_zarr_store resolved_store = _resolve_zarr_store(store) - root_group, root_store_path = _open_zarr(resolved_store) + root_group = zarr.open_group(resolved_store, mode="r") + root_store_path = root_group.store.root images = {} labels = {} @@ -197,6 +182,17 @@ def read_zarr( ): group = root_group["tables"] tables = _read_table(root_store_path, group, tables, on_bad_files=on_bad_files) + if "tables" in selector and "table" in root_group: + with handle_read_errors( + on_bad_files, + location="table", + exc_types=(ValueError,), + ): + raise ValueError( + f"`table` group found in zarr store at location {root_store_path} " + "instead of `tables`. Please update the zarr store to use `tables` " + "instead.", + ) # read attrs metadata attrs = root_group.attrs.asdict() @@ -305,7 +301,7 @@ def _group_for_element_exists(zarr_path: Path, element_type: str, element_name: def _write_consolidated_metadata(path: Path | str | None) -> None: if path is not None: - f, f_store_path = _open_zarr(path, mode="r+", use_consolidated=False) + f = zarr.open_group(path, mode="r+", use_consolidated=False) # .parquet files are not recognized as proper zarr and thus throw a warning. This does not affect SpatialData. # and therefore we silence it for our users as they can't do anything about this. with warnings.catch_warnings(): diff --git a/src/spatialdata/_types.py b/src/spatialdata/_types.py index 26fad13e..30d623a5 100644 --- a/src/spatialdata/_types.py +++ b/src/spatialdata/_types.py @@ -1,12 +1,9 @@ -from pathlib import Path -from typing import Any, TypeAlias +from typing import Any import numpy as np -import zarr.storage -from upath import UPath from xarray import DataArray, DataTree -__all__ = ["ArrayLike", "ColorLike", "DTypeLike", "Raster_T", "StoreLike"] +__all__ = ["ArrayLike", "ColorLike", "DTypeLike", "Raster_T"] from numpy.typing import DTypeLike, NDArray @@ -15,5 +12,3 @@ Raster_T = DataArray | DataTree ColorLike = tuple[float, ...] | str - -StoreLike: TypeAlias = str | Path | UPath | zarr.storage.StoreLike | zarr.Group diff --git a/src/spatialdata/models/models.py b/src/spatialdata/models/models.py index bd5a077f..14ad7193 100644 --- a/src/spatialdata/models/models.py +++ b/src/spatialdata/models/models.py @@ -157,11 +157,11 @@ def parse( Please refer to :func:`to_spatial_image` for more information. Note: if you set `rgb=None` in `kwargs`, 3-4 channel images will be interpreted automatically as RGB(A) images. - **Setting axes / dims** - In case of the data being a numpy or dask array, there are no named axes yet. In this case, we first try to - use the dimensions specified by the user in the `dims` argument of `.parse`. These dimensions are potentially - transposed. See the description of the `dims` argument above. If `dims` is not specified, the dims are set - to (c)(z)yx, dependent on the number of dimensions of the data. + **Setting axes / dims** In case of the data being a numpy or dask array, there are no named axes yet. In this + case, we first try to use the dimensions specified by the user in the `dims` argument of `.parse`. These + dimensions are used to potentially transpose the data to match the order (c)(z)yx. See the description of the + `dims` argument above. If `dims` is not specified, the dims are set to (c)(z)yx, dependent on the number of + dimensions of the data. """ if transformations: transformations = transformations.copy() From 906e3a94df529f6c8f57dfda3b0123da237595a6 Mon Sep 17 00:00:00 2001 From: Luca Marconato Date: Wed, 10 Sep 2025 15:21:37 +0200 Subject: [PATCH 093/126] fix --- src/spatialdata/_core/spatialdata.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spatialdata/_core/spatialdata.py b/src/spatialdata/_core/spatialdata.py index a2143172..fa8ede10 100644 --- a/src/spatialdata/_core/spatialdata.py +++ b/src/spatialdata/_core/spatialdata.py @@ -300,7 +300,7 @@ def set_channel_names(self, element_name: str, channel_names: str | list[str], w This method will overwrite the element in memory with the same element, but with new channel names. If 'write` is 'True', this method assumes that the `SpatialData` object and the element are already stored on disk as it will also overwrite the channel names metadata on disk. If you do not want to overwrite the element - on disk, or it is not stored, set `write` to False. + on disk, or it is not stored, set `write` to False (default). Parameters ---------- From 878dce13b9f0939c2f1d67869023578cba8ec602 Mon Sep 17 00:00:00 2001 From: Wouter-Michiel Vierdag Date: Wed, 10 Sep 2025 15:57:21 +0200 Subject: [PATCH 094/126] remove format without effect --- src/spatialdata/_core/spatialdata.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/spatialdata/_core/spatialdata.py b/src/spatialdata/_core/spatialdata.py index fa8ede10..93102d15 100644 --- a/src/spatialdata/_core/spatialdata.py +++ b/src/spatialdata/_core/spatialdata.py @@ -16,7 +16,6 @@ from dask.dataframe import read_parquet from dask.delayed import Delayed from geopandas import GeoDataFrame -from ome_zarr.format import FormatV05 from ome_zarr.io import parse_url from shapely import MultiPolygon, Polygon from xarray import DataArray, DataTree @@ -1058,7 +1057,7 @@ def _validate_can_safely_write_to_path( # TODO: add test for this if os.path.exists(file_path): - if parse_url(file_path, mode="r", fmt=FormatV05()) is None: + if parse_url(file_path, mode="r") is None: raise ValueError( "The target file path specified already exists, and it has been detected to not be a Zarr store. " "Overwriting non-Zarr stores is not supported to prevent accidental data loss." From a9f4ca05427068f0a5da3380b0549658fc76f008 Mon Sep 17 00:00:00 2001 From: Wouter-Michiel Vierdag Date: Wed, 10 Sep 2025 16:48:37 +0200 Subject: [PATCH 095/126] remove unnecessary catch warnings --- src/spatialdata/_core/spatialdata.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/spatialdata/_core/spatialdata.py b/src/spatialdata/_core/spatialdata.py index 93102d15..32b550ab 100644 --- a/src/spatialdata/_core/spatialdata.py +++ b/src/spatialdata/_core/spatialdata.py @@ -311,9 +311,7 @@ def set_channel_names(self, element_name: str, channel_names: str | list[str], w Whether to overwrite the channel metadata on disk (lightweight operation). This will not rewrite the pixel data itself (heavy operation). """ - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", category=UserWarning) - self.images[element_name] = set_channel_names(self.images[element_name], channel_names) + self.images[element_name] = set_channel_names(self.images[element_name], channel_names) if write: self.write_channel_names(element_name) From 09ac15210927283fbfc0909ab11807ed5d6e5ca2 Mon Sep 17 00:00:00 2001 From: Wouter-Michiel Vierdag Date: Wed, 10 Sep 2025 16:50:15 +0200 Subject: [PATCH 096/126] add todo --- src/spatialdata/_io/io_zarr.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/spatialdata/_io/io_zarr.py b/src/spatialdata/_io/io_zarr.py index 3aaf800a..9747d2c4 100644 --- a/src/spatialdata/_io/io_zarr.py +++ b/src/spatialdata/_io/io_zarr.py @@ -304,6 +304,8 @@ def _write_consolidated_metadata(path: Path | str | None) -> None: f = zarr.open_group(path, mode="r+", use_consolidated=False) # .parquet files are not recognized as proper zarr and thus throw a warning. This does not affect SpatialData. # and therefore we silence it for our users as they can't do anything about this. + # TODO check with remote PR whether we can prevent this warning at least for points data and whether with zarrv3 + # that pr would still work. with warnings.catch_warnings(): warnings.filterwarnings("ignore", category=zarr.errors.ZarrUserWarning) zarr.consolidate_metadata(f.store) From f6bae29454a2817b56c477ef5428e97bd8d48e5d Mon Sep 17 00:00:00 2001 From: Luca Marconato Date: Wed, 10 Sep 2025 16:54:15 +0200 Subject: [PATCH 097/126] fixes --- tests/core/query/test_relational_query.py | 24 +++----- tests/io/test_format.py | 74 +++++++++++++++-------- tests/io/test_readwrite.py | 3 +- tests/utils/test_testing.py | 3 +- 4 files changed, 60 insertions(+), 44 deletions(-) diff --git a/tests/core/query/test_relational_query.py b/tests/core/query/test_relational_query.py index a821ba03..938d6871 100644 --- a/tests/core/query/test_relational_query.py +++ b/tests/core/query/test_relational_query.py @@ -35,10 +35,8 @@ def test_join_using_string_instance_id_and_index(sdata_query_aggregation): [f"string_{i}" for i in sdata_query_aggregation["values_polygons"].index] ) - values_polygons = sdata_query_aggregation["values_polygons"][:5] - values_circles = sdata_query_aggregation["values_circles"][:5] - sdata_query_aggregation["values_polygons"] = values_polygons - sdata_query_aggregation["values_circles"] = values_circles + sdata_query_aggregation["values_polygons"] = sdata_query_aggregation["values_polygons"][:5] + sdata_query_aggregation["values_circles"] = sdata_query_aggregation["values_circles"][:5] element_dict, table = join_spatialelement_table( sdata=sdata_query_aggregation, @@ -87,8 +85,7 @@ def test_left_inner_right_exclusive_join(sdata_query_aggregation): assert table is None assert all(element_dict[key] is None for key in element_dict) - values_polygons = sdata["values_polygons"].drop([10, 11]) - sdata["values_polygons"] = values_polygons + sdata["values_polygons"] = sdata["values_polygons"].drop([10, 11]) with pytest.raises(ValueError, match="No table with"): join_spatialelement_table( sdata=sdata, @@ -150,8 +147,7 @@ def test_left_inner_right_exclusive_join(sdata_query_aggregation): assert "by_polygons" in element_dict # check multiple elements joined to table. - values_circles = sdata["values_circles"].drop([7, 8]) - sdata["values_circles"] = values_circles + sdata["values_circles"] = sdata["values_circles"].drop([7, 8]) element_dict, table = join_spatialelement_table( sdata=sdata, spatial_element_names=["values_circles", "values_polygons"], @@ -293,8 +289,7 @@ def test_left_exclusive_and_right_join(sdata_query_aggregation): assert table is None # Dropped indices correspond to instance ids 7, 8 for 'values_circles' and 10, 11 for 'values_polygons' - table_update = sdata["table"][sdata["table"].obs.index.drop(["7", "8", "19", "20"])] - sdata["table"] = table_update + sdata["table"] = sdata["table"][sdata["table"].obs.index.drop(["7", "8", "19", "20"])] with pytest.warns(UserWarning, match="The element"): element_dict, table = join_spatialelement_table( sdata=sdata, @@ -377,8 +372,7 @@ def test_left_exclusive_and_right_join(sdata_query_aggregation): def test_match_rows_inner_join_non_matching_element(sdata_query_aggregation): sdata = sdata_query_aggregation - circles = sdata["values_circles"][4:] - sdata["values_circles"] = circles + sdata["values_circles"] = sdata["values_circles"][4:] original_index = sdata["values_circles"].index reversed_instance_id = [3, 5, 8, 7, 6, 4, 1, 2, 0] + list(reversed(range(12))) sdata["table"].obs["instance_id"] = reversed_instance_id @@ -439,10 +433,8 @@ def test_match_rows_inner_join_non_matching_table(sdata_query_aggregation): def test_inner_join_match_rows_duplicate_obs_indices(sdata_query_aggregation: SpatialData, join_type: str) -> None: sdata = sdata_query_aggregation sdata["table"].obs.index = ["a"] * sdata["table"].n_obs - values_circles = sdata_query_aggregation["values_circles"][:4] - values_polygons = sdata_query_aggregation["values_polygons"][:5] - sdata["values_circles"] = values_circles - sdata["values_polygons"] = values_polygons + sdata["values_circles"] = sdata_query_aggregation["values_circles"][:4] + sdata["values_polygons"] = sdata_query_aggregation["values_polygons"][:5] element_dict, table = join_spatialelement_table( sdata=sdata, diff --git a/tests/io/test_format.py b/tests/io/test_format.py index dcc5b938..e0465fe7 100644 --- a/tests/io/test_format.py +++ b/tests/io/test_format.py @@ -14,6 +14,7 @@ RasterFormatV01, RasterFormatV02, RasterFormatV03, + ShapesFormatType, ShapesFormatV01, ShapesFormatV02, ShapesFormatV03, @@ -30,11 +31,11 @@ class TestFormat: """Test format.""" - @pytest.mark.parametrize("element_format", [PointsFormatV01()]) + @pytest.mark.parametrize("element_format", [PointsFormatV01(), PointsFormatV02()]) @pytest.mark.parametrize("attrs_key", [PointsModel.ATTRS_KEY]) @pytest.mark.parametrize("feature_key", [None, PointsModel.FEATURE_KEY]) @pytest.mark.parametrize("instance_key", [None, PointsModel.INSTANCE_KEY]) - def test_format_points_v1( + def test_format_points_v1_v2( self, element_format: PointsFormatType, attrs_key: str | None, @@ -81,28 +82,42 @@ def test_format_shape_v1( geometry = GeometryType(metadata[attrs_key][geos_key][type_key]) assert metadata[attrs_key] == ShapesFormatV01().attrs_to_dict(geometry) + @pytest.mark.parametrize("element_format", [ShapesFormatV02(), ShapesFormatV03()]) @pytest.mark.parametrize("attrs_key", [ShapesModel.ATTRS_KEY]) - def test_format_shapes_v2( + def test_format_shapes_v2_v3( self, + element_format: ShapesFormatType, attrs_key: str, ) -> None: - metadata: dict[str, Any] = {attrs_key: {"version": ShapesFormatV02().spatialdata_format_version}} + metadata: dict[str, Any] = {attrs_key: {"version": element_format.spatialdata_format_version}} metadata[attrs_key].pop("version") - assert metadata[attrs_key] == ShapesFormatV02().attrs_to_dict({}) + assert metadata[attrs_key] == element_format.attrs_to_dict({}) - @pytest.mark.parametrize("rformat", [RasterFormatV01, RasterFormatV02]) - def test_format_raster_v1_v2(self, images, rformat: type[SpatialDataFormatType]) -> None: + @pytest.mark.parametrize("rformat", [RasterFormatV01, RasterFormatV02, RasterFormatV03]) + def test_format_raster_v1_v2_v3(self, images, rformat: type[SpatialDataFormatType]) -> None: with tempfile.TemporaryDirectory() as tmpdir: - images.write(Path(tmpdir) / "images.zarr", sdata_formats=[SpatialDataContainerFormatV01(), rformat()]) - zattrs_file = Path(tmpdir) / "images.zarr/images/image2d/.zattrs" + sdata_container_format = ( + SpatialDataContainerFormatV01() if rformat != RasterFormatV03 else SpatialDataContainerFormatV02() + ) + images.write(Path(tmpdir) / "images.zarr", sdata_formats=[sdata_container_format, rformat()]) + + metadata_file = ".zattrs" if rformat != RasterFormatV03 else "zarr.json" + zattrs_file = Path(tmpdir) / "images.zarr/images/image2d/" / metadata_file + with open(zattrs_file) as infile: zattrs = json.load(infile) - ngff_version = zattrs["multiscales"][0]["version"] if rformat == RasterFormatV01: + ngff_version = zattrs["multiscales"][0]["version"] assert ngff_version == "0.4" - else: - assert rformat == RasterFormatV02 + elif rformat == RasterFormatV02: + ngff_version = zattrs["multiscales"][0]["version"] assert ngff_version == "0.4-dev-spatialdata" + else: + ngff_version = zattrs["attributes"]["ome"]["version"] + assert rformat == RasterFormatV03 + assert ngff_version == "0.5-dev-spatialdata" + + # TODO: add tests for TablesFormatV01 and TablesFormatV02 class TestFormatConversions: @@ -182,10 +197,12 @@ def test_points_v1_to_v2(self, points): points.write(f1, sdata_formats=[PointsFormatV01(), SpatialDataContainerFormatV01()]) points_read_v1 = read_zarr(f1) assert_spatial_data_objects_are_identical(points, points_read_v1) + assert points_read_v1.is_self_contained() points_read_v1.write(f2, sdata_formats=[PointsFormatV02(), SpatialDataContainerFormatV02()]) points_read_v2 = read_zarr(f2) assert_spatial_data_objects_are_identical(points, points_read_v2) + assert points_read_v2.is_self_contained() def test_tables_v1_to_v2(self, table_multiple_annotations): with tempfile.TemporaryDirectory() as tmpdir: @@ -204,8 +221,6 @@ def test_container_v1_to_v2(self, full_sdata): with tempfile.TemporaryDirectory() as tmpdir: f1 = Path(tmpdir) / "data1.zarr" f2 = Path(tmpdir) / "data2.zarr" - f3 = Path(tmpdir) / "data3.zarr" - f4 = Path(tmpdir) / "data4.zarr" full_sdata.write(f1, sdata_formats=[SpatialDataContainerFormatV01()]) sdata_read_v1 = read_zarr(f1) @@ -219,16 +234,27 @@ def test_container_v1_to_v2(self, full_sdata): assert sdata_read_v2.is_self_contained() assert sdata_read_v2.has_consolidated_metadata() + def test_chanel_names_raster_images_v1_to_v2_to_v3(self, images): + with tempfile.TemporaryDirectory() as tmpdir: + f1 = Path(tmpdir) / "data1.zarr" + f2 = Path(tmpdir) / "data2.zarr" + f3 = Path(tmpdir) / "data3.zarr" + new_channels = ["first", "second", "third"] - sdata_read_v1.set_channel_names("image2d", new_channels, write=True) - sdata_read_v1.set_channel_names("image2d_multiscale", new_channels, write=True) - assert sdata_read_v1["image2d"].coords["c"].data.tolist() == new_channels - assert sdata_read_v1["image2d_multiscale"]["scale0"]["image"].coords["c"].data.tolist() == new_channels - sdata_read_v1.write(f3, sdata_formats=[SpatialDataContainerFormatV01()]) - sdata_read_v1 = read_zarr(f3) - assert sdata_read_v1["image2d"].coords["c"].data.tolist() == new_channels - assert sdata_read_v1["image2d_multiscale"]["scale0"]["image"].coords["c"].data.tolist() == new_channels - sdata_read_v1.write(f4, sdata_formats=[SpatialDataContainerFormatV02()]) - sdata_read_v2 = read_zarr(f4) + + images.write(f1, sdata_formats=[RasterFormatV01(), SpatialDataContainerFormatV01()]) + images_read_v1 = read_zarr(f1) + images_read_v1.set_channel_names("image2d", new_channels, write=True) + images_read_v1.set_channel_names("image2d_multiscale", new_channels, write=True) + assert images_read_v1["image2d"].coords["c"].data.tolist() == new_channels + assert images_read_v1["image2d_multiscale"]["scale0"]["image"].coords["c"].data.tolist() == new_channels + + images_read_v1.write(f2, sdata_formats=[SpatialDataContainerFormatV01()]) + images_read_v1 = read_zarr(f2) + assert images_read_v1["image2d"].coords["c"].data.tolist() == new_channels + assert images_read_v1["image2d_multiscale"]["scale0"]["image"].coords["c"].data.tolist() == new_channels + + images_read_v1.write(f3, sdata_formats=[SpatialDataContainerFormatV02()]) + sdata_read_v2 = read_zarr(f3) assert sdata_read_v2["image2d"].coords["c"].data.tolist() == new_channels assert sdata_read_v2["image2d_multiscale"]["scale0"]["image"].coords["c"].data.tolist() == new_channels diff --git a/tests/io/test_readwrite.py b/tests/io/test_readwrite.py index feba8409..1288496d 100644 --- a/tests/io/test_readwrite.py +++ b/tests/io/test_readwrite.py @@ -63,8 +63,7 @@ def test_points(self, tmp_path: str, points: SpatialData) -> None: # check the index is correctly written and then read new_index = dd.from_array(np.arange(1, len(points["points_0"]) + 1)) - el_point = points["points_0"].set_index(new_index) - points["points_0"] = el_point + points["points_0"] = points["points_0"].set_index(new_index) points.write(tmpdir) sdata = SpatialData.read(tmpdir) diff --git a/tests/utils/test_testing.py b/tests/utils/test_testing.py index 6125d73c..a181c87f 100644 --- a/tests/utils/test_testing.py +++ b/tests/utils/test_testing.py @@ -46,8 +46,7 @@ def _change_metadata_tables(sdata: SpatialData, element_name: str) -> None: def _change_metadata_image(sdata: SpatialData, element_name: str, coords: bool, transformations: bool) -> None: if coords: if isinstance(sdata[element_name], DataArray): - element = sdata[element_name].assign_coords({"c": np.array(["m", "l", "b"])}) - sdata[element_name] = element + sdata[element_name] = sdata[element_name].assign_coords({"c": np.array(["m", "l", "b"])}) else: assert isinstance(sdata[element_name], DataTree) From 7997f3beb96358d000e5c0f817efbee976c440a6 Mon Sep 17 00:00:00 2001 From: Wouter-Michiel Vierdag Date: Wed, 10 Sep 2025 17:22:01 +0200 Subject: [PATCH 098/126] remove unnecessary .get('ome') --- src/spatialdata/_io/io_raster.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/spatialdata/_io/io_raster.py b/src/spatialdata/_io/io_raster.py index db5df06e..80aa07e3 100644 --- a/src/spatialdata/_io/io_raster.py +++ b/src/spatialdata/_io/io_raster.py @@ -91,7 +91,8 @@ def _read_multiscale(store: str | Path, raster_type: Literal["image", "labels"]) node = nodes[0] loaded_node = node.load(Multiscales) datasets, multiscales = loaded_node.datasets, loaded_node.zarr.root_attrs["multiscales"] - omero_metadata = loaded_node.zarr.root_attrs.get("omero") or loaded_node.zarr.root_attrs.get("ome", {}).get("omero") + # This works for all versions as in zarr v3 the level of the 'ome' key is taken as root_attrs. + omero_metadata = loaded_node.zarr.root_attrs.get("omero") # TODO: check if below is still valid legacy_channels_metadata = node.load(Multiscales).zarr.root_attrs.get("channels_metadata", None) # legacy v0.1 assert len(multiscales) == 1 From b19256bb409123ab491acee77b946442a99e7a10 Mon Sep 17 00:00:00 2001 From: Wouter-Michiel Vierdag Date: Wed, 10 Sep 2025 17:37:00 +0200 Subject: [PATCH 099/126] add clarifying comment --- src/spatialdata/_io/_utils.py | 1 + tests/io/test_format.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/src/spatialdata/_io/_utils.py b/src/spatialdata/_io/_utils.py index e727c963..82f5a7ce 100644 --- a/src/spatialdata/_io/_utils.py +++ b/src/spatialdata/_io/_utils.py @@ -144,6 +144,7 @@ def overwrite_channel_names(group: zarr.Group, element: DataArray | DataTree) -> channel_names = element["scale0"]["image"].coords["c"].data.tolist() channel_metadata = [{"label": name} for name in channel_names] + # This is required here as we do not use the load node API of ome-zarr omero_meta = group.attrs.get("omero", None) or group.attrs.get("ome", {}).get("omero") omero_meta["channels"] = channel_metadata if ome_meta := group.attrs.get("ome", None): diff --git a/tests/io/test_format.py b/tests/io/test_format.py index e0465fe7..c8d9f04c 100644 --- a/tests/io/test_format.py +++ b/tests/io/test_format.py @@ -234,7 +234,7 @@ def test_container_v1_to_v2(self, full_sdata): assert sdata_read_v2.is_self_contained() assert sdata_read_v2.has_consolidated_metadata() - def test_chanel_names_raster_images_v1_to_v2_to_v3(self, images): + def test_channel_names_raster_images_v1_to_v2_to_v3(self, images): with tempfile.TemporaryDirectory() as tmpdir: f1 = Path(tmpdir) / "data1.zarr" f2 = Path(tmpdir) / "data2.zarr" From 2bed0ceeeb0e5f5b926d89dce91b2795f1daab59 Mon Sep 17 00:00:00 2001 From: Luca Marconato Date: Wed, 10 Sep 2025 17:37:21 +0200 Subject: [PATCH 100/126] wip tests readwrite across formats --- src/spatialdata/_core/spatialdata.py | 3 +- tests/io/test_readwrite.py | 424 ++++++++++++++++++--------- 2 files changed, 287 insertions(+), 140 deletions(-) diff --git a/src/spatialdata/_core/spatialdata.py b/src/spatialdata/_core/spatialdata.py index 32b550ab..ff35aed9 100644 --- a/src/spatialdata/_core/spatialdata.py +++ b/src/spatialdata/_core/spatialdata.py @@ -1638,6 +1638,7 @@ def write_metadata( element_name: str | None = None, consolidate_metadata: bool | None = None, write_attrs: bool = True, + format: SpatialDataContainerFormatType | None = None, ) -> None: """ Write the metadata of a single element, or of all elements, to the Zarr store, without rewriting the data. @@ -1674,7 +1675,7 @@ def write_metadata( # TODO: write .attrs['spatialdata_attrs'] metadata for DaskDataFrame. if write_attrs: - self.write_attrs() + self.write_attrs(format=format) # TODO: discuss when has_consolidated_metadata that we should just consolidate it because after a writing # operation the consolidated store could otherwise be out of sync. diff --git a/tests/io/test_readwrite.py b/tests/io/test_readwrite.py index 1288496d..7dc1b59a 100644 --- a/tests/io/test_readwrite.py +++ b/tests/io/test_readwrite.py @@ -16,6 +16,12 @@ from spatialdata import SpatialData, deepcopy, read_zarr from spatialdata._core.validation import ValidationError from spatialdata._io._utils import _are_directories_identical, get_dask_backing_files +from spatialdata._io.format import ( + CurrentSpatialDataContainerFormat, + SpatialDataContainerFormats, + SpatialDataContainerFormatType, + SpatialDataContainerFormatV01, +) from spatialdata.datasets import blobs from spatialdata.models import Image2DModel from spatialdata.models._utils import get_channel_names @@ -25,114 +31,145 @@ set_transformation, ) from spatialdata.transformations.transformations import Identity, Scale -from tests.conftest import _get_images, _get_labels, _get_points, _get_shapes, _get_table, _get_tables +from tests.conftest import ( + _get_images, + _get_labels, + _get_points, + _get_shapes, + _get_table, + _get_tables, +) RNG = default_rng(0) +SDATA_FORMATS = list(SpatialDataContainerFormats.values()) +@pytest.mark.parametrize("sdata_container_format", SDATA_FORMATS) class TestReadWrite: - def test_images(self, tmp_path: str, images: SpatialData) -> None: + def test_images( + self, + tmp_path: str, + images: SpatialData, + sdata_container_format: SpatialDataContainerFormatType, + ) -> None: tmpdir = Path(tmp_path) / "tmp.zarr" # ensures that we are inplicitly testing the read and write of channel names assert get_channel_names(images["image2d"]) == ["r", "g", "b"] assert get_channel_names(images["image2d_multiscale"]) == ["r", "g", "b"] - images.write(tmpdir) + images.write(tmpdir, sdata_formats=sdata_container_format) sdata = SpatialData.read(tmpdir) assert_spatial_data_objects_are_identical(images, sdata) - def test_labels(self, tmp_path: str, labels: SpatialData) -> None: + def test_labels( + self, + tmp_path: str, + labels: SpatialData, + sdata_container_format: SpatialDataContainerFormatType, + ) -> None: tmpdir = Path(tmp_path) / "tmp.zarr" - labels.write(tmpdir) + labels.write(tmpdir, sdata_formats=sdata_container_format) sdata = SpatialData.read(tmpdir) assert_spatial_data_objects_are_identical(labels, sdata) - def test_shapes(self, tmp_path: str, shapes: SpatialData) -> None: + def test_shapes( + self, + tmp_path: str, + shapes: SpatialData, + sdata_container_format: SpatialDataContainerFormatType, + ) -> None: tmpdir = Path(tmp_path) / "tmp.zarr" # check the index is correctly written and then read shapes["circles"].index = np.arange(1, len(shapes["circles"]) + 1) - shapes.write(tmpdir) + shapes.write(tmpdir, sdata_formats=sdata_container_format) sdata = SpatialData.read(tmpdir) assert_spatial_data_objects_are_identical(shapes, sdata) - def test_points(self, tmp_path: str, points: SpatialData) -> None: + def test_points( + self, + tmp_path: str, + points: SpatialData, + sdata_container_format: SpatialDataContainerFormatType, + ) -> None: tmpdir = Path(tmp_path) / "tmp.zarr" # check the index is correctly written and then read new_index = dd.from_array(np.arange(1, len(points["points_0"]) + 1)) points["points_0"] = points["points_0"].set_index(new_index) - points.write(tmpdir) + points.write(tmpdir, sdata_formats=sdata_container_format) sdata = SpatialData.read(tmpdir) assert_spatial_data_objects_are_identical(points, sdata) - def _test_table(self, tmp_path: str, table: SpatialData) -> None: + def _test_table( + self, + tmp_path: str, + table: SpatialData, + sdata_container_format: SpatialDataContainerFormatType, + ) -> None: tmpdir = Path(tmp_path) / "tmp.zarr" - table.write(tmpdir) + table.write(tmpdir, sdata_formats=sdata_container_format) sdata = SpatialData.read(tmpdir) assert_spatial_data_objects_are_identical(table, sdata) - def test_single_table_single_annotation(self, tmp_path: str, table_single_annotation: SpatialData) -> None: - self._test_table(tmp_path, table_single_annotation) + def test_single_table_single_annotation( + self, + tmp_path: str, + table_single_annotation: SpatialData, + sdata_container_format: SpatialDataContainerFormatType, + ) -> None: + self._test_table( + tmp_path, + table_single_annotation, + sdata_container_format=sdata_container_format, + ) - def test_single_table_multiple_annotations(self, tmp_path: str, table_multiple_annotations: SpatialData) -> None: - self._test_table(tmp_path, table_multiple_annotations) + def test_single_table_multiple_annotations( + self, + tmp_path: str, + table_multiple_annotations: SpatialData, + sdata_container_format: SpatialDataContainerFormatType, + ) -> None: + self._test_table( + tmp_path, + table_multiple_annotations, + sdata_container_format=sdata_container_format, + ) - def test_multiple_tables(self, tmp_path: str, tables: list[AnnData]) -> None: + def test_multiple_tables( + self, + tmp_path: str, + tables: list[AnnData], + sdata_container_format: SpatialDataContainerFormatType, + ) -> None: sdata_tables = SpatialData(tables={str(i): tables[i] for i in range(len(tables))}) - self._test_table(tmp_path, sdata_tables) + self._test_table(tmp_path, sdata_tables, sdata_container_format=sdata_container_format) def test_roundtrip( self, tmp_path: str, sdata: SpatialData, + sdata_container_format: SpatialDataContainerFormatType, ) -> None: tmpdir = Path(tmp_path) / "tmp.zarr" - sdata.write(tmpdir) + sdata.write(tmpdir, sdata_formats=sdata_container_format) sdata2 = SpatialData.read(tmpdir) tmpdir2 = Path(tmp_path) / "tmp2.zarr" - sdata2.write(tmpdir2) + sdata2.write(tmpdir2, sdata_formats=sdata_container_format) _are_directories_identical(tmpdir, tmpdir2, exclude_regexp="[1-9][0-9]*.*") - def test_incremental_io_in_memory( + def test_incremental_io_list_of_elements( self, - full_sdata: SpatialData, + shapes: SpatialData, + sdata_container_format: SpatialDataContainerFormatType, ) -> None: - sdata = full_sdata - - for k, v in _get_images().items(): - sdata.images[f"additional_{k}"] = v - with pytest.raises(KeyError, match="Key `table` is not unique"): - sdata["table"] = v - - for k, v in _get_labels().items(): - sdata.labels[f"additional_{k}"] = v - with pytest.raises(KeyError, match="Key `table` is not unique"): - sdata["table"] = v - - for k, v in _get_shapes().items(): - sdata.shapes[f"additional_{k}"] = v - with pytest.raises(KeyError, match="Key `table` is not unique"): - sdata["table"] = v - - for k, v in _get_points().items(): - sdata.points[f"additional_{k}"] = v - with pytest.raises(KeyError, match="Key `table` is not unique"): - sdata["table"] = v - - for k, v in _get_tables().items(): - sdata.tables[f"additional_{k}"] = v - with pytest.raises(KeyError, match="Key `poly` is not unique"): - sdata["poly"] = v - - def test_incremental_io_list_of_elements(self, shapes: SpatialData) -> None: with tempfile.TemporaryDirectory() as tmpdir: f = os.path.join(tmpdir, "data.zarr") - shapes.write(f) + shapes.write(f, sdata_formats=sdata_container_format) new_shapes0 = deepcopy(shapes["circles"]) new_shapes1 = deepcopy(shapes["poly"]) shapes["new_shapes0"] = new_shapes0 @@ -140,7 +177,7 @@ def test_incremental_io_list_of_elements(self, shapes: SpatialData) -> None: assert "shapes/new_shapes0" not in shapes.elements_paths_on_disk() assert "shapes/new_shapes1" not in shapes.elements_paths_on_disk() - shapes.write_element(["new_shapes0", "new_shapes1"]) + shapes.write_element(["new_shapes0", "new_shapes1"], sdata_formats=sdata_container_format) assert "shapes/new_shapes0" in shapes.elements_paths_on_disk() assert "shapes/new_shapes1" in shapes.elements_paths_on_disk() @@ -148,10 +185,59 @@ def test_incremental_io_list_of_elements(self, shapes: SpatialData) -> None: assert "shapes/new_shapes0" not in shapes.elements_paths_on_disk() assert "shapes/new_shapes1" not in shapes.elements_paths_on_disk() + @staticmethod + def _workaround1_non_dask_backed( + sdata: SpatialData, + name: str, + new_name: str, + sdata_container_format: SpatialDataContainerFormatType = CurrentSpatialDataContainerFormat(), + ) -> None: + # a. write a backup copy of the data + sdata[new_name] = sdata[name] + sdata.write_element(new_name, sdata_formats=sdata_container_format) + # b. rewrite the original data + sdata.delete_element_from_disk(name) + sdata.write_element(name, sdata_formats=sdata_container_format) + # c. remove the backup copy + del sdata[new_name] + sdata.delete_element_from_disk(new_name) + + @staticmethod + def _workaround1_dask_backed( + sdata: SpatialData, + name: str, + new_name: str, + sdata_container_format: SpatialDataContainerFormatType = CurrentSpatialDataContainerFormat(), + ) -> None: + # a. write a backup copy of the data + sdata[new_name] = sdata[name] + sdata.write_element(new_name, sdata_formats=sdata_container_format) + # a2. remove the in-memory copy from the SpatialData object (note, + # at this point the backup copy still exists on-disk) + del sdata[new_name] + del sdata[name] + # a3 load the backup copy into memory + sdata_copy = read_zarr(sdata.path) + # b1. rewrite the original data + sdata.delete_element_from_disk(name) + sdata[name] = sdata_copy[new_name] + sdata.write_element(name, sdata_formats=sdata_container_format) + # b2. reload the new data into memory (because it has been written but in-memory it still points + # from the backup location) + sdata = read_zarr(sdata.path) + # c. remove the backup copy + del sdata[new_name] + sdata.delete_element_from_disk(new_name) + @pytest.mark.parametrize("dask_backed", [True, False]) @pytest.mark.parametrize("workaround", [1, 2]) def test_incremental_io_on_disk( - self, tmp_path: str, full_sdata: SpatialData, dask_backed: bool, workaround: int + self, + tmp_path: str, + full_sdata: SpatialData, + dask_backed: bool, + workaround: int, + sdata_container_format: SpatialDataContainerFormatType, ) -> None: """ This tests shows workaround on how to rewrite existing data on disk. @@ -164,7 +250,7 @@ def test_incremental_io_on_disk( """ tmpdir = Path(tmp_path) / "incremental_io.zarr" sdata = SpatialData() - sdata.write(tmpdir) + sdata.write(tmpdir, sdata_formats=sdata_container_format) for name in [ "image2d", @@ -176,19 +262,20 @@ def test_incremental_io_on_disk( "table", ]: sdata[name] = full_sdata[name] - sdata.write_element(name) + sdata.write_element(name, sdata_formats=sdata_container_format) if dask_backed: # this forces the element to write to be dask-backed from disk. In this case, overwriting the data is # more laborious because we are writing the data to the same location that defines the data! sdata = read_zarr(sdata.path) with pytest.raises( - ValueError, match="The Zarr store already exists. Use `overwrite=True` to try overwriting the store." + ValueError, + match="The Zarr store already exists. Use `overwrite=True` to try overwriting the store.", ): - sdata.write_element(name) + sdata.write_element(name, sdata_formats=sdata_container_format) with pytest.raises(ValueError, match="Cannot overwrite."): - sdata.write_element(name, overwrite=True) + sdata.write_element(name, overwrite=True, sdata_formats=sdata_container_format) if workaround == 1: new_name = f"{name}_new_place" @@ -196,35 +283,19 @@ def test_incremental_io_on_disk( # setups, ...). If the scenario matches your use case, please use with caution. if not dask_backed: # easier case - # a. write a backup copy of the data - sdata[new_name] = sdata[name] - sdata.write_element(new_name) - # b. rewrite the original data - sdata.delete_element_from_disk(name) - sdata.write_element(name) - # c. remove the backup copy - del sdata[new_name] - sdata.delete_element_from_disk(new_name) + self._workaround1_non_dask_backed( + sdata=sdata, + name=name, + new_name=new_name, + sdata_container_format=sdata_container_format, + ) else: # dask-backed case, more complex - # a. write a backup copy of the data - sdata[new_name] = sdata[name] - sdata.write_element(new_name) - # a2. remove the in-memory copy from the SpatialData object (note, - # at this point the backup copy still exists on-disk) - del sdata[new_name] - del sdata[name] - # a3 load the backup copy into memory - sdata_copy = read_zarr(sdata.path) - # b1. rewrite the original data - sdata.delete_element_from_disk(name) - sdata[name] = sdata_copy[new_name] - sdata.write_element(name) - # b2. reload the new data into memory (because it has been written but in-memory it still points - # from the backup location) - sdata = read_zarr(sdata.path) - # c. remove the backup copy - del sdata[new_name] - sdata.delete_element_from_disk(new_name) + self._workaround1_dask_backed( + sdata=sdata, + name=name, + new_name=new_name, + sdata_container_format=sdata_container_format, + ) elif workaround == 2: # workaround 2, unsafe but sometimes acceptable depending on the user's workflow. @@ -233,18 +304,18 @@ def test_incremental_io_on_disk( if not dask_backed: # a. rewrite the original data (risky!) sdata.delete_element_from_disk(name) - sdata.write_element(name) + sdata.write_element(name, sdata_formats=sdata_container_format) - def test_io_and_lazy_loading_points(self, points): + def test_io_and_lazy_loading_points(self, points, sdata_container_format: SpatialDataContainerFormatType): with tempfile.TemporaryDirectory() as td: f = os.path.join(td, "data.zarr") - points.write(f) + points.write(f, sdata_formats=sdata_container_format) assert len(get_dask_backing_files(points)) == 0 sdata2 = SpatialData.read(f) assert len(get_dask_backing_files(sdata2)) > 0 - def test_io_and_lazy_loading_raster(self, images, labels): + def test_io_and_lazy_loading_raster(self, images, labels, sdata_container_format: SpatialDataContainerFormatType): sdatas = {"images": images, "labels": labels} for k, sdata in sdatas.items(): d = getattr(sdata, k) @@ -252,7 +323,7 @@ def test_io_and_lazy_loading_raster(self, images, labels): with tempfile.TemporaryDirectory() as td: f = os.path.join(td, "data.zarr") dask0 = sdata[elem_name].data - sdata.write(f) + sdata.write(f, sdata_formats=sdata_container_format) assert all("from-zarr" not in key for key in dask0.dask.layers) assert len(get_dask_backing_files(sdata)) == 0 @@ -261,7 +332,9 @@ def test_io_and_lazy_loading_raster(self, images, labels): assert any("from-zarr" in key for key in dask1.dask.layers) assert len(get_dask_backing_files(sdata2)) > 0 - def test_replace_transformation_on_disk_raster(self, images, labels): + def test_replace_transformation_on_disk_raster( + self, images, labels, sdata_container_format: SpatialDataContainerFormatType + ): sdatas = {"images": images, "labels": labels} for k, sdata in sdatas.items(): d = getattr(sdata, k) @@ -271,7 +344,7 @@ def test_replace_transformation_on_disk_raster(self, images, labels): single_sdata = SpatialData(**kwargs) with tempfile.TemporaryDirectory() as td: f = os.path.join(td, "data.zarr") - single_sdata.write(f) + single_sdata.write(f, sdata_formats=sdata_container_format) t0 = get_transformation(SpatialData.read(f)[elem_name]) assert isinstance(t0, Identity) set_transformation( @@ -282,37 +355,48 @@ def test_replace_transformation_on_disk_raster(self, images, labels): t1 = get_transformation(SpatialData.read(f)[elem_name]) assert isinstance(t1, Scale) - def test_replace_transformation_on_disk_non_raster(self, shapes, points): + def test_replace_transformation_on_disk_non_raster( + self, shapes, points, sdata_container_format: SpatialDataContainerFormatType + ): sdatas = {"shapes": shapes, "points": points} for k, sdata in sdatas.items(): d = sdata.__getattribute__(k) elem_name = list(d.keys())[0] with tempfile.TemporaryDirectory() as td: f = os.path.join(td, "data.zarr") - sdata.write(f) + sdata.write(f, sdata_formats=sdata_container_format) t0 = get_transformation(SpatialData.read(f).__getattribute__(k)[elem_name]) assert isinstance(t0, Identity) set_transformation(sdata[elem_name], Scale([2.0], axes=("x",)), write_to_sdata=sdata) t1 = get_transformation(SpatialData.read(f)[elem_name]) assert isinstance(t1, Scale) - def test_overwrite_works_when_no_zarr_store(self, full_sdata): + def test_overwrite_works_when_no_zarr_store( + self, full_sdata, sdata_container_format: SpatialDataContainerFormatType + ): with tempfile.TemporaryDirectory() as tmpdir: f = os.path.join(tmpdir, "data.zarr") old_data = SpatialData() - old_data.write(f) + old_data.write(f, sdata_formats=sdata_container_format) # Since no, no risk of overwriting backing data. # Should not raise "The file path specified is the same as the one used for backing." - full_sdata.write(f, overwrite=True) + full_sdata.write(f, overwrite=True, sdata_formats=sdata_container_format) - def test_overwrite_fails_when_no_zarr_store_bug_dask_backed_data(self, full_sdata, points, images, labels): + def test_overwrite_fails_when_no_zarr_store_bug_dask_backed_data( + self, + full_sdata, + points, + images, + labels, + sdata_container_format: SpatialDataContainerFormatType, + ): sdatas = {"images": images, "labels": labels, "points": points} elements = {"images": "image2d", "labels": "labels2d", "points": "points_0"} for k, sdata in sdatas.items(): element = elements[k] with tempfile.TemporaryDirectory() as tmpdir: f = os.path.join(tmpdir, "data.zarr") - sdata.write(f) + sdata.write(f, sdata_formats=sdata_container_format) # now we have a sdata with dask-backed elements sdata2 = SpatialData.read(f) @@ -322,31 +406,33 @@ def test_overwrite_fails_when_no_zarr_store_bug_dask_backed_data(self, full_sdat ValueError, match="The Zarr store already exists. Use `overwrite=True` to try overwriting the store.", ): - full_sdata.write(f) + full_sdata.write(f, sdata_formats=sdata_container_format) with pytest.raises( ValueError, match="Cannot overwrite.", ): - full_sdata.write(f, overwrite=True) + full_sdata.write(f, overwrite=True, sdata_formats=sdata_container_format) - def test_overwrite_fails_when_zarr_store_present(self, full_sdata): + def test_overwrite_fails_when_zarr_store_present( + self, full_sdata, sdata_container_format: SpatialDataContainerFormatType + ): # addressing https://github.com/scverse/spatialdata/issues/137 with tempfile.TemporaryDirectory() as tmpdir: f = os.path.join(tmpdir, "data.zarr") - full_sdata.write(f) + full_sdata.write(f, sdata_formats=sdata_container_format) with pytest.raises( ValueError, match="The Zarr store already exists. Use `overwrite=True` to try overwriting the store.", ): - full_sdata.write(f) + full_sdata.write(f, sdata_formats=sdata_container_format) with pytest.raises( ValueError, match="Cannot overwrite.", ): - full_sdata.write(f, overwrite=True) + full_sdata.write(f, overwrite=True, sdata_formats=sdata_container_format) # support for overwriting backed sdata has been temporarily removed # with tempfile.TemporaryDirectory() as tmpdir: @@ -364,7 +450,9 @@ def test_overwrite_fails_when_zarr_store_present(self, full_sdata): # ) # sdata2.write(f, overwrite=True) - def test_overwrite_fails_onto_non_zarr_file(self, full_sdata): + def test_overwrite_fails_onto_non_zarr_file( + self, full_sdata, sdata_container_format: SpatialDataContainerFormatType + ): ERROR_MESSAGE = ( "The target file path specified already exists, and it has been detected to not be a Zarr store." ) @@ -375,18 +463,49 @@ def test_overwrite_fails_onto_non_zarr_file(self, full_sdata): ValueError, match=ERROR_MESSAGE, ): - full_sdata.write(f0) + full_sdata.write(f0, sdata_formats=sdata_container_format) with pytest.raises( ValueError, match=ERROR_MESSAGE, ): - full_sdata.write(f0, overwrite=True) + full_sdata.write(f0, overwrite=True, sdata_formats=sdata_container_format) f1 = os.path.join(tmpdir, "test.zarr") os.mkdir(f1) with pytest.raises(ValueError, match=ERROR_MESSAGE): - full_sdata.write(f1) + full_sdata.write(f1, sdata_formats=sdata_container_format) with pytest.raises(ValueError, match=ERROR_MESSAGE): - full_sdata.write(f1, overwrite=True) + full_sdata.write(f1, overwrite=True, sdata_formats=sdata_container_format) + + +def test_incremental_io_in_memory( + full_sdata: SpatialData, +) -> None: + sdata = full_sdata + + for k, v in _get_images().items(): + sdata.images[f"additional_{k}"] = v + with pytest.raises(KeyError, match="Key `table` is not unique"): + sdata["table"] = v + + for k, v in _get_labels().items(): + sdata.labels[f"additional_{k}"] = v + with pytest.raises(KeyError, match="Key `table` is not unique"): + sdata["table"] = v + + for k, v in _get_shapes().items(): + sdata.shapes[f"additional_{k}"] = v + with pytest.raises(KeyError, match="Key `table` is not unique"): + sdata["table"] = v + + for k, v in _get_points().items(): + sdata.points[f"additional_{k}"] = v + with pytest.raises(KeyError, match="Key `table` is not unique"): + sdata["table"] = v + + for k, v in _get_tables().items(): + sdata.tables[f"additional_{k}"] = v + with pytest.raises(KeyError, match="Key `poly` is not unique"): + sdata["poly"] = v def test_bug_rechunking_after_queried_raster(): @@ -397,14 +516,18 @@ def test_bug_rechunking_after_queried_raster(): images = {"single_scale": single_scale, "multi_scale": multi_scale} sdata = SpatialData(images=images) queried = sdata.query.bounding_box( - axes=("x", "y"), min_coordinate=[2, 5], max_coordinate=[12, 12], target_coordinate_system="global" + axes=("x", "y"), + min_coordinate=[2, 5], + max_coordinate=[12, 12], + target_coordinate_system="global", ) with tempfile.TemporaryDirectory() as tmpdir: f = os.path.join(tmpdir, "data.zarr") queried.write(f) -def test_self_contained(full_sdata: SpatialData) -> None: +@pytest.mark.parametrize("sdata_container_format", SDATA_FORMATS) +def test_self_contained(full_sdata: SpatialData, sdata_container_format: SpatialDataContainerFormatType) -> None: # data only in-memory, so the SpatialData object and all its elements are self-contained assert full_sdata.is_self_contained() description = full_sdata.elements_are_self_contained() @@ -413,7 +536,7 @@ def test_self_contained(full_sdata: SpatialData) -> None: with tempfile.TemporaryDirectory() as tmpdir: # data saved to disk, it's self contained f = os.path.join(tmpdir, "data.zarr") - full_sdata.write(f) + full_sdata.write(f, sdata_formats=sdata_container_format) full_sdata.is_self_contained() # we read the data, so it's self-contained @@ -422,7 +545,7 @@ def test_self_contained(full_sdata: SpatialData) -> None: # we save the data to a new location, so it's not self-contained anymore f2 = os.path.join(tmpdir, "data2.zarr") - sdata2.write(f2) + sdata2.write(f2, sdata_formats=sdata_container_format) assert not sdata2.is_self_contained() # because of the images, labels and points @@ -464,10 +587,13 @@ def test_self_contained(full_sdata: SpatialData) -> None: assert all(description[element_name] for element_name in description if element_name != "combined") -def test_symmetric_different_with_zarr_store(full_sdata: SpatialData) -> None: +@pytest.mark.parametrize("sdata_container_format", SDATA_FORMATS) +def test_symmetric_difference_with_zarr_store( + full_sdata: SpatialData, sdata_container_format: SpatialDataContainerFormatType +) -> None: with tempfile.TemporaryDirectory() as tmpdir: f = os.path.join(tmpdir, "data.zarr") - full_sdata.write(f) + full_sdata.write(f, sdata_formats=sdata_container_format) # the list of element on-disk and in-memory is the same only_in_memory, only_on_disk = full_sdata._symmetric_difference_with_zarr_store() @@ -503,11 +629,12 @@ def test_symmetric_different_with_zarr_store(full_sdata: SpatialData) -> None: } -def test_change_path_of_subset(full_sdata: SpatialData) -> None: +@pytest.mark.parametrize("sdata_container_format", SDATA_FORMATS) +def test_change_path_of_subset(full_sdata: SpatialData, sdata_container_format: SpatialDataContainerFormatType) -> None: """A subset SpatialData object has not Zarr path associated, show that we can reassign the path""" with tempfile.TemporaryDirectory() as tmpdir: f = os.path.join(tmpdir, "data.zarr") - full_sdata.write(f) + full_sdata.write(f, sdata_formats=sdata_container_format) subset = full_sdata.subset(["image2d", "labels2d", "points_0", "circles", "table"]) @@ -520,7 +647,7 @@ def test_change_path_of_subset(full_sdata: SpatialData) -> None: assert len(only_on_disk) > 0 f2 = os.path.join(tmpdir, "data2.zarr") - subset.write(f2) + subset.write(f2, sdata_formats=sdata_container_format) assert subset.is_self_contained() only_in_memory, only_on_disk = subset._symmetric_difference_with_zarr_store() assert len(only_in_memory) == 0 @@ -551,15 +678,18 @@ def _check_valid_name(f: Callable[[str], Any]) -> None: with pytest.raises(ValueError, match="Name cannot start with '__'"): f("__a") with pytest.raises( - ValueError, match="Name must contain only alphanumeric characters, underscores, dots and hyphens." + ValueError, + match="Name must contain only alphanumeric characters, underscores, dots and hyphens.", ): f("has whitespace") with pytest.raises( - ValueError, match="Name must contain only alphanumeric characters, underscores, dots and hyphens." + ValueError, + match="Name must contain only alphanumeric characters, underscores, dots and hyphens.", ): f("this/is/not/valid") with pytest.raises( - ValueError, match="Name must contain only alphanumeric characters, underscores, dots and hyphens." + ValueError, + match="Name must contain only alphanumeric characters, underscores, dots and hyphens.", ): f("non-alnum_#$%&()*+,?@") @@ -570,12 +700,13 @@ def test_incremental_io_valid_name(full_sdata: SpatialData) -> None: _check_valid_name(full_sdata.write_transformations) -def test_incremental_io_attrs(points: SpatialData) -> None: +@pytest.mark.parametrize("sdata_container_format", SDATA_FORMATS) +def test_incremental_io_attrs(points: SpatialData, sdata_container_format: SpatialDataContainerFormatType) -> None: with tempfile.TemporaryDirectory() as tmpdir: f = os.path.join(tmpdir, "data.zarr") my_attrs = {"a": "b", "c": 1} points.attrs = my_attrs - points.write(f) + points.write(f, sdata_formats=sdata_container_format) # test that the attributes are written to disk sdata = SpatialData.read(f) @@ -583,13 +714,13 @@ def test_incremental_io_attrs(points: SpatialData) -> None: # test incremental io attrs (write_attrs()) sdata.attrs["c"] = 2 - sdata.write_attrs() + sdata.write_attrs(format=sdata_container_format) sdata2 = SpatialData.read(f) assert sdata2.attrs["c"] == 2 # test incremental io attrs (write_metadata()) sdata.attrs["c"] = 3 - sdata.write_metadata() + sdata.write_metadata(format=sdata_container_format) sdata2 = SpatialData.read(f) assert sdata2.attrs["c"] == 3 @@ -599,19 +730,24 @@ def test_incremental_io_attrs(points: SpatialData) -> None: # TODO: make consolidated metadata open cleaner @pytest.mark.parametrize("element_name", ["image2d", "labels2d", "points_0", "circles", "table"]) -def test_delete_element_from_disk(full_sdata, element_name: str) -> None: +@pytest.mark.parametrize("sdata_container_format", SDATA_FORMATS) +def test_delete_element_from_disk( + full_sdata, + element_name: str, + sdata_container_format: SpatialDataContainerFormatType, +) -> None: # can't delete an element for a SpatialData object without associated Zarr store with pytest.raises(ValueError, match="The SpatialData object is not backed by a Zarr store."): full_sdata.delete_element_from_disk("image2d") with tempfile.TemporaryDirectory() as tmpdir: f = os.path.join(tmpdir, "data.zarr") - full_sdata.write(f) + full_sdata.write(f, sdata_formats=sdata_container_format) # cannot delete an element which is in-memory, but not in the Zarr store subset = full_sdata.subset(["points_0_1"]) f2 = os.path.join(tmpdir, "data2.zarr") - subset.write(f2) + subset.write(f2, sdata_formats=sdata_container_format) full_sdata.path = Path(f2) with pytest.raises( ValueError, @@ -635,7 +771,7 @@ def test_delete_element_from_disk(full_sdata, element_name: str) -> None: assert element_path in only_in_memory # resave it - full_sdata.write_element(element_name) + full_sdata.write_element(element_name, sdata_formats=sdata_container_format) # now delete it from memory, and then show it can still be deleted on-disk del getattr(full_sdata, element_type)[element_name] @@ -645,14 +781,19 @@ def test_delete_element_from_disk(full_sdata, element_name: str) -> None: @pytest.mark.parametrize("element_name", ["image2d", "labels2d", "points_0", "circles", "table"]) -def test_element_already_on_disk_different_type(full_sdata, element_name: str) -> None: +@pytest.mark.parametrize("sdata_container_format", SDATA_FORMATS) +def test_element_already_on_disk_different_type( + full_sdata, + element_name: str, + sdata_container_format: SpatialDataContainerFormatType, +) -> None: # Constructing a corrupted object (element present both on disk and in-memory but with different type). # Attempting to perform and IO operation will trigger an error. # The checks assessed in this test will not be needed anymore after # https://github.com/scverse/spatialdata/issues/504 is addressed with tempfile.TemporaryDirectory() as tmpdir: f = os.path.join(tmpdir, "data.zarr") - full_sdata.write(f) + full_sdata.write(f, sdata_formats=sdata_container_format) element_type = full_sdata._element_type_from_element_name(element_name) wrong_group = "images" if element_type == "tables" else "tables" @@ -672,13 +813,13 @@ def test_element_already_on_disk_different_type(full_sdata, element_name: str) - ValueError, match=ERROR_MSG, ): - full_sdata.write_element(element_name) + full_sdata.write_element(element_name, sdata_formats=sdata_container_format) with pytest.raises( ValueError, match=ERROR_MSG, ): - full_sdata.write_metadata(element_name) + full_sdata.write_metadata(element_name, format=sdata_container_format) with pytest.raises( ValueError, @@ -772,10 +913,11 @@ def test_reading_invalid_name(tmp_path: Path): ) -def test_write_store_unconsolidated_and_read(full_sdata): +@pytest.mark.parametrize("sdata_container_format", SDATA_FORMATS) +def test_write_store_unconsolidated_and_read(full_sdata, sdata_container_format: SpatialDataContainerFormatType): with tempfile.TemporaryDirectory() as tmpdir: path = Path(tmpdir) / "data.zarr" - full_sdata.write(path, consolidate_metadata=False) + full_sdata.write(path, consolidate_metadata=False, sdata_formats=sdata_container_format) group = zarr.open_group(path, mode="r") assert group.metadata.consolidated_metadata is None @@ -783,12 +925,16 @@ def test_write_store_unconsolidated_and_read(full_sdata): assert_spatial_data_objects_are_identical(full_sdata, second_read) -def test_can_read_sdata_with_reconsolidation(full_sdata): +@pytest.mark.parametrize("sdata_container_format", SDATA_FORMATS) +def test_can_read_sdata_with_reconsolidation(full_sdata, sdata_container_format: SpatialDataContainerFormatType): with tempfile.TemporaryDirectory() as tmpdir: path = Path(tmpdir) / "data.zarr" - full_sdata.write(path) + full_sdata.write(path, sdata_formats=sdata_container_format) - json_path = path / "zarr.json" + if isinstance(sdata_container_format, SpatialDataContainerFormatV01): + json_path = path / "zarr.json" + else: + json_path = path / ".zattrs" json_dict = json.loads(json_path.read_text()) del json_dict["consolidated_metadata"]["metadata"]["images/image2d"] json_path.write_text(json.dumps(json_dict, indent=4)) From 21e57941e617e274ca73afec79f6bf53b1529cad Mon Sep 17 00:00:00 2001 From: Wouter-Michiel Vierdag Date: Wed, 10 Sep 2025 17:39:19 +0200 Subject: [PATCH 101/126] return None instead of AnnData --- src/spatialdata/models/_utils.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/spatialdata/models/_utils.py b/src/spatialdata/models/_utils.py index db5ebac9..45b2dd4c 100644 --- a/src/spatialdata/models/_utils.py +++ b/src/spatialdata/models/_utils.py @@ -364,7 +364,7 @@ def get_raster_model_from_data_dims(dims: tuple[str, ...]) -> type[RasterSchema] return Labels3DModel if Z in dims else Labels2DModel -def convert_region_column_to_categorical(table: AnnData) -> AnnData: +def convert_region_column_to_categorical(table: AnnData) -> None: from spatialdata.models.models import TableModel if TableModel.ATTRS_KEY in table.uns: @@ -376,7 +376,6 @@ def convert_region_column_to_categorical(table: AnnData) -> AnnData: stacklevel=2, ) table.obs[region_key] = pd.Categorical(table.obs[region_key]) - return table def set_channel_names(element: DataArray | DataTree, channel_names: str | list[str]) -> DataArray | DataTree: From 46c07535e8747a117742370dd29921b372c1b1d6 Mon Sep 17 00:00:00 2001 From: Wouter-Michiel Vierdag Date: Wed, 10 Sep 2025 17:41:29 +0200 Subject: [PATCH 102/126] remove TODO --- tests/io/test_readwrite.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/io/test_readwrite.py b/tests/io/test_readwrite.py index 1288496d..a838b531 100644 --- a/tests/io/test_readwrite.py +++ b/tests/io/test_readwrite.py @@ -597,7 +597,6 @@ def test_incremental_io_attrs(points: SpatialData) -> None: cached_sdata_blobs = blobs() -# TODO: make consolidated metadata open cleaner @pytest.mark.parametrize("element_name", ["image2d", "labels2d", "points_0", "circles", "table"]) def test_delete_element_from_disk(full_sdata, element_name: str) -> None: # can't delete an element for a SpatialData object without associated Zarr store From 9c3914e3557ed060090d4798d5852e020fd69faa Mon Sep 17 00:00:00 2001 From: Wouter-Michiel Vierdag Date: Wed, 10 Sep 2025 17:47:53 +0200 Subject: [PATCH 103/126] remove invalid characters from test --- tests/io/test_readwrite.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/tests/io/test_readwrite.py b/tests/io/test_readwrite.py index a838b531..ae9b7edf 100644 --- a/tests/io/test_readwrite.py +++ b/tests/io/test_readwrite.py @@ -753,8 +753,7 @@ def test_reading_invalid_name(tmp_path: Path): # Circumvent validation at construction time and check validation happens again at writing time. (tmp_path / "data.zarr/points" / points_name).rename(tmp_path / "data.zarr/points" / "has whitespace") # This one is not allowed on windows - if os.name != "nt": - (tmp_path / "data.zarr/shapes" / shapes_name).rename(tmp_path / "data.zarr/shapes" / "non-alnum_#$%&()*+,?@") + (tmp_path / "data.zarr/shapes" / shapes_name).rename(tmp_path / "data.zarr/shapes" / "non-alnum_#$%&()+,@") # We do this as the key of the element is otherwise not in the consolidated metadata, leading to an error. valid_sdata.write_consolidated_metadata() @@ -763,8 +762,7 @@ def test_reading_invalid_name(tmp_path: Path): actual_message = str(exc_info.value) assert "points/has whitespace" in actual_message - if os.name != "nt": - assert "shapes/non-alnum_#$%&()*+,?@" in actual_message + assert "shapes/non-alnum_#$%&()+,@" in actual_message assert ( "For renaming, please see the discussion here https://github.com/scverse/spatialdata/discussions/707" in actual_message From 0f230b7269a399d12f223884d9ac5a2168fc31cb Mon Sep 17 00:00:00 2001 From: Wouter-Michiel Vierdag Date: Wed, 10 Sep 2025 17:51:04 +0200 Subject: [PATCH 104/126] remove unused fixture and commented code --- tests/conftest.py | 23 ----------------------- 1 file changed, 23 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index deca11f3..d4ad346c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -84,11 +84,6 @@ def tables() -> list[AnnData]: return _tables -@pytest.fixture # (params=['images', 'labels', 'shapes', 'points', 'tables']) -def corrupted_sdata(request): - return request.getfixturevalue(request.param) - - @pytest.fixture() def full_sdata() -> SpatialData: return SpatialData( @@ -100,24 +95,6 @@ def full_sdata() -> SpatialData: ) -# @pytest.fixture() -# def empty_points() -> SpatialData: -# geo_df = GeoDataFrame( -# geometry=[], -# ) -# from spatialdata import NgffIdentity -# _set_transformations(geo_df, NgffIdentity()) -# -# return SpatialData(points={"empty": geo_df}) - - -# @pytest.fixture() -# def empty_table() -> SpatialData: -# adata = AnnData(shape=(0, 0), obs=pd.DataFrame(columns="region"), var=pd.DataFrame()) -# adata = TableModel.parse(adata=adata) -# return SpatialData(table=adata) - - @pytest.fixture( # params=["labels"] params=["full", "empty"] + ["images", "labels", "points", "table_single_annotation", "table_multiple_annotations"] From 5f5b8d729c093ebfba8676bd9f4465153dc4e974 Mon Sep 17 00:00:00 2001 From: Luca Marconato Date: Wed, 10 Sep 2025 17:57:31 +0200 Subject: [PATCH 105/126] almost completed extending readwrite tests to all container versions --- tests/io/test_readwrite.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/tests/io/test_readwrite.py b/tests/io/test_readwrite.py index 7dc1b59a..cc8fea70 100644 --- a/tests/io/test_readwrite.py +++ b/tests/io/test_readwrite.py @@ -932,11 +932,14 @@ def test_can_read_sdata_with_reconsolidation(full_sdata, sdata_container_format: full_sdata.write(path, sdata_formats=sdata_container_format) if isinstance(sdata_container_format, SpatialDataContainerFormatV01): - json_path = path / "zarr.json" + json_path = path / ".zmetadata" + json_dict = json.loads(json_path.read_text()) + # TODO: this raises no exception! + del json_dict["metadata"]["images/image2d/0/.zattrs"] else: - json_path = path / ".zattrs" - json_dict = json.loads(json_path.read_text()) - del json_dict["consolidated_metadata"]["metadata"]["images/image2d"] + json_path = path / "zarr.json" + json_dict = json.loads(json_path.read_text()) + del json_dict["consolidated_metadata"]["metadata"]["images/image2d"] json_path.write_text(json.dumps(json_dict, indent=4)) with pytest.raises(GroupNotFoundError): From 02a5df7a0ed6db122febed45c5186d9ced35feb3 Mon Sep 17 00:00:00 2001 From: Wouter-Michiel Vierdag Date: Wed, 10 Sep 2025 23:23:03 +0200 Subject: [PATCH 106/126] add OSError --- src/spatialdata/_io/io_zarr.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/spatialdata/_io/io_zarr.py b/src/spatialdata/_io/io_zarr.py index 9747d2c4..0bf95dbf 100644 --- a/src/spatialdata/_io/io_zarr.py +++ b/src/spatialdata/_io/io_zarr.py @@ -90,6 +90,7 @@ def read_zarr( KeyError, # Missing JSON key ArrayNotFoundError, # Image chunks missing TypeError, # instead of ArrayNotFoundError, with dask>=2024.10.0 zarr<=2.18.3 + OSError, ), ): element = _read_multiscale(elem_group_path, raster_type="image") @@ -117,6 +118,7 @@ def read_zarr( ValueError, ArrayNotFoundError, TypeError, + OSError, ), ): labels[subgroup_name] = _read_multiscale(elem_group_path, raster_type="labels") From f10c9bfa712efb1d31c91f5b9f0cad8f6e649bee Mon Sep 17 00:00:00 2001 From: Wouter-Michiel Vierdag Date: Thu, 11 Sep 2025 00:12:08 +0200 Subject: [PATCH 107/126] partial fix writing empty spatialdata --- src/spatialdata/_core/spatialdata.py | 46 +++++++++++++++------------- 1 file changed, 25 insertions(+), 21 deletions(-) diff --git a/src/spatialdata/_core/spatialdata.py b/src/spatialdata/_core/spatialdata.py index ff35aed9..949f174f 100644 --- a/src/spatialdata/_core/spatialdata.py +++ b/src/spatialdata/_core/spatialdata.py @@ -51,7 +51,10 @@ if TYPE_CHECKING: from spatialdata._core.query.spatial_query import BaseSpatialRequest - from spatialdata._io.format import SpatialDataContainerFormatType, SpatialDataFormatType + from spatialdata._io.format import ( + SpatialDataContainerFormatType, + SpatialDataFormatType, + ) # schema for elements Label2D_s = Labels2DModel() @@ -1147,13 +1150,13 @@ def write( self._validate_can_safely_write_to_path(file_path, overwrite=overwrite) self._validate_all_elements() - # parse_url cannot be replaced here as it actually also initialized an ome-zarr store. - store = parse_url(file_path, mode="w", fmt=parsed["SpatialData"]).store - zarr_group = zarr.open_group(store=store, mode="w" if overwrite else "a") - self.write_attrs(zarr_group=zarr_group) - store.close() - - for element_type, element_name, element in self.gen_elements(): + for index, (element_type, element_name, element) in enumerate(self.gen_elements()): + if index == 0: + # parse_url cannot be replaced here as it actually also initialized an ome-zarr store. + store = parse_url(file_path, mode="w", fmt=parsed["SpatialData"]).store + zarr_group = zarr.open_group(store=store, mode="w" if overwrite else "a") + self.write_attrs(zarr_group=zarr_group, sdata_format=parsed["SpatialData"]) + store.close() self._write_element( element=element, zarr_container_path=file_path, @@ -1163,13 +1166,14 @@ def write( parsed_formats=parsed, ) - if self.path != file_path: - old_path = self.path - self.path = file_path - logger.info(f"The Zarr backing store has been changed from {old_path} the new file path: {file_path}") + if parse_url(file_path): + if self.path != file_path: + old_path = self.path + self.path = file_path + logger.info(f"The Zarr backing store has been changed from {old_path} the new file path: {file_path}") - if consolidate_metadata: - self.write_consolidated_metadata() + if consolidate_metadata: + self.write_consolidated_metadata() def _write_element( self, @@ -1602,17 +1606,17 @@ def _element_type_and_name_from_element_path(self, element_path: str) -> tuple[s element_type, element_name = element_path.split("/") return element_type, element_name + @_deprecation_alias(format="sdata_format", version="0.7.0") def write_attrs( self, - format: SpatialDataContainerFormatType | None = None, + sdata_format: SpatialDataContainerFormatType | None = None, zarr_group: zarr.Group | None = None, ) -> None: from spatialdata._io._utils import _resolve_zarr_store - from spatialdata._io.format import SpatialDataContainerFormatType, _parse_formats + from spatialdata._io.format import CurrentSpatialDataContainerFormat, SpatialDataContainerFormatType - parsed = _parse_formats(formats=format) - spatialdata_container_format = parsed["SpatialData"] - assert isinstance(spatialdata_container_format, SpatialDataContainerFormatType) + sdata_format = sdata_format if sdata_format is not None else CurrentSpatialDataContainerFormat + assert isinstance(sdata_format, SpatialDataContainerFormatType) store = None @@ -1621,8 +1625,8 @@ def write_attrs( store = _resolve_zarr_store(self.path) zarr_group = zarr.open_group(store=store, mode="r+") - version = spatialdata_container_format.spatialdata_format_version - version_specific_attrs = spatialdata_container_format.attrs_to_dict() + version = sdata_format.spatialdata_format_version + version_specific_attrs = sdata_format.attrs_to_dict() attrs_to_write = {"spatialdata_attrs": {"version": version} | version_specific_attrs} | self.attrs try: From e73126182c55c386cf9e88308b94f3e699e8c016 Mon Sep 17 00:00:00 2001 From: Wouter-Michiel Vierdag Date: Thu, 11 Sep 2025 08:25:21 +0200 Subject: [PATCH 108/126] fix type --- src/spatialdata/_core/spatialdata.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spatialdata/_core/spatialdata.py b/src/spatialdata/_core/spatialdata.py index 949f174f..2c80f4db 100644 --- a/src/spatialdata/_core/spatialdata.py +++ b/src/spatialdata/_core/spatialdata.py @@ -1615,7 +1615,7 @@ def write_attrs( from spatialdata._io._utils import _resolve_zarr_store from spatialdata._io.format import CurrentSpatialDataContainerFormat, SpatialDataContainerFormatType - sdata_format = sdata_format if sdata_format is not None else CurrentSpatialDataContainerFormat + sdata_format = sdata_format if sdata_format is not None else CurrentSpatialDataContainerFormat() assert isinstance(sdata_format, SpatialDataContainerFormatType) store = None From b431781d3b6e8e0bffa10933d1b1a4b3979d2927 Mon Sep 17 00:00:00 2001 From: Wouter-Michiel Vierdag Date: Thu, 11 Sep 2025 09:09:44 +0200 Subject: [PATCH 109/126] fix overwrite when no zarr store --- src/spatialdata/_core/spatialdata.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spatialdata/_core/spatialdata.py b/src/spatialdata/_core/spatialdata.py index 2c80f4db..d924204c 100644 --- a/src/spatialdata/_core/spatialdata.py +++ b/src/spatialdata/_core/spatialdata.py @@ -1154,7 +1154,7 @@ def write( if index == 0: # parse_url cannot be replaced here as it actually also initialized an ome-zarr store. store = parse_url(file_path, mode="w", fmt=parsed["SpatialData"]).store - zarr_group = zarr.open_group(store=store, mode="w" if overwrite else "a") + zarr_group = zarr.open_group(store=store, mode="r+" if overwrite else "a") self.write_attrs(zarr_group=zarr_group, sdata_format=parsed["SpatialData"]) store.close() self._write_element( From 4e33d980fca3c5e7e1984356f16505cfad02bc21 Mon Sep 17 00:00:00 2001 From: Wouter-Michiel Vierdag Date: Thu, 11 Sep 2025 10:19:44 +0200 Subject: [PATCH 110/126] fix write element to empty directory location --- src/spatialdata/_core/spatialdata.py | 31 ++++++++++++++++++++-------- 1 file changed, 22 insertions(+), 9 deletions(-) diff --git a/src/spatialdata/_core/spatialdata.py b/src/spatialdata/_core/spatialdata.py index d924204c..2ae4e43e 100644 --- a/src/spatialdata/_core/spatialdata.py +++ b/src/spatialdata/_core/spatialdata.py @@ -1115,7 +1115,7 @@ def write( overwrite: bool = False, consolidate_metadata: bool = True, sdata_formats: SpatialDataFormatType | list[SpatialDataFormatType] | None = None, - ) -> None: + ) -> tuple[Path | None, Path | None]: """ Write the `SpatialData` object to a Zarr store. @@ -1140,6 +1140,10 @@ def write( By default, the latest format is used for all elements, i.e. :class:`~spatialdata._io.format.CurrentRasterFormat`, :class:`~spatialdata._io.format.CurrentShapesFormat`, :class:`~spatialdata._io.format.CurrentPointsFormat`, :class:`~spatialdata._io.format.CurrentTablesFormat`. + + Returns + ------- + The old path and the new path of the SpatialData zarr store. """ from spatialdata._io.format import _parse_formats @@ -1166,14 +1170,20 @@ def write( parsed_formats=parsed, ) + old_path = self.path + if self.path != file_path: + self.path = file_path if parse_url(file_path): - if self.path != file_path: - old_path = self.path - self.path = file_path - logger.info(f"The Zarr backing store has been changed from {old_path} the new file path: {file_path}") - if consolidate_metadata: self.write_consolidated_metadata() + else: + warnings.warn( + "The SpatialData object is empty. Only the directory has been written, but it is nozarr store.", + UserWarning, + stacklevel=2, + ) + + return old_path, self.path def _write_element( self, @@ -1272,15 +1282,18 @@ def write_element( If you pass a list of names, the elements will be written one by one. If an error occurs during the writing of an element, the writing of the remaining elements will not be attempted. """ + from spatialdata._io.format import _parse_formats + + parsed_formats = _parse_formats(formats=sdata_formats) + if parse_url(self.path) is None: + store = parse_url(self.path, mode="w", fmt=parsed_formats["SpatialData"]).store + store.close() if isinstance(element_name, list): for name in element_name: assert isinstance(name, str) self.write_element(name, overwrite=overwrite) return - from spatialdata._io.format import _parse_formats - - parsed_formats = _parse_formats(formats=sdata_formats) check_valid_name(element_name) self._validate_element_names_are_unique() element = self.get(element_name) From 04aae9fcce50d8a57379106fd92aee2c1e25b693 Mon Sep 17 00:00:00 2001 From: Wouter-Michiel Vierdag Date: Thu, 11 Sep 2025 11:32:43 +0200 Subject: [PATCH 111/126] correct no zarr store write test --- src/spatialdata/_core/spatialdata.py | 26 ++++++++++---------------- tests/io/test_readwrite.py | 18 +++++++++++------- 2 files changed, 21 insertions(+), 23 deletions(-) diff --git a/src/spatialdata/_core/spatialdata.py b/src/spatialdata/_core/spatialdata.py index 2ae4e43e..967d5928 100644 --- a/src/spatialdata/_core/spatialdata.py +++ b/src/spatialdata/_core/spatialdata.py @@ -1154,13 +1154,13 @@ def write( self._validate_can_safely_write_to_path(file_path, overwrite=overwrite) self._validate_all_elements() - for index, (element_type, element_name, element) in enumerate(self.gen_elements()): - if index == 0: - # parse_url cannot be replaced here as it actually also initialized an ome-zarr store. - store = parse_url(file_path, mode="w", fmt=parsed["SpatialData"]).store - zarr_group = zarr.open_group(store=store, mode="r+" if overwrite else "a") - self.write_attrs(zarr_group=zarr_group, sdata_format=parsed["SpatialData"]) - store.close() + # parse_url cannot be replaced here as it actually also initialized an ome-zarr store. + store = parse_url(file_path, mode="w", fmt=parsed["SpatialData"]).store + zarr_group = zarr.open_group(store=store, mode="r+") + self.write_attrs(zarr_group=zarr_group, sdata_format=parsed["SpatialData"]) + store.close() + + for element_type, element_name, element in self.gen_elements(): self._write_element( element=element, zarr_container_path=file_path, @@ -1173,15 +1173,9 @@ def write( old_path = self.path if self.path != file_path: self.path = file_path - if parse_url(file_path): - if consolidate_metadata: - self.write_consolidated_metadata() - else: - warnings.warn( - "The SpatialData object is empty. Only the directory has been written, but it is nozarr store.", - UserWarning, - stacklevel=2, - ) + + if consolidate_metadata: + self.write_consolidated_metadata() return old_path, self.path diff --git a/tests/io/test_readwrite.py b/tests/io/test_readwrite.py index 9deb18ab..c8b28e36 100644 --- a/tests/io/test_readwrite.py +++ b/tests/io/test_readwrite.py @@ -160,7 +160,10 @@ def test_roundtrip( sdata2 = SpatialData.read(tmpdir) tmpdir2 = Path(tmp_path) / "tmp2.zarr" sdata2.write(tmpdir2, sdata_formats=sdata_container_format) - _are_directories_identical(tmpdir, tmpdir2, exclude_regexp="[1-9][0-9]*.*") + if len(list(sdata.gen_elements())) > 0: + _are_directories_identical(tmpdir, tmpdir2, exclude_regexp="[1-9][0-9]*.*") + else: + assert tmpdir.exists() and tmpdir2.exists() def test_incremental_io_list_of_elements( self, @@ -371,16 +374,17 @@ def test_replace_transformation_on_disk_non_raster( t1 = get_transformation(SpatialData.read(f)[elem_name]) assert isinstance(t1, Scale) - def test_overwrite_works_when_no_zarr_store( + def test_write_overwrite_fails_when_no_zarr_store( self, full_sdata, sdata_container_format: SpatialDataContainerFormatType ): with tempfile.TemporaryDirectory() as tmpdir: - f = os.path.join(tmpdir, "data.zarr") + f = Path(tmpdir) / "data.zarr" + f.mkdir() old_data = SpatialData() - old_data.write(f, sdata_formats=sdata_container_format) - # Since no, no risk of overwriting backing data. - # Should not raise "The file path specified is the same as the one used for backing." - full_sdata.write(f, overwrite=True, sdata_formats=sdata_container_format) + with pytest.raises(ValueError, match="The target file path specified already exists"): + old_data.write(f, sdata_formats=sdata_container_format) + with pytest.raises(ValueError, match="The target file path specified already exists"): + full_sdata.write(f, overwrite=True, sdata_formats=sdata_container_format) def test_overwrite_fails_when_no_zarr_store_bug_dask_backed_data( self, From 594003b6ecee992a3be94ba30b3d201ff6c6026c Mon Sep 17 00:00:00 2001 From: Wouter-Michiel Vierdag Date: Thu, 11 Sep 2025 11:38:21 +0200 Subject: [PATCH 112/126] remove unnecessary code --- tests/io/test_readwrite.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/tests/io/test_readwrite.py b/tests/io/test_readwrite.py index c8b28e36..04cbe367 100644 --- a/tests/io/test_readwrite.py +++ b/tests/io/test_readwrite.py @@ -160,10 +160,7 @@ def test_roundtrip( sdata2 = SpatialData.read(tmpdir) tmpdir2 = Path(tmp_path) / "tmp2.zarr" sdata2.write(tmpdir2, sdata_formats=sdata_container_format) - if len(list(sdata.gen_elements())) > 0: - _are_directories_identical(tmpdir, tmpdir2, exclude_regexp="[1-9][0-9]*.*") - else: - assert tmpdir.exists() and tmpdir2.exists() + _are_directories_identical(tmpdir, tmpdir2, exclude_regexp="[1-9][0-9]*.*") def test_incremental_io_list_of_elements( self, From 17eaf8ddfeb357df8b87593398f62452ca17df2e Mon Sep 17 00:00:00 2001 From: Luca Marconato Date: Thu, 11 Sep 2025 11:54:44 +0200 Subject: [PATCH 113/126] fix write_element --- src/spatialdata/_core/spatialdata.py | 15 +++------------ 1 file changed, 3 insertions(+), 12 deletions(-) diff --git a/src/spatialdata/_core/spatialdata.py b/src/spatialdata/_core/spatialdata.py index 967d5928..d0ad831f 100644 --- a/src/spatialdata/_core/spatialdata.py +++ b/src/spatialdata/_core/spatialdata.py @@ -1115,7 +1115,7 @@ def write( overwrite: bool = False, consolidate_metadata: bool = True, sdata_formats: SpatialDataFormatType | list[SpatialDataFormatType] | None = None, - ) -> tuple[Path | None, Path | None]: + ) -> None: """ Write the `SpatialData` object to a Zarr store. @@ -1140,10 +1140,6 @@ def write( By default, the latest format is used for all elements, i.e. :class:`~spatialdata._io.format.CurrentRasterFormat`, :class:`~spatialdata._io.format.CurrentShapesFormat`, :class:`~spatialdata._io.format.CurrentPointsFormat`, :class:`~spatialdata._io.format.CurrentTablesFormat`. - - Returns - ------- - The old path and the new path of the SpatialData zarr store. """ from spatialdata._io.format import _parse_formats @@ -1170,15 +1166,12 @@ def write( parsed_formats=parsed, ) - old_path = self.path if self.path != file_path: self.path = file_path if consolidate_metadata: self.write_consolidated_metadata() - return old_path, self.path - def _write_element( self, element: SpatialElement | AnnData, @@ -1279,13 +1272,11 @@ def write_element( from spatialdata._io.format import _parse_formats parsed_formats = _parse_formats(formats=sdata_formats) - if parse_url(self.path) is None: - store = parse_url(self.path, mode="w", fmt=parsed_formats["SpatialData"]).store - store.close() + if isinstance(element_name, list): for name in element_name: assert isinstance(name, str) - self.write_element(name, overwrite=overwrite) + self.write_element(name, overwrite=overwrite, sdata_formats=sdata_formats) return check_valid_name(element_name) From ffdc9da766a0c78352ea460056af6e096d77df86 Mon Sep 17 00:00:00 2001 From: Wouter-Michiel Vierdag Date: Thu, 11 Sep 2025 12:14:29 +0200 Subject: [PATCH 114/126] delete group isntead of .zattrs --- tests/io/test_readwrite.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/io/test_readwrite.py b/tests/io/test_readwrite.py index 04cbe367..a00f5405 100644 --- a/tests/io/test_readwrite.py +++ b/tests/io/test_readwrite.py @@ -933,7 +933,7 @@ def test_can_read_sdata_with_reconsolidation(full_sdata, sdata_container_format: json_path = path / ".zmetadata" json_dict = json.loads(json_path.read_text()) # TODO: this raises no exception! - del json_dict["metadata"]["images/image2d/0/.zattrs"] + del json_dict["metadata"]["images/image2d/.zgroup"] else: json_path = path / "zarr.json" json_dict = json.loads(json_path.read_text()) From 8f0f438cb3ef66becd7f17e1c9d1aa616e8c29b9 Mon Sep 17 00:00:00 2001 From: Luca Marconato Date: Thu, 11 Sep 2025 12:27:26 +0200 Subject: [PATCH 115/126] remove parse_url --- src/spatialdata/_core/spatialdata.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/spatialdata/_core/spatialdata.py b/src/spatialdata/_core/spatialdata.py index d0ad831f..1970eace 100644 --- a/src/spatialdata/_core/spatialdata.py +++ b/src/spatialdata/_core/spatialdata.py @@ -16,7 +16,6 @@ from dask.dataframe import read_parquet from dask.delayed import Delayed from geopandas import GeoDataFrame -from ome_zarr.io import parse_url from shapely import MultiPolygon, Polygon from xarray import DataArray, DataTree @@ -1141,6 +1140,7 @@ def write( :class:`~spatialdata._io.format.CurrentRasterFormat`, :class:`~spatialdata._io.format.CurrentShapesFormat`, :class:`~spatialdata._io.format.CurrentPointsFormat`, :class:`~spatialdata._io.format.CurrentTablesFormat`. """ + from spatialdata._io._utils import _resolve_zarr_store from spatialdata._io.format import _parse_formats parsed = _parse_formats(sdata_formats) @@ -1150,9 +1150,9 @@ def write( self._validate_can_safely_write_to_path(file_path, overwrite=overwrite) self._validate_all_elements() - # parse_url cannot be replaced here as it actually also initialized an ome-zarr store. - store = parse_url(file_path, mode="w", fmt=parsed["SpatialData"]).store - zarr_group = zarr.open_group(store=store, mode="r+") + store = _resolve_zarr_store(file_path) + zarr_format = parsed["SpatialData"].zarr_format + zarr_group = zarr.create_group(store=store, overwrite=overwrite, zarr_format=zarr_format) self.write_attrs(zarr_group=zarr_group, sdata_format=parsed["SpatialData"]) store.close() @@ -1162,7 +1162,7 @@ def write( zarr_container_path=file_path, element_type=element_type, element_name=element_name, - overwrite=overwrite, + overwrite=False, parsed_formats=parsed, ) From d3191a34fc611fcfb31aad2f762a59a10ad43484 Mon Sep 17 00:00:00 2001 From: Luca Marconato Date: Thu, 11 Sep 2025 15:03:34 +0200 Subject: [PATCH 116/126] removed logger.ingo() and most of remaining warnings from tests --- pyproject.toml | 2 +- .../_core/operations/rasterize_bins.py | 3 +- src/spatialdata/_core/spatialdata.py | 64 +++++++++++++------ src/spatialdata/_io/format.py | 8 +-- src/spatialdata/datasets.py | 8 ++- src/spatialdata/models/models.py | 4 +- tests/conftest.py | 45 ++++++++++--- tests/core/operations/test_rasterize_bins.py | 33 +++++++--- tests/io/test_readwrite.py | 14 ++-- tests/models/test_models.py | 2 +- tests/transformations/test_transformations.py | 3 +- 11 files changed, 124 insertions(+), 62 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 274eeb18..3911ae4c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -98,7 +98,7 @@ xfail_strict = true addopts = [ # "-Werror", # if 3rd party libs raise DeprecationWarnings, just use filterwarnings below "--import-mode=importlib", # allow using test files with same name - "-s" # print output from tests + "-s", # print output from tests ] # These are all markers coming from xarray, dask or anndata. Added here to silence warnings. markers = [ diff --git a/src/spatialdata/_core/operations/rasterize_bins.py b/src/spatialdata/_core/operations/rasterize_bins.py index cff846ac..17470812 100644 --- a/src/spatialdata/_core/operations/rasterize_bins.py +++ b/src/spatialdata/_core/operations/rasterize_bins.py @@ -132,7 +132,8 @@ def rasterize_bins( random_indices = RNG.choice(table.n_obs, min(20, table.n_obs), replace=True) location_ids = table.obs[instance_key].iloc[random_indices].values - sub_df, sub_table = element.loc[location_ids], table[random_indices] + sub_df = element.loc[location_ids] + sub_table = table[random_indices] src = np.stack([sub_table.obs[col_key] - min_col, sub_table.obs[row_key] - min_row], axis=1) if isinstance(sub_df, GeoDataFrame): diff --git a/src/spatialdata/_core/spatialdata.py b/src/spatialdata/_core/spatialdata.py index 1970eace..27fdb590 100644 --- a/src/spatialdata/_core/spatialdata.py +++ b/src/spatialdata/_core/spatialdata.py @@ -18,6 +18,7 @@ from geopandas import GeoDataFrame from shapely import MultiPolygon, Polygon from xarray import DataArray, DataTree +from zarr.errors import GroupNotFoundError from spatialdata._core._elements import Images, Labels, Points, Shapes, Tables from spatialdata._core.validation import ( @@ -563,13 +564,6 @@ def path(self, value: Path | None) -> None: else: raise TypeError("Path must be `None`, a `str` or a `Path` object.") - if not self.is_self_contained(): - logger.info( - "The SpatialData object is not self-contained (i.e. it contains some elements that are Dask-backed from" - f" locations outside {self.path}). Please see the documentation of `is_self_contained()` to understand" - f" the implications of working with SpatialData objects that are not self-contained." - ) - def locate_element(self, element: SpatialElement) -> list[str]: """ Locate a SpatialElement within the SpatialData object and returns its Zarr paths relative to the root. @@ -1047,7 +1041,7 @@ def _validate_can_safely_write_to_path( overwrite: bool = False, saving_an_element: bool = False, ) -> None: - from spatialdata._io._utils import _backed_elements_contained_in_path, _is_subfolder + from spatialdata._io._utils import _backed_elements_contained_in_path, _is_subfolder, _resolve_zarr_store if isinstance(file_path, str): file_path = Path(file_path) @@ -1057,11 +1051,14 @@ def _validate_can_safely_write_to_path( # TODO: add test for this if os.path.exists(file_path): - if parse_url(file_path, mode="r") is None: + store = _resolve_zarr_store(file_path) + try: + zarr.open(store, mode="r") + except GroupNotFoundError as err: raise ValueError( "The target file path specified already exists, and it has been detected to not be a Zarr store. " "Overwriting non-Zarr stores is not supported to prevent accidental data loss." - ) + ) from err if not overwrite: raise ValueError( "The Zarr store already exists. Use `overwrite=True` to try overwriting the store. " @@ -1113,6 +1110,7 @@ def write( file_path: str | Path, overwrite: bool = False, consolidate_metadata: bool = True, + update_sdata_path: bool = True, sdata_formats: SpatialDataFormatType | list[SpatialDataFormatType] | None = None, ) -> None: """ @@ -1129,16 +1127,35 @@ def write( If `True`, triggers :func:`zarr.convenience.consolidate_metadata`, which writes all the metadata in a single file at the root directory of the store. This makes the data cloud accessible, which is required for certain cloud stores (such as S3). - format + update_sdata_path + Whether to update the `path` attribute of the `SpatialData` object to `file_path` after a successful write + (default yes). Here are the implications. + + - If `True`, and if the `SpatialData` object has dask-backed elements, the object will become "not + self-contained" because the dask-backed element will have a path that is different from the new + `sdata.path` attribute (now equal to `file_path`). By re-reading the object from disk, the object will + become self-contained again. + - If `False`, the `SpatialData` object will keep its current `path` attribute, meaning that calling + `sdata.write_element()`, `sdata.write_attrs()`, ... will write the data to the current `sdata.path` + location, not to the `file_path` location. + + Please consult :func:`spatialdata.SpatialData.is_self_contained` for more information on the implications of + working with self-contained and non-self-contained SpatialData objects. + sdata_formats The format to use for writing the elements of the `SpatialData` object. It is recommended to leave this parameter equal to `None` (default to latest format for all the elements). If not `None`, it must be - either a format for an element, or a list of formats. - For example it can be a subset of the following list `[RasterFormatVXX(), ShapesFormatVXX(), - PointsFormatVXX(), TablesFormatVXX()]`. (XX denote the version number, and should be replaced with the - respective format; the version numbers can differ across elements). - By default, the latest format is used for all elements, i.e. - :class:`~spatialdata._io.format.CurrentRasterFormat`, :class:`~spatialdata._io.format.CurrentShapesFormat`, - :class:`~spatialdata._io.format.CurrentPointsFormat`, :class:`~spatialdata._io.format.CurrentTablesFormat`. + either a format for an element, or a list of formats. For example it can be a subset of the following + list `[RasterFormatVXX(), ShapesFormatVXX(), PointsFormatVXX(), TablesFormatVXX()]`. (XX denote the + version number, and should be replaced with the respective format; the version numbers can differ across + elements). By default, the latest format is used for all elements, + i.e. :class:`~spatialdata._io.format.CurrentRasterFormat`, + :class:`~spatialdata._io.format.CurrentShapesFormat`, + :class:`~spatialdata._io.format.CurrentPointsFormat`, + :class:`~spatialdata._io.format.CurrentTablesFormat`. Also, by default, if a format for the SpatialData + "container" object (i.e. SpatialDataContainerFormatVXX) is specified, but the format for some elements is + unspecified, the element formats will be set to the latest element format compatible with the specified + SpatialData container format. All the formats and relationships between them are defined in + `spatialdata._io.format.py`. """ from spatialdata._io._utils import _resolve_zarr_store from spatialdata._io.format import _parse_formats @@ -1166,7 +1183,7 @@ def write( parsed_formats=parsed, ) - if self.path != file_path: + if self.path != file_path and update_sdata_path: self.path = file_path if consolidate_metadata: @@ -1640,7 +1657,7 @@ def write_metadata( element_name: str | None = None, consolidate_metadata: bool | None = None, write_attrs: bool = True, - format: SpatialDataContainerFormatType | None = None, + sdata_format: SpatialDataContainerFormatType | None = None, ) -> None: """ Write the metadata of a single element, or of all elements, to the Zarr store, without rewriting the data. @@ -1661,6 +1678,11 @@ def write_metadata( consolidate_metadata If True, consolidate the metadata to more easily support remote reading. By default write the metadata only if the metadata was already consolidated. + write_attrs + If True, write the SpatialData.attrs metadata to the root of the Zarr store. + sdata_format + The format to use for writing the metadata of the `SpatialData` object. See more details in the + documentation of `SpatialData.write()`. Notes ----- @@ -1677,7 +1699,7 @@ def write_metadata( # TODO: write .attrs['spatialdata_attrs'] metadata for DaskDataFrame. if write_attrs: - self.write_attrs(format=format) + self.write_attrs(sdata_format=sdata_format) # TODO: discuss when has_consolidated_metadata that we should just consolidate it because after a writing # operation the consolidated store could otherwise be out of sync. diff --git a/src/spatialdata/_io/format.py b/src/spatialdata/_io/format.py index df275cb5..0c1cf0cb 100644 --- a/src/spatialdata/_io/format.py +++ b/src/spatialdata/_io/format.py @@ -1,4 +1,3 @@ -import warnings from collections.abc import Iterator from typing import Any @@ -405,12 +404,7 @@ def _check_modified(element_type: str) -> None: raise ValueError(f"Unsupported format {fmt}") if parsed["SpatialData"].__str__() == "SpatialDataContainerFormatV01": - warnings.warn( - "SpatialData format defined to be 'SpatialDataContainerFormatV01'. Defaulting undefined element " - "formats to element formats valid for 'SpatialDataContainerFormatV01'.", - UserWarning, - stacklevel=2, - ) + # defaulting undefined element formats to element formats valid for 'SpatialDataContainerFormatV01' for el_type, value in modified.items(): if el_type != "SpatialData" and not value: parsed[el_type] = ContainerV01DefaultTypes[el_type] diff --git a/src/spatialdata/datasets.py b/src/spatialdata/datasets.py index cc828c91..63c137cd 100644 --- a/src/spatialdata/datasets.py +++ b/src/spatialdata/datasets.py @@ -1,5 +1,6 @@ """SpatialData datasets.""" +import warnings from typing import Any, Literal import dask.dataframe.core @@ -18,7 +19,6 @@ from spatialdata._core.operations.aggregate import aggregate from spatialdata._core.query.relational_query import get_element_instances from spatialdata._core.spatialdata import SpatialData -from spatialdata._logging import logger from spatialdata._types import ArrayLike from spatialdata.models import Image2DModel, Labels2DModel, PointsModel, ShapesModel, TableModel from spatialdata.transformations import Identity @@ -126,9 +126,11 @@ def __init__( self.c_coords = c_coords if c_coords is not None: if n_channels != len(c_coords): - logger.info( + warnings.warn( f"Number of channels ({n_channels}) and c_coords ({len(c_coords)}) do not match; ignoring " - f"n_channels value" + f"n_channels value", + UserWarning, + stacklevel=2, ) n_channels = len(c_coords) self.n_channels = n_channels diff --git a/src/spatialdata/models/models.py b/src/spatialdata/models/models.py index 576d3b58..60f4ee20 100644 --- a/src/spatialdata/models/models.py +++ b/src/spatialdata/models/models.py @@ -667,8 +667,8 @@ def validate(cls, data: DaskDataFrame) -> None: ) if ATTRS_KEY in data.attrs and "feature_key" in data.attrs[ATTRS_KEY]: feature_key = data.attrs[ATTRS_KEY][cls.FEATURE_KEY] - if not isinstance(data[feature_key].dtype, CategoricalDtype): - logger.info(f"Feature key `{feature_key}`could be of type `pd.Categorical`. Consider casting it.") + if feature_key not in data.columns: + warnings.warn(f"Column `{feature_key}` not found." + SUGGESTION, UserWarning, stacklevel=2) @singledispatchmethod @classmethod diff --git a/tests/conftest.py b/tests/conftest.py index d4ad346c..77572125 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -64,12 +64,16 @@ def points() -> SpatialData: @pytest.fixture() def table_single_annotation() -> SpatialData: - return SpatialData(tables={"table": _get_table(region="labels2d")}) + return SpatialData(tables={"table": _get_table(region="labels2d")}, labels=_get_labels()) @pytest.fixture() def table_multiple_annotations() -> SpatialData: - return SpatialData(tables={"table": _get_table(region=["labels2d", "poly"])}) + return SpatialData( + tables={"table": _get_table(region=["labels2d", "poly"])}, + labels=_get_labels(), + shapes=_get_shapes(), + ) @pytest.fixture() @@ -91,13 +95,20 @@ def full_sdata() -> SpatialData: labels=_get_labels(), shapes=_get_shapes(), points=_get_points(), - tables=_get_tables(region="labels2d"), + tables=_get_tables(region="labels2d", region_key="region", instance_key="instance_id"), ) @pytest.fixture( # params=["labels"] - params=["full", "empty"] + ["images", "labels", "points", "table_single_annotation", "table_multiple_annotations"] + params=["full", "empty"] + + [ + "images", + "labels", + "points", + "table_single_annotation", + "table_multiple_annotations", + ] # + ["empty_" + x for x in ["table"]] # TODO: empty table not supported yet ) def sdata(request) -> SpatialData: @@ -258,7 +269,7 @@ def _get_points() -> dict[str, DaskDataFrame]: def _get_tables( - region: None | str | list[str] = "sample1", + region: None | str | list[str], region_key: None | str = "region", instance_key: None | str = "instance_id", ) -> dict[str, AnnData]: @@ -266,13 +277,17 @@ def _get_tables( def _get_table( - region: None | str | list[str] = "sample1", + region: None | str | list[str], region_key: None | str = "region", instance_key: None | str = "instance_id", ) -> AnnData: adata = AnnData( RNG.normal(size=(100, 10)), - obs=pd.DataFrame(RNG.normal(size=(100, 3)), columns=["a", "b", "c"], index=[f"{i}" for i in range(100)]), + obs=pd.DataFrame( + RNG.normal(size=(100, 3)), + columns=["a", "b", "c"], + index=[f"{i}" for i in range(100)], + ), ) if not all(var for var in (region, region_key, instance_key)): return TableModel.parse(adata=adata) @@ -392,7 +407,10 @@ def _make_sdata_for_testing_querying_and_aggretation() -> SpatialData: s_num = pd.Series(RNG.random(20)) # workaround for https://github.com/dask/dask/issues/11147, let's recompute the dataframe (it's a small one) values_points = PointsModel.parse( - dd.from_pandas(values_points.compute().assign(categorical_in_ddf=s_cat, numerical_in_ddf=s_num), npartitions=1) + dd.from_pandas( + values_points.compute().assign(categorical_in_ddf=s_cat, numerical_in_ddf=s_num), + npartitions=1, + ) ) sdata = SpatialData( @@ -425,7 +443,10 @@ def _make_sdata_for_testing_querying_and_aggretation() -> SpatialData: var=pd.DataFrame(index=["numerical_in_var"]), ) table = TableModel.parse( - table, region=["values_circles", "values_polygons"], region_key="region", instance_key="instance_id" + table, + region=["values_circles", "values_polygons"], + region_key="region", + instance_key="instance_id", ) sdata["table"] = table return sdata @@ -473,7 +494,11 @@ def adata_labels() -> AnnData: index=np.arange(n_obs_labels).astype(str), ) uns_labels = { - "spatialdata_attrs": {"region": "test", "region_key": "region", "instance_key": "instance_id"}, + "spatialdata_attrs": { + "region": "test", + "region_key": "region", + "instance_key": "instance_id", + }, } obsm_labels = { "tensor": rng.integers(0, blobs.shape[0], size=(n_obs_labels, 2)), diff --git a/tests/core/operations/test_rasterize_bins.py b/tests/core/operations/test_rasterize_bins.py index 84346af3..b99508ef 100644 --- a/tests/core/operations/test_rasterize_bins.py +++ b/tests/core/operations/test_rasterize_bins.py @@ -20,7 +20,13 @@ from spatialdata._core.spatialdata import SpatialData from spatialdata._logging import logger from spatialdata._types import ArrayLike -from spatialdata.models.models import Image2DModel, Labels2DModel, PointsModel, ShapesModel, TableModel +from spatialdata.models.models import ( + Image2DModel, + Labels2DModel, + PointsModel, + ShapesModel, + TableModel, +) from spatialdata.transformations.transformations import Scale RNG = default_rng(0) @@ -42,24 +48,29 @@ def test_rasterize_bins(geometry: str, value_key: str | list[str] | None, return n = 10 data, x, y = _get_bins_data(n) scale = Scale([2.0], axes=("x",)) + index = np.arange(1, len(data) + 1) if geometry == "points": - points = PointsModel.parse(data, transformations={"global": scale}) + points = PointsModel.parse( + data, + transformations={"global": scale}, + annotation=pd.DataFrame(index=index), + ) elif geometry == "circles": - points = ShapesModel.parse(data, geometry=0, radius=1, transformations={"global": scale}) + points = ShapesModel.parse(data, geometry=0, radius=1, transformations={"global": scale}, index=index) else: assert geometry == "squares" gdf = GeoDataFrame( - data={"geometry": [Polygon([(x, y), (x + 1, y), (x + 1, y + 1), (x, y + 1), (x, y)]) for x, y in data]} + index=index, + data={"geometry": [Polygon([(x, y), (x + 1, y), (x + 1, y + 1), (x, y + 1), (x, y)]) for x, y in data]}, ) - points = ShapesModel.parse(gdf, transformations={"global": scale}) obs = DataFrame( data={ "region": pd.Categorical(["points"] * n * n), - "instance_id": np.arange(n * n), + "instance_id": index, "col_index": x, "row_index": y, }, @@ -68,7 +79,10 @@ def test_rasterize_bins(geometry: str, value_key: str | list[str] | None, return X = RNG.normal(size=(n * n, 2)) var = DataFrame(index=["gene0", "gene1"]) table = TableModel.parse( - AnnData(X=X, var=var, obs=obs), region="points", region_key="region", instance_key="instance_id" + AnnData(X=X, var=var, obs=obs), + region="points", + region_key="region", + instance_key="instance_id", ) sdata = SpatialData.init_from_elements({"points": points, "table": table}) rasterized = rasterize_bins( @@ -215,7 +229,10 @@ def _get_sdata(n: int): sdata = _get_sdata(n=3) table = sdata.tables["table"] table.obs["region"] = table.obs["region"].astype(str) - with pytest.raises(ValueError, match="Please convert `table.obs.*` to a category series to improve performances"): + with pytest.raises( + ValueError, + match="Please convert `table.obs.*` to a category series to improve performances", + ): _ = rasterize_bins( sdata=sdata, bins="points", diff --git a/tests/io/test_readwrite.py b/tests/io/test_readwrite.py index a00f5405..ad0662e2 100644 --- a/tests/io/test_readwrite.py +++ b/tests/io/test_readwrite.py @@ -503,7 +503,7 @@ def test_incremental_io_in_memory( with pytest.raises(KeyError, match="Key `table` is not unique"): sdata["table"] = v - for k, v in _get_tables().items(): + for k, v in _get_tables(region="labels2d").items(): sdata.tables[f"additional_{k}"] = v with pytest.raises(KeyError, match="Key `poly` is not unique"): sdata["poly"] = v @@ -715,13 +715,13 @@ def test_incremental_io_attrs(points: SpatialData, sdata_container_format: Spati # test incremental io attrs (write_attrs()) sdata.attrs["c"] = 2 - sdata.write_attrs(format=sdata_container_format) + sdata.write_attrs(sdata_format=sdata_container_format) sdata2 = SpatialData.read(f) assert sdata2.attrs["c"] == 2 # test incremental io attrs (write_metadata()) sdata.attrs["c"] = 3 - sdata.write_metadata(format=sdata_container_format) + sdata.write_metadata(sdata_format=sdata_container_format) sdata2 = SpatialData.read(f) assert sdata2.attrs["c"] == 3 @@ -819,7 +819,7 @@ def test_element_already_on_disk_different_type( ValueError, match=ERROR_MSG, ): - full_sdata.write_metadata(element_name, format=sdata_container_format) + full_sdata.write_metadata(element_name, sdata_format=sdata_container_format) with pytest.raises( ValueError, @@ -835,7 +835,7 @@ def test_writing_invalid_name(tmp_path: Path): invalid_sdata.labels.data["."] = next(iter(_get_labels().values())) invalid_sdata.points.data["path/separator"] = next(iter(_get_points().values())) invalid_sdata.shapes.data["non-alnum_#$%&()*+,?@"] = next(iter(_get_shapes().values())) - invalid_sdata.tables.data["has whitespace"] = _get_table() + invalid_sdata.tables.data["has whitespace"] = _get_table(region="any") with pytest.raises(ValueError, match="Name (must|cannot)"): invalid_sdata.write(tmp_path / "data.zarr") @@ -859,7 +859,7 @@ def test_incremental_writing_invalid_name(tmp_path: Path): invalid_sdata.labels.data["."] = next(iter(_get_labels().values())) invalid_sdata.points.data["path/separator"] = next(iter(_get_points().values())) invalid_sdata.shapes.data["non-alnum_#$%&()*+,?@"] = next(iter(_get_shapes().values())) - invalid_sdata.tables.data["has whitespace"] = _get_table() + invalid_sdata.tables.data["has whitespace"] = _get_table(region="any") for element_type in ["images", "labels", "points", "shapes", "tables"]: elements = getattr(invalid_sdata, element_type) @@ -883,7 +883,7 @@ def test_reading_invalid_name(tmp_path: Path): labels_name, labels = next(iter(_get_labels().items())) points_name, points = next(iter(_get_points().items())) shapes_name, shapes = next(iter(_get_shapes().items())) - table_name, table = "table", _get_table() + table_name, table = "table", _get_table(region="labels2d") valid_sdata = SpatialData( images={image_name: image}, labels={labels_name: labels}, diff --git a/tests/models/test_models.py b/tests/models/test_models.py index b62eb53b..3b03da88 100644 --- a/tests/models/test_models.py +++ b/tests/models/test_models.py @@ -534,7 +534,7 @@ def test_get_schema(): labels = _get_labels() points = _get_points() shapes = _get_shapes() - table = _get_table() + table = _get_table(region="any", region_key="region", instance_key="instance_id") for k, v in images.items(): schema = get_model(v) if "2d" in k: diff --git a/tests/transformations/test_transformations.py b/tests/transformations/test_transformations.py index 180d007d..b8571dc9 100644 --- a/tests/transformations/test_transformations.py +++ b/tests/transformations/test_transformations.py @@ -1001,6 +1001,7 @@ def test_keep_numerical_coordinates_c(image_name): def test_keep_string_coordinates_c(image_name): c_coords = ["a", "b", "c"] # n_channels will be ignored, testing also that this works - sdata = blobs(c_coords=c_coords, n_channels=4) + with pytest.warns(UserWarning, match="Number of channels "): + sdata = blobs(c_coords=c_coords, n_channels=4) t_blobs = transform(sdata.images[image_name], to_coordinate_system=DEFAULT_COORDINATE_SYSTEM) assert np.array_equal(get_channel_names(t_blobs), c_coords) From 48ad6fd29b39301197e575b608a7c64b9fd38e98 Mon Sep 17 00:00:00 2001 From: Luca Marconato Date: Thu, 11 Sep 2025 15:33:52 +0200 Subject: [PATCH 117/126] addressing review comments --- pyproject.toml | 1 + tests/io/test_multi_table.py | 8 ++++++++ tests/models/test_models.py | 3 ++- 3 files changed, 11 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 3911ae4c..766f7515 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -54,6 +54,7 @@ dependencies = [ [project.optional-dependencies] dev = [ "bump2version", + "sentry-prevent-cli", ] test = [ "pytest", diff --git a/tests/io/test_multi_table.py b/tests/io/test_multi_table.py index 0e37e1b4..70974c8a 100644 --- a/tests/io/test_multi_table.py +++ b/tests/io/test_multi_table.py @@ -41,6 +41,14 @@ def test_set_get_tables_from_spatialdata(self, full_sdata: SpatialData, tmp_path assert_equal(adata1, full_sdata["adata1"]) assert "adata0" in full_sdata.tables and "adata1" in full_sdata.tables + def test_null_values_in_instance_key_column(self, full_sdata: SpatialData): + n_obs = full_sdata["table"].n_obs + full_sdata["table"].obs["instance_id"] = range(n_obs) + # introduce null values + full_sdata["table"].obs.loc[0, "instance_id"] = None + with pytest.raises(ValueError, match="must not contain null values, but it does."): + full_sdata.validate_table_in_spatialdata(table=full_sdata["table"]) + @pytest.mark.parametrize( "region_key, instance_key, error_msg", [ diff --git a/tests/models/test_models.py b/tests/models/test_models.py index 3b03da88..2ed108b7 100644 --- a/tests/models/test_models.py +++ b/tests/models/test_models.py @@ -520,9 +520,10 @@ def test_table_model_invalid_names(self, key: str, attr: str, parse: bool): @pytest.mark.parametrize("attr", ["obs", "var"]) @pytest.mark.parametrize("parse", [True, False]) def test_table_model_not_unique_columns(self, keys: list[str], attr: str, parse: bool): + invalid_key = keys[1] df = pd.DataFrame([[None] * len(keys)], columns=keys, index=["1"]) adata = AnnData(np.array([[0]]), **{attr: df}) - with pytest.raises(ValueError, match="Table contains invalid names.\nFor renaming, please"): + with pytest.raises(ValueError, match=f"Table contains invalid names(.|\n)*\n {attr}/{invalid_key}: "): if parse: TableModel.parse(adata) else: From a9e72426004fcb1e91f07bdd9e7fb211f28b2185 Mon Sep 17 00:00:00 2001 From: Luca Marconato Date: Thu, 11 Sep 2025 18:48:35 +0200 Subject: [PATCH 118/126] addressed consolidate metadata comment --- src/spatialdata/_core/spatialdata.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/spatialdata/_core/spatialdata.py b/src/spatialdata/_core/spatialdata.py index 27fdb590..e7ce4da4 100644 --- a/src/spatialdata/_core/spatialdata.py +++ b/src/spatialdata/_core/spatialdata.py @@ -1701,9 +1701,7 @@ def write_metadata( if write_attrs: self.write_attrs(sdata_format=sdata_format) - # TODO: discuss when has_consolidated_metadata that we should just consolidate it because after a writing - # operation the consolidated store could otherwise be out of sync. - if consolidate_metadata is None and self.has_consolidated_metadata(): + if self.has_consolidated_metadata(): consolidate_metadata = True if consolidate_metadata: self.write_consolidated_metadata() From 71adfe1252cc5bc8b159abf789d238ab2a64d5d1 Mon Sep 17 00:00:00 2001 From: Luca Marconato Date: Thu, 11 Sep 2025 19:22:08 +0200 Subject: [PATCH 119/126] make full coverage of _validate_can_safely_write_to_path() easier to understand --- tests/io/test_readwrite.py | 29 ++++++++++++++++++++++++----- 1 file changed, 24 insertions(+), 5 deletions(-) diff --git a/tests/io/test_readwrite.py b/tests/io/test_readwrite.py index ad0662e2..8501687c 100644 --- a/tests/io/test_readwrite.py +++ b/tests/io/test_readwrite.py @@ -229,7 +229,8 @@ def _workaround1_dask_backed( del sdata[new_name] sdata.delete_element_from_disk(new_name) - @pytest.mark.parametrize("dask_backed", [True, False]) + # @pytest.mark.parametrize("dask_backed", [True, False]) + @pytest.mark.parametrize("dask_backed", [True]) @pytest.mark.parametrize("workaround", [1, 2]) def test_incremental_io_on_disk( self, @@ -274,7 +275,24 @@ def test_incremental_io_on_disk( ): sdata.write_element(name, sdata_formats=sdata_container_format) - with pytest.raises(ValueError, match="Cannot overwrite."): + match = ( + "Details: the target path contains one or more files that Dask use for backing elements in the " + "SpatialData object" + if dask_backed + and name + in [ + "image2d", + "labels2d", + "image3d_multiscale_xarray", + "labels3d_multiscale_xarray", + "points_0", + ] + else "Details: the target path in which to save an element is a subfolder of the current Zarr store." + ) + with pytest.raises( + ValueError, + match=match, + ): sdata.write_element(name, overwrite=True, sdata_formats=sdata_container_format) if workaround == 1: @@ -383,7 +401,7 @@ def test_write_overwrite_fails_when_no_zarr_store( with pytest.raises(ValueError, match="The target file path specified already exists"): full_sdata.write(f, overwrite=True, sdata_formats=sdata_container_format) - def test_overwrite_fails_when_no_zarr_store_bug_dask_backed_data( + def test_overwrite_fails_when_no_zarr_store_but_dask_backed_data( self, full_sdata, points, @@ -411,7 +429,8 @@ def test_overwrite_fails_when_no_zarr_store_bug_dask_backed_data( with pytest.raises( ValueError, - match="Cannot overwrite.", + match=r"Details: the target path contains one or more files that Dask use for " + "backing elements in the SpatialData object", ): full_sdata.write(f, overwrite=True, sdata_formats=sdata_container_format) @@ -431,7 +450,7 @@ def test_overwrite_fails_when_zarr_store_present( with pytest.raises( ValueError, - match="Cannot overwrite.", + match=r"Details: the target path either contains, coincides or is contained in the current Zarr store", ): full_sdata.write(f, overwrite=True, sdata_formats=sdata_container_format) From 7c56b251577beac3ecba497ea9d2d70d5478e0e7 Mon Sep 17 00:00:00 2001 From: Wouter-Michiel Vierdag Date: Fri, 12 Sep 2025 10:43:16 +0200 Subject: [PATCH 120/126] restore partial read/write tests --- src/spatialdata/_io/_utils.py | 15 +- src/spatialdata/_io/io_zarr.py | 106 +++++------ tests/io/test_partial_read.py | 336 ++++++++++++++++++++++++++++++--- 3 files changed, 355 insertions(+), 102 deletions(-) diff --git a/src/spatialdata/_io/_utils.py b/src/spatialdata/_io/_utils.py index 82f5a7ce..b501e210 100644 --- a/src/spatialdata/_io/_utils.py +++ b/src/spatialdata/_io/_utils.py @@ -1,5 +1,4 @@ import filecmp -import logging import os.path import re import sys @@ -35,18 +34,6 @@ from spatialdata.transformations.transformations import BaseTransformation, _get_current_output_axes -# suppress logger debug from ome_zarr with context manager -@contextmanager -def ome_zarr_logger(level: Any) -> Generator[None, None, None]: - logger = logging.getLogger("ome_zarr") - current_level = logger.getEffectiveLevel() - logger.setLevel(level) - try: - yield - finally: - logger.setLevel(current_level) - - def _get_transformations_from_ngff_dict( list_of_encoded_ngff_transformations: list[dict[str, Any]], ) -> MappingToCoordinateSystem_t: @@ -472,7 +459,7 @@ class BadFileHandleMethod(Enum): def handle_read_errors( on_bad_files: Literal[BadFileHandleMethod.ERROR, BadFileHandleMethod.WARN], location: str, - exc_types: tuple[type[Exception], ...], + exc_types: type[BaseException] | tuple[type[BaseException], ...], ) -> Generator[None, None, None]: """ Handle read errors according to parameter `on_bad_files`. diff --git a/src/spatialdata/_io/io_zarr.py b/src/spatialdata/_io/io_zarr.py index 0bf95dbf..79a827a5 100644 --- a/src/spatialdata/_io/io_zarr.py +++ b/src/spatialdata/_io/io_zarr.py @@ -1,4 +1,3 @@ -import logging import os import warnings from json import JSONDecodeError @@ -8,14 +7,13 @@ import zarr.storage from anndata import AnnData from pyarrow import ArrowInvalid -from zarr.errors import ArrayNotFoundError, MetadataValidationError +from zarr.errors import ArrayNotFoundError from spatialdata._core.spatialdata import SpatialData from spatialdata._io._utils import ( BadFileHandleMethod, _resolve_zarr_store, handle_read_errors, - ome_zarr_logger, ) from spatialdata._io.io_points import _read_points from spatialdata._io.io_raster import _read_multiscale @@ -72,34 +70,40 @@ def read_zarr( # We raise OS errors instead for some read errors now as in zarr v3 with some corruptions nothing will be read. # related to images / labels. - if "images" in selector and "images" in root_group: - group = root_group["images"] - count = 0 - for subgroup_name in group: - if Path(subgroup_name).name.startswith("."): - # skip hidden files like .zgroup or .zmetadata - continue - elem_group = group[subgroup_name] - elem_group_path = os.path.join(root_store_path, elem_group.path) - with handle_read_errors( - on_bad_files, - location=f"{group.path}/{subgroup_name}", - exc_types=( - JSONDecodeError, # JSON parse error - ValueError, # ome_zarr: Unable to read the NGFF file - KeyError, # Missing JSON key - ArrayNotFoundError, # Image chunks missing - TypeError, # instead of ArrayNotFoundError, with dask>=2024.10.0 zarr<=2.18.3 - OSError, - ), - ): - element = _read_multiscale(elem_group_path, raster_type="image") - images[subgroup_name] = element - count += 1 - logger.debug(f"Found {count} elements in {group}") + with handle_read_errors( + on_bad_files, + location="images", + exc_types=JSONDecodeError, + ): + if "images" in selector and "images" in root_group: + group = root_group["images"] + count = 0 + for subgroup_name in group: + if Path(subgroup_name).name.startswith("."): + # skip hidden files like .zgroup or .zmetadata + continue + elem_group = group[subgroup_name] + elem_group_path = os.path.join(root_store_path, elem_group.path) + with handle_read_errors( + on_bad_files, + location=f"{group.path}/{subgroup_name}", + exc_types=( + KeyError, + ArrayNotFoundError, + OSError, + ), + ): + element = _read_multiscale(elem_group_path, raster_type="image") + images[subgroup_name] = element + count += 1 + logger.debug(f"Found {count} elements in {group}") # read multiscale labels - with ome_zarr_logger(logging.ERROR): + with handle_read_errors( + on_bad_files, + location="labels", + exc_types=JSONDecodeError, + ): if "labels" in selector and "labels" in root_group: group = root_group["labels"] count = 0 @@ -113,25 +117,21 @@ def read_zarr( on_bad_files, location=f"{group.path}/{subgroup_name}", exc_types=( - JSONDecodeError, KeyError, - ValueError, ArrayNotFoundError, - TypeError, OSError, ), ): labels[subgroup_name] = _read_multiscale(elem_group_path, raster_type="labels") count += 1 logger.debug(f"Found {count} elements in {group}") - # now read rest of the data - if "points" in selector and "points" in root_group: - with handle_read_errors( - on_bad_files, - location="points", - exc_types=(JSONDecodeError, MetadataValidationError), - ): + with handle_read_errors( + on_bad_files, + location="points", + exc_types=JSONDecodeError, + ): + if "points" in selector and "points" in root_group: group = root_group["points"] count = 0 for subgroup_name in group: @@ -143,18 +143,18 @@ def read_zarr( with handle_read_errors( on_bad_files, location=f"{group.path}/{subgroup_name}", - exc_types=(JSONDecodeError, KeyError, ArrowInvalid), + exc_types=(KeyError, ArrowInvalid, JSONDecodeError), ): points[subgroup_name] = _read_points(elem_group_path) count += 1 logger.debug(f"Found {count} elements in {group}") - if "shapes" in selector and "shapes" in root_group: - with handle_read_errors( - on_bad_files, - location="shapes", - exc_types=(JSONDecodeError, MetadataValidationError), - ): + with handle_read_errors( + on_bad_files, + location="shapes", + exc_types=JSONDecodeError, + ): + if "shapes" in selector and "shapes" in root_group: group = root_group["shapes"] count = 0 for subgroup_name in group: @@ -168,7 +168,6 @@ def read_zarr( location=f"{group.path}/{subgroup_name}", exc_types=( JSONDecodeError, - ValueError, KeyError, ArrayNotFoundError, ), @@ -180,21 +179,10 @@ def read_zarr( with handle_read_errors( on_bad_files, location="tables", - exc_types=(JSONDecodeError, MetadataValidationError), + exc_types=JSONDecodeError, ): group = root_group["tables"] tables = _read_table(root_store_path, group, tables, on_bad_files=on_bad_files) - if "tables" in selector and "table" in root_group: - with handle_read_errors( - on_bad_files, - location="table", - exc_types=(ValueError,), - ): - raise ValueError( - f"`table` group found in zarr store at location {root_store_path} " - "instead of `tables`. Please update the zarr store to use `tables` " - "instead.", - ) # read attrs metadata attrs = root_group.attrs.asdict() diff --git a/tests/io/test_partial_read.py b/tests/io/test_partial_read.py index 8a57d0a1..27ebb2bb 100644 --- a/tests/io/test_partial_read.py +++ b/tests/io/test_partial_read.py @@ -11,13 +11,15 @@ from pathlib import Path from typing import TYPE_CHECKING +import numpy as np import py import pytest import zarr from pyarrow import ArrowInvalid -from zarr.errors import ArrayNotFoundError +from zarr.errors import ArrayNotFoundError, ZarrUserWarning from spatialdata import SpatialData, read_zarr +from spatialdata._io.format import SpatialDataContainerFormatV01 from spatialdata.datasets import blobs if TYPE_CHECKING: @@ -84,10 +86,55 @@ def session_tmp_path(request: _pytest.fixtures.SubRequest) -> Path: @pytest.fixture(scope="module") -def sdata_with_corrupted_zarr_json(session_tmp_path: Path) -> PartialReadTestCase: +def sdata_with_corrupted_elem_types_zgroup(session_tmp_path: Path) -> PartialReadTestCase: + # Zarr v2 + sdata = blobs() + sdata_path = session_tmp_path / "sdata_with_corrupted_top_level_zgroup.zarr" + # Errors only when no consolidation metadata store is used as this takes precedence over group metadata when reading + sdata.write(sdata_path, sdata_formats=SpatialDataContainerFormatV01(), consolidate_metadata=False) + + (sdata_path / "images" / ".zgroup").unlink() # missing, not detected by reader. So it doesn't raise an exception, + # but it will not be found in the read SpatialData object + (sdata_path / "labels" / ".zgroup").write_text("") # corrupted + (sdata_path / "points" / ".zgroup").write_text("{}") # invalid + not_corrupted = [name for t, name, _ in sdata.gen_elements() if t not in ("images", "labels", "points")] + + return PartialReadTestCase( + path=sdata_path, + expected_elements=not_corrupted, + expected_exceptions=(JSONDecodeError, ZarrUserWarning), + warnings_patterns=["labels: JSONDecodeError", "Object at"], + ) + + +@pytest.fixture(scope="module") +def sdata_with_corrupted_elem_types_zarr_json(session_tmp_path: Path) -> PartialReadTestCase: + # Zarr v2 + sdata = blobs() + sdata_path = session_tmp_path / "sdata_with_corrupted_top_level_zarr_json.zarr" + # Errors only when no consolidation metadata store is used as this takes precedence over group metadata when reading + sdata.write(sdata_path, consolidate_metadata=False) + + (sdata_path / "images" / "zarr.json").unlink() # missing, not detected by reader. So it doesn't raise an exception, + # but it will not be found in the read SpatialData object + (sdata_path / "labels" / "zarr.json").write_text("") # corrupted + (sdata_path / "points" / "zarr.json").write_text('"not_valid": "not_valid"}') # invalid + not_corrupted = [name for t, name, _ in sdata.gen_elements() if t not in ("images", "labels", "points")] + + return PartialReadTestCase( + path=sdata_path, + expected_elements=not_corrupted, + expected_exceptions=(JSONDecodeError), + warnings_patterns=["labels: JSONDecodeError", "Extra data"], + ) + + +@pytest.fixture(scope="module") +def sdata_with_corrupted_zarr_json_elements(session_tmp_path: Path) -> PartialReadTestCase: + # Zarr v3 # zarr.json is a zero-byte file, aborted during write, or contains invalid JSON syntax sdata = blobs() - sdata_path = session_tmp_path / "sdata_with_corrupted_zarr_json.zarr" + sdata_path = session_tmp_path / "sdata_with_corrupted_zarr_json_elements.zarr" sdata.write(sdata_path) corrupted_elements = ["blobs_image", "blobs_labels", "blobs_points", "blobs_polygons", "table"] @@ -95,7 +142,7 @@ def sdata_with_corrupted_zarr_json(session_tmp_path: Path) -> PartialReadTestCas for corrupted_element in corrupted_elements: elem_path = sdata.locate_element(sdata[corrupted_element])[0] (sdata_path / elem_path / "zarr.json").write_bytes(b"") - warnings_patterns.append(f"{elem_path}: JSONDecodeError") + warnings_patterns.append(rf"{elem_path}: (?:OSError|JSONDecodeError):") not_corrupted = [name for _, name, _ in sdata.gen_elements() if name not in corrupted_elements] return PartialReadTestCase( @@ -107,10 +154,34 @@ def sdata_with_corrupted_zarr_json(session_tmp_path: Path) -> PartialReadTestCas @pytest.fixture(scope="module") -def sdata_with_corrupted_image_chunks(session_tmp_path: Path) -> PartialReadTestCase: +def sdata_with_corrupted_zattrs_elements(session_tmp_path: Path) -> PartialReadTestCase: + # Zarr v2 + # .zattrs is a zero-byte file, aborted during write, or contains invalid JSON syntax + sdata = blobs() + sdata_path = session_tmp_path / "sdata_with_corrupted_zattrs_elements.zarr" + sdata.write(sdata_path, sdata_formats=SpatialDataContainerFormatV01()) + + corrupted_elements = ["blobs_image", "blobs_labels", "blobs_points", "blobs_polygons", "table"] + warnings_patterns = [] + for corrupted_element in corrupted_elements: + elem_path = sdata.locate_element(sdata[corrupted_element])[0] + (sdata_path / elem_path / ".zattrs").write_bytes(b"") + warnings_patterns.append(rf"{elem_path}: (?:OSError|JSONDecodeError):") + not_corrupted = [name for _, name, _ in sdata.gen_elements() if name not in corrupted_elements] + + return PartialReadTestCase( + path=sdata_path, + expected_elements=not_corrupted, + expected_exceptions=(OSError, JSONDecodeError), + warnings_patterns=warnings_patterns, + ) + + +@pytest.fixture(scope="module") +def sdata_with_corrupted_image_chunks_zarrv3(session_tmp_path: Path) -> PartialReadTestCase: # images/blobs_image/0 is a zero-byte file or aborted during write sdata = blobs() - sdata_path = session_tmp_path / "sdata_with_corrupted_image_chunks.zarr" + sdata_path = session_tmp_path / "sdata_with_corrupted_image_chunks_zarrv3.zarr" sdata.write(sdata_path) corrupted = "blobs_image" @@ -123,20 +194,37 @@ def sdata_with_corrupted_image_chunks(session_tmp_path: Path) -> PartialReadTest return PartialReadTestCase( path=sdata_path, expected_elements=not_corrupted, - expected_exceptions=( - ArrayNotFoundError, - TypeError, # instead of ArrayNotFoundError, with dask>=2024.10.0 zarr<=2.18.3 - ), - warnings_patterns=[rf"images/{corrupted}: (TypeError)"], - # warnings_patterns=[rf"images/{corrupted}: (ArrayNotFoundError|TypeError)"], + expected_exceptions=(ArrayNotFoundError,), + warnings_patterns=[rf"images/{corrupted}: ArrayNotFoundError"], ) @pytest.fixture(scope="module") -def sdata_with_corrupted_parquet(session_tmp_path: Path) -> PartialReadTestCase: +def sdata_with_corrupted_image_chunks_zarrv2(session_tmp_path: Path) -> PartialReadTestCase: + # images/blobs_image/0 is a zero-byte file or aborted during write + sdata = blobs() + sdata_path = session_tmp_path / "sdata_with_corrupted_image_chunks_zarrv2.zarr" + sdata.write(sdata_path, sdata_formats=SpatialDataContainerFormatV01()) + + corrupted = "blobs_image" + os.unlink(sdata_path / "images" / corrupted / "0" / ".zarray") # it will hide the "0" array from the Zarr reader + os.rename(sdata_path / "images" / corrupted / "0", sdata_path / "images" / corrupted / "0_corrupted") + (sdata_path / "images" / corrupted / "0").touch() + not_corrupted = [name for _, name, _ in sdata.gen_elements() if name != corrupted] + + return PartialReadTestCase( + path=sdata_path, + expected_elements=not_corrupted, + expected_exceptions=(ArrayNotFoundError,), + warnings_patterns=[rf"images/{corrupted}: ArrayNotFoundError"], + ) + + +@pytest.fixture(scope="module") +def sdata_with_corrupted_parquet_zarrv3(session_tmp_path: Path) -> PartialReadTestCase: # points/blobs_points/0 is a zero-byte file or aborted during write sdata = blobs() - sdata_path = session_tmp_path / "sdata_with_corrupted_parquet.zarr" + sdata_path = session_tmp_path / "sdata_with_corrupted_parquet_zarrv3.zarr" sdata.write(sdata_path) corrupted = "blobs_points" @@ -157,10 +245,34 @@ def sdata_with_corrupted_parquet(session_tmp_path: Path) -> PartialReadTestCase: @pytest.fixture(scope="module") -def sdata_with_missing_zarr_json(session_tmp_path: Path) -> PartialReadTestCase: +def sdata_with_corrupted_parquet_zarrv2(session_tmp_path: Path) -> PartialReadTestCase: + # points/blobs_points/0 is a zero-byte file or aborted during write + sdata = blobs() + sdata_path = session_tmp_path / "sdata_with_corrupted_parquet_zarrv2.zarr" + sdata.write(sdata_path, sdata_formats=SpatialDataContainerFormatV01()) + + corrupted = "blobs_points" + os.rename( + sdata_path / "points" / corrupted / "points.parquet", + sdata_path / "points" / corrupted / "points_corrupted.parquet", + ) + (sdata_path / "points" / corrupted / "points.parquet").touch() + + not_corrupted = [name for _, name, _ in sdata.gen_elements() if name != corrupted] + + return PartialReadTestCase( + path=sdata_path, + expected_elements=not_corrupted, + expected_exceptions=ArrowInvalid, + warnings_patterns=[rf"points/{corrupted}: ArrowInvalid"], + ) + + +@pytest.fixture(scope="module") +def sdata_with_missing_zarr_json_element(session_tmp_path: Path) -> PartialReadTestCase: # zarr.json is missing sdata = blobs() - sdata_path = session_tmp_path / "sdata_with_missing_zattrs.zarr" + sdata_path = session_tmp_path / "sdata_with_missing_zarr_json_element.zarr" sdata.write(sdata_path) corrupted = "blobs_image" @@ -171,16 +283,105 @@ def sdata_with_missing_zarr_json(session_tmp_path: Path) -> PartialReadTestCase: path=sdata_path, expected_elements=not_corrupted, expected_exceptions=OSError, - warnings_patterns=[rf"images/{corrupted}: .* Unable to read the NGFF file"], + warnings_patterns=[r"images/blobs_image: OSError:"], ) @pytest.fixture(scope="module") -def sdata_with_invalid_zarr_json_violating_spec(session_tmp_path: Path) -> PartialReadTestCase: - # zarr.json contains readable JSON which is not valid for SpatialData/NGFF specs +def sdata_with_missing_zattrs_element(session_tmp_path: Path) -> PartialReadTestCase: + # Zarrv2 + # .zattrs is missing + sdata = blobs() + sdata_path = session_tmp_path / "sdata_with_missing_zattrs_element.zarr" + sdata.write(sdata_path, sdata_formats=SpatialDataContainerFormatV01()) + + corrupted = "blobs_image" + (sdata_path / "images" / corrupted / ".zattrs").unlink() + not_corrupted = [name for _, name, _ in sdata.gen_elements() if name != corrupted] + + return PartialReadTestCase( + path=sdata_path, + expected_elements=not_corrupted, + expected_exceptions=OSError, + warnings_patterns=["OSError: Image location"], + ) + + +@pytest.fixture(scope="module") +def sdata_with_missing_image_chunks_zarrv3( + session_tmp_path: Path, +) -> PartialReadTestCase: + sdata = blobs() + sdata_path = session_tmp_path / "sdata_with_missing_image_chunks_zarrv3.zarr" + sdata.write(sdata_path) + + corrupted = "blobs_image" + os.unlink(sdata_path / "images" / corrupted / "0" / "zarr.json") + os.rename(sdata_path / "images" / corrupted / "0", sdata_path / "images" / corrupted / "0_corrupted") + + not_corrupted = [name for _, name, _ in sdata.gen_elements() if name != corrupted] + + return PartialReadTestCase( + path=sdata_path, + expected_elements=not_corrupted, + expected_exceptions=(ArrayNotFoundError,), + warnings_patterns=[rf"images/{corrupted}: ArrayNotFoundError"], + ) + + +@pytest.fixture(scope="module") +def sdata_with_missing_image_chunks_zarrv2( + session_tmp_path: Path, +) -> PartialReadTestCase: + # Zarrv2 + # .zattrs exists, but refers to binary array chunks that do not exist + sdata = blobs() + sdata_path = session_tmp_path / "sdata_with_missing_image_chunks_zarrv2.zarr" + sdata.write(sdata_path, sdata_formats=SpatialDataContainerFormatV01()) + + corrupted = "blobs_image" + os.unlink(sdata_path / "images" / corrupted / "0" / ".zarray") + os.rename(sdata_path / "images" / corrupted / "0", sdata_path / "images" / corrupted / "0_corrupted") + + not_corrupted = [name for _, name, _ in sdata.gen_elements() if name != corrupted] + + return PartialReadTestCase( + path=sdata_path, + expected_elements=not_corrupted, + expected_exceptions=(ArrayNotFoundError,), + warnings_patterns=[rf"images/{corrupted}: (ArrayNotFoundError|TypeError)"], + ) + + +@pytest.fixture(scope="module") +def sdata_with_invalid_zattrs_element_violating_spec(session_tmp_path: Path) -> PartialReadTestCase: + # Zarr v2 + # .zattrs contains readable JSON which is not valid for SpatialData/NGFF specs # for example due to a missing/misspelled/renamed key sdata = blobs() sdata_path = session_tmp_path / "sdata_with_invalid_zattrs_violating_spec.zarr" + sdata.write(sdata_path, sdata_formats=SpatialDataContainerFormatV01()) + + corrupted = "blobs_image" + json_dict = json.loads((sdata_path / "images" / corrupted / ".zattrs").read_text()) + del json_dict["multiscales"][0]["coordinateTransformations"] + (sdata_path / "images" / corrupted / ".zattrs").write_text(json.dumps(json_dict, indent=4)) + not_corrupted = [name for _, name, _ in sdata.gen_elements() if name != corrupted] + + return PartialReadTestCase( + path=sdata_path, + expected_elements=not_corrupted, + expected_exceptions=KeyError, + warnings_patterns=[rf"images/{corrupted}: KeyError: coordinateTransformations"], + ) + + +@pytest.fixture(scope="module") +def sdata_with_invalid_zarr_json_element_violating_spec(session_tmp_path: Path) -> PartialReadTestCase: + # zarr.json contains readable JSON which is not valid for SpatialData/NGFF specs + # for example due to a missing/misspelled/renamed key + sdata = blobs() + sdata_path = session_tmp_path / "sdata_with_invalid_zarr_json_violating_spec.zarr" sdata.write(sdata_path) corrupted = "blobs_image" @@ -198,13 +399,67 @@ def sdata_with_invalid_zarr_json_violating_spec(session_tmp_path: Path) -> Parti @pytest.fixture(scope="module") -def sdata_with_table_region_not_found(session_tmp_path: Path) -> PartialReadTestCase: +def sdata_with_invalid_zattrs_table_region_not_found(session_tmp_path: Path) -> PartialReadTestCase: # table/table/.zarr referring to a region that is not found # This has been emitting just a warning, but does not fail reading the table element. sdata = blobs() sdata_path = session_tmp_path / "sdata_with_invalid_zattrs_table_region_not_found.zarr" sdata.write(sdata_path) + corrupted = "blobs_labels" + # The element data is missing + os.unlink(sdata_path / "labels" / corrupted / ".zgroup") + os.rename(sdata_path / "labels" / corrupted, sdata_path / "labels" / f"{corrupted}_corrupted") + # But the labels element is referenced as a region in a table + regions = zarr.open_group(sdata_path / "tables" / "table" / "obs" / "region", mode="r") + assert corrupted in np.asarray(regions.categories)[regions.codes] + not_corrupted = [name for _, name, _ in sdata.gen_elements() if name != corrupted] + + return PartialReadTestCase( + path=sdata_path, + expected_elements=not_corrupted, + expected_exceptions=(), + warnings_patterns=[ + rf"The table is annotating '{re.escape(corrupted)}', which is not present in the SpatialData object" + ], + ) + + +@pytest.fixture(scope="module") +def sdata_with_table_region_not_found_zarrv3(session_tmp_path: Path) -> PartialReadTestCase: + # table/table/.zarr referring to a region that is not found + # This has been emitting just a warning, but does not fail reading the table element. + sdata = blobs() + sdata_path = session_tmp_path / "sdata_with_invalid_table_region_not_found_zarrv3.zarr" + sdata.write(sdata_path) + + corrupted = "blobs_labels" + # The element data is missing + sdata.delete_element_from_disk(corrupted) + # But the labels element is referenced as a region in a table + regions = zarr.open_group(sdata_path / "tables" / "table" / "obs" / "region", mode="r") + arrs = dict(regions.arrays()) + assert corrupted in arrs["categories"][arrs["codes"]] + not_corrupted = [name for _, name, _ in sdata.gen_elements() if name != corrupted] + + return PartialReadTestCase( + path=sdata_path, + expected_elements=not_corrupted, + expected_exceptions=(), + warnings_patterns=[ + rf"The table is annotating '{re.escape(corrupted)}', which is not present in the SpatialData object" + ], + ) + + +@pytest.fixture(scope="module") +def sdata_with_table_region_not_found_zarrv2(session_tmp_path: Path) -> PartialReadTestCase: + # table/table/.zarr referring to a region that is not found + # This has been emitting just a warning, but does not fail reading the table element. + sdata = blobs() + sdata_path = session_tmp_path / "sdata_with_invalid_zattrs_table_region_not_found.zarr" + sdata.write(sdata_path, sdata_formats=SpatialDataContainerFormatV01()) + corrupted = "blobs_labels" # The element data is missing sdata.delete_element_from_disk(corrupted) @@ -227,12 +482,22 @@ def sdata_with_table_region_not_found(session_tmp_path: Path) -> PartialReadTest @pytest.mark.parametrize( "test_case", [ - sdata_with_corrupted_zarr_json, - sdata_with_corrupted_image_chunks, - sdata_with_corrupted_parquet, - sdata_with_missing_zarr_json, - sdata_with_invalid_zarr_json_violating_spec, - sdata_with_table_region_not_found, + sdata_with_corrupted_zattrs_elements, # OSError + sdata_with_corrupted_zarr_json_elements, # OSError + sdata_with_corrupted_image_chunks_zarrv2, # zarr.errors.ArrayNotFoundError + sdata_with_corrupted_image_chunks_zarrv3, # zarr.errors.ArrayNotFoundError + sdata_with_corrupted_parquet_zarrv2, # ArrowInvalid + sdata_with_corrupted_parquet_zarrv3, # ArrowInvalid + sdata_with_missing_zattrs_element, # OSError + sdata_with_missing_zarr_json_element, # OSError + sdata_with_missing_image_chunks_zarrv2, # zarr.errors.ArrayNotFoundError + sdata_with_missing_image_chunks_zarrv3, # zarr.errors.ArrayNotFoundError + sdata_with_invalid_zattrs_element_violating_spec, # KeyError + sdata_with_invalid_zarr_json_element_violating_spec, # KeyError + sdata_with_corrupted_elem_types_zgroup, # JSONDecodeError + sdata_with_corrupted_elem_types_zarr_json, # JSONDecodeError + sdata_with_table_region_not_found_zarrv2, + sdata_with_table_region_not_found_zarrv3, ], indirect=True, ) @@ -247,9 +512,22 @@ def test_read_zarr_with_error(test_case: PartialReadTestCase): @pytest.mark.parametrize( "test_case", [ - # sdata_with_corrupted_parquet, - sdata_with_invalid_zarr_json_violating_spec, - # sdata_with_table_region_not_found, + sdata_with_corrupted_zattrs_elements, # JSONDecodeError for non raster, else OSError + sdata_with_corrupted_zarr_json_elements, # JSONDecodeError for non raster, else OSError + sdata_with_corrupted_image_chunks_zarrv2, # ArrayNotFoundError + sdata_with_corrupted_image_chunks_zarrv3, # ArrayNotFoundError + sdata_with_corrupted_parquet_zarrv2, # ArrowInvalid + sdata_with_corrupted_parquet_zarrv3, # ArrowInvalid + sdata_with_missing_zattrs_element, # OSError + sdata_with_missing_zarr_json_element, # OSError + sdata_with_missing_image_chunks_zarrv2, # ArrayNotFoundError + sdata_with_missing_image_chunks_zarrv3, # ArrayNotFoundError + sdata_with_invalid_zattrs_element_violating_spec, # KeyError + sdata_with_invalid_zarr_json_element_violating_spec, # KeyError + sdata_with_corrupted_elem_types_zgroup, # ZarrUserWarning + sdata_with_corrupted_elem_types_zarr_json, # JSONDecodeError + sdata_with_table_region_not_found_zarrv2, + sdata_with_table_region_not_found_zarrv3, ], indirect=True, ) From c88b488481a51809c2291d42afa8ada6366cd9f0 Mon Sep 17 00:00:00 2001 From: Luca Marconato Date: Fri, 12 Sep 2025 20:53:48 +0200 Subject: [PATCH 121/126] minor changes in test_partial_read(); code review finished --- tests/io/test_partial_read.py | 68 +++++++++++------------------------ 1 file changed, 20 insertions(+), 48 deletions(-) diff --git a/tests/io/test_partial_read.py b/tests/io/test_partial_read.py index 27ebb2bb..e200c1fa 100644 --- a/tests/io/test_partial_read.py +++ b/tests/io/test_partial_read.py @@ -11,7 +11,6 @@ from pathlib import Path from typing import TYPE_CHECKING -import numpy as np import py import pytest import zarr @@ -31,7 +30,7 @@ def pytest_warns_multiple( expected_warning: type[Warning] | tuple[type[Warning], ...] = Warning, matches: Iterable[str] = () ) -> Generator[None, None, None]: """ - Assert that code raises a warnings matching particular patterns. + Assert that code raises warnings matching particular patterns. Like `pytest.warns`, but with multiple patterns which each must match a warning. @@ -109,7 +108,7 @@ def sdata_with_corrupted_elem_types_zgroup(session_tmp_path: Path) -> PartialRea @pytest.fixture(scope="module") def sdata_with_corrupted_elem_types_zarr_json(session_tmp_path: Path) -> PartialReadTestCase: - # Zarr v2 + # Zarr v3 sdata = blobs() sdata_path = session_tmp_path / "sdata_with_corrupted_top_level_zarr_json.zarr" # Errors only when no consolidation metadata store is used as this takes precedence over group metadata when reading @@ -398,33 +397,6 @@ def sdata_with_invalid_zarr_json_element_violating_spec(session_tmp_path: Path) ) -@pytest.fixture(scope="module") -def sdata_with_invalid_zattrs_table_region_not_found(session_tmp_path: Path) -> PartialReadTestCase: - # table/table/.zarr referring to a region that is not found - # This has been emitting just a warning, but does not fail reading the table element. - sdata = blobs() - sdata_path = session_tmp_path / "sdata_with_invalid_zattrs_table_region_not_found.zarr" - sdata.write(sdata_path) - - corrupted = "blobs_labels" - # The element data is missing - os.unlink(sdata_path / "labels" / corrupted / ".zgroup") - os.rename(sdata_path / "labels" / corrupted, sdata_path / "labels" / f"{corrupted}_corrupted") - # But the labels element is referenced as a region in a table - regions = zarr.open_group(sdata_path / "tables" / "table" / "obs" / "region", mode="r") - assert corrupted in np.asarray(regions.categories)[regions.codes] - not_corrupted = [name for _, name, _ in sdata.gen_elements() if name != corrupted] - - return PartialReadTestCase( - path=sdata_path, - expected_elements=not_corrupted, - expected_exceptions=(), - warnings_patterns=[ - rf"The table is annotating '{re.escape(corrupted)}', which is not present in the SpatialData object" - ], - ) - - @pytest.fixture(scope="module") def sdata_with_table_region_not_found_zarrv3(session_tmp_path: Path) -> PartialReadTestCase: # table/table/.zarr referring to a region that is not found @@ -482,22 +454,22 @@ def sdata_with_table_region_not_found_zarrv2(session_tmp_path: Path) -> PartialR @pytest.mark.parametrize( "test_case", [ - sdata_with_corrupted_zattrs_elements, # OSError + sdata_with_corrupted_elem_types_zgroup, # JSONDecodeError + sdata_with_corrupted_elem_types_zarr_json, # JSONDecodeError sdata_with_corrupted_zarr_json_elements, # OSError - sdata_with_corrupted_image_chunks_zarrv2, # zarr.errors.ArrayNotFoundError + sdata_with_corrupted_zattrs_elements, # OSError sdata_with_corrupted_image_chunks_zarrv3, # zarr.errors.ArrayNotFoundError - sdata_with_corrupted_parquet_zarrv2, # ArrowInvalid + sdata_with_corrupted_image_chunks_zarrv2, # zarr.errors.ArrayNotFoundError sdata_with_corrupted_parquet_zarrv3, # ArrowInvalid - sdata_with_missing_zattrs_element, # OSError + sdata_with_corrupted_parquet_zarrv2, # ArrowInvalid sdata_with_missing_zarr_json_element, # OSError - sdata_with_missing_image_chunks_zarrv2, # zarr.errors.ArrayNotFoundError + sdata_with_missing_zattrs_element, # OSError sdata_with_missing_image_chunks_zarrv3, # zarr.errors.ArrayNotFoundError + sdata_with_missing_image_chunks_zarrv2, # zarr.errors.ArrayNotFoundError sdata_with_invalid_zattrs_element_violating_spec, # KeyError sdata_with_invalid_zarr_json_element_violating_spec, # KeyError - sdata_with_corrupted_elem_types_zgroup, # JSONDecodeError - sdata_with_corrupted_elem_types_zarr_json, # JSONDecodeError - sdata_with_table_region_not_found_zarrv2, sdata_with_table_region_not_found_zarrv3, + sdata_with_table_region_not_found_zarrv2, ], indirect=True, ) @@ -512,22 +484,22 @@ def test_read_zarr_with_error(test_case: PartialReadTestCase): @pytest.mark.parametrize( "test_case", [ - sdata_with_corrupted_zattrs_elements, # JSONDecodeError for non raster, else OSError + sdata_with_corrupted_elem_types_zgroup, # JSONDecodeError + sdata_with_corrupted_elem_types_zarr_json, # JSONDecodeError sdata_with_corrupted_zarr_json_elements, # JSONDecodeError for non raster, else OSError - sdata_with_corrupted_image_chunks_zarrv2, # ArrayNotFoundError - sdata_with_corrupted_image_chunks_zarrv3, # ArrayNotFoundError - sdata_with_corrupted_parquet_zarrv2, # ArrowInvalid + sdata_with_corrupted_zattrs_elements, # JSONDecodeError for non raster, else OSError + sdata_with_corrupted_image_chunks_zarrv3, # zarr.errors.ArrayNotFoundError + sdata_with_corrupted_image_chunks_zarrv2, # zarr.errors.ArrayNotFoundError sdata_with_corrupted_parquet_zarrv3, # ArrowInvalid - sdata_with_missing_zattrs_element, # OSError + sdata_with_corrupted_parquet_zarrv2, # ArrowInvalid sdata_with_missing_zarr_json_element, # OSError - sdata_with_missing_image_chunks_zarrv2, # ArrayNotFoundError - sdata_with_missing_image_chunks_zarrv3, # ArrayNotFoundError + sdata_with_missing_zattrs_element, # OSError + sdata_with_missing_image_chunks_zarrv3, # zarr.errors.ArrayNotFoundError + sdata_with_missing_image_chunks_zarrv2, # zarr.errors.ArrayNotFoundError sdata_with_invalid_zattrs_element_violating_spec, # KeyError sdata_with_invalid_zarr_json_element_violating_spec, # KeyError - sdata_with_corrupted_elem_types_zgroup, # ZarrUserWarning - sdata_with_corrupted_elem_types_zarr_json, # JSONDecodeError - sdata_with_table_region_not_found_zarrv2, sdata_with_table_region_not_found_zarrv3, + sdata_with_table_region_not_found_zarrv2, ], indirect=True, ) From 51292d53051a5aa8c5470ef2ac2fd0f625cd6f4b Mon Sep 17 00:00:00 2001 From: Luca Marconato Date: Tue, 16 Sep 2025 14:36:09 +0200 Subject: [PATCH 122/126] support ome-zarr-py master --- src/spatialdata/_io/_utils.py | 14 +++++++++- src/spatialdata/_io/format.py | 8 ++++++ src/spatialdata/_io/io_raster.py | 44 +++++++++++++++++++++++++++----- 3 files changed, 58 insertions(+), 8 deletions(-) diff --git a/src/spatialdata/_io/_utils.py b/src/spatialdata/_io/_utils.py index b501e210..ee6cc1f2 100644 --- a/src/spatialdata/_io/_utils.py +++ b/src/spatialdata/_io/_utils.py @@ -23,6 +23,7 @@ from zarr.storage import FsspecStore, LocalStore from spatialdata._core.spatialdata import SpatialData +from spatialdata._io.format import RasterFormatType, RasterFormatV01, RasterFormatV02, RasterFormatV03 from spatialdata._utils import get_pyramid_levels from spatialdata.models._utils import ( MappingToCoordinateSystem_t, @@ -76,7 +77,10 @@ def overwrite_coordinate_transformations_non_raster( def overwrite_coordinate_transformations_raster( - group: zarr.Group, axes: tuple[ValidAxis_t, ...], transformations: MappingToCoordinateSystem_t + group: zarr.Group, + axes: tuple[ValidAxis_t, ...], + transformations: MappingToCoordinateSystem_t, + raster_format: RasterFormatType | None = None, ) -> None: """Write transformations of raster elements to disk. @@ -120,6 +124,14 @@ def overwrite_coordinate_transformations_raster( multiscale = multiscales[0] multiscale["coordinateTransformations"] = coordinate_transformations + if raster_format is not None: + if isinstance(raster_format, RasterFormatV01 | RasterFormatV02): + multiscale["version"] = raster_format.version + elif isinstance(raster_format, RasterFormatV03): + group.metadata.attributes["ome"]["version"] = raster_format.version + else: + raise ValueError(f"Unsupported raster format: {type(raster_format)}") + group.attrs["multiscales"] = multiscales diff --git a/src/spatialdata/_io/format.py b/src/spatialdata/_io/format.py index 0c1cf0cb..f44a2cf5 100644 --- a/src/spatialdata/_io/format.py +++ b/src/spatialdata/_io/format.py @@ -421,3 +421,11 @@ def _check_modified(element_type: str) -> None: ) return parsed + + +def get_ome_zarr_format(raster_format: RasterFormatType) -> Format: + if isinstance(raster_format, RasterFormatV01 | RasterFormatV02): + return FormatV04() + if isinstance(raster_format, RasterFormatV03): + return FormatV05() + raise ValueError(f"Unsupported raster format {raster_format}") diff --git a/src/spatialdata/_io/io_raster.py b/src/spatialdata/_io/io_raster.py index 80aa07e3..3e4b4232 100644 --- a/src/spatialdata/_io/io_raster.py +++ b/src/spatialdata/_io/io_raster.py @@ -21,6 +21,7 @@ from spatialdata._io.format import ( CurrentRasterFormat, RasterFormatType, + get_ome_zarr_format, ) from spatialdata._utils import get_pyramid_levels from spatialdata.models._utils import get_channel_names @@ -90,7 +91,10 @@ def _read_multiscale(store: str | Path, raster_type: Literal["image", "labels"]) node = nodes[0] loaded_node = node.load(Multiscales) - datasets, multiscales = loaded_node.datasets, loaded_node.zarr.root_attrs["multiscales"] + datasets, multiscales = ( + loaded_node.datasets, + loaded_node.zarr.root_attrs["multiscales"], + ) # This works for all versions as in zarr v3 the level of the 'ome' key is taken as root_attrs. omero_metadata = loaded_node.zarr.root_attrs.get("omero") # TODO: check if below is still valid @@ -186,9 +190,25 @@ def _write_raster( metadata["metadata"]["omero"]["channels"].append({"label": c}) # type: ignore[union-attr, index, call-overload] if isinstance(raster_data, DataArray): - _write_raster_dataarray(raster_type, group, name, raster_data, raster_format, storage_options, **metadata) + _write_raster_dataarray( + raster_type, + group, + name, + raster_data, + raster_format, + storage_options, + **metadata, + ) elif isinstance(raster_data, DataTree): - _write_raster_datatree(raster_type, group, name, raster_data, raster_format, storage_options, **metadata) + _write_raster_datatree( + raster_type, + group, + name, + raster_data, + raster_format, + storage_options, + **metadata, + ) else: raise ValueError("Not a valid labels object") @@ -247,10 +267,11 @@ def _write_raster_dataarray( # We need this because the argument of write_image_ngff is called image while the argument of # write_labels_ngff is called label. metadata[raster_type] = data + ome_zarr_format = get_ome_zarr_format(raster_format) write_single_scale_ngff( group=group, scaler=None, - fmt=raster_format, + fmt=ome_zarr_format, axes=parsed_axes, coordinate_transformations=None, storage_options=storage_options, @@ -258,7 +279,12 @@ def _write_raster_dataarray( ) trans_group = group["labels"][element_name] if raster_type == "labels" else group - overwrite_coordinate_transformations_raster(group=trans_group, transformations=transformations, axes=input_axes) + overwrite_coordinate_transformations_raster( + group=trans_group, + transformations=transformations, + axes=input_axes, + raster_format=raster_format, + ) def _write_raster_datatree( @@ -305,10 +331,11 @@ def _write_raster_datatree( parsed_axes = _get_valid_axes(axes=list(input_axes), fmt=raster_format) storage_options = [{"chunks": chunk} for chunk in chunks] + ome_zarr_format = get_ome_zarr_format(raster_format) dask_delayed = write_multi_scale_ngff( pyramid=data, group=group, - fmt=raster_format, + fmt=ome_zarr_format, axes=parsed_axes, coordinate_transformations=None, storage_options=storage_options, @@ -320,7 +347,10 @@ def _write_raster_datatree( trans_group = group["labels"][element_name] if raster_type == "labels" else group overwrite_coordinate_transformations_raster( - group=trans_group, transformations=transformations, axes=tuple(input_axes) + group=trans_group, + transformations=transformations, + axes=tuple(input_axes), + raster_format=raster_format, ) From d2b0463dc83f622afdc7879786b25e3874bcbb3b Mon Sep 17 00:00:00 2001 From: Luca Marconato Date: Tue, 16 Sep 2025 14:45:56 +0200 Subject: [PATCH 123/126] fix docs --- docs/api/io.md | 1 - src/spatialdata/_core/spatialdata.py | 10 +++++----- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/docs/api/io.md b/docs/api/io.md index 7ae79bf3..714f4ab2 100644 --- a/docs/api/io.md +++ b/docs/api/io.md @@ -7,6 +7,5 @@ use any of the [spatialdata-io readers](https://spatialdata.scverse.org/projects .. currentmodule:: spatialdata .. autofunction:: read_zarr -.. autofunction:: save_transformations .. autofunction:: get_dask_backing_files ``` diff --git a/src/spatialdata/_core/spatialdata.py b/src/spatialdata/_core/spatialdata.py index e7ce4da4..326d4dba 100644 --- a/src/spatialdata/_core/spatialdata.py +++ b/src/spatialdata/_core/spatialdata.py @@ -1132,12 +1132,12 @@ def write( (default yes). Here are the implications. - If `True`, and if the `SpatialData` object has dask-backed elements, the object will become "not - self-contained" because the dask-backed element will have a path that is different from the new - `sdata.path` attribute (now equal to `file_path`). By re-reading the object from disk, the object will - become self-contained again. + self-contained" because the dask-backed element will have a path that is different from the new + `sdata.path` attribute (now equal to `file_path`). By re-reading the object from disk, the object + will become self-contained again. - If `False`, the `SpatialData` object will keep its current `path` attribute, meaning that calling - `sdata.write_element()`, `sdata.write_attrs()`, ... will write the data to the current `sdata.path` - location, not to the `file_path` location. + `sdata.write_element()`, `sdata.write_attrs()`, ... will write the data to the current `sdata.path` + location, not to the `file_path` location. Please consult :func:`spatialdata.SpatialData.is_self_contained` for more information on the implications of working with self-contained and non-self-contained SpatialData objects. From 709bdd8312e388f8b6da18f5774442efd642e9a3 Mon Sep 17 00:00:00 2001 From: Wouter-Michiel Vierdag Date: Wed, 17 Sep 2025 11:24:38 +0200 Subject: [PATCH 124/126] ensure multiscales written correctly --- src/spatialdata/_core/spatialdata.py | 10 ++++++-- src/spatialdata/_io/_utils.py | 15 ++++++++---- src/spatialdata/_io/format.py | 1 + src/spatialdata/_io/io_raster.py | 7 ++++-- src/spatialdata/_io/io_zarr.py | 36 ++++++++++++++++++++++++++-- 5 files changed, 59 insertions(+), 10 deletions(-) diff --git a/src/spatialdata/_core/spatialdata.py b/src/spatialdata/_core/spatialdata.py index 326d4dba..47235501 100644 --- a/src/spatialdata/_core/spatialdata.py +++ b/src/spatialdata/_core/spatialdata.py @@ -1586,15 +1586,21 @@ def write_transformations(self, element_name: str | None = None) -> None: from spatialdata._io._utils import ( overwrite_coordinate_transformations_raster, ) + from spatialdata._io.format import RasterFormats - overwrite_coordinate_transformations_raster(group=element_group, axes=axes, transformations=transformations) + raster_format = RasterFormats[element_group.metadata.attributes["spatialdata_attrs"]["version"]] + overwrite_coordinate_transformations_raster( + group=element_group, axes=axes, transformations=transformations, raster_format=raster_format + ) elif isinstance(element, DaskDataFrame | GeoDataFrame | AnnData): from spatialdata._io._utils import ( overwrite_coordinate_transformations_non_raster, ) overwrite_coordinate_transformations_non_raster( - group=element_group, axes=axes, transformations=transformations + group=element_group, + axes=axes, + transformations=transformations, ) else: raise ValueError(f"Unknown element type {type(element)}") diff --git a/src/spatialdata/_io/_utils.py b/src/spatialdata/_io/_utils.py index ee6cc1f2..20c23627 100644 --- a/src/spatialdata/_io/_utils.py +++ b/src/spatialdata/_io/_utils.py @@ -80,7 +80,7 @@ def overwrite_coordinate_transformations_raster( group: zarr.Group, axes: tuple[ValidAxis_t, ...], transformations: MappingToCoordinateSystem_t, - raster_format: RasterFormatType | None = None, + raster_format: RasterFormatType, ) -> None: """Write transformations of raster elements to disk. @@ -99,6 +99,9 @@ def overwrite_coordinate_transformations_raster( The list with axes names in the same order as the dimensions of the raster element. transformations Mapping between names of the coordinate system and the transformations. + raster_format + The raster format of the raster element used to determine where in the metadata the transformations should be + written. """ _validate_mapping_to_coordinate_system_type(transformations) # prepare the transformations in the dict representation @@ -123,17 +126,21 @@ def overwrite_coordinate_transformations_raster( raise ValueError(f"The length of multiscales metadata should be 1, found length of {len_scales}") multiscale = multiscales[0] + # Previously, there was CoordinateTransformations key present at the level of multiscale and datasets in multiscale. + # This is not the case anymore so we are creating a new key here and keeping the one in datasets intact. multiscale["coordinateTransformations"] = coordinate_transformations if raster_format is not None: if isinstance(raster_format, RasterFormatV01 | RasterFormatV02): multiscale["version"] = raster_format.version + group.attrs["multiscales"] = multiscales elif isinstance(raster_format, RasterFormatV03): - group.metadata.attributes["ome"]["version"] = raster_format.version + ome = group.metadata.attributes["ome"] + ome["version"] = raster_format.version + ome["multiscales"] = multiscales + group.attrs["ome"] = ome else: raise ValueError(f"Unsupported raster format: {type(raster_format)}") - group.attrs["multiscales"] = multiscales - def overwrite_channel_names(group: zarr.Group, element: DataArray | DataTree) -> None: """Write channel metadata to a group.""" diff --git a/src/spatialdata/_io/format.py b/src/spatialdata/_io/format.py index f44a2cf5..3e69b862 100644 --- a/src/spatialdata/_io/format.py +++ b/src/spatialdata/_io/format.py @@ -289,6 +289,7 @@ def spatialdata_format_version(self) -> str: RasterFormatType | ShapesFormatType | PointsFormatType | TablesFormatType | SpatialDataContainerFormatType ) +SdataVersion_to_Format = {"0.4": FormatV04(), "0.4-dev-spatialdata": FormatV04(), "0.5-dev-spatialdata": FormatV05()} RasterFormats: dict[str, RasterFormatType] = { "0.1": RasterFormatV01(), "0.2": RasterFormatV02(), diff --git a/src/spatialdata/_io/io_raster.py b/src/spatialdata/_io/io_raster.py index 3e4b4232..af6578cd 100644 --- a/src/spatialdata/_io/io_raster.py +++ b/src/spatialdata/_io/io_raster.py @@ -4,6 +4,7 @@ import dask.array as da import numpy as np import zarr +from ome_zarr.format import Format from ome_zarr.io import ZarrLocation from ome_zarr.reader import Multiscales, Node, Reader from ome_zarr.types import JSONDict @@ -61,12 +62,14 @@ def _get_multiscale_nodes(image_nodes: list[Node], nodes: list[Node]) -> list[No return nodes -def _read_multiscale(store: str | Path, raster_type: Literal["image", "labels"]) -> DataArray | DataTree: +def _read_multiscale( + store: str | Path, raster_type: Literal["image", "labels"], reader_format: Format +) -> DataArray | DataTree: assert isinstance(store, str | Path) assert raster_type in ["image", "labels"] nodes: list[Node] = [] - image_loc = ZarrLocation(store) + image_loc = ZarrLocation(store, fmt=reader_format) if exists := image_loc.exists(): image_reader = Reader(image_loc)() image_nodes = list(image_reader) diff --git a/src/spatialdata/_io/io_zarr.py b/src/spatialdata/_io/io_zarr.py index 79a827a5..73b0ef6e 100644 --- a/src/spatialdata/_io/io_zarr.py +++ b/src/spatialdata/_io/io_zarr.py @@ -6,6 +6,7 @@ import zarr.storage from anndata import AnnData +from ome_zarr.format import Format from pyarrow import ArrowInvalid from zarr.errors import ArrayNotFoundError @@ -22,6 +23,32 @@ from spatialdata._logging import logger +def get_raster_format_for_read(group: zarr.Group, sdata_version: Literal["0.1", "0.2"]) -> Format: + """Get raster format of stored raster data. + + This checks the image or label element zarr group metadata to retrieve the format that is used by + ome-zarr's ZarrLocation for reading the data. + + Parameters + ---------- + group + The zarr group of the raster element to be read. + sdata_version + The version of the SpatialData zarr store retrieved from the spatialdata attributes. + + Returns + ------- + The ome-zarr format to use for reading the raster element. + """ + from spatialdata._io.format import SdataVersion_to_Format + + if sdata_version == "0.1": + group_version = group.metadata.attributes["multiscales"][0]["version"] + if sdata_version == "0.2": + group_version = group.metadata.attributes["ome"]["version"] + return SdataVersion_to_Format[group_version] + + def read_zarr( store: str | Path, selection: None | tuple[str] = None, @@ -57,6 +84,7 @@ def read_zarr( resolved_store = _resolve_zarr_store(store) root_group = zarr.open_group(resolved_store, mode="r") + sdata_version = root_group.metadata.attributes["spatialdata_attrs"]["version"] root_store_path = root_group.store.root images = {} @@ -83,6 +111,7 @@ def read_zarr( # skip hidden files like .zgroup or .zmetadata continue elem_group = group[subgroup_name] + reader_format = get_raster_format_for_read(elem_group, sdata_version) elem_group_path = os.path.join(root_store_path, elem_group.path) with handle_read_errors( on_bad_files, @@ -93,7 +122,7 @@ def read_zarr( OSError, ), ): - element = _read_multiscale(elem_group_path, raster_type="image") + element = _read_multiscale(elem_group_path, raster_type="image", reader_format=reader_format) images[subgroup_name] = element count += 1 logger.debug(f"Found {count} elements in {group}") @@ -112,6 +141,7 @@ def read_zarr( # skip hidden files like .zgroup or .zmetadata continue elem_group = group[subgroup_name] + reader_format = get_raster_format_for_read(elem_group, sdata_version) elem_group_path = root_store_path / elem_group.path with handle_read_errors( on_bad_files, @@ -122,7 +152,9 @@ def read_zarr( OSError, ), ): - labels[subgroup_name] = _read_multiscale(elem_group_path, raster_type="labels") + labels[subgroup_name] = _read_multiscale( + elem_group_path, raster_type="labels", reader_format=reader_format + ) count += 1 logger.debug(f"Found {count} elements in {group}") # now read rest of the data From 59df1cabe523ac453635b53ebbf265ec018a947f Mon Sep 17 00:00:00 2001 From: Wouter-Michiel Vierdag Date: Wed, 17 Sep 2025 13:49:11 +0200 Subject: [PATCH 125/126] refactor read_zarr (#982) * refactor read_zarr * remove unneccesary checks --- src/spatialdata/_io/io_points.py | 5 + src/spatialdata/_io/io_raster.py | 62 +++++---- src/spatialdata/_io/io_shapes.py | 7 +- src/spatialdata/_io/io_table.py | 14 ++- src/spatialdata/_io/io_zarr.py | 207 ++++++++++++------------------- 5 files changed, 138 insertions(+), 157 deletions(-) diff --git a/src/spatialdata/_io/io_points.py b/src/spatialdata/_io/io_points.py index a251b042..bc52c94b 100644 --- a/src/spatialdata/_io/io_points.py +++ b/src/spatialdata/_io/io_points.py @@ -46,6 +46,11 @@ def _read_points( return points +class PointsReader: + def __call__(self, store: str | Path | MutableMapping[str, object] | zarr.Group) -> DaskDataFrame: + return _read_points(store) + + def write_points( points: DaskDataFrame, group: zarr.Group, diff --git a/src/spatialdata/_io/io_raster.py b/src/spatialdata/_io/io_raster.py index af6578cd..fe7f6b36 100644 --- a/src/spatialdata/_io/io_raster.py +++ b/src/spatialdata/_io/io_raster.py @@ -35,33 +35,6 @@ ) -def _get_multiscale_nodes(image_nodes: list[Node], nodes: list[Node]) -> list[Node]: - """Get nodes with Multiscales spec from a list of nodes. - - The nodes with the Multiscales spec are the nodes used for reading in image and label data. We only have to check - the multiscales now, while before we also had to check the label spec. In the new ome-zarr-py though labels can have - the Label spec, these do not contain the multiscales anymore used to read the data. They can contain label specific - metadata though. - - Parameters - ---------- - image_nodes - List of nodes returned from the ome-zarr-py Reader. - nodes - List to append the nodes with the multiscales spec to. - - Returns - ------- - List of nodes with the multiscales spec. - """ - if len(image_nodes): - for node in image_nodes: - # Labels are now also Multiscales in newer version of ome-zarr-py - if np.any([isinstance(spec, Multiscales) for spec in node.specs]): - nodes.append(node) - return nodes - - def _read_multiscale( store: str | Path, raster_type: Literal["image", "labels"], reader_format: Format ) -> DataArray | DataTree: @@ -134,6 +107,7 @@ def _read_multiscale( msi = DataTree.from_dict(multiscale_image) _set_transformations(msi, transformations) return compute_coordinates(msi) + data = node.load(Multiscales).array(resolution=datasets[0]) si = DataArray( data, @@ -145,6 +119,40 @@ def _read_multiscale( return compute_coordinates(si) +def _get_multiscale_nodes(image_nodes: list[Node], nodes: list[Node]) -> list[Node]: + """Get nodes with Multiscales spec from a list of nodes. + + The nodes with the Multiscales spec are the nodes used for reading in image and label data. We only have to check + the multiscales now, while before we also had to check the label spec. In the new ome-zarr-py though labels can have + the Label spec, these do not contain the multiscales anymore used to read the data. They can contain label specific + metadata though. + + Parameters + ---------- + image_nodes + List of nodes returned from the ome-zarr-py Reader. + nodes + List to append the nodes with the multiscales spec to. + + Returns + ------- + List of nodes with the multiscales spec. + """ + if len(image_nodes): + for node in image_nodes: + # Labels are now also Multiscales in newer version of ome-zarr-py + if np.any([isinstance(spec, Multiscales) for spec in node.specs]): + nodes.append(node) + return nodes + + +class MultiscaleReader: + def __call__( + self, path: str | Path, raster_type: Literal["image", "labels"], reader_format: Format + ) -> DataArray | DataTree: + return _read_multiscale(path, raster_type, reader_format) + + def _write_raster( raster_type: Literal["image", "labels"], raster_data: DataArray | DataTree, diff --git a/src/spatialdata/_io/io_shapes.py b/src/spatialdata/_io/io_shapes.py index 6e14baa4..15efdc47 100644 --- a/src/spatialdata/_io/io_shapes.py +++ b/src/spatialdata/_io/io_shapes.py @@ -30,7 +30,7 @@ def _read_shapes( - store: str | Path | MutableMapping | zarr.Group, # type: ignore[type-arg] + store: str | Path | MutableMapping[str, object] | zarr.Group, ) -> GeoDataFrame: """Read shapes from a zarr store.""" assert isinstance(store, str | Path) @@ -67,6 +67,11 @@ def _read_shapes( return geo_df +class ShapesReader: + def __call__(self, store: str | Path | MutableMapping[str, object] | zarr.Group) -> GeoDataFrame: + return _read_shapes(store) + + def write_shapes( shapes: GeoDataFrame, group: zarr.Group, diff --git a/src/spatialdata/_io/io_table.py b/src/spatialdata/_io/io_table.py index 6a9d191f..9e71c4a3 100644 --- a/src/spatialdata/_io/io_table.py +++ b/src/spatialdata/_io/io_table.py @@ -21,7 +21,7 @@ def _read_table( group: zarr.Group, tables: dict[str, AnnData], on_bad_files: Literal[BadFileHandleMethod.ERROR, BadFileHandleMethod.WARN] = BadFileHandleMethod.ERROR, -) -> dict[str, AnnData]: +) -> None: """ Read in tables in the tables Zarr.group of a SpatialData Zarr store. @@ -85,7 +85,17 @@ def _read_table( count += 1 logger.debug(f"Found {count} elements in {group}") - return tables + + +class TablesReader: + def __call__( + self, + path: str, + group: zarr.Group, + container: dict[str, AnnData], + on_bad_files: Literal[BadFileHandleMethod.ERROR, BadFileHandleMethod.WARN], + ) -> None: + return _read_table(path, group, container, on_bad_files) def write_table( diff --git a/src/spatialdata/_io/io_zarr.py b/src/spatialdata/_io/io_zarr.py index 73b0ef6e..b0e1d32b 100644 --- a/src/spatialdata/_io/io_zarr.py +++ b/src/spatialdata/_io/io_zarr.py @@ -2,7 +2,7 @@ import warnings from json import JSONDecodeError from pathlib import Path -from typing import Literal +from typing import Literal, cast import zarr.storage from anndata import AnnData @@ -16,11 +16,59 @@ _resolve_zarr_store, handle_read_errors, ) -from spatialdata._io.io_points import _read_points -from spatialdata._io.io_raster import _read_multiscale -from spatialdata._io.io_shapes import _read_shapes -from spatialdata._io.io_table import _read_table +from spatialdata._io.io_points import PointsReader +from spatialdata._io.io_raster import MultiscaleReader +from spatialdata._io.io_shapes import ShapesReader +from spatialdata._io.io_table import TablesReader from spatialdata._logging import logger +from spatialdata.models import SpatialElement + +ReadClasses = MultiscaleReader | PointsReader | ShapesReader | TablesReader + + +def _read_zarr_group_spatialdata_element( + root_group: zarr.Group, + root_store_path: str, + sdata_version: Literal["0.1", "0.2"], + selector: set[str], + read_func: ReadClasses, + group_name: Literal["images", "labels", "shapes", "points", "tables"], + element_type: Literal["image", "labels", "shapes", "points", "tables"], + element_container: dict[str, SpatialElement | AnnData], + on_bad_files: Literal[BadFileHandleMethod.ERROR, BadFileHandleMethod.WARN], +) -> None: + with handle_read_errors( + on_bad_files, + location=group_name, + exc_types=JSONDecodeError, + ): + if group_name in selector and group_name in root_group: + group = root_group[group_name] + if isinstance(read_func, TablesReader): + read_func(root_store_path, group, element_container, on_bad_files=on_bad_files) + else: + count = 0 + for subgroup_name in group: + if Path(subgroup_name).name.startswith("."): + # skip hidden files like .zgroup or .zmetadata + continue + elem_group = group[subgroup_name] + elem_group_path = os.path.join(root_store_path, elem_group.path) + with handle_read_errors( + on_bad_files, + location=f"{group.path}/{subgroup_name}", + exc_types=(KeyError, ArrayNotFoundError, OSError, ArrowInvalid, JSONDecodeError), + ): + if isinstance(read_func, MultiscaleReader): + reader_format = get_raster_format_for_read(elem_group, sdata_version) + element = read_func( + elem_group_path, cast(Literal["image", "labels"], element_type), reader_format + ) + if isinstance(read_func, PointsReader | ShapesReader): + element = read_func(elem_group_path) + element_container[subgroup_name] = element + count += 1 + logger.debug(f"Found {count} elements in {group}") def get_raster_format_for_read(group: zarr.Group, sdata_version: Literal["0.1", "0.2"]) -> Format: @@ -87,134 +135,39 @@ def read_zarr( sdata_version = root_group.metadata.attributes["spatialdata_attrs"]["version"] root_store_path = root_group.store.root - images = {} - labels = {} - points = {} + images: dict[str, SpatialElement] = {} + labels: dict[str, SpatialElement] = {} + points: dict[str, SpatialElement] = {} tables: dict[str, AnnData] = {} - shapes = {} + shapes: dict[str, SpatialElement] = {} selector = {"images", "labels", "points", "shapes", "tables"} if not selection else set(selection or []) logger.debug(f"Reading selection {selector}") - # We raise OS errors instead for some read errors now as in zarr v3 with some corruptions nothing will be read. - # related to images / labels. - with handle_read_errors( - on_bad_files, - location="images", - exc_types=JSONDecodeError, - ): - if "images" in selector and "images" in root_group: - group = root_group["images"] - count = 0 - for subgroup_name in group: - if Path(subgroup_name).name.startswith("."): - # skip hidden files like .zgroup or .zmetadata - continue - elem_group = group[subgroup_name] - reader_format = get_raster_format_for_read(elem_group, sdata_version) - elem_group_path = os.path.join(root_store_path, elem_group.path) - with handle_read_errors( - on_bad_files, - location=f"{group.path}/{subgroup_name}", - exc_types=( - KeyError, - ArrayNotFoundError, - OSError, - ), - ): - element = _read_multiscale(elem_group_path, raster_type="image", reader_format=reader_format) - images[subgroup_name] = element - count += 1 - logger.debug(f"Found {count} elements in {group}") - - # read multiscale labels - with handle_read_errors( - on_bad_files, - location="labels", - exc_types=JSONDecodeError, - ): - if "labels" in selector and "labels" in root_group: - group = root_group["labels"] - count = 0 - for subgroup_name in group: - if Path(subgroup_name).name.startswith("."): - # skip hidden files like .zgroup or .zmetadata - continue - elem_group = group[subgroup_name] - reader_format = get_raster_format_for_read(elem_group, sdata_version) - elem_group_path = root_store_path / elem_group.path - with handle_read_errors( - on_bad_files, - location=f"{group.path}/{subgroup_name}", - exc_types=( - KeyError, - ArrayNotFoundError, - OSError, - ), - ): - labels[subgroup_name] = _read_multiscale( - elem_group_path, raster_type="labels", reader_format=reader_format - ) - count += 1 - logger.debug(f"Found {count} elements in {group}") - # now read rest of the data - with handle_read_errors( - on_bad_files, - location="points", - exc_types=JSONDecodeError, - ): - if "points" in selector and "points" in root_group: - group = root_group["points"] - count = 0 - for subgroup_name in group: - elem_group = group[subgroup_name] - if Path(subgroup_name).name.startswith("."): - # skip hidden files like .zgroup or .zmetadata - continue - elem_group_path = os.path.join(root_store_path, elem_group.path) - with handle_read_errors( - on_bad_files, - location=f"{group.path}/{subgroup_name}", - exc_types=(KeyError, ArrowInvalid, JSONDecodeError), - ): - points[subgroup_name] = _read_points(elem_group_path) - count += 1 - logger.debug(f"Found {count} elements in {group}") - - with handle_read_errors( - on_bad_files, - location="shapes", - exc_types=JSONDecodeError, - ): - if "shapes" in selector and "shapes" in root_group: - group = root_group["shapes"] - count = 0 - for subgroup_name in group: - if Path(subgroup_name).name.startswith("."): - # skip hidden files like .zgroup or .zmetadata - continue - elem_group = group[subgroup_name] - elem_group_path = os.path.join(root_store_path, elem_group.path) - with handle_read_errors( - on_bad_files, - location=f"{group.path}/{subgroup_name}", - exc_types=( - JSONDecodeError, - KeyError, - ArrayNotFoundError, - ), - ): - shapes[subgroup_name] = _read_shapes(elem_group_path) - count += 1 - logger.debug(f"Found {count} elements in {group}") - if "tables" in selector and "tables" in root_group: - with handle_read_errors( + group_readers: dict[ + Literal["images", "labels", "shapes", "points", "tables"], + tuple[ + ReadClasses, Literal["image", "labels", "shapes", "points", "tables"], dict[str, SpatialElement | AnnData] + ], + ] = { + "images": (MultiscaleReader(), "image", images), + "labels": (MultiscaleReader(), "labels", labels), + "points": (PointsReader(), "points", points), + "shapes": (ShapesReader(), "shapes", shapes), + "tables": (TablesReader(), "tables", tables), + } + for group_name, (reader, raster_type, container) in group_readers.items(): + _read_zarr_group_spatialdata_element( + root_group, + root_store_path, + sdata_version, + selector, + reader, + group_name, + raster_type, + container, on_bad_files, - location="tables", - exc_types=JSONDecodeError, - ): - group = root_group["tables"] - tables = _read_table(root_store_path, group, tables, on_bad_files=on_bad_files) + ) # read attrs metadata attrs = root_group.attrs.asdict() From 5521a2a0c8fd5af52d4c820cbbe4df26ab318e0e Mon Sep 17 00:00:00 2001 From: Wouter-Michiel Vierdag Date: Wed, 17 Sep 2025 15:08:52 +0200 Subject: [PATCH 126/126] Refactor zarrv3 (#986) * refactor read_zarr * remove unneccesary checks * emit warning with old spatialdata storage version detected --- src/spatialdata/_core/spatialdata.py | 5 ++--- src/spatialdata/_io/io_zarr.py | 7 +++++++ 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/src/spatialdata/_core/spatialdata.py b/src/spatialdata/_core/spatialdata.py index 47235501..3d6a9ed0 100644 --- a/src/spatialdata/_core/spatialdata.py +++ b/src/spatialdata/_core/spatialdata.py @@ -1698,15 +1698,14 @@ def write_metadata( check_valid_name(element_name) if element_name not in self: raise ValueError(f"Element with name {element_name} not found in SpatialData object.") + if write_attrs: + self.write_attrs(sdata_format=sdata_format) self.write_transformations(element_name) self.write_channel_names(element_name) # TODO: write .uns['spatialdata_attrs'] metadata for AnnData. # TODO: write .attrs['spatialdata_attrs'] metadata for DaskDataFrame. - if write_attrs: - self.write_attrs(sdata_format=sdata_format) - if self.has_consolidated_metadata(): consolidate_metadata = True if consolidate_metadata: diff --git a/src/spatialdata/_io/io_zarr.py b/src/spatialdata/_io/io_zarr.py index b0e1d32b..5367dcf0 100644 --- a/src/spatialdata/_io/io_zarr.py +++ b/src/spatialdata/_io/io_zarr.py @@ -133,6 +133,13 @@ def read_zarr( resolved_store = _resolve_zarr_store(store) root_group = zarr.open_group(resolved_store, mode="r") sdata_version = root_group.metadata.attributes["spatialdata_attrs"]["version"] + if sdata_version == "0.1": + warnings.warn( + "SpatialData is not stored in the most current format. If you want to use Zarr v3" + ", please write the store to a new location.", + UserWarning, + stacklevel=2, + ) root_store_path = root_group.store.root images: dict[str, SpatialElement] = {}