diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 7831b0b4b..1003fddf1 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -18,15 +18,17 @@ jobs: strategy: fail-fast: false matrix: - python: ["3.10", "3.12"] + python: ["3.11", "3.13"] os: [ubuntu-latest] include: - os: macos-latest - python: "3.10" + python: "3.11" - os: macos-latest python: "3.12" pip-flags: "--pre" name: "Python 3.12 (pre-release)" + - os: windows-latest + python: "3.11" env: OS: ${{ matrix.os }} diff --git a/.mypy.ini b/.mypy.ini index 0eee2044e..78edd09ca 100644 --- a/.mypy.ini +++ b/.mypy.ini @@ -1,5 +1,5 @@ [mypy] -python_version = 3.10 +python_version = 3.11 ignore_errors = False warn_redundant_casts = True diff --git a/.readthedocs.yaml b/.readthedocs.yaml index ab6cb4fbe..acecf90e6 100644 --- a/.readthedocs.yaml +++ b/.readthedocs.yaml @@ -3,7 +3,7 @@ version: 2 build: os: ubuntu-20.04 tools: - python: "3.10" + python: "3.11" sphinx: configuration: docs/conf.py fail_on_warning: true diff --git a/docs/api/data_formats.md b/docs/api/data_formats.md index 825d2dfdc..816382508 100644 --- a/docs/api/data_formats.md +++ b/docs/api/data_formats.md @@ -1,7 +1,7 @@ # Data formats (advanced) -The SpatialData format is defined as a set of versioned subclasses of `spatialdata._io.format.SpatialDataFormat`, one per type of element. -These classes are useful to ensure backward compatibility whenever a major version change is introduced. We also provide pointers to the latest format. +The SpatialData format is defined as a set of versioned subclasses of `ome_zarr.format.Format`, one per type of element. The `spatialdata.SpatialDataFormatType` is a union type encompassing the possible valid formats. +These format subclasses are useful to ensure backward compatibility whenever a major version change is introduced. We also provide pointers to the latest format. ```{eval-rst} .. currentmodule:: spatialdata._io.format diff --git a/docs/api/io.md b/docs/api/io.md index 7ae79bf36..714f4ab22 100644 --- a/docs/api/io.md +++ b/docs/api/io.md @@ -7,6 +7,5 @@ use any of the [spatialdata-io readers](https://spatialdata.scverse.org/projects .. currentmodule:: spatialdata .. autofunction:: read_zarr -.. autofunction:: save_transformations .. autofunction:: get_dask_backing_files ``` diff --git a/pyproject.toml b/pyproject.toml index 1f16398d9..766f75152 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,7 +15,7 @@ maintainers = [ urls.Documentation = "https://spatialdata.scverse.org/en/latest" urls.Source = "https://github.com/scverse/spatialdata.git" urls.Home-page = "https://github.com/scverse/spatialdata.git" -requires-python = ">=3.10" +requires-python = ">=3.11" dynamic= [ "version" # allow version to be set by git tags ] @@ -25,15 +25,15 @@ dependencies = [ "anndata>=0.9.1", "click", "dask-image", - "dask>=2024.4.1,<=2024.11.2", + "dask>=2024.10.0,<=2024.11.2", "datashader", - "fsspec", + "fsspec[s3,http]", "geopandas>=0.14", "multiscale_spatial_image>=2.0.3", "networkx", "numba>=0.55.0", "numpy", - "ome_zarr>=0.8.4", + "ome_zarr>=0.12.2", "pandas", "pooch", "pyarrow", @@ -44,15 +44,17 @@ dependencies = [ "scikit-image", "scipy", "typing_extensions>=4.8.0", + "universal_pathlib>=0.2.6", "xarray>=2024.10.0", "xarray-schema", "xarray-spatial>=0.3.5", - "zarr<3", + "zarr>=3.0.0", ] [project.optional-dependencies] dev = [ "bump2version", + "sentry-prevent-cli", ] test = [ "pytest", @@ -97,7 +99,13 @@ xfail_strict = true addopts = [ # "-Werror", # if 3rd party libs raise DeprecationWarnings, just use filterwarnings below "--import-mode=importlib", # allow using test files with same name - "-s" # print output from tests + "-s", # print output from tests +] +# These are all markers coming from xarray, dask or anndata. Added here to silence warnings. +markers = [ + "slow: marks tests as slow (deselect with '-m \"not slow\"')", + "gpu: run test on GPU using CuPY.", + "skip_with_pyarrow_strings: skipwhen pyarrow string conversion is turned on", ] # info on how to use this https://stackoverflow.com/questions/57925071/how-do-i-avoid-getting-deprecationwarning-from-inside-dependencies-with-pytest filterwarnings = [ @@ -131,7 +139,7 @@ exclude = [ ] line-length = 120 -target-version = "py310" +target-version = "py311" [tool.ruff.lint] ignore = [ diff --git a/src/spatialdata/__init__.py b/src/spatialdata/__init__.py index 3a683c661..5d84e172b 100644 --- a/src/spatialdata/__init__.py +++ b/src/spatialdata/__init__.py @@ -42,10 +42,10 @@ "SpatialData", "get_extent", "get_centroids", + "SpatialDataFormatType", "read_zarr", "unpad_raster", "get_pyramid_levels", - "save_transformations", "get_dask_backing_files", "are_extents_equal", "relabel_sequential", @@ -79,7 +79,7 @@ ) from spatialdata._core.query.spatial_query import bounding_box_query, polygon_query from spatialdata._core.spatialdata import SpatialData -from spatialdata._io._utils import get_dask_backing_files, save_transformations -from spatialdata._io.format import SpatialDataFormat +from spatialdata._io._utils import get_dask_backing_files +from spatialdata._io.format import SpatialDataFormatType from spatialdata._io.io_zarr import read_zarr from spatialdata._utils import get_pyramid_levels, unpad_raster diff --git a/src/spatialdata/_core/_deepcopy.py b/src/spatialdata/_core/_deepcopy.py index d55db7561..6a5b43367 100644 --- a/src/spatialdata/_core/_deepcopy.py +++ b/src/spatialdata/_core/_deepcopy.py @@ -46,7 +46,7 @@ def _(sdata: SpatialData) -> SpatialData: for _, element_name, element in sdata.gen_elements(): elements_dict[element_name] = deepcopy(element) deepcopied_attrs = _deepcopy(sdata.attrs) - return SpatialData.from_elements_dict(elements_dict, attrs=deepcopied_attrs) + return SpatialData.init_from_elements(elements_dict, attrs=deepcopied_attrs) @deepcopy.register(DataArray) diff --git a/src/spatialdata/_core/_elements.py b/src/spatialdata/_core/_elements.py index d1c2f4a57..99ff9d33d 100644 --- a/src/spatialdata/_core/_elements.py +++ b/src/spatialdata/_core/_elements.py @@ -3,7 +3,6 @@ from collections import UserDict from collections.abc import Iterable, KeysView, ValuesView from typing import TypeVar -from warnings import warn from anndata import AnnData from dask.dataframe import DataFrame as DaskDataFrame @@ -41,9 +40,7 @@ def _remove_shared_key(self, key: str) -> None: @staticmethod def _check_key(key: str, element_keys: Iterable[str], shared_keys: set[str | None]) -> None: check_valid_name(key) - if key in element_keys: - warn(f"Key `{key}` already exists. Overwriting it in-memory.", UserWarning, stacklevel=2) - else: + if key not in element_keys: try: check_key_is_case_insensitively_unique(key, shared_keys) except ValueError as e: diff --git a/src/spatialdata/_core/concatenate.py b/src/spatialdata/_core/concatenate.py index 953e4f2f3..90b897c71 100644 --- a/src/spatialdata/_core/concatenate.py +++ b/src/spatialdata/_core/concatenate.py @@ -68,8 +68,7 @@ def _concatenate_tables( TableModel.INSTANCE_KEY: instance_key, } merged_table.uns[TableModel.ATTRS_KEY] = attrs - - return TableModel().validate(merged_table) + return TableModel().parse(merged_table) def concatenate( @@ -252,6 +251,8 @@ def _fix_ensure_unique_element_names( tables_by_sdata.append(tables) sdatas_fixed = [] for elements, tables in zip(elements_by_sdata, tables_by_sdata, strict=True): - sdata = SpatialData.init_from_elements(elements, tables=tables) + if tables is not None: + elements.update(tables) + sdata = SpatialData.init_from_elements(elements) sdatas_fixed.append(sdata) return sdatas_fixed diff --git a/src/spatialdata/_core/operations/_utils.py b/src/spatialdata/_core/operations/_utils.py index 2879af456..d3c438abb 100644 --- a/src/spatialdata/_core/operations/_utils.py +++ b/src/spatialdata/_core/operations/_utils.py @@ -136,7 +136,7 @@ def transform_to_data_extent( set_transformation(el, transformation={coordinate_system: Identity()}, set_all=True) for k, v in sdata.tables.items(): sdata_to_return_elements[k] = v.copy() - return SpatialData.from_elements_dict(sdata_to_return_elements, attrs=sdata.attrs) + return SpatialData.init_from_elements(sdata_to_return_elements, attrs=sdata.attrs) def _parse_element( diff --git a/src/spatialdata/_core/operations/aggregate.py b/src/spatialdata/_core/operations/aggregate.py index 83a3d6f9a..131b8d7f3 100644 --- a/src/spatialdata/_core/operations/aggregate.py +++ b/src/spatialdata/_core/operations/aggregate.py @@ -232,7 +232,7 @@ def _create_sdata_from_table_and_shapes( f"Instance key column dtype in table resulting from aggregation cannot be cast to the dtype of" f"element {shapes_name}.index" ) from err - table.obs[region_key] = shapes_name + table.obs[region_key] = pd.Categorical([shapes_name] * len(table)) table = TableModel.parse(table, region=shapes_name, region_key=region_key, instance_key=instance_key) # labels case, needs conversion from str to int @@ -242,7 +242,7 @@ def _create_sdata_from_table_and_shapes( if deepcopy: shapes = _deepcopy(shapes) - return SpatialData.from_elements_dict({shapes_name: shapes, table_name: table}) + return SpatialData.init_from_elements({shapes_name: shapes, table_name: table}) def _aggregate_image_by_labels( @@ -440,7 +440,7 @@ def _aggregate_shapes( vk = value_key[0] if fractions_of_values is not None: joined[ONES_COLUMN] = fractions_of_values - aggregated = joined.groupby([INDEX, vk])[ONES_COLUMN].agg(agg_func).reset_index() + aggregated = joined.groupby([INDEX, vk], observed=False)[ONES_COLUMN].agg(agg_func).reset_index() aggregated_values = aggregated[ONES_COLUMN].values else: if fractions_of_values is not None: @@ -478,7 +478,7 @@ def _aggregate_shapes( anndata = ad.AnnData( X, - obs=pd.DataFrame(index=rows_categories), + obs=pd.DataFrame(index=list(map(str, rows_categories))), var=pd.DataFrame(index=columns_categories), ) diff --git a/src/spatialdata/_core/operations/rasterize_bins.py b/src/spatialdata/_core/operations/rasterize_bins.py index 37558e1b7..174708124 100644 --- a/src/spatialdata/_core/operations/rasterize_bins.py +++ b/src/spatialdata/_core/operations/rasterize_bins.py @@ -132,7 +132,8 @@ def rasterize_bins( random_indices = RNG.choice(table.n_obs, min(20, table.n_obs), replace=True) location_ids = table.obs[instance_key].iloc[random_indices].values - sub_df, sub_table = element.loc[location_ids], table[random_indices] + sub_df = element.loc[location_ids] + sub_table = table[random_indices] src = np.stack([sub_table.obs[col_key] - min_col, sub_table.obs[row_key] - min_row], axis=1) if isinstance(sub_df, GeoDataFrame): @@ -281,7 +282,7 @@ def rasterize_bins_link_table_to_labels(sdata: SpatialData, table_name: str, ras The name of the rasterized labels in the spatial data object. """ _, region_key, instance_key = get_table_keys(sdata[table_name]) - sdata[table_name].obs[region_key] = rasterized_labels_name + sdata[table_name].obs[region_key] = pd.Categorical([rasterized_labels_name] * sdata[table_name].n_obs) relabled_instance_key = _get_relabeled_column_name(instance_key) sdata.set_table_annotates_spatialelement( table_name=table_name, region=rasterized_labels_name, region_key=region_key, instance_key=relabled_instance_key diff --git a/src/spatialdata/_core/operations/vectorize.py b/src/spatialdata/_core/operations/vectorize.py index a4f904dc7..40d3e31f4 100644 --- a/src/spatialdata/_core/operations/vectorize.py +++ b/src/spatialdata/_core/operations/vectorize.py @@ -272,7 +272,7 @@ def _(gdf: GeoDataFrame, buffer_resolution: int = 16) -> GeoDataFrame: if isinstance(gdf.geometry.iloc[0], Point): buffered_df = gdf.copy() buffered_df["geometry"] = buffered_df.apply( - lambda row: row.geometry.buffer(row[ShapesModel.RADIUS_KEY], resolution=buffer_resolution), axis=1 + lambda row: row.geometry.buffer(row[ShapesModel.RADIUS_KEY], quad_segs=buffer_resolution), axis=1 ) # Ensure the GeoDataFrame recognizes the updated geometry column diff --git a/src/spatialdata/_core/query/relational_query.py b/src/spatialdata/_core/query/relational_query.py index 0803158ca..247bffebf 100644 --- a/src/spatialdata/_core/query/relational_query.py +++ b/src/spatialdata/_core/query/relational_query.py @@ -679,7 +679,7 @@ def join_spatialelement_table( if sdata is not None: elements_dict = _create_sdata_elements_dict_for_join(sdata, spatial_element_names) else: - derived_sdata = SpatialData.from_elements_dict(dict(zip(spatial_element_names, spatial_elements, strict=True))) + derived_sdata = SpatialData.init_from_elements(dict(zip(spatial_element_names, spatial_elements, strict=True))) element_types = ["labels", "shapes", "points"] elements_dict = defaultdict(lambda: defaultdict(dict)) for element_type in element_types: @@ -738,14 +738,6 @@ def match_table_to_element(sdata: SpatialData, element_name: str, table_name: st match_element_to_table : Function to match a spatial element to a table. join_spatialelement_table : General function, to join spatial elements with a table with more control. """ - if table_name is None: - warnings.warn( - "Assumption of table with name `table` being present is being deprecated in SpatialData v0.1. " - "Please provide the name of the table as argument to table_name.", - DeprecationWarning, - stacklevel=2, - ) - table_name = "table" _, table = join_spatialelement_table( sdata=sdata, spatial_element_names=element_name, table_name=table_name, how="left", match_rows="left" ) diff --git a/src/spatialdata/_core/query/spatial_query.py b/src/spatialdata/_core/query/spatial_query.py index 0bb639d2f..e6dccb458 100644 --- a/src/spatialdata/_core/query/spatial_query.py +++ b/src/spatialdata/_core/query/spatial_query.py @@ -797,19 +797,6 @@ def _( return queried_polygons -# TODO: we can replace the manually triggered deprecation warning heres with the decorator from Wouter -def _check_deprecated_kwargs(kwargs: dict[str, Any]) -> None: - deprecated_args = ["shapes", "points", "images", "labels"] - for arg in deprecated_args: - if arg in kwargs and kwargs[arg] is False: - warnings.warn( - f"The '{arg}' argument is deprecated and will be removed in one of the next following releases. Please " - f"filter the SpatialData object before calling this function.", - DeprecationWarning, - stacklevel=2, - ) - - @singledispatch def polygon_query( element: SpatialElement | SpatialData, @@ -817,10 +804,6 @@ def polygon_query( target_coordinate_system: str, filter_table: bool = True, clip: bool = False, - shapes: bool = True, - points: bool = True, - images: bool = True, - labels: bool = True, ) -> SpatialElement | SpatialData | None: """ Query a SpatialData object or a SpatialElement by a polygon or multipolygon. @@ -845,18 +828,6 @@ def polygon_query( Importantly, when clipping is enabled, the circles will be converted to polygons before the clipping. This may affect downstream operations that rely on the circle radius or on performance, so it is recommended to disable clipping when querying circles or when querying a `SpatialData` object that contains circles. - shapes [Deprecated] - This argument is now ignored and will be removed. Please filter the SpatialData object before calling this - function. - points [Deprecated] - This argument is now ignored and will be removed. Please filter the SpatialData object before calling this - function. - images [Deprecated] - This argument is now ignored and will be removed. Please filter the SpatialData object before calling this - function. - labels [Deprecated] - This argument is now ignored and will be removed. Please filter the SpatialData object before calling this - function. Returns ------- @@ -879,12 +850,7 @@ def _( target_coordinate_system: str, filter_table: bool = True, clip: bool = False, - shapes: bool = True, - points: bool = True, - images: bool = True, - labels: bool = True, ) -> SpatialData: - _check_deprecated_kwargs({"shapes": shapes, "points": points, "images": images, "labels": labels}) new_elements = {} for element_type in ["points", "images", "labels", "shapes"]: elements = getattr(sdata, element_type) @@ -911,7 +877,6 @@ def _( return_request_only: bool = False, **kwargs: Any, ) -> DataArray | DataTree | None: - _check_deprecated_kwargs(kwargs) gdf = GeoDataFrame(geometry=[polygon]) min_x, min_y, max_x, max_y = gdf.bounds.values.flatten().tolist() return bounding_box_query( @@ -933,7 +898,6 @@ def _( ) -> DaskDataFrame | None: from spatialdata.transformations import get_transformation, set_transformation - _check_deprecated_kwargs(kwargs) polygon_gdf = _get_polygon_in_intrinsic_coordinates(points, target_coordinate_system, polygon) points_gdf = points_dask_dataframe_to_geopandas(points, suppress_z_warning=True) @@ -966,7 +930,6 @@ def _( ) -> GeoDataFrame | None: from spatialdata.transformations import get_transformation, set_transformation - _check_deprecated_kwargs(kwargs) polygon_gdf = _get_polygon_in_intrinsic_coordinates(element, target_coordinate_system, polygon) polygon = polygon_gdf["geometry"].iloc[0] diff --git a/src/spatialdata/_core/spatialdata.py b/src/spatialdata/_core/spatialdata.py index 48f6386ca..3d6a9ed06 100644 --- a/src/spatialdata/_core/spatialdata.py +++ b/src/spatialdata/_core/spatialdata.py @@ -16,10 +16,9 @@ from dask.dataframe import read_parquet from dask.delayed import Delayed from geopandas import GeoDataFrame -from ome_zarr.io import parse_url -from ome_zarr.types import JSONDict from shapely import MultiPolygon, Polygon from xarray import DataArray, DataTree +from zarr.errors import GroupNotFoundError from spatialdata._core._elements import Images, Labels, Points, Shapes, Tables from spatialdata._core.validation import ( @@ -31,10 +30,7 @@ ) from spatialdata._logging import logger from spatialdata._types import ArrayLike, Raster_T -from spatialdata._utils import ( - _deprecation_alias, - _error_message_add_element, -) +from spatialdata._utils import _deprecation_alias from spatialdata.models import ( Image2DModel, Image3DModel, @@ -55,7 +51,10 @@ if TYPE_CHECKING: from spatialdata._core.query.spatial_query import BaseSpatialRequest - from spatialdata._io.format import SpatialDataFormat + from spatialdata._io.format import ( + SpatialDataContainerFormatType, + SpatialDataFormatType, + ) # schema for elements Label2D_s = Labels2DModel() @@ -121,7 +120,6 @@ class SpatialData: annotation directly. """ - @_deprecation_alias(table="tables", version="0.1.0") def __init__( self, images: dict[str, Raster_T] | None = None, @@ -141,10 +139,6 @@ def __init__( self._tables: Tables = Tables(shared_keys=self._shared_keys) self.attrs = attrs if attrs else {} # type: ignore[assignment] - # Workaround to allow for backward compatibility - if isinstance(tables, AnnData): - tables = {"table": tables} - element_names = list(chain.from_iterable([e.keys() for e in [images, labels, points, shapes] if e is not None])) if len(element_names) != len(set(element_names)): @@ -234,33 +228,6 @@ def validate_table_in_spatialdata(self, table: AnnData) -> None: f"the annotated element ({dtype})." ) - @staticmethod - def from_elements_dict( - elements_dict: dict[str, SpatialElement | AnnData], attrs: Mapping[Any, Any] | None = None - ) -> SpatialData: - """ - Create a SpatialData object from a dict of elements. - - Parameters - ---------- - elements_dict - Dict of elements. The keys are the names of the elements and the values are the elements. - A table can be present in the dict, but only at most one; its name is not used and can be anything. - attrs - Additional attributes to store in the SpatialData object. - - Returns - ------- - The SpatialData object. - """ - warnings.warn( - 'This method is deprecated and will be removed in a future release. Use "SpatialData.init_from_elements(' - ')" instead. For the moment, such methods will be automatically called.', - DeprecationWarning, - stacklevel=2, - ) - return SpatialData.init_from_elements(elements=elements_dict, attrs=attrs) - @staticmethod def get_annotated_regions(table: AnnData) -> list[str]: """ @@ -275,7 +242,9 @@ def get_annotated_regions(table: AnnData) -> list[str]: ------- The annotated regions. """ - from spatialdata.models.models import _get_region_metadata_from_region_key_column + from spatialdata.models.models import ( + _get_region_metadata_from_region_key_column, + ) return _get_region_metadata_from_region_key_column(table) @@ -328,11 +297,12 @@ def get_instance_key_column(table: AnnData) -> pd.Series: raise KeyError(f"{instance_key} is set as instance key column. However the column is not found in table.obs.") def set_channel_names(self, element_name: str, channel_names: str | list[str], write: bool = False) -> None: - """Set the channel names for a image `SpatialElement` in the `SpatialData` object. + """Set the channel names for an image `SpatialElement` in the `SpatialData` object. - This method assumes that the `SpatialData` object and the element are already stored on disk as it will - also overwrite the channel names metadata on disk. In case either the `SpatialData` object or the - element are not stored on disk, please use `SpatialData.set_image_channel_names` instead. + This method will overwrite the element in memory with the same element, but with new channel names. + If 'write` is 'True', this method assumes that the `SpatialData` object and the element are already stored on + disk as it will also overwrite the channel names metadata on disk. If you do not want to overwrite the element + on disk, or it is not stored, set `write` to False (default). Parameters ---------- @@ -341,7 +311,8 @@ def set_channel_names(self, element_name: str, channel_names: str | list[str], w channel_names The channel names to be assigned to the c dimension of the image `SpatialElement`. write - Whether to overwrite the channel metadata on disk. + Whether to overwrite the channel metadata on disk (lightweight operation). This will not rewrite the pixel + data itself (heavy operation). """ self.images[element_name] = set_channel_names(self.images[element_name], channel_names) if write: @@ -593,66 +564,6 @@ def path(self, value: Path | None) -> None: else: raise TypeError("Path must be `None`, a `str` or a `Path` object.") - if not self.is_self_contained(): - logger.info( - "The SpatialData object is not self-contained (i.e. it contains some elements that are Dask-backed from" - f" locations outside {self.path}). Please see the documentation of `is_self_contained()` to understand" - f" the implications of working with SpatialData objects that are not self-contained." - ) - - def _get_groups_for_element( - self, zarr_path: Path, element_type: str, element_name: str - ) -> tuple[zarr.Group, zarr.Group, zarr.Group]: - """ - Get the Zarr groups for the root, element_type and element for a specific element. - - The store must exist, but creates the element type group and the element group if they don't exist. - - Parameters - ---------- - zarr_path - The path to the Zarr storage. - element_type - type of the element; must be in ["images", "labels", "points", "polygons", "shapes", "tables"]. - element_name - name of the element - - Returns - ------- - either the existing Zarr subgroup or a new one. - """ - if not isinstance(zarr_path, Path): - raise ValueError("zarr_path should be a Path object") - store = parse_url(zarr_path, mode="r+").store - root = zarr.group(store=store) - if element_type not in ["images", "labels", "points", "polygons", "shapes", "tables"]: - raise ValueError(f"Unknown element type {element_type}") - element_type_group = root.require_group(element_type) - element_name_group = element_type_group.require_group(element_name) - return root, element_type_group, element_name_group - - def _group_for_element_exists(self, zarr_path: Path, element_type: str, element_name: str) -> bool: - """ - Check if the group for an element exists. - - Parameters - ---------- - element_type - type of the element; must be in ["images", "labels", "points", "polygons", "shapes", "tables"]. - element_name - name of the element - - Returns - ------- - True if the group exists, False otherwise. - """ - store = parse_url(zarr_path, mode="r").store - root = zarr.group(store=store) - assert element_type in ["images", "labels", "points", "polygons", "shapes", "tables"] - exists = element_type in root and element_name in root[element_type] - store.close() - return exists - def locate_element(self, element: SpatialElement) -> list[str]: """ Locate a SpatialElement within the SpatialData object and returns its Zarr paths relative to the root. @@ -682,9 +593,11 @@ def locate_element(self, element: SpatialElement) -> list[str]: raise ValueError("Found an element name with a '/' character. This is not allowed.") return [f"{found_element_type[i]}/{found_element_name[i]}" for i in range(len(found))] - @_deprecation_alias(filter_table="filter_tables", version="0.1.0") def filter_by_coordinate_system( - self, coordinate_system: str | list[str], filter_tables: bool = True, include_orphan_tables: bool = False + self, + coordinate_system: str | list[str], + filter_tables: bool = True, + include_orphan_tables: bool = False, ) -> SpatialData: """ Filter the SpatialData by one (or a list of) coordinate system. @@ -726,7 +639,11 @@ def filter_by_coordinate_system( elements[element_type][element_name] = element element_names_in_coordinate_system.append(element_name) tables = self._filter_tables( - set(), filter_tables, "cs", include_orphan_tables, element_names=element_names_in_coordinate_system + set(), + filter_tables, + "cs", + include_orphan_tables, + element_names=element_names_in_coordinate_system, ) return SpatialData(**elements, tables=tables, attrs=self.attrs) @@ -779,14 +696,18 @@ def _filter_tables( continue # each mode here requires paths or elements, using assert here to avoid mypy errors. if by == "cs": - from spatialdata._core.query.relational_query import _filter_table_by_element_names + from spatialdata._core.query.relational_query import ( + _filter_table_by_element_names, + ) assert element_names is not None table = _filter_table_by_element_names(table, element_names) if len(table) != 0: tables[table_name] = table elif by == "elements": - from spatialdata._core.query.relational_query import _filter_table_by_elements + from spatialdata._core.query.relational_query import ( + _filter_table_by_elements, + ) assert elements_dict is not None table = _filter_table_by_elements(table, elements_dict=elements_dict) @@ -811,7 +732,10 @@ def rename_coordinate_systems(self, rename_dict: dict[str, str]) -> None: The method does not allow to rename a coordinate system into an existing one, unless the existing one is also renamed in the same call. """ - from spatialdata.transformations.operations import get_transformation, set_transformation + from spatialdata.transformations.operations import ( + get_transformation, + set_transformation, + ) # check that the rename_dict is valid old_names = self.coordinate_systems @@ -853,9 +777,11 @@ def rename_coordinate_systems(self, rename_dict: dict[str, str]) -> None: # set the new transformations set_transformation(element=element, transformation=new_transformations, set_all=True) - @_deprecation_alias(element="element_name", version="0.3.0") def transform_element_to_coordinate_system( - self, element_name: str, target_coordinate_system: str, maintain_positioning: bool = False + self, + element_name: str, + target_coordinate_system: str, + maintain_positioning: bool = False, ) -> SpatialElement: """ Transform an element to a given coordinate system. @@ -884,18 +810,7 @@ def transform_element_to_coordinate_system( set_transformation, ) - # TODO remove after deprecation - if not isinstance(element_name, str): - warnings.warn( - "Passing a SpatialElement is as element will be deprecated in SpatialData v0.3.0. Pass" - "element_name as string to silence this warning.", - DeprecationWarning, - stacklevel=2, - ) - element = element_name - else: - element = self.get(element_name) - + element = self.get(element_name) t = get_transformation_between_coordinate_systems(self, element, target_coordinate_system) if maintain_positioning: transformed = transform(element, transformation=t, maintain_positioning=maintain_positioning) @@ -907,7 +822,9 @@ def transform_element_to_coordinate_system( d[target_coordinate_system] = t to_remove = True transformed = transform( - element, to_coordinate_system=target_coordinate_system, maintain_positioning=maintain_positioning + element, + to_coordinate_system=target_coordinate_system, + maintain_positioning=maintain_positioning, ) if to_remove: del d[target_coordinate_system] @@ -963,10 +880,12 @@ def transform_to_coordinate_system( """ sdata = self.filter_by_coordinate_system(target_coordinate_system, filter_tables=False) elements: dict[str, dict[str, SpatialElement]] = {} - for element_type, element_name, element in sdata.gen_elements(): + for element_type, element_name, _ in sdata.gen_elements(): if element_type != "tables": transformed = sdata.transform_element_to_coordinate_system( - element, target_coordinate_system, maintain_positioning=maintain_positioning + element_name, + target_coordinate_system, + maintain_positioning=maintain_positioning, ) if element_type not in elements: elements[element_type] = {} @@ -1068,19 +987,27 @@ def elements_paths_on_disk(self) -> list[str]: ------- A list of paths of the elements saved in the Zarr store. """ + from spatialdata._io._utils import _resolve_zarr_store + if self.path is None: raise ValueError("The SpatialData object is not backed by a Zarr store.") - store = parse_url(self.path, mode="r").store - root = zarr.group(store=store) + + store = _resolve_zarr_store(self.path) + root = zarr.open_group(store=store, mode="r") elements_in_zarr = [] def find_groups(obj: zarr.Group, path: str) -> None: - # with the current implementation, a path of a zarr group if the path for an element if and only if its + # with the current implementation, a path of a zarr group is the path for an element if and only if its # string representation contains exactly one "/" if isinstance(obj, zarr.Group) and path.count("/") == 1: elements_in_zarr.append(path) - root.visit(lambda path: find_groups(root[path], path)) + for element_type in root: + if element_type in ["images", "labels", "points", "shapes", "tables"]: + for element_name in root[element_type]: + path = f"{element_type}/{element_name}" + elements_in_zarr.append(path) + # root.visit(lambda path: find_groups(root[path], path)) store.close() return elements_in_zarr @@ -1114,7 +1041,7 @@ def _validate_can_safely_write_to_path( overwrite: bool = False, saving_an_element: bool = False, ) -> None: - from spatialdata._io._utils import _backed_elements_contained_in_path, _is_subfolder + from spatialdata._io._utils import _backed_elements_contained_in_path, _is_subfolder, _resolve_zarr_store if isinstance(file_path, str): file_path = Path(file_path) @@ -1122,12 +1049,16 @@ def _validate_can_safely_write_to_path( if not isinstance(file_path, Path): raise ValueError(f"file_path must be a string or a Path object, type(file_path) = {type(file_path)}.") + # TODO: add test for this if os.path.exists(file_path): - if parse_url(file_path, mode="r") is None: + store = _resolve_zarr_store(file_path) + try: + zarr.open(store, mode="r") + except GroupNotFoundError as err: raise ValueError( "The target file path specified already exists, and it has been detected to not be a Zarr store. " "Overwriting non-Zarr stores is not supported to prevent accidental data loss." - ) + ) from err if not overwrite: raise ValueError( "The Zarr store already exists. Use `overwrite=True` to try overwriting the store. " @@ -1173,12 +1104,14 @@ def _validate_all_elements(self) -> None: with collect_error(location=element_path): validate_table_attr_keys(element, location=element_path) + @_deprecation_alias(format="sdata_formats", version="0.7.0") def write( self, file_path: str | Path, overwrite: bool = False, consolidate_metadata: bool = True, - format: SpatialDataFormat | list[SpatialDataFormat] | None = None, + update_sdata_path: bool = True, + sdata_formats: SpatialDataFormatType | list[SpatialDataFormatType] | None = None, ) -> None: """ Write the `SpatialData` object to a Zarr store. @@ -1194,25 +1127,50 @@ def write( If `True`, triggers :func:`zarr.convenience.consolidate_metadata`, which writes all the metadata in a single file at the root directory of the store. This makes the data cloud accessible, which is required for certain cloud stores (such as S3). - format + update_sdata_path + Whether to update the `path` attribute of the `SpatialData` object to `file_path` after a successful write + (default yes). Here are the implications. + + - If `True`, and if the `SpatialData` object has dask-backed elements, the object will become "not + self-contained" because the dask-backed element will have a path that is different from the new + `sdata.path` attribute (now equal to `file_path`). By re-reading the object from disk, the object + will become self-contained again. + - If `False`, the `SpatialData` object will keep its current `path` attribute, meaning that calling + `sdata.write_element()`, `sdata.write_attrs()`, ... will write the data to the current `sdata.path` + location, not to the `file_path` location. + + Please consult :func:`spatialdata.SpatialData.is_self_contained` for more information on the implications of + working with self-contained and non-self-contained SpatialData objects. + sdata_formats The format to use for writing the elements of the `SpatialData` object. It is recommended to leave this parameter equal to `None` (default to latest format for all the elements). If not `None`, it must be - either a format for an element, or a list of formats. - For example it can be a subset of the following list `[RasterFormatVXX(), ShapesFormatVXX(), - PointsFormatVXX(), TablesFormatVXX()]`. (XX denote the version number, and should be replaced with the - respective format; the version numbers can differ across elements). - By default, the latest format is used for all elements, i.e. - :class:`~spatialdata._io.format.CurrentRasterFormat`, :class:`~spatialdata._io.format.CurrentShapesFormat`, - :class:`~spatialdata._io.format.CurrentPointsFormat`, :class:`~spatialdata._io.format.CurrentTablesFormat`. - """ + either a format for an element, or a list of formats. For example it can be a subset of the following + list `[RasterFormatVXX(), ShapesFormatVXX(), PointsFormatVXX(), TablesFormatVXX()]`. (XX denote the + version number, and should be replaced with the respective format; the version numbers can differ across + elements). By default, the latest format is used for all elements, + i.e. :class:`~spatialdata._io.format.CurrentRasterFormat`, + :class:`~spatialdata._io.format.CurrentShapesFormat`, + :class:`~spatialdata._io.format.CurrentPointsFormat`, + :class:`~spatialdata._io.format.CurrentTablesFormat`. Also, by default, if a format for the SpatialData + "container" object (i.e. SpatialDataContainerFormatVXX) is specified, but the format for some elements is + unspecified, the element formats will be set to the latest element format compatible with the specified + SpatialData container format. All the formats and relationships between them are defined in + `spatialdata._io.format.py`. + """ + from spatialdata._io._utils import _resolve_zarr_store + from spatialdata._io.format import _parse_formats + + parsed = _parse_formats(sdata_formats) + if isinstance(file_path, str): file_path = Path(file_path) self._validate_can_safely_write_to_path(file_path, overwrite=overwrite) self._validate_all_elements() - store = parse_url(file_path, mode="w").store - zarr_group = zarr.group(store=store, overwrite=overwrite) - self.write_attrs(zarr_group=zarr_group) + store = _resolve_zarr_store(file_path) + zarr_format = parsed["SpatialData"].zarr_format + zarr_group = zarr.create_group(store=store, overwrite=overwrite, zarr_format=zarr_format) + self.write_attrs(zarr_group=zarr_group, sdata_format=parsed["SpatialData"]) store.close() for element_type, element_name, element in self.gen_elements(): @@ -1222,13 +1180,11 @@ def write( element_type=element_type, element_name=element_name, overwrite=False, - format=format, + parsed_formats=parsed, ) - if self.path != file_path: - old_path = self.path + if self.path != file_path and update_sdata_path: self.path = file_path - logger.info(f"The Zarr backing store has been changed from {old_path} the new file path: {file_path}") if consolidate_metadata: self.write_consolidated_metadata() @@ -1240,8 +1196,10 @@ def _write_element( element_type: str, element_name: str, overwrite: bool, - format: SpatialDataFormat | list[SpatialDataFormat] | None = None, + parsed_formats: dict[str, SpatialDataFormatType] | None = None, ) -> None: + from spatialdata._io.io_zarr import _get_groups_for_element + if not isinstance(zarr_container_path, Path): raise ValueError( f"zarr_container_path must be a Path object, type(zarr_container_path) = {type(zarr_container_path)}." @@ -1251,24 +1209,54 @@ def _write_element( file_path=file_path_of_element, overwrite=overwrite, saving_an_element=True ) - root_group, element_type_group, _ = self._get_groups_for_element( - zarr_path=zarr_container_path, element_type=element_type, element_name=element_name + root_group, element_type_group, element_group = _get_groups_for_element( + zarr_path=zarr_container_path, element_type=element_type, element_name=element_name, use_consolidated=False + ) + from spatialdata._io import ( + write_image, + write_labels, + write_points, + write_shapes, + write_table, ) - from spatialdata._io import write_image, write_labels, write_points, write_shapes, write_table from spatialdata._io.format import _parse_formats - parsed = _parse_formats(formats=format) + if parsed_formats is None: + parsed_formats = _parse_formats(formats=parsed_formats) if element_type == "images": - write_image(image=element, group=element_type_group, name=element_name, format=parsed["raster"]) + write_image( + image=element, + group=element_group, + name=element_name, + element_format=parsed_formats["raster"], + ) elif element_type == "labels": - write_labels(labels=element, group=root_group, name=element_name, format=parsed["raster"]) + write_labels( + labels=element, + group=root_group, + name=element_name, + element_format=parsed_formats["raster"], + ) elif element_type == "points": - write_points(points=element, group=element_type_group, name=element_name, format=parsed["points"]) + write_points( + points=element, + group=element_group, + element_format=parsed_formats["points"], + ) elif element_type == "shapes": - write_shapes(shapes=element, group=element_type_group, name=element_name, format=parsed["shapes"]) + write_shapes( + shapes=element, + group=element_group, + element_format=parsed_formats["shapes"], + ) elif element_type == "tables": - write_table(table=element, group=element_type_group, name=element_name, format=parsed["tables"]) + write_table( + table=element, + group=element_type_group, + name=element_name, + element_format=parsed_formats["tables"], + ) else: raise ValueError(f"Unknown element type: {element_type}") @@ -1276,7 +1264,7 @@ def write_element( self, element_name: str | list[str], overwrite: bool = False, - format: SpatialDataFormat | list[SpatialDataFormat] | None = None, + sdata_formats: SpatialDataFormatType | list[SpatialDataFormatType] | None = None, ) -> None: """ Write a single element, or a list of elements, to the Zarr store used for backing. @@ -1289,7 +1277,7 @@ def write_element( The name(s) of the element(s) to write. overwrite If True, overwrite the element if it already exists. - format + sdata_formats It is recommended to leave this parameter equal to `None`. See more details in the documentation of `SpatialData.write()`. @@ -1298,10 +1286,14 @@ def write_element( If you pass a list of names, the elements will be written one by one. If an error occurs during the writing of an element, the writing of the remaining elements will not be attempted. """ + from spatialdata._io.format import _parse_formats + + parsed_formats = _parse_formats(formats=sdata_formats) + if isinstance(element_name, list): for name in element_name: assert isinstance(name, str) - self.write_element(name, overwrite=overwrite) + self.write_element(name, overwrite=overwrite, sdata_formats=sdata_formats) return check_valid_name(element_name) @@ -1334,8 +1326,11 @@ def write_element( element_type=element_type, element_name=element_name, overwrite=overwrite, - format=format, + parsed_formats=parsed_formats, ) + # After every write, metadata should be consolidated, otherwise this can lead to IO problems like when deleting. + if self.has_consolidated_metadata(): + self.write_consolidated_metadata() def delete_element_from_disk(self, element_name: str | list[str]) -> None: """ @@ -1370,14 +1365,14 @@ def delete_element_from_disk(self, element_name: str | list[str]) -> None: environment (e.g. operating system, local vs network storage, file permissions, ...) and call this function appropriately (or implement a tailored solution), to prevent data loss. """ + from spatialdata._io._utils import _backed_elements_contained_in_path + if isinstance(element_name, list): for name in element_name: assert isinstance(name, str) self.delete_element_from_disk(name) return - from spatialdata._io._utils import _backed_elements_contained_in_path - if self.path is None: raise ValueError("The SpatialData object is not backed by a Zarr store.") @@ -1416,10 +1411,12 @@ def delete_element_from_disk(self, element_name: str | list[str]) -> None: "more elements in the SpatialData object. Deleting the data would corrupt the SpatialData object." ) + from spatialdata._io._utils import _resolve_zarr_store + # delete the element - store = parse_url(self.path, mode="r+").store - root = zarr.group(store=store) - root[element_type].pop(element_name) + store = _resolve_zarr_store(self.path) + root = zarr.open_group(store=store, mode="r+", use_consolidated=False) + del root[element_type][element_name] store.close() if self.has_consolidated_metadata(): @@ -1438,16 +1435,17 @@ def _check_element_not_on_disk_with_different_type(self, element_type: str, elem ) def write_consolidated_metadata(self) -> None: - store = parse_url(self.path, mode="r+").store - # consolidate metadata to more easily support remote reading bug in zarr. In reality, 'zmetadata' is written - # instead of '.zmetadata' see discussion https://github.com/zarr-developers/zarr-python/issues/1121 - zarr.consolidate_metadata(store, metadata_key=".zmetadata") - store.close() + from spatialdata._io.io_zarr import _write_consolidated_metadata + + _write_consolidated_metadata(self.path) def has_consolidated_metadata(self) -> bool: + from spatialdata._io._utils import _resolve_zarr_store + return_value = False - store = parse_url(self.path, mode="r").store - if "zmetadata" in store: + store = _resolve_zarr_store(self.path) + group = zarr.open_group(store, mode="r") + if getattr(group.metadata, "consolidated_metadata", None): return_value = True store.close() return return_value @@ -1455,6 +1453,7 @@ def has_consolidated_metadata(self) -> bool: def _validate_can_write_metadata_on_element(self, element_name: str) -> tuple[str, SpatialElement | AnnData] | None: """Validate if metadata can be written on an element, returns None if it cannot be written.""" from spatialdata._io._utils import _is_element_self_contained + from spatialdata._io.io_zarr import _group_for_element_exists # check the element exists in the SpatialData object element = self.get(element_name) @@ -1477,8 +1476,10 @@ def _validate_can_write_metadata_on_element(self, element_name: str) -> tuple[st self._check_element_not_on_disk_with_different_type(element_type=element_type, element_name=element_name) # check if the element exists in the Zarr storage - if not self._group_for_element_exists( - zarr_path=Path(self.path), element_type=element_type, element_name=element_name + if not _group_for_element_exists( + zarr_path=Path(self.path), + element_type=element_type, + element_name=element_name, ): warnings.warn( f"Not saving the metadata to element {element_type}/{element_name} as it is" @@ -1509,6 +1510,8 @@ def write_channel_names(self, element_name: str | None = None) -> None: The name of the element to write the channel names of. If None, write the channel names of all image elements. """ + from spatialdata._io.io_zarr import _get_groups_for_element + if element_name is not None: check_valid_name(element_name) if element_name not in self: @@ -1528,8 +1531,8 @@ def write_channel_names(self, element_name: str | None = None) -> None: # Mypy does not understand that path is not None so we have the check in the conditional if element_type == "images" and self.path is not None: - _, _, element_group = self._get_groups_for_element( - zarr_path=Path(self.path), element_type=element_type, element_name=element_name + _, _, element_group = _get_groups_for_element( + zarr_path=Path(self.path), element_type=element_type, element_name=element_name, use_consolidated=False ) from spatialdata._io._utils import overwrite_channel_names @@ -1547,6 +1550,8 @@ def write_transformations(self, element_name: str | None = None) -> None: element_name The name of the element to write. If None, write the transformations of all elements. """ + from spatialdata._io.io_zarr import _get_groups_for_element + if element_name is not None: check_valid_name(element_name) if element_name not in self: @@ -1570,23 +1575,32 @@ def write_transformations(self, element_name: str | None = None) -> None: # Mypy does not understand that path is not None so we have a conditional assert self.path is not None - _, _, element_group = self._get_groups_for_element( - zarr_path=Path(self.path), element_type=element_type, element_name=element_name + _, _, element_group = _get_groups_for_element( + zarr_path=Path(self.path), + element_type=element_type, + element_name=element_name, + use_consolidated=False, ) axes = get_axes_names(element) if isinstance(element, DataArray | DataTree): from spatialdata._io._utils import ( overwrite_coordinate_transformations_raster, ) + from spatialdata._io.format import RasterFormats - overwrite_coordinate_transformations_raster(group=element_group, axes=axes, transformations=transformations) + raster_format = RasterFormats[element_group.metadata.attributes["spatialdata_attrs"]["version"]] + overwrite_coordinate_transformations_raster( + group=element_group, axes=axes, transformations=transformations, raster_format=raster_format + ) elif isinstance(element, DaskDataFrame | GeoDataFrame | AnnData): from spatialdata._io._utils import ( overwrite_coordinate_transformations_non_raster, ) overwrite_coordinate_transformations_non_raster( - group=element_group, axes=axes, transformations=transformations + group=element_group, + axes=axes, + transformations=transformations, ) else: raise ValueError(f"Unknown element type {type(element)}") @@ -1613,20 +1627,27 @@ def _element_type_and_name_from_element_path(self, element_path: str) -> tuple[s element_type, element_name = element_path.split("/") return element_type, element_name - def write_attrs(self, format: SpatialDataFormat | None = None, zarr_group: zarr.Group | None = None) -> None: - from spatialdata._io.format import _parse_formats + @_deprecation_alias(format="sdata_format", version="0.7.0") + def write_attrs( + self, + sdata_format: SpatialDataContainerFormatType | None = None, + zarr_group: zarr.Group | None = None, + ) -> None: + from spatialdata._io._utils import _resolve_zarr_store + from spatialdata._io.format import CurrentSpatialDataContainerFormat, SpatialDataContainerFormatType - parsed = _parse_formats(formats=format) + sdata_format = sdata_format if sdata_format is not None else CurrentSpatialDataContainerFormat() + assert isinstance(sdata_format, SpatialDataContainerFormatType) store = None if zarr_group is None: assert self.is_backed(), "The SpatialData object must be backed by a Zarr store to write attrs." - store = parse_url(self.path, mode="r+").store - zarr_group = zarr.group(store=store, overwrite=False) + store = _resolve_zarr_store(self.path) + zarr_group = zarr.open_group(store=store, mode="r+") - version = parsed["SpatialData"].spatialdata_format_version - version_specific_attrs = parsed["SpatialData"].attrs_to_dict() + version = sdata_format.spatialdata_format_version + version_specific_attrs = sdata_format.attrs_to_dict() attrs_to_write = {"spatialdata_attrs": {"version": version} | version_specific_attrs} | self.attrs try: @@ -1642,6 +1663,7 @@ def write_metadata( element_name: str | None = None, consolidate_metadata: bool | None = None, write_attrs: bool = True, + sdata_format: SpatialDataContainerFormatType | None = None, ) -> None: """ Write the metadata of a single element, or of all elements, to the Zarr store, without rewriting the data. @@ -1662,6 +1684,11 @@ def write_metadata( consolidate_metadata If True, consolidate the metadata to more easily support remote reading. By default write the metadata only if the metadata was already consolidated. + write_attrs + If True, write the SpatialData.attrs metadata to the root of the Zarr store. + sdata_format + The format to use for writing the metadata of the `SpatialData` object. See more details in the + documentation of `SpatialData.write()`. Notes ----- @@ -1671,16 +1698,15 @@ def write_metadata( check_valid_name(element_name) if element_name not in self: raise ValueError(f"Element with name {element_name} not found in SpatialData object.") + if write_attrs: + self.write_attrs(sdata_format=sdata_format) self.write_transformations(element_name) self.write_channel_names(element_name) # TODO: write .uns['spatialdata_attrs'] metadata for AnnData. # TODO: write .attrs['spatialdata_attrs'] metadata for DaskDataFrame. - if write_attrs: - self.write_attrs() - - if consolidate_metadata is None and self.has_consolidated_metadata(): + if self.has_consolidated_metadata(): consolidate_metadata = True if consolidate_metadata: self.write_consolidated_metadata() @@ -1783,53 +1809,10 @@ def tables(self, tables: dict[str, AnnData]) -> None: TableModel().validate(v) self._tables[k] = v - @property - def table(self) -> None | AnnData: - """ - Return table with name table from tables if it exists. - - Returns - ------- - The table. - """ - warnings.warn( - "Table accessor will be deprecated with SpatialData version 0.1, use sdata.tables instead.", - DeprecationWarning, - stacklevel=2, - ) - # Isinstance will still return table if anndata has 0 rows. - if isinstance(self.tables.get("table"), AnnData): - return self.tables["table"] - return None - - @table.setter - def table(self, table: AnnData) -> None: - warnings.warn( - "Table setter will be deprecated with SpatialData version 0.1, use tables instead.", - DeprecationWarning, - stacklevel=2, - ) - TableModel().validate(table) - if self.tables.get("table") is not None: - raise ValueError("The table already exists. Use del sdata.tables['table'] to remove it first.") - self.tables["table"] = table - - @table.deleter - def table(self) -> None: - """Delete the table.""" - warnings.warn( - "del sdata.table will be deprecated with SpatialData version 0.1, use del sdata.tables['table'] instead.", - DeprecationWarning, - stacklevel=2, - ) - if self.tables.get("table"): - del self.tables["table"] - else: - # More informative than the error in the zarr library. - raise KeyError("table with name 'table' not present in the SpatialData object.") - @staticmethod - def read(file_path: Path | str, selection: tuple[str] | None = None) -> SpatialData: + def read( + file_path: Path | str, selection: tuple[str] | None = None, reconsolidate_metadata: bool = False + ) -> SpatialData: """ Read a SpatialData object from a Zarr storage (on-disk or remote). @@ -1839,6 +1822,8 @@ def read(file_path: Path | str, selection: tuple[str] | None = None) -> SpatialD The path or URL to the Zarr storage. selection The elements to read (images, labels, points, shapes, table). If None, all elements are read. + reconsolidate_metadata + If the consolidated metadata store got corrupted this can lead to errors when trying to read the data. Returns ------- @@ -1846,45 +1831,12 @@ def read(file_path: Path | str, selection: tuple[str] | None = None) -> SpatialD """ from spatialdata import read_zarr - return read_zarr(file_path, selection=selection) + if reconsolidate_metadata: + from spatialdata._io.io_zarr import _write_consolidated_metadata - def add_image( - self, - name: str, - image: DataArray | DataTree, - storage_options: JSONDict | list[JSONDict] | None = None, - overwrite: bool = False, - ) -> None: - """Deprecated. Use `sdata[name] = image` instead.""" # noqa: D401 - _error_message_add_element() + _write_consolidated_metadata(file_path) - def add_labels( - self, - name: str, - labels: DataArray | DataTree, - storage_options: JSONDict | list[JSONDict] | None = None, - overwrite: bool = False, - ) -> None: - """Deprecated. Use `sdata[name] = labels` instead.""" # noqa: D401 - _error_message_add_element() - - def add_points( - self, - name: str, - points: DaskDataFrame, - overwrite: bool = False, - ) -> None: - """Deprecated. Use `sdata[name] = points` instead.""" # noqa: D401 - _error_message_add_element() - - def add_shapes( - self, - name: str, - shapes: GeoDataFrame, - overwrite: bool = False, - ) -> None: - """Deprecated. Use `sdata[name] = shapes` instead.""" # noqa: D401 - _error_message_add_element() + return read_zarr(file_path, selection=selection) @property def images(self) -> Images: @@ -2033,7 +1985,9 @@ def h(s: str) -> str: else: shape_str = ( "(" - + ", ".join([str(dim) if not isinstance(dim, Delayed) else "" for dim in v.shape]) + + ", ".join( + [(str(dim) if not isinstance(dim, Delayed) else "") for dim in v.shape] + ) + ")" ) descr += f"{h(attr + 'level1.1')}{k!r}: {descr_class} with shape: {shape_str} {dim_string}" @@ -2143,14 +2097,14 @@ def _gen_spatial_element_values(self) -> Generator[SpatialElement, None, None]: yield from d.values() def _gen_elements( - self, include_table: bool = False + self, include_tables: bool = False ) -> Generator[tuple[str, str, SpatialElement | AnnData], None, None]: """ Generate elements contained in the SpatialData instance. Parameters ---------- - include_table + include_tables Whether to also generate table elements. Returns @@ -2159,14 +2113,16 @@ def _gen_elements( itself. """ element_types = ["images", "labels", "points", "shapes"] - if include_table: + if include_tables: element_types.append("tables") for element_type in element_types: d = getattr(SpatialData, element_type).fget(self) for k, v in d.items(): yield element_type, k, v - def gen_spatial_elements(self) -> Generator[tuple[str, str, SpatialElement], None, None]: + def gen_spatial_elements( + self, + ) -> Generator[tuple[str, str, SpatialElement], None, None]: """ Generate spatial elements within the SpatialData object. @@ -2179,7 +2135,9 @@ def gen_spatial_elements(self) -> Generator[tuple[str, str, SpatialElement], Non """ return self._gen_elements() - def gen_elements(self) -> Generator[tuple[str, str, SpatialElement | AnnData], None, None]: + def gen_elements( + self, + ) -> Generator[tuple[str, str, SpatialElement | AnnData], None, None]: """ Generate elements within the SpatialData object. @@ -2189,7 +2147,7 @@ def gen_elements(self) -> Generator[tuple[str, str, SpatialElement | AnnData], N ------- A generator that yields tuples containing the name, description, and element objects themselves. """ - return self._gen_elements(include_table=True) + return self._gen_elements(include_tables=True) def _validate_element_names_are_unique(self) -> None: """ @@ -2240,8 +2198,7 @@ def _find_element(self, element_name: str) -> tuple[str, str, SpatialElement | A @classmethod def init_from_elements( cls, - elements: dict[str, SpatialElement], - tables: AnnData | dict[str, AnnData] | None = None, + elements: dict[str, SpatialElement | AnnData], attrs: Mapping[Any, Any] | None = None, ) -> SpatialData: """ @@ -2250,9 +2207,7 @@ def init_from_elements( Parameters ---------- elements - A dict of named elements. - tables - An optional table or dictionary of tables + A dict of named elements, e.g. SpatialElements like images and labels and AnnData tables. attrs Additional attributes to store in the SpatialData object. @@ -2260,7 +2215,7 @@ def init_from_elements( ------- The SpatialData object. """ - elements_dict: dict[str, SpatialElement] = {} + elements_dict: dict[str, SpatialElement | AnnData] = {} for name, element in elements.items(): model = get_model(element) if model in [Image2DModel, Image3DModel]: @@ -2275,30 +2230,13 @@ def init_from_elements( assert model == ShapesModel element_type = "shapes" elements_dict.setdefault(element_type, {})[name] = element - # when the "tables" argument is removed, we can remove all this if block - if tables is not None: - warnings.warn( - 'The "tables" argument is deprecated and will be removed in a future version. Please ' - "specifies the tables in the `elements` argument. Until the removal occurs, the `elements` " - "variable will be automatically populated with the tables if the `tables` argument is not None.", - DeprecationWarning, - stacklevel=2, - ) - if "tables" in elements_dict: - raise ValueError( - "The tables key is already present in the elements dictionary. Please do not specify " - "the `tables` argument." - ) - elements_dict["tables"] = {} - if isinstance(tables, AnnData): - elements_dict["tables"]["table"] = tables - else: - for name, table in tables.items(): - elements_dict["tables"][name] = table return cls(**elements_dict, attrs=attrs) def subset( - self, element_names: list[str], filter_tables: bool = True, include_orphan_tables: bool = False + self, + element_names: list[str], + filter_tables: bool = True, + include_orphan_tables: bool = False, ) -> SpatialData: """ Subset the SpatialData object. @@ -2321,7 +2259,7 @@ def subset( """ elements_dict: dict[str, SpatialElement] = {} names_tables_to_keep: set[str] = set() - for element_type, element_name, element in self._gen_elements(include_table=True): + for element_type, element_name, element in self._gen_elements(include_tables=True): if element_name in element_names: if element_type != "tables": elements_dict.setdefault(element_type, {})[element_name] = element @@ -2354,7 +2292,7 @@ def __getitem__(self, item: str) -> SpatialElement | AnnData: def __contains__(self, key: str) -> bool: element_dict = { - element_name: element_value for _, element_name, element_value in self._gen_elements(include_table=True) + element_name: element_value for _, element_name, element_value in self._gen_elements(include_tables=True) } return key in element_dict diff --git a/src/spatialdata/_core/validation.py b/src/spatialdata/_core/validation.py index 537c49f34..50e1b65d5 100644 --- a/src/spatialdata/_core/validation.py +++ b/src/spatialdata/_core/validation.py @@ -185,7 +185,10 @@ def check_key_is_case_insensitively_unique(key: str, other_keys: set[str | None] """ normalized_key = key.lower() if normalized_key in other_keys: - raise ValueError(f"Key `{key}` is not unique, or another case-variant of it exists.") + raise ValueError( + f"Key `{key}` is not unique as it exists with a different element type, or another " + f"case-variant of it exists." + ) def check_valid_dataframe_column_name(name: str) -> None: diff --git a/src/spatialdata/_io/__init__.py b/src/spatialdata/_io/__init__.py index 94d6816a9..b0dc914e6 100644 --- a/src/spatialdata/_io/__init__.py +++ b/src/spatialdata/_io/__init__.py @@ -1,5 +1,5 @@ from spatialdata._io._utils import get_dask_backing_files -from spatialdata._io.format import SpatialDataFormat +from spatialdata._io.format import SpatialDataFormatType from spatialdata._io.io_points import write_points from spatialdata._io.io_raster import write_image, write_labels from spatialdata._io.io_shapes import write_shapes @@ -11,6 +11,6 @@ "write_points", "write_shapes", "write_table", - "SpatialDataFormat", + "SpatialDataFormatType", "get_dask_backing_files", ] diff --git a/src/spatialdata/_io/_utils.py b/src/spatialdata/_io/_utils.py index 5e8eb832d..20c236275 100644 --- a/src/spatialdata/_io/_utils.py +++ b/src/spatialdata/_io/_utils.py @@ -1,5 +1,4 @@ import filecmp -import logging import os.path import re import sys @@ -18,9 +17,13 @@ from dask.array import Array as DaskArray from dask.dataframe import DataFrame as DaskDataFrame from geopandas import GeoDataFrame +from upath import UPath +from upath.implementations.local import PosixUPath, WindowsUPath from xarray import DataArray, DataTree +from zarr.storage import FsspecStore, LocalStore from spatialdata._core.spatialdata import SpatialData +from spatialdata._io.format import RasterFormatType, RasterFormatV01, RasterFormatV02, RasterFormatV03 from spatialdata._utils import get_pyramid_levels from spatialdata.models._utils import ( MappingToCoordinateSystem_t, @@ -32,18 +35,6 @@ from spatialdata.transformations.transformations import BaseTransformation, _get_current_output_axes -# suppress logger debug from ome_zarr with context manager -@contextmanager -def ome_zarr_logger(level: Any) -> Generator[None, None, None]: - logger = logging.getLogger("ome_zarr") - current_level = logger.getEffectiveLevel() - logger.setLevel(level) - try: - yield - finally: - logger.setLevel(current_level) - - def _get_transformations_from_ngff_dict( list_of_encoded_ngff_transformations: list[dict[str, Any]], ) -> MappingToCoordinateSystem_t: @@ -59,6 +50,18 @@ def _get_transformations_from_ngff_dict( def overwrite_coordinate_transformations_non_raster( group: zarr.Group, axes: tuple[ValidAxis_t, ...], transformations: MappingToCoordinateSystem_t ) -> None: + """Write coordinate transformations of non-raster element to disk. + + Parameters + ---------- + group: zarr.Group + The zarr group containing the non-raster element for which to write the transformations, e.g. the zarr group + containing sdata['points']. + axes: tuple[ValidAxis_t, ...] + The list with axes names in the same order as the coordinates of the non-raster element. + transformations: MappingToCoordinateSystem_t + Mapping between names of the coordinate system and the transformations. + """ _validate_mapping_to_coordinate_system_type(transformations) ngff_transformations = [] for target_coordinate_system, t in transformations.items(): @@ -74,8 +77,32 @@ def overwrite_coordinate_transformations_non_raster( def overwrite_coordinate_transformations_raster( - group: zarr.Group, axes: tuple[ValidAxis_t, ...], transformations: MappingToCoordinateSystem_t + group: zarr.Group, + axes: tuple[ValidAxis_t, ...], + transformations: MappingToCoordinateSystem_t, + raster_format: RasterFormatType, ) -> None: + """Write transformations of raster elements to disk. + + This function supports both writing of transformations for raster elements stored using zarr v3 and v2. + For the case of zarr v3, there is already a 'coordinateTransformations' from ome-zarr in the metadata of + the group. However, we store our transformations in the first element of the 'multiscales' of the attributes + in the group metadata. This is subject to change. + In the case of zarr v2 the existing 'coordinateTransformations' from ome-zarr is overwritten. + + Parameters + ---------- + group + The zarr group containing the raster element for which to write the transformations, e.g. the zarr group + containing sdata['image2d']. + axes + The list with axes names in the same order as the dimensions of the raster element. + transformations + Mapping between names of the coordinate system and the transformations. + raster_format + The raster format of the raster element used to determine where in the metadata the transformations should be + written. + """ _validate_mapping_to_coordinate_system_type(transformations) # prepare the transformations in the dict representation ngff_transformations = [] @@ -90,16 +117,29 @@ def overwrite_coordinate_transformations_raster( ) coordinate_transformations = [t.to_dict() for t in ngff_transformations] # replace the metadata storage - multiscales = group.attrs["multiscales"] - assert len(multiscales) == 1 + if group.metadata.zarr_format == 3 and len(multiscales := group.metadata.attributes["ome"]["multiscales"]) != 1: + len_scales = len(multiscales) + raise ValueError(f"The length of multiscales metadata should be 1, found the length to be {len_scales}") + if group.metadata.zarr_format == 2: + multiscales = group.attrs["multiscales"] + if (len_scales := len(multiscales)) != 1: + raise ValueError(f"The length of multiscales metadata should be 1, found length of {len_scales}") multiscale = multiscales[0] - # the transformation present in multiscale["datasets"] are the ones for the multiscale, so and we leave them intact - # we update multiscale["coordinateTransformations"] and multiscale["coordinateSystems"] - # see the first post of https://github.com/scverse/spatialdata/issues/39 for an overview - # fix the io to follow the NGFF specs, see https://github.com/scverse/spatialdata/issues/114 + + # Previously, there was CoordinateTransformations key present at the level of multiscale and datasets in multiscale. + # This is not the case anymore so we are creating a new key here and keeping the one in datasets intact. multiscale["coordinateTransformations"] = coordinate_transformations - # multiscale["coordinateSystems"] = [t.output_coordinate_system_name for t in ngff_transformations] - group.attrs["multiscales"] = multiscales + if raster_format is not None: + if isinstance(raster_format, RasterFormatV01 | RasterFormatV02): + multiscale["version"] = raster_format.version + group.attrs["multiscales"] = multiscales + elif isinstance(raster_format, RasterFormatV03): + ome = group.metadata.attributes["ome"] + ome["version"] = raster_format.version + ome["multiscales"] = multiscales + group.attrs["ome"] = ome + else: + raise ValueError(f"Unsupported raster format: {type(raster_format)}") def overwrite_channel_names(group: zarr.Group, element: DataArray | DataTree) -> None: @@ -110,16 +150,14 @@ def overwrite_channel_names(group: zarr.Group, element: DataArray | DataTree) -> channel_names = element["scale0"]["image"].coords["c"].data.tolist() channel_metadata = [{"label": name} for name in channel_names] - omero_meta = group.attrs["omero"] + # This is required here as we do not use the load node API of ome-zarr + omero_meta = group.attrs.get("omero", None) or group.attrs.get("ome", {}).get("omero") omero_meta["channels"] = channel_metadata - group.attrs["omero"] = omero_meta - multiscales_meta = group.attrs["multiscales"] - if len(multiscales_meta) != 1: - raise ValueError( - f"Multiscale metadata must be of length one but got length {len(multiscales_meta)}. Data mightbe corrupted." - ) - multiscales_meta[0]["metadata"]["omero"]["channels"] = channel_metadata - group.attrs["multiscales"] = multiscales_meta + if ome_meta := group.attrs.get("ome", None): + ome_meta["omero"] = omero_meta + group.attrs["ome"] = ome_meta + else: + group.attrs["omero"] = omero_meta def _write_metadata( @@ -285,8 +323,9 @@ def _search_for_backing_files_recursively(subgraph: Any, files: list[str]) -> No name = k if name is not None: if name.startswith("original-from-zarr"): - path = v.store.path - files.append(os.path.realpath(path)) + # LocalStore.store does not have an attribute path, but we keep it like this for backward compat. + path = getattr(v.store, "path", None) if getattr(v.store, "path", None) else v.store.root + files.append(str(UPath(path).resolve())) elif name.startswith("read-parquet") or name.startswith("read_parquet"): if hasattr(v, "creation_info"): # https://github.com/dask/dask/blob/ff2488aec44d641696e0b7aa41ed9e995c710705/dask/dataframe/io/parquet/core.py#L625 @@ -297,7 +336,7 @@ def _search_for_backing_files_recursively(subgraph: Any, files: list[str]) -> No f"report this bug." ) parquet_file = t[0] - files.append(os.path.realpath(parquet_file)) + files.append(str(UPath(parquet_file).resolve())) elif isinstance(v, tuple) and len(v) > 1 and isinstance(v[1], dict) and "piece" in v[1]: # https://github.com/dask/dask/blob/ff2488aec44d641696e0b7aa41ed9e995c710705/dask/dataframe/io/parquet/core.py#L870 parquet_file, check0, check1 = v[1]["piece"] @@ -363,24 +402,71 @@ def _is_element_self_contained( ) -> bool: if isinstance(element, DaskDataFrame): pass + # TODO when running test_save_transformations it seems that for the same element this is called multiple times return all(_backed_elements_contained_in_path(path=element_path, object=element)) -def save_transformations(sdata: SpatialData) -> None: +def _resolve_zarr_store( + path: str | Path | UPath | zarr.storage.StoreLike | zarr.Group, **kwargs: Any +) -> zarr.storage.StoreLike: """ - Save all the transformations of a SpatialData object to disk. + Normalize different Zarr store inputs into a usable store instance. + + This function accepts various forms of input (e.g. filesystem paths, + UPath objects, existing Zarr stores, or `zarr.Group`s) and resolves + them into a `StoreLike` that can be passed to Zarr APIs. It handles + local files, fsspec-backed stores, consolidated metadata stores, and + groups with nested paths. + + Parameters + ---------- + path + The input representing a Zarr store or group. Can be a filesystem + path, remote path, existing store, or Zarr group. + **kwargs + Additional keyword arguments forwarded to the underlying store + constructor (e.g. `mode`, `storage_options`). - sdata - The SpatialData object + Returns + ------- + A normalized store instance suitable for use with Zarr. + + Raises + ------ + TypeError + If the input type is unsupported. + ValueError + If a `zarr.Group` has an unsupported store type. """ - warnings.warn( - "This function is deprecated and should be replaced by `SpatialData.write_transformations()` or " - "`SpatialData.write_metadata()`, which gives more control over which metadata to write. This function will call" - " `SpatialData.write_transformations()`; please call this function directly.", - DeprecationWarning, - stacklevel=2, - ) - sdata.write_transformations() + # TODO: ensure kwargs like mode are enforced everywhere and passed correctly to the store + if isinstance(path, str | Path): + # if the input is str or Path, map it to UPath + path = UPath(path) + + if isinstance(path, PosixUPath | WindowsUPath): + # if the input is a local path, use LocalStore + return LocalStore(path.path) + + if isinstance(path, zarr.Group): + # if the input is a zarr.Group, wrap it with a store + if isinstance(path.store, LocalStore): + # create a simple FSStore if the store is a LocalStore with just the path + return FsspecStore(os.path.join(path.store.path, path.path), **kwargs) + if isinstance(path.store, FsspecStore): + # if the store within the zarr.Group is an FSStore, return it + # but extend the path of the store with that of the zarr.Group + return FsspecStore(path.store.path + "/" + path.path, fs=path.store.fs, **kwargs) + if isinstance(path.store, zarr.storage.ConsolidatedMetadataStore): + # if the store is a ConsolidatedMetadataStore, just return the underlying FSSpec store + return path.store.store + raise ValueError(f"Unsupported store type or zarr.Group: {type(path.store)}") + if isinstance(path, zarr.storage.StoreLike): + # if the input already a store, wrap it in an FSStore + return FsspecStore(path, **kwargs) + if isinstance(path, UPath): + # if input is a remote UPath, map it to an FSStore + return FsspecStore(path.path, fs=path.fs, **kwargs) + raise TypeError(f"Unsupported type: {type(path)}") class BadFileHandleMethod(Enum): @@ -392,7 +478,7 @@ class BadFileHandleMethod(Enum): def handle_read_errors( on_bad_files: Literal[BadFileHandleMethod.ERROR, BadFileHandleMethod.WARN], location: str, - exc_types: tuple[type[Exception], ...], + exc_types: type[BaseException] | tuple[type[BaseException], ...], ) -> Generator[None, None, None]: """ Handle read errors according to parameter `on_bad_files`. diff --git a/src/spatialdata/_io/format.py b/src/spatialdata/_io/format.py index 5ee675be6..3e69b862c 100644 --- a/src/spatialdata/_io/format.py +++ b/src/spatialdata/_io/format.py @@ -3,9 +3,14 @@ import ome_zarr.format import zarr -from anndata import AnnData -from ome_zarr.format import CurrentFormat, Format, FormatV01, FormatV02, FormatV03, FormatV04 -from pandas.api.types import CategoricalDtype +from ome_zarr.format import ( + Format, + FormatV01, + FormatV02, + FormatV03, + FormatV04, + FormatV05, +) from shapely import GeometryType from spatialdata.models.models import ATTRS_KEY, PointsModel, ShapesModel @@ -44,27 +49,7 @@ def _parse_version(group: zarr.Group, expect_attrs_key: bool) -> str | None: return version -class SpatialDataFormat(CurrentFormat): - pass - - -class SpatialDataContainerFormatV01(SpatialDataFormat): - @property - def spatialdata_format_version(self) -> str: - return "0.1" - - def attrs_from_dict(self, metadata: dict[str, Any]) -> dict[str, Any]: - return {} - - def attrs_to_dict(self) -> dict[str, str | dict[str, Any]]: - from spatialdata import __version__ - - return {"spatialdata_software_version": __version__} - - -class RasterFormatV01(SpatialDataFormat): - """Formatter for raster data.""" - +class CoordinateMixinV01: def generate_coordinate_transformations(self, shapes: list[tuple[Any]]) -> None | list[list[dict[str, Any]]]: data_shape = shapes[0] coordinate_transformations: list[list[dict[str, Any]]] = [] @@ -112,6 +97,61 @@ def validate_coordinate_transformations( assert np.all([j0 == j1 for j0, j1 in zip(json0, json1, strict=True)]) + +class PointsAttrsMixinV01: + def attrs_from_dict(self, metadata: dict[str, Any]) -> dict[str, dict[str, Any]]: + if Points_s.ATTRS_KEY not in metadata: + raise KeyError(f"Missing key {Points_s.ATTRS_KEY} in points metadata.") + metadata_ = metadata[Points_s.ATTRS_KEY] + assert self.spatialdata_format_version == metadata_["version"] # type: ignore[attr-defined] + d = {} + if Points_s.FEATURE_KEY in metadata_: + d[Points_s.FEATURE_KEY] = metadata_[Points_s.FEATURE_KEY] + if Points_s.INSTANCE_KEY in metadata_: + d[Points_s.INSTANCE_KEY] = metadata_[Points_s.INSTANCE_KEY] + return d + + def attrs_to_dict(self, data: dict[str, Any]) -> dict[str, dict[str, Any]]: + d = {} + if Points_s.ATTRS_KEY in data: + if Points_s.INSTANCE_KEY in data[Points_s.ATTRS_KEY]: + d[Points_s.INSTANCE_KEY] = data[Points_s.ATTRS_KEY][Points_s.INSTANCE_KEY] + if Points_s.FEATURE_KEY in data[Points_s.ATTRS_KEY]: + d[Points_s.FEATURE_KEY] = data[Points_s.ATTRS_KEY][Points_s.FEATURE_KEY] + return d + + +class SpatialDataContainerFormatV01(FormatV04): + @property + def spatialdata_format_version(self) -> str: + return "0.1" + + def attrs_from_dict(self, metadata: dict[str, Any]) -> dict[str, Any]: + return {} + + def attrs_to_dict(self) -> dict[str, str | dict[str, Any]]: + from spatialdata import __version__ + + return {"spatialdata_software_version": __version__} + + +class SpatialDataContainerFormatV02(FormatV05): + @property + def spatialdata_format_version(self) -> str: + return "0.2" + + def attrs_from_dict(self, metadata: dict[str, Any]) -> dict[str, Any]: + return {} + + def attrs_to_dict(self) -> dict[str, str | dict[str, Any]]: + from spatialdata import __version__ + + return {"spatialdata_software_version": __version__} + + +class RasterFormatV01(FormatV04, CoordinateMixinV01): + """Formatter for raster data.""" + # eventually we are fully compliant with NGFF and we can drop SPATIALDATA_FORMAT_VERSION and simply rely on # "version"; still, until the coordinate transformations make it into NGFF, we need to have our extension @property @@ -135,7 +175,19 @@ def version(self) -> str: return "0.4-dev-spatialdata" -class ShapesFormatV01(SpatialDataFormat): +class RasterFormatV03(FormatV05, CoordinateMixinV01): + @property + def spatialdata_format_version(self) -> str: + return "0.3" + + @property + def version(self) -> str: + # 0.1 -> 0.2 changed the version string for the NGFF format, from 0.4 to 0.6-dev-spatialdata as discussed here + # https://github.com/scverse/spatialdata/pull/849 + return "0.5-dev-spatialdata" + + +class ShapesFormatV01(FormatV04): """Formatter for shapes.""" @property @@ -158,10 +210,15 @@ def attrs_from_dict(self, metadata: dict[str, Any]) -> GeometryType: return typ def attrs_to_dict(self, geometry: GeometryType) -> dict[str, str | dict[str, Any]]: - return {Shapes_s.GEOS_KEY: {Shapes_s.NAME_KEY: geometry.name, Shapes_s.TYPE_KEY: geometry.value}} + return { + Shapes_s.GEOS_KEY: { + Shapes_s.NAME_KEY: geometry.name, + Shapes_s.TYPE_KEY: geometry.value, + } + } -class ShapesFormatV02(SpatialDataFormat): +class ShapesFormatV02(FormatV04): """Formatter for shapes.""" @property @@ -173,106 +230,140 @@ def attrs_to_dict(self, data: dict[str, Any]) -> dict[str, str | dict[str, Any]] return {} -class PointsFormatV01(SpatialDataFormat): +class ShapesFormatV03(FormatV05): + """Formatter for shapes.""" + + @property + def spatialdata_format_version(self) -> str: + return "0.3" + + # no need for attrs_from_dict as we are not saving metadata except for the coordinate transformations + def attrs_to_dict(self, data: dict[str, Any]) -> dict[str, str | dict[str, Any]]: + return {} + + +class PointsFormatV01(FormatV04, PointsAttrsMixinV01): """Formatter for points.""" @property def spatialdata_format_version(self) -> str: return "0.1" - def attrs_from_dict(self, metadata: dict[str, Any]) -> dict[str, dict[str, Any]]: - if Points_s.ATTRS_KEY not in metadata: - raise KeyError(f"Missing key {Points_s.ATTRS_KEY} in points metadata.") - metadata_ = metadata[Points_s.ATTRS_KEY] - assert self.spatialdata_format_version == metadata_["version"] - d = {} - if Points_s.FEATURE_KEY in metadata_: - d[Points_s.FEATURE_KEY] = metadata_[Points_s.FEATURE_KEY] - if Points_s.INSTANCE_KEY in metadata_: - d[Points_s.INSTANCE_KEY] = metadata_[Points_s.INSTANCE_KEY] - return d - def attrs_to_dict(self, data: dict[str, Any]) -> dict[str, dict[str, Any]]: - d = {} - if Points_s.ATTRS_KEY in data: - if Points_s.INSTANCE_KEY in data[Points_s.ATTRS_KEY]: - d[Points_s.INSTANCE_KEY] = data[Points_s.ATTRS_KEY][Points_s.INSTANCE_KEY] - if Points_s.FEATURE_KEY in data[Points_s.ATTRS_KEY]: - d[Points_s.FEATURE_KEY] = data[Points_s.ATTRS_KEY][Points_s.FEATURE_KEY] - return d +class PointsFormatV02(FormatV05, PointsAttrsMixinV01): + """Formatter for points.""" + @property + def spatialdata_format_version(self) -> str: + return "0.2" -class TablesFormatV01(SpatialDataFormat): + +class TablesFormatV01(FormatV04): """Formatter for the table.""" @property def spatialdata_format_version(self) -> str: return "0.1" - def validate_table( - self, - table: AnnData, - region_key: None | str = None, - instance_key: None | str = None, - ) -> None: - if not isinstance(table, AnnData): - raise TypeError(f"`table` must be `anndata.AnnData`, was {type(table)}.") - if region_key is not None and not isinstance(table.obs[region_key].dtype, CategoricalDtype): - raise ValueError( - f"`table.obs[region_key]` must be of type `categorical`, not `{type(table.obs[region_key])}`." - ) - if instance_key is not None and table.obs[instance_key].isnull().values.any(): - raise ValueError("`table.obs[instance_key]` must not contain null values, but it does.") - - -CurrentRasterFormat = RasterFormatV02 -CurrentShapesFormat = ShapesFormatV02 -CurrentPointsFormat = PointsFormatV01 -CurrentTablesFormat = TablesFormatV01 -CurrentSpatialDataContainerFormats = SpatialDataContainerFormatV01 - -ShapesFormats = { + +class TablesFormatV02(FormatV05): + """Formatter for the table.""" + + @property + def spatialdata_format_version(self) -> str: + return "0.2" + + +CurrentRasterFormat = RasterFormatV03 +CurrentShapesFormat = ShapesFormatV03 +CurrentPointsFormat = PointsFormatV02 +CurrentTablesFormat = TablesFormatV02 +CurrentSpatialDataContainerFormat = SpatialDataContainerFormatV02 + +RasterFormatType = RasterFormatV01 | RasterFormatV02 | RasterFormatV03 +ShapesFormatType = ShapesFormatV01 | ShapesFormatV02 | ShapesFormatV03 +PointsFormatType = PointsFormatV01 | PointsFormatV02 +TablesFormatType = TablesFormatV01 | TablesFormatV02 +SpatialDataContainerFormatType = SpatialDataContainerFormatV01 | SpatialDataContainerFormatV02 +SpatialDataFormatType = ( + RasterFormatType | ShapesFormatType | PointsFormatType | TablesFormatType | SpatialDataContainerFormatType +) + +SdataVersion_to_Format = {"0.4": FormatV04(), "0.4-dev-spatialdata": FormatV04(), "0.5-dev-spatialdata": FormatV05()} +RasterFormats: dict[str, RasterFormatType] = { + "0.1": RasterFormatV01(), + "0.2": RasterFormatV02(), + "0.3": RasterFormatV03(), +} +ShapesFormats: dict[str, ShapesFormatType] = { "0.1": ShapesFormatV01(), "0.2": ShapesFormatV02(), + "0.3": ShapesFormatV03(), } -PointsFormats = { +PointsFormats: dict[str, PointsFormatType] = { "0.1": PointsFormatV01(), + "0.2": PointsFormatV02(), } -TablesFormats = { +TablesFormats: dict[str, TablesFormatType] = { "0.1": TablesFormatV01(), + "0.2": TablesFormatV02(), } -RasterFormats = { - "0.1": RasterFormatV01(), - "0.2": RasterFormatV02(), -} -SpatialDataContainerFormats = { +SpatialDataContainerFormats: dict[str, SpatialDataContainerFormatType] = { "0.1": SpatialDataContainerFormatV01(), + "0.2": SpatialDataContainerFormatV02(), +} +ContainerFormatValidElements = { + SpatialDataContainerFormatV01().__str__(): [ + RasterFormatV01().__str__(), + RasterFormatV02().__str__(), + ShapesFormatV01().__str__(), + ShapesFormatV02().__str__(), + PointsFormatV01().__str__(), + TablesFormatV01().__str__(), + ], + SpatialDataContainerFormatV02().__str__(): [ + RasterFormatV03().__str__(), + ShapesFormatV03().__str__(), + PointsFormatV02().__str__(), + TablesFormatV02().__str__(), + ], +} +ContainerV01DefaultTypes: dict[str, SpatialDataFormatType] = { + "raster": RasterFormatV02(), + "shapes": ShapesFormatV02(), + "points": PointsFormatV01(), + "tables": TablesFormatV01(), } def format_implementations() -> Iterator[Format]: """Return an instance of each format implementation, newest to oldest.""" + yield RasterFormatV03() yield RasterFormatV02() - # yield RasterFormatV01() # same format string as FormatV04 + yield RasterFormatV01() # same format string as FormatV04 + + yield FormatV05() yield FormatV04() yield FormatV03() yield FormatV02() yield FormatV01() -# monkeypatch the ome_zarr.format module to include the SpatialDataFormat (we want to use the APIs from ome_zarr to +# monkeypatch the ome_zarr.format module to include the SpatialDataFormatType (we want to use the APIs from ome_zarr to # read, but signal that the format we are using is a dev version of NGFF, since it builds on some open PR that are # not released yet) ome_zarr.format.format_implementations = format_implementations -def _parse_formats(formats: SpatialDataFormat | list[SpatialDataFormat] | None) -> dict[str, SpatialDataFormat]: - parsed = { +def _parse_formats( + formats: SpatialDataFormatType | list[SpatialDataFormatType] | None, +) -> dict[str, SpatialDataFormatType]: + parsed: dict[str, SpatialDataFormatType] = { "raster": CurrentRasterFormat(), "shapes": CurrentShapesFormat(), "points": CurrentPointsFormat(), "tables": CurrentTablesFormat(), - "SpatialData": CurrentSpatialDataContainerFormats(), + "SpatialData": CurrentSpatialDataContainerFormat(), } if formats is None: return parsed @@ -312,4 +403,30 @@ def _check_modified(element_type: str) -> None: parsed["SpatialData"] = fmt else: raise ValueError(f"Unsupported format {fmt}") + + if parsed["SpatialData"].__str__() == "SpatialDataContainerFormatV01": + # defaulting undefined element formats to element formats valid for 'SpatialDataContainerFormatV01' + for el_type, value in modified.items(): + if el_type != "SpatialData" and not value: + parsed[el_type] = ContainerV01DefaultTypes[el_type] + + if any( + (invalid := el_format.__str__()) not in ContainerFormatValidElements[parsed["SpatialData"].__str__()] + for el_type, el_format in parsed.items() + if el_type != "SpatialData" + ): + raise ValueError( + f"Unsupported format '{invalid}' for SpatialDataContainerFormat '{parsed['SpatialData'].__str__()}'. " + f"Please ensure all element formats are either of these: " + f"'{' '.join(f for f in ContainerFormatValidElements[parsed['SpatialData'].__str__()])}'" + ) + return parsed + + +def get_ome_zarr_format(raster_format: RasterFormatType) -> Format: + if isinstance(raster_format, RasterFormatV01 | RasterFormatV02): + return FormatV04() + if isinstance(raster_format, RasterFormatV03): + return FormatV05() + raise ValueError(f"Unsupported raster format {raster_format}") diff --git a/src/spatialdata/_io/io_points.py b/src/spatialdata/_io/io_points.py index 3106c8470..bc52c94ba 100644 --- a/src/spatialdata/_io/io_points.py +++ b/src/spatialdata/_io/io_points.py @@ -1,4 +1,3 @@ -import os from collections.abc import MutableMapping from pathlib import Path @@ -29,35 +28,53 @@ def _read_points( version = _parse_version(f, expect_attrs_key=True) assert version is not None - format = PointsFormats[version] + points_format = PointsFormats[version] - path = os.path.join(f._store.path, f.path, "points.parquet") + store_root = f.store_path.store.root + path = store_root / f.path / "points.parquet" # cache on remote file needed for parquet reader to work # TODO: allow reading in the metadata without caching all the data - points = read_parquet("simplecache::" + path if path.startswith("http") else path) + points = read_parquet("simplecache::" + str(path) if str(path).startswith("http") else path) assert isinstance(points, DaskDataFrame) transformations = _get_transformations_from_ngff_dict(f.attrs.asdict()["coordinateTransformations"]) _set_transformations(points, transformations) - attrs = format.attrs_from_dict(f.attrs.asdict()) + attrs = points_format.attrs_from_dict(f.attrs.asdict()) if len(attrs): points.attrs["spatialdata_attrs"] = attrs return points +class PointsReader: + def __call__(self, store: str | Path | MutableMapping[str, object] | zarr.Group) -> DaskDataFrame: + return _read_points(store) + + def write_points( points: DaskDataFrame, group: zarr.Group, - name: str, group_type: str = "ngff:points", - format: Format = CurrentPointsFormat(), + element_format: Format = CurrentPointsFormat(), ) -> None: + """Write a points element to a zarr store. + + Parameters + ---------- + points + The dataframe of the points element. + group + The zarr group in the 'points' zarr group to write the points element to. + group_type + The type of the element. + element_format + The format of the points element used to store it. + """ axes = get_axes_names(points) - t = _get_transformations(points) + transformations = _get_transformations(points) - points_groups = group.require_group(name) - path = Path(points_groups._store.path) / points_groups.path / "points.parquet" + store_root = group.store_path.store.root + path = store_root / group.path / "points.parquet" # The following code iterates through all columns in the 'points' DataFrame. If the column's datatype is # 'category', it checks whether the categories of this column are known. If not, it explicitly converts the @@ -72,14 +89,15 @@ def write_points( points.to_parquet(path) - attrs = format.attrs_to_dict(points.attrs) - attrs["version"] = format.spatialdata_format_version + attrs = element_format.attrs_to_dict(points.attrs) + attrs["version"] = element_format.spatialdata_format_version _write_metadata( - points_groups, + group, group_type=group_type, axes=list(axes), attrs=attrs, ) - assert t is not None - overwrite_coordinate_transformations_non_raster(group=points_groups, axes=axes, transformations=t) + if transformations is None: + raise ValueError(f"No transformations specified for element '{group.basename}'. Cannot write.") + overwrite_coordinate_transformations_non_raster(group=group, axes=axes, transformations=transformations) diff --git a/src/spatialdata/_io/io_raster.py b/src/spatialdata/_io/io_raster.py index 541be3ead..fe7f6b36b 100644 --- a/src/spatialdata/_io/io_raster.py +++ b/src/spatialdata/_io/io_raster.py @@ -6,7 +6,7 @@ import zarr from ome_zarr.format import Format from ome_zarr.io import ZarrLocation -from ome_zarr.reader import Label, Multiscales, Node, Reader +from ome_zarr.reader import Multiscales, Node, Reader from ome_zarr.types import JSONDict from ome_zarr.writer import _get_valid_axes from ome_zarr.writer import write_image as write_image_ngff @@ -21,9 +21,8 @@ ) from spatialdata._io.format import ( CurrentRasterFormat, - RasterFormats, - RasterFormatV01, - _parse_version, + RasterFormatType, + get_ome_zarr_format, ) from spatialdata._utils import get_pyramid_levels from spatialdata.models._utils import get_channel_names @@ -36,40 +35,45 @@ ) -def _read_multiscale(store: str | Path, raster_type: Literal["image", "labels"]) -> DataArray | DataTree: +def _read_multiscale( + store: str | Path, raster_type: Literal["image", "labels"], reader_format: Format +) -> DataArray | DataTree: assert isinstance(store, str | Path) assert raster_type in ["image", "labels"] - f = zarr.open(store, mode="r") - version = _parse_version(f, expect_attrs_key=True) - # old spatialdata datasets don't have format metadata for raster elements; this line ensure backwards compatibility, - # interpreting the lack of such information as the presence of the format v01 - format = RasterFormatV01() if version is None else RasterFormats[version] - f.store.close() - nodes: list[Node] = [] - image_loc = ZarrLocation(store) - if image_loc.exists(): + image_loc = ZarrLocation(store, fmt=reader_format) + if exists := image_loc.exists(): image_reader = Reader(image_loc)() image_nodes = list(image_reader) - if len(image_nodes): - for node in image_nodes: - if np.any([isinstance(spec, Multiscales) for spec in node.specs]) and ( - raster_type == "image" - and np.all([not isinstance(spec, Label) for spec in node.specs]) - or raster_type == "labels" - and np.any([isinstance(spec, Label) for spec in node.specs]) - ): - nodes.append(node) + nodes = _get_multiscale_nodes(image_nodes, nodes) + else: + raise OSError( + f"Image location {image_loc} does not seem to exist. If it does, potentially the zarr.json (or .zattrs) " + f"file inside is corrupted or not present or the image files themselves are corrupted." + ) if len(nodes) != 1: - raise ValueError( - f"len(nodes) = {len(nodes)}, expected 1. Unable to read the NGFF file. Please report this " - f"bug and attach a minimal data example." + if not exists: + raise ValueError( + f"len(nodes) = {len(nodes)}, expected 1 and image location {image_loc} " + "does not exist. Unable to read the NGFF file. Please report this bug " + "and attach a minimal data example." + ) + raise OSError( + f"Image location {image_loc} exists, but len(nodes) = {len(nodes)}, expected 1. Element " + f"{image_loc.basename()} is potentially corrupted. Please report this bug and attach a minimal data " + f"example." ) + node = nodes[0] - datasets = node.load(Multiscales).datasets - multiscales = node.load(Multiscales).zarr.root_attrs["multiscales"] - omero_metadata = node.load(Multiscales).zarr.root_attrs.get("omero", None) + loaded_node = node.load(Multiscales) + datasets, multiscales = ( + loaded_node.datasets, + loaded_node.zarr.root_attrs["multiscales"], + ) + # This works for all versions as in zarr v3 the level of the 'ome' key is taken as root_attrs. + omero_metadata = loaded_node.zarr.root_attrs.get("omero") + # TODO: check if below is still valid legacy_channels_metadata = node.load(Multiscales).zarr.root_attrs.get("channels_metadata", None) # legacy v0.1 assert len(multiscales) == 1 # checking for multiscales[0]["coordinateTransformations"] would make fail @@ -78,8 +82,6 @@ def _read_multiscale(store: str | Path, raster_type: Literal["image", "labels"]) # and for instance in the xenium example encoded_ngff_transformations = multiscales[0]["coordinateTransformations"] transformations = _get_transformations_from_ngff_dict(encoded_ngff_transformations) - # TODO: what to do with name? For now remove? - # name = os.path.basename(node.metadata["name"]) # if image, read channels metadata channels: list[Any] | None = None if raster_type == "image": @@ -91,7 +93,7 @@ def _read_multiscale(store: str | Path, raster_type: Literal["image", "labels"]) if len(datasets) > 1: multiscale_image = {} for i, d in enumerate(datasets): - data = node.load(Multiscales).array(resolution=d, version=format.version) + data = node.load(Multiscales).array(resolution=d) multiscale_image[f"scale{i}"] = Dataset( { "image": DataArray( @@ -105,7 +107,8 @@ def _read_multiscale(store: str | Path, raster_type: Literal["image", "labels"]) msi = DataTree.from_dict(multiscale_image) _set_transformations(msi, transformations) return compute_coordinates(msi) - data = node.load(Multiscales).array(resolution=datasets[0], version=format.version) + + data = node.load(Multiscales).array(resolution=datasets[0]) si = DataArray( data, name="image", @@ -116,38 +119,80 @@ def _read_multiscale(store: str | Path, raster_type: Literal["image", "labels"]) return compute_coordinates(si) +def _get_multiscale_nodes(image_nodes: list[Node], nodes: list[Node]) -> list[Node]: + """Get nodes with Multiscales spec from a list of nodes. + + The nodes with the Multiscales spec are the nodes used for reading in image and label data. We only have to check + the multiscales now, while before we also had to check the label spec. In the new ome-zarr-py though labels can have + the Label spec, these do not contain the multiscales anymore used to read the data. They can contain label specific + metadata though. + + Parameters + ---------- + image_nodes + List of nodes returned from the ome-zarr-py Reader. + nodes + List to append the nodes with the multiscales spec to. + + Returns + ------- + List of nodes with the multiscales spec. + """ + if len(image_nodes): + for node in image_nodes: + # Labels are now also Multiscales in newer version of ome-zarr-py + if np.any([isinstance(spec, Multiscales) for spec in node.specs]): + nodes.append(node) + return nodes + + +class MultiscaleReader: + def __call__( + self, path: str | Path, raster_type: Literal["image", "labels"], reader_format: Format + ) -> DataArray | DataTree: + return _read_multiscale(path, raster_type, reader_format) + + def _write_raster( raster_type: Literal["image", "labels"], raster_data: DataArray | DataTree, group: zarr.Group, name: str, - format: Format = CurrentRasterFormat(), + raster_format: RasterFormatType, storage_options: JSONDict | list[JSONDict] | None = None, label_metadata: JSONDict | None = None, **metadata: str | JSONDict | list[JSONDict], ) -> None: - assert raster_type in ["image", "labels"] - # the argument "name" and "label_metadata" are only used for labels (to be precise, name is used in - # write_multiscale_ngff() when writing metadata, but name is not used in write_image_ngff(). Maybe this is bug of - # ome-zarr-py. In any case, we don't need that metadata and we use the argument name so that when we write labels - # the correct group is created by the ome-zarr-py APIs. For images we do it manually in the function - # _get_group_for_writing_data() - if raster_type == "image": - assert label_metadata is None - else: + """Write raster data to disk. + + Parameters + ---------- + raster_type + Whether the raster data pertains to a image or labels 'SpatialElement`. + raster_data + The raster data to write. + group + The zarr group in the 'image' or 'labels' zarr group to write the raster data to. + name: str + The name of the raster element. + raster_format + The format used to write the raster data. + storage_options + Additional options for writing the raster data, like chunks and compression. + label_metadata + Label metadata which can only be defined when writing 'labels'. + metadata + Additional metadata for the raster element + """ + if raster_type not in ["image", "labels"]: + raise ValueError(f"{raster_type} is not a valid raster type. Must be 'image' or 'labels'.") + # "name" and "label_metadata" are only used for labels. "name" is written in write_multiscale_ngff() but ignored in + # write_image_ngff() (possibly an ome-zarr-py bug). We only use "name" to ensure correct group access in the + # ome-zarr API. + if raster_type == "labels": metadata["name"] = name metadata["label_metadata"] = label_metadata - write_single_scale_ngff = write_image_ngff if raster_type == "image" else write_labels_ngff - write_multi_scale_ngff = write_multiscale_ngff if raster_type == "image" else write_multiscale_labels_ngff - - group_data = group.require_group(name) if raster_type == "image" else group - - def _get_group_for_writing_transformations() -> zarr.Group: - if raster_type == "image": - return group.require_group(name) - return group["labels"][name] - # convert channel names to channel metadata in omero if raster_type == "image": metadata["metadata"] = {"omero": {"channels": []}} @@ -156,86 +201,175 @@ def _get_group_for_writing_transformations() -> zarr.Group: metadata["metadata"]["omero"]["channels"].append({"label": c}) # type: ignore[union-attr, index, call-overload] if isinstance(raster_data, DataArray): - data = raster_data.data - transformations = _get_transformations(raster_data) - input_axes: tuple[str, ...] = tuple(raster_data.dims) - chunks = raster_data.chunks - parsed_axes = _get_valid_axes(axes=list(input_axes), fmt=format) - if storage_options is not None: - if "chunks" not in storage_options and isinstance(storage_options, dict): - storage_options["chunks"] = chunks - else: - storage_options = {"chunks": chunks} - # Scaler needs to be None since we are passing the data already downscaled for the multiscale case. - # We need this because the argument of write_image_ngff is called image while the argument of - # write_labels_ngff is called label. - metadata[raster_type] = data - write_single_scale_ngff( - group=group_data, - scaler=None, - fmt=format, - axes=parsed_axes, - coordinate_transformations=None, - storage_options=storage_options, + _write_raster_dataarray( + raster_type, + group, + name, + raster_data, + raster_format, + storage_options, **metadata, ) - assert transformations is not None - overwrite_coordinate_transformations_raster( - group=_get_group_for_writing_transformations(), transformations=transformations, axes=input_axes - ) elif isinstance(raster_data, DataTree): - data = get_pyramid_levels(raster_data, attr="data") - list_of_input_axes: list[Any] = get_pyramid_levels(raster_data, attr="dims") - assert len(set(list_of_input_axes)) == 1 - input_axes = list_of_input_axes[0] - # saving only the transformations of the first scale - d = dict(raster_data["scale0"]) - assert len(d) == 1 - xdata = d.values().__iter__().__next__() - transformations = _get_transformations_xarray(xdata) - assert transformations is not None - assert len(transformations) > 0 - chunks = get_pyramid_levels(raster_data, "chunks") - # coords = iterate_pyramid_levels(raster_data, "coords") - parsed_axes = _get_valid_axes(axes=list(input_axes), fmt=format) - storage_options = [{"chunks": chunk} for chunk in chunks] - dask_delayed = write_multi_scale_ngff( - pyramid=data, - group=group_data, - fmt=format, - axes=parsed_axes, - coordinate_transformations=None, - storage_options=storage_options, + _write_raster_datatree( + raster_type, + group, + name, + raster_data, + raster_format, + storage_options, **metadata, - compute=False, - ) - # Compute all pyramid levels at once to allow Dask to optimize the computational graph. - da.compute(*dask_delayed) - assert transformations is not None - overwrite_coordinate_transformations_raster( - group=_get_group_for_writing_transformations(), transformations=transformations, axes=tuple(input_axes) ) else: raise ValueError("Not a valid labels object") - # as explained in a comment in format.py, since coordinate transformations are not part of NGFF yet, we need to have - # our spatialdata extension also for raster type (eventually it will be dropped in favor of pure NGFF). Until then, - # saving the NGFF version (i.e. 0.4) is not enough, and we need to also record which version of the spatialdata - # format we are using for raster types - group = _get_group_for_writing_transformations() + group = group["labels"][name] if raster_type == "labels" else group if ATTRS_KEY not in group.attrs: group.attrs[ATTRS_KEY] = {} attrs = group.attrs[ATTRS_KEY] - attrs["version"] = format.spatialdata_format_version + attrs["version"] = raster_format.spatialdata_format_version # triggers the write operation group.attrs[ATTRS_KEY] = attrs +def _write_raster_dataarray( + raster_type: Literal["image", "labels"], + group: zarr.Group, + element_name: str, + raster_data: DataArray, + raster_format: RasterFormatType, + storage_options: JSONDict | list[JSONDict] | None = None, + **metadata: str | JSONDict | list[JSONDict], +) -> None: + """Write raster data of type DataArray to disk. + + Parameters + ---------- + raster_type + Whether the raster data pertains to a image or labels 'SpatialElement`. + group + The zarr group in the 'image' or 'labels' zarr group to write the raster data to. + element_name + The name of the raster element. + raster_data + The raster data to write. + raster_format + The format used to write the raster data. + storage_options + Additional options for writing the raster data, like chunks and compression. + metadata + Additional metadata for the raster element + """ + write_single_scale_ngff = write_image_ngff if raster_type == "image" else write_labels_ngff + + data = raster_data.data + transformations = _get_transformations(raster_data) + if transformations is None: + raise ValueError(f"{element_name} does not have any transformations and can therefore not be written.") + input_axes: tuple[str, ...] = tuple(raster_data.dims) + chunks = raster_data.chunks + parsed_axes = _get_valid_axes(axes=list(input_axes), fmt=raster_format) + if storage_options is not None: + if "chunks" not in storage_options and isinstance(storage_options, dict): + storage_options["chunks"] = chunks + else: + storage_options = {"chunks": chunks} + # Scaler needs to be None since we are passing the data already downscaled for the multiscale case. + # We need this because the argument of write_image_ngff is called image while the argument of + # write_labels_ngff is called label. + metadata[raster_type] = data + ome_zarr_format = get_ome_zarr_format(raster_format) + write_single_scale_ngff( + group=group, + scaler=None, + fmt=ome_zarr_format, + axes=parsed_axes, + coordinate_transformations=None, + storage_options=storage_options, + **metadata, + ) + + trans_group = group["labels"][element_name] if raster_type == "labels" else group + overwrite_coordinate_transformations_raster( + group=trans_group, + transformations=transformations, + axes=input_axes, + raster_format=raster_format, + ) + + +def _write_raster_datatree( + raster_type: Literal["image", "labels"], + group: zarr.Group, + element_name: str, + raster_data: DataTree, + raster_format: RasterFormatType, + storage_options: JSONDict | list[JSONDict] | None = None, + **metadata: str | JSONDict | list[JSONDict], +) -> None: + """Write raster data of type DataTree to disk. + + Parameters + ---------- + raster_type + Whether the raster data pertains to a image or labels 'SpatialElement`. + group + The zarr group in the 'image' or 'labels' zarr group to write the raster data to. + element_name + The name of the raster element. + raster_data + The raster data to write. + raster_format + The format used to write the raster data. + storage_options + Additional options for writing the raster data, like chunks and compression. + metadata + Additional metadata for the raster element + """ + write_multi_scale_ngff = write_multiscale_ngff if raster_type == "image" else write_multiscale_labels_ngff + data = get_pyramid_levels(raster_data, attr="data") + list_of_input_axes: list[Any] = get_pyramid_levels(raster_data, attr="dims") + assert len(set(list_of_input_axes)) == 1 + input_axes = list_of_input_axes[0] + # saving only the transformations of the first scale + d = dict(raster_data["scale0"]) + assert len(d) == 1 + xdata = d.values().__iter__().__next__() + transformations = _get_transformations_xarray(xdata) + if transformations is None: + raise ValueError(f"{element_name} does not have any transformations and can therefore not be written.") + chunks = get_pyramid_levels(raster_data, "chunks") + + parsed_axes = _get_valid_axes(axes=list(input_axes), fmt=raster_format) + storage_options = [{"chunks": chunk} for chunk in chunks] + ome_zarr_format = get_ome_zarr_format(raster_format) + dask_delayed = write_multi_scale_ngff( + pyramid=data, + group=group, + fmt=ome_zarr_format, + axes=parsed_axes, + coordinate_transformations=None, + storage_options=storage_options, + **metadata, + compute=False, + ) + # Compute all pyramid levels at once to allow Dask to optimize the computational graph. + da.compute(*dask_delayed) + + trans_group = group["labels"][element_name] if raster_type == "labels" else group + overwrite_coordinate_transformations_raster( + group=trans_group, + transformations=transformations, + axes=tuple(input_axes), + raster_format=raster_format, + ) + + def write_image( image: DataArray | DataTree, group: zarr.Group, name: str, - format: Format = CurrentRasterFormat(), + element_format: RasterFormatType = CurrentRasterFormat(), storage_options: JSONDict | list[JSONDict] | None = None, **metadata: str | JSONDict | list[JSONDict], ) -> None: @@ -244,7 +378,7 @@ def write_image( raster_data=image, group=group, name=name, - format=format, + raster_format=element_format, storage_options=storage_options, **metadata, ) @@ -254,7 +388,7 @@ def write_labels( labels: DataArray | DataTree, group: zarr.Group, name: str, - format: Format = CurrentRasterFormat(), + element_format: RasterFormatType = CurrentRasterFormat(), storage_options: JSONDict | list[JSONDict] | None = None, label_metadata: JSONDict | None = None, **metadata: JSONDict, @@ -264,7 +398,7 @@ def write_labels( raster_data=labels, group=group, name=name, - format=format, + raster_format=element_format, storage_options=storage_options, label_metadata=label_metadata, **metadata, diff --git a/src/spatialdata/_io/io_shapes.py b/src/spatialdata/_io/io_shapes.py index c32ce1f34..15efdc471 100644 --- a/src/spatialdata/_io/io_shapes.py +++ b/src/spatialdata/_io/io_shapes.py @@ -1,9 +1,11 @@ from collections.abc import MutableMapping from pathlib import Path +from typing import Any import numpy as np import zarr from geopandas import GeoDataFrame, read_parquet +from natsort import natsorted from ome_zarr.format import Format from shapely import from_ragged_array, to_ragged_array @@ -17,6 +19,7 @@ ShapesFormats, ShapesFormatV01, ShapesFormatV02, + ShapesFormatV03, _parse_version, ) from spatialdata.models import ShapesModel, get_axes_names @@ -27,34 +30,36 @@ def _read_shapes( - store: str | Path | MutableMapping | zarr.Group, # type: ignore[type-arg] + store: str | Path | MutableMapping[str, object] | zarr.Group, ) -> GeoDataFrame: """Read shapes from a zarr store.""" assert isinstance(store, str | Path) f = zarr.open(store, mode="r") version = _parse_version(f, expect_attrs_key=True) assert version is not None - format = ShapesFormats[version] + shape_format = ShapesFormats[version] - if isinstance(format, ShapesFormatV01): + if isinstance(shape_format, ShapesFormatV01): coords = np.array(f["coords"]) index = np.array(f["Index"]) - typ = format.attrs_from_dict(f.attrs.asdict()) + typ = shape_format.attrs_from_dict(f.attrs.asdict()) if typ.name == "POINT": radius = np.array(f["radius"]) geometry = from_ragged_array(typ, coords) geo_df = GeoDataFrame({"geometry": geometry, "radius": radius}, index=index) else: offsets_keys = [k for k in f if k.startswith("offset")] + offsets_keys = natsorted(offsets_keys) offsets = tuple(np.array(f[k]).flatten() for k in offsets_keys) geometry = from_ragged_array(typ, coords, offsets) geo_df = GeoDataFrame({"geometry": geometry}, index=index) - elif isinstance(format, ShapesFormatV02): - path = Path(f._store.path) / f.path / "shapes.parquet" + elif isinstance(shape_format, ShapesFormatV02 | ShapesFormatV03): + store_root = f.store_path.store.root + path = Path(store_root) / f.path / "shapes.parquet" geo_df = read_parquet(path) else: raise ValueError( - f"Unsupported shapes format {format} from version {version}. Please update the spatialdata library." + f"Unsupported shapes format {shape_format} from version {version}. Please update the spatialdata library." ) transformations = _get_transformations_from_ngff_dict(f.attrs.asdict()["coordinateTransformations"]) @@ -62,50 +67,99 @@ def _read_shapes( return geo_df +class ShapesReader: + def __call__(self, store: str | Path | MutableMapping[str, object] | zarr.Group) -> GeoDataFrame: + return _read_shapes(store) + + def write_shapes( shapes: GeoDataFrame, group: zarr.Group, - name: str, group_type: str = "ngff:shapes", - format: Format = CurrentShapesFormat(), + element_format: Format = CurrentShapesFormat(), ) -> None: - import numcodecs + """Write shapes to spatialdata zarr store. + + Note that the parquet file is not recognized as part of the zarr hierarchy as it is not a valid component of a + zarr store, e.g. group, array or metadata file. + Parameters + ---------- + shapes + The shapes dataframe + group + The zarr group in the 'shapes' zarr group to write the shapes element to. + group_type + The type of the element. + element_format + The format of the shapes element used to store it. + """ axes = get_axes_names(shapes) - t = _get_transformations(shapes) - - shapes_group = group.require_group(name) - - if isinstance(format, ShapesFormatV01): - geometry, coords, offsets = to_ragged_array(shapes.geometry) - shapes_group.create_dataset(name="coords", data=coords) - for i, o in enumerate(offsets): - shapes_group.create_dataset(name=f"offset{i}", data=o) - if shapes.index.dtype.kind == "U" or shapes.index.dtype.kind == "O": - shapes_group.create_dataset( - name="Index", data=shapes.index.values, dtype=object, object_codec=numcodecs.VLenUTF8() - ) - else: - shapes_group.create_dataset(name="Index", data=shapes.index.values) - if geometry.name == "POINT": - shapes_group.create_dataset(name=ShapesModel.RADIUS_KEY, data=shapes[ShapesModel.RADIUS_KEY].values) - - attrs = format.attrs_to_dict(geometry) - attrs["version"] = format.spatialdata_format_version - elif isinstance(format, ShapesFormatV02): - path = Path(shapes_group._store.path) / shapes_group.path / "shapes.parquet" - shapes.to_parquet(path) - - attrs = format.attrs_to_dict(shapes.attrs) - attrs["version"] = format.spatialdata_format_version + transformations = _get_transformations(shapes) + if transformations is None: + raise ValueError(f"{group.basename} does not have any transformations and can therefore not be written.") + if isinstance(element_format, ShapesFormatV01): + attrs = _write_shapes_v01(shapes, group, element_format) + elif isinstance(element_format, ShapesFormatV02 | ShapesFormatV03): + attrs = _write_shapes_v02_v03(shapes, group, element_format) else: - raise ValueError(f"Unsupported format version {format.version}. Please update the spatialdata library.") + raise ValueError(f"Unsupported format version {element_format.version}. Please update the spatialdata library.") _write_metadata( - shapes_group, + group, group_type=group_type, axes=list(axes), attrs=attrs, ) - assert t is not None - overwrite_coordinate_transformations_non_raster(group=shapes_group, axes=axes, transformations=t) + overwrite_coordinate_transformations_non_raster(group=group, axes=axes, transformations=transformations) + + +def _write_shapes_v01(shapes: GeoDataFrame, group: zarr.Group, element_format: Format) -> Any: + """Write shapes to spatialdata zarr store using format ShapesFormatV01. + + Parameters + ---------- + shapes + The shapes dataframe + group + The zarr group in the 'shapes' zarr group to write the shapes element to. + element_format + The format of the shapes element used to store it. + """ + import numcodecs + + geometry, coords, offsets = to_ragged_array(shapes.geometry) + group.create_array(name="coords", data=coords) + for i, o in enumerate(offsets): + group.create_array(name=f"offset{i}", data=o) + if shapes.index.dtype.kind == "U" or shapes.index.dtype.kind == "O": + group.create_array(name="Index", data=shapes.index.values, dtype=object, object_codec=numcodecs.VLenUTF8()) + else: + group.create_array(name="Index", data=shapes.index.values) + if geometry.name == "POINT": + group.create_array(name=ShapesModel.RADIUS_KEY, data=shapes[ShapesModel.RADIUS_KEY].values) + + attrs = element_format.attrs_to_dict(geometry) + attrs["version"] = element_format.spatialdata_format_version + return attrs + + +def _write_shapes_v02_v03(shapes: GeoDataFrame, group: zarr.Group, element_format: Format) -> Any: + """Write shapes to spatialdata zarr store using format ShapesFormatV02 or ShapesFormatV03. + + Parameters + ---------- + shapes + The shapes dataframe + group + The zarr group in the 'shapes' zarr group to write the shapes element to. + element_format + The format of the shapes element used to store it. + """ + store_root = group.store_path.store.root + path = store_root / group.path / "shapes.parquet" + shapes.to_parquet(path) + + attrs = element_format.attrs_to_dict(shapes.attrs) + attrs["version"] = element_format.spatialdata_format_version + return attrs diff --git a/src/spatialdata/_io/io_table.py b/src/spatialdata/_io/io_table.py index 92ff64b94..9e71c4a3c 100644 --- a/src/spatialdata/_io/io_table.py +++ b/src/spatialdata/_io/io_table.py @@ -13,16 +13,15 @@ from spatialdata._io._utils import BadFileHandleMethod, handle_read_errors from spatialdata._io.format import CurrentTablesFormat, TablesFormats, _parse_version from spatialdata._logging import logger -from spatialdata.models import TableModel +from spatialdata.models import TableModel, get_table_keys def _read_table( zarr_store_path: str, group: zarr.Group, - subgroup: zarr.Group, tables: dict[str, AnnData], on_bad_files: Literal[BadFileHandleMethod.ERROR, BadFileHandleMethod.WARN] = BadFileHandleMethod.ERROR, -) -> dict[str, AnnData]: +) -> None: """ Read in tables in the tables Zarr.group of a SpatialData Zarr store. @@ -31,9 +30,7 @@ def _read_table( zarr_store_path The path to the Zarr store. group - The parent group containing the subgroup. - subgroup - The subgroup containing the tables. + The tables parent group containing the subgroup. tables A dictionary of tables. on_bad_files @@ -44,14 +41,19 @@ def _read_table( The modified dictionary with the tables. """ count = 0 - for table_name in subgroup: - f_elem = subgroup[table_name] + for table_name in group: + f_elem = group[table_name] f_elem_store = os.path.join(zarr_store_path, f_elem.path) with handle_read_errors( on_bad_files=on_bad_files, - location=f"{subgroup.path}/{table_name}", - exc_types=(JSONDecodeError, KeyError, ValueError, ArrayNotFoundError), + location=f"{group.path}/{table_name}", + exc_types=( + JSONDecodeError, + KeyError, + ValueError, + ArrayNotFoundError, + ), ): tables[table_name] = read_anndata_zarr(f_elem_store) @@ -63,6 +65,7 @@ def _read_table( _ = TablesFormats[version] f.store.close() + # TODO: implement per read logic of format # # replace with format from above # version = "0.1" # format = TablesFormats[version] @@ -81,8 +84,18 @@ def _read_table( count += 1 - logger.debug(f"Found {count} elements in {subgroup}") - return tables + logger.debug(f"Found {count} elements in {group}") + + +class TablesReader: + def __call__( + self, + path: str, + group: zarr.Group, + container: dict[str, AnnData], + on_bad_files: Literal[BadFileHandleMethod.ERROR, BadFileHandleMethod.WARN], + ) -> None: + return _read_table(path, group, container, on_bad_files) def write_table( @@ -90,19 +103,18 @@ def write_table( group: zarr.Group, name: str, group_type: str = "ngff:regions_table", - format: Format = CurrentTablesFormat(), + element_format: Format = CurrentTablesFormat(), ) -> None: if TableModel.ATTRS_KEY in table.uns: - region = table.uns["spatialdata_attrs"]["region"] - region_key = table.uns["spatialdata_attrs"].get("region_key", None) - instance_key = table.uns["spatialdata_attrs"].get("instance_key", None) - format.validate_table(table, region_key, instance_key) + region, region_key, instance_key = get_table_keys(table) + TableModel().validate(table) else: region, region_key, instance_key = (None, None, None) - write_adata(group, name, table) # creates group[name] + + write_adata(group, name, table) tables_group = group[name] tables_group.attrs["spatialdata-encoding-type"] = group_type tables_group.attrs["region"] = region tables_group.attrs["region_key"] = region_key tables_group.attrs["instance_key"] = instance_key - tables_group.attrs["version"] = format.spatialdata_format_version + tables_group.attrs["version"] = element_format.spatialdata_format_version diff --git a/src/spatialdata/_io/io_zarr.py b/src/spatialdata/_io/io_zarr.py index 224ef1129..5367dcf0c 100644 --- a/src/spatialdata/_io/io_zarr.py +++ b/src/spatialdata/_io/io_zarr.py @@ -1,47 +1,104 @@ -import logging import os import warnings from json import JSONDecodeError from pathlib import Path -from typing import Literal +from typing import Literal, cast -import zarr +import zarr.storage from anndata import AnnData +from ome_zarr.format import Format from pyarrow import ArrowInvalid -from zarr.errors import ArrayNotFoundError, MetadataError +from zarr.errors import ArrayNotFoundError from spatialdata._core.spatialdata import SpatialData -from spatialdata._io._utils import BadFileHandleMethod, handle_read_errors, ome_zarr_logger -from spatialdata._io.io_points import _read_points -from spatialdata._io.io_raster import _read_multiscale -from spatialdata._io.io_shapes import _read_shapes -from spatialdata._io.io_table import _read_table +from spatialdata._io._utils import ( + BadFileHandleMethod, + _resolve_zarr_store, + handle_read_errors, +) +from spatialdata._io.io_points import PointsReader +from spatialdata._io.io_raster import MultiscaleReader +from spatialdata._io.io_shapes import ShapesReader +from spatialdata._io.io_table import TablesReader from spatialdata._logging import logger +from spatialdata.models import SpatialElement +ReadClasses = MultiscaleReader | PointsReader | ShapesReader | TablesReader -def _open_zarr_store(store: str | Path | zarr.Group) -> tuple[zarr.Group, str]: - """ - Open a zarr store (on-disk or remote) and return the zarr.Group object and the path to the store. + +def _read_zarr_group_spatialdata_element( + root_group: zarr.Group, + root_store_path: str, + sdata_version: Literal["0.1", "0.2"], + selector: set[str], + read_func: ReadClasses, + group_name: Literal["images", "labels", "shapes", "points", "tables"], + element_type: Literal["image", "labels", "shapes", "points", "tables"], + element_container: dict[str, SpatialElement | AnnData], + on_bad_files: Literal[BadFileHandleMethod.ERROR, BadFileHandleMethod.WARN], +) -> None: + with handle_read_errors( + on_bad_files, + location=group_name, + exc_types=JSONDecodeError, + ): + if group_name in selector and group_name in root_group: + group = root_group[group_name] + if isinstance(read_func, TablesReader): + read_func(root_store_path, group, element_container, on_bad_files=on_bad_files) + else: + count = 0 + for subgroup_name in group: + if Path(subgroup_name).name.startswith("."): + # skip hidden files like .zgroup or .zmetadata + continue + elem_group = group[subgroup_name] + elem_group_path = os.path.join(root_store_path, elem_group.path) + with handle_read_errors( + on_bad_files, + location=f"{group.path}/{subgroup_name}", + exc_types=(KeyError, ArrayNotFoundError, OSError, ArrowInvalid, JSONDecodeError), + ): + if isinstance(read_func, MultiscaleReader): + reader_format = get_raster_format_for_read(elem_group, sdata_version) + element = read_func( + elem_group_path, cast(Literal["image", "labels"], element_type), reader_format + ) + if isinstance(read_func, PointsReader | ShapesReader): + element = read_func(elem_group_path) + element_container[subgroup_name] = element + count += 1 + logger.debug(f"Found {count} elements in {group}") + + +def get_raster_format_for_read(group: zarr.Group, sdata_version: Literal["0.1", "0.2"]) -> Format: + """Get raster format of stored raster data. + + This checks the image or label element zarr group metadata to retrieve the format that is used by + ome-zarr's ZarrLocation for reading the data. Parameters ---------- - store - Path to the zarr store (on-disk or remote) or a zarr.Group object. + group + The zarr group of the raster element to be read. + sdata_version + The version of the SpatialData zarr store retrieved from the spatialdata attributes. Returns ------- - A tuple of the zarr.Group object and the path to the store. + The ome-zarr format to use for reading the raster element. """ - f = store if isinstance(store, zarr.Group) else zarr.open(store, mode="r") - # workaround: .zmetadata is being written as zmetadata (https://github.com/zarr-developers/zarr-python/issues/1121) - if isinstance(store, str | Path) and str(store).startswith("http") and len(f) == 0: - f = zarr.open_consolidated(store, mode="r", metadata_key="zmetadata") - f_store_path = f.store.store.path if isinstance(f.store, zarr.storage.ConsolidatedMetadataStore) else f.store.path - return f, f_store_path + from spatialdata._io.format import SdataVersion_to_Format + + if sdata_version == "0.1": + group_version = group.metadata.attributes["multiscales"][0]["version"] + if sdata_version == "0.2": + group_version = group.metadata.attributes["ome"]["version"] + return SdataVersion_to_Format[group_version] def read_zarr( - store: str | Path | zarr.Group, + store: str | Path, selection: None | tuple[str] = None, on_bad_files: Literal[BadFileHandleMethod.ERROR, BadFileHandleMethod.WARN] = BadFileHandleMethod.ERROR, ) -> SpatialData: @@ -51,7 +108,7 @@ def read_zarr( Parameters ---------- store - Path to the zarr store (on-disk or remote) or a zarr.Group object. + Path to the zarr store (on-disk or remote). selection List of elements to read from the zarr store (images, labels, points, shapes, table). If None, all elements are @@ -71,154 +128,56 @@ def read_zarr( ------- A SpatialData object. """ - f, f_store_path = _open_zarr_store(store) + from spatialdata._io._utils import _resolve_zarr_store + + resolved_store = _resolve_zarr_store(store) + root_group = zarr.open_group(resolved_store, mode="r") + sdata_version = root_group.metadata.attributes["spatialdata_attrs"]["version"] + if sdata_version == "0.1": + warnings.warn( + "SpatialData is not stored in the most current format. If you want to use Zarr v3" + ", please write the store to a new location.", + UserWarning, + stacklevel=2, + ) + root_store_path = root_group.store.root - images = {} - labels = {} - points = {} + images: dict[str, SpatialElement] = {} + labels: dict[str, SpatialElement] = {} + points: dict[str, SpatialElement] = {} tables: dict[str, AnnData] = {} - shapes = {} + shapes: dict[str, SpatialElement] = {} - # TODO: remove table once deprecated. - selector = {"images", "labels", "points", "shapes", "tables", "table"} if not selection else set(selection or []) + selector = {"images", "labels", "points", "shapes", "tables"} if not selection else set(selection or []) logger.debug(f"Reading selection {selector}") - # read multiscale images - if "images" in selector and "images" in f: - with handle_read_errors( - on_bad_files, - location="images", - exc_types=(JSONDecodeError, MetadataError), - ): - group = f["images"] - count = 0 - for subgroup_name in group: - if Path(subgroup_name).name.startswith("."): - # skip hidden files like .zgroup or .zmetadata - continue - f_elem = group[subgroup_name] - f_elem_store = os.path.join(f_store_path, f_elem.path) - with handle_read_errors( - on_bad_files, - location=f"{group.path}/{subgroup_name}", - exc_types=( - JSONDecodeError, # JSON parse error - ValueError, # ome_zarr: Unable to read the NGFF file - KeyError, # Missing JSON key - ArrayNotFoundError, # Image chunks missing - TypeError, # instead of ArrayNotFoundError, with dask>=2024.10.0 zarr<=2.18.3 - ), - ): - element = _read_multiscale(f_elem_store, raster_type="image") - images[subgroup_name] = element - count += 1 - logger.debug(f"Found {count} elements in {group}") - - # read multiscale labels - with ome_zarr_logger(logging.ERROR): - if "labels" in selector and "labels" in f: - with handle_read_errors( - on_bad_files, - location="labels", - exc_types=(JSONDecodeError, MetadataError), - ): - group = f["labels"] - count = 0 - for subgroup_name in group: - if Path(subgroup_name).name.startswith("."): - # skip hidden files like .zgroup or .zmetadata - continue - f_elem = group[subgroup_name] - f_elem_store = os.path.join(f_store_path, f_elem.path) - with handle_read_errors( - on_bad_files, - location=f"{group.path}/{subgroup_name}", - exc_types=(JSONDecodeError, KeyError, ValueError, ArrayNotFoundError, TypeError), - ): - labels[subgroup_name] = _read_multiscale(f_elem_store, raster_type="labels") - count += 1 - logger.debug(f"Found {count} elements in {group}") - - # now read rest of the data - if "points" in selector and "points" in f: - with handle_read_errors( - on_bad_files, - location="points", - exc_types=(JSONDecodeError, MetadataError), - ): - group = f["points"] - count = 0 - for subgroup_name in group: - f_elem = group[subgroup_name] - if Path(subgroup_name).name.startswith("."): - # skip hidden files like .zgroup or .zmetadata - continue - f_elem_store = os.path.join(f_store_path, f_elem.path) - with handle_read_errors( - on_bad_files, - location=f"{group.path}/{subgroup_name}", - exc_types=(JSONDecodeError, KeyError, ArrowInvalid), - ): - points[subgroup_name] = _read_points(f_elem_store) - count += 1 - logger.debug(f"Found {count} elements in {group}") - - if "shapes" in selector and "shapes" in f: - with handle_read_errors( - on_bad_files, - location="shapes", - exc_types=(JSONDecodeError, MetadataError), - ): - group = f["shapes"] - count = 0 - for subgroup_name in group: - if Path(subgroup_name).name.startswith("."): - # skip hidden files like .zgroup or .zmetadata - continue - f_elem = group[subgroup_name] - f_elem_store = os.path.join(f_store_path, f_elem.path) - with handle_read_errors( - on_bad_files, - location=f"{group.path}/{subgroup_name}", - exc_types=( - JSONDecodeError, - ValueError, - KeyError, - ArrayNotFoundError, - ), - ): - shapes[subgroup_name] = _read_shapes(f_elem_store) - count += 1 - logger.debug(f"Found {count} elements in {group}") - if "tables" in selector and "tables" in f: - with handle_read_errors( + group_readers: dict[ + Literal["images", "labels", "shapes", "points", "tables"], + tuple[ + ReadClasses, Literal["image", "labels", "shapes", "points", "tables"], dict[str, SpatialElement | AnnData] + ], + ] = { + "images": (MultiscaleReader(), "image", images), + "labels": (MultiscaleReader(), "labels", labels), + "points": (PointsReader(), "points", points), + "shapes": (ShapesReader(), "shapes", shapes), + "tables": (TablesReader(), "tables", tables), + } + for group_name, (reader, raster_type, container) in group_readers.items(): + _read_zarr_group_spatialdata_element( + root_group, + root_store_path, + sdata_version, + selector, + reader, + group_name, + raster_type, + container, on_bad_files, - location="tables", - exc_types=(JSONDecodeError, MetadataError), - ): - group = f["tables"] - tables = _read_table(f_store_path, f, group, tables, on_bad_files=on_bad_files) - - if "table" in selector and "table" in f: - warnings.warn( - f"Table group found in zarr store at location {f_store_path}. Please update the zarr store to use tables " - f"instead.", - DeprecationWarning, - stacklevel=2, ) - subgroup_name = "table" - with handle_read_errors( - on_bad_files, - location=subgroup_name, - exc_types=(JSONDecodeError, MetadataError), - ): - group = f[subgroup_name] - tables = _read_table(f_store_path, f, group, tables, on_bad_files=on_bad_files) - - logger.debug(f"Found {count} elements in {group}") # read attrs metadata - attrs = f.attrs.asdict() + attrs = root_group.attrs.asdict() if "spatialdata_attrs" in attrs: # when refactoring the read_zarr function into reading componenets separately (and according to the version), # we can move the code below (.pop()) into attrs_from_dict() @@ -236,3 +195,100 @@ def read_zarr( ) sdata.path = Path(store) return sdata + + +def _get_groups_for_element( + zarr_path: Path, element_type: str, element_name: str, use_consolidated: bool = True +) -> tuple[zarr.Group, zarr.Group, zarr.Group]: + """ + Get the Zarr groups for the root, element_type and element for a specific element. + + The store must exist, but creates the element type group and the element group if they don't exist. In all cases + the zarr group will also be opened. When writing data to disk this should always be done with 'use_consolidated' + being 'False'. If a user wrote the data previously with consolidation of the metadata and then they write new data + in the zarr store, it can give errors otherwise, due to partially written elements not yet being present in the + consolidated metadata store, e.g. when first writing the element and then opening the zarr group again for writing + transformations. + + Parameters + ---------- + zarr_path + The path to the Zarr storage. + element_type + type of the element; must be in ["images", "labels", "points", "polygons", "shapes", "tables"]. + element_name + name of the element + use_consolidated + whether to open zarr groups using consolidated metadata. This should be false when writing as we open + zarr groups multiple times when writing an element. If the consolidated metadata store is out of sync with + what is written on disk this leads to errors. + + Returns + ------- + The Zarr groups for the root, element_type and element for a specific element. + """ + if not isinstance(zarr_path, Path): + raise ValueError("zarr_path should be a Path object") + + if element_type not in [ + "images", + "labels", + "points", + "polygons", + "shapes", + "tables", + ]: + raise ValueError(f"Unknown element type {element_type}") + resolved_store = _resolve_zarr_store(zarr_path) + + # When writing, use_consolidated must be set to False. Otherwise, the metadata store + # can get out of sync with newly added elements (e.g., labels), leading to errors. + root_group = zarr.open_group(store=resolved_store, mode="r+", use_consolidated=use_consolidated) + element_type_group = root_group.require_group(element_type) + element_type_group = zarr.open_group(element_type_group.store_path, mode="a", use_consolidated=use_consolidated) + + element_name_group = element_type_group.require_group(element_name) + return root_group, element_type_group, element_name_group + + +def _group_for_element_exists(zarr_path: Path, element_type: str, element_name: str) -> bool: + """ + Check if the group for an element exists. + + Parameters + ---------- + element_type + type of the element; must be in ["images", "labels", "points", "polygons", "shapes", "tables"]. + element_name + name of the element + + Returns + ------- + True if the group exists, False otherwise. + """ + store = _resolve_zarr_store(zarr_path) + root = zarr.open_group(store=store, mode="r") + assert element_type in [ + "images", + "labels", + "points", + "polygons", + "shapes", + "tables", + ] + exists = element_type in root and element_name in root[element_type] + store.close() + return exists + + +def _write_consolidated_metadata(path: Path | str | None) -> None: + if path is not None: + f = zarr.open_group(path, mode="r+", use_consolidated=False) + # .parquet files are not recognized as proper zarr and thus throw a warning. This does not affect SpatialData. + # and therefore we silence it for our users as they can't do anything about this. + # TODO check with remote PR whether we can prevent this warning at least for points data and whether with zarrv3 + # that pr would still work. + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=zarr.errors.ZarrUserWarning) + zarr.consolidate_metadata(f.store) + f.store.close() diff --git a/src/spatialdata/dataloader/datasets.py b/src/spatialdata/dataloader/datasets.py index 4dc433235..4bf14a916 100644 --- a/src/spatialdata/dataloader/datasets.py +++ b/src/spatialdata/dataloader/datasets.py @@ -258,9 +258,7 @@ def _preprocess( if table_name is not None: table_subset = filtered_table[filtered_table.obs[region_key] == region_name] - circles_sdata = SpatialData.init_from_elements( - {region_name: circles}, tables={"table": table_subset.copy()} - ) + circles_sdata = SpatialData.init_from_elements({region_name: circles, "table": table_subset.copy()}) _, table = join_spatialelement_table( sdata=circles_sdata, spatial_element_names=region_name, diff --git a/src/spatialdata/datasets.py b/src/spatialdata/datasets.py index 4b3d61f6a..63c137cdc 100644 --- a/src/spatialdata/datasets.py +++ b/src/spatialdata/datasets.py @@ -1,5 +1,6 @@ """SpatialData datasets.""" +import warnings from typing import Any, Literal import dask.dataframe.core @@ -18,7 +19,6 @@ from spatialdata._core.operations.aggregate import aggregate from spatialdata._core.query.relational_query import get_element_instances from spatialdata._core.spatialdata import SpatialData -from spatialdata._logging import logger from spatialdata._types import ArrayLike from spatialdata.models import Image2DModel, Labels2DModel, PointsModel, ShapesModel, TableModel from spatialdata.transformations import Identity @@ -126,9 +126,11 @@ def __init__( self.c_coords = c_coords if c_coords is not None: if n_channels != len(c_coords): - logger.info( + warnings.warn( f"Number of channels ({n_channels}) and c_coords ({len(c_coords)}) do not match; ignoring " - f"n_channels value" + f"n_channels value", + UserWarning, + stacklevel=2, ) n_channels = len(c_coords) self.n_channels = n_channels @@ -160,7 +162,7 @@ def blobs( labels={"blobs_labels": labels, "blobs_multiscale_labels": multiscale_labels}, points={"blobs_points": points}, shapes={"blobs_circles": circles, "blobs_polygons": polygons, "blobs_multipolygons": multipolygons}, - tables=table, + tables={"table": table}, ) def _image_blobs( @@ -365,7 +367,7 @@ def blobs_annotating_element(name: BlobsTypes) -> SpatialData: index = sdata[name].index instance_id = index.compute().tolist() if isinstance(index, dask.dataframe.core.Index) else index.tolist() n = len(instance_id) - new_table = AnnData(shape=(n, 0), obs={"region": [name for _ in range(n)], "instance_id": instance_id}) + new_table = AnnData(shape=(n, 0), obs={"region": pd.Categorical([name] * n), "instance_id": instance_id}) new_table = TableModel.parse(new_table, region=name, region_key="region", instance_key="instance_id") del sdata.tables["table"] sdata["table"] = new_table diff --git a/src/spatialdata/models/__init__.py b/src/spatialdata/models/__init__.py index 3c86fa0ec..ba064e0a6 100644 --- a/src/spatialdata/models/__init__.py +++ b/src/spatialdata/models/__init__.py @@ -8,7 +8,6 @@ force_2d, get_axes_names, get_channel_names, - get_channels, get_spatial_axes, points_dask_dataframe_to_geopandas, points_geopandas_to_dask_dataframe, @@ -50,9 +49,7 @@ "points_dask_dataframe_to_geopandas", "check_target_region_column_symmetry", "get_table_keys", - "get_channels", "get_channel_names", "set_channel_names", "force_2d", - "RasterSchema", ] diff --git a/src/spatialdata/models/_utils.py b/src/spatialdata/models/_utils.py index eeaa7ecd9..45b2dd4cd 100644 --- a/src/spatialdata/models/_utils.py +++ b/src/spatialdata/models/_utils.py @@ -290,34 +290,6 @@ def get_channel_names(data: Any) -> list[Any]: raise ValueError(f"Cannot get channels from {type(data)}") -def get_channels(data: Any) -> list[Any]: - """Get channels from data for an image element (both single and multiscale). - - [Deprecation] This function will be deprecated in version 0.3.0. Please use - `get_channel_names`. - - Parameters - ---------- - data - data to get channels from - - Returns - ------- - List of channels - - Notes - ----- - For multiscale images, the channels are validated to be consistent across scales. - """ - warnings.warn( - "The function 'get_channels' is deprecated and will be removed in version 0.3.0. " - "Please use 'get_channel_names' instead.", - DeprecationWarning, - stacklevel=2, # Adjust the stack level to point to the caller - ) - return get_channel_names(data) - - @get_channel_names.register def _(data: DataArray) -> list[Any]: return data.coords["c"].values.tolist() # type: ignore[no-any-return] @@ -326,7 +298,7 @@ def _(data: DataArray) -> list[Any]: @get_channel_names.register def _(data: DataTree) -> list[Any]: name = list({list(data[i].data_vars.keys())[0] for i in data})[0] - channels = {tuple(data[i][name].coords["c"].values) for i in data} + channels = {tuple(data[i][name].coords["c"].values.tolist()) for i in data} if len(channels) > 1: raise ValueError(f"Channels are not consistent across scales: {channels}") return list(next(iter(channels))) @@ -392,7 +364,7 @@ def get_raster_model_from_data_dims(dims: tuple[str, ...]) -> type[RasterSchema] return Labels3DModel if Z in dims else Labels2DModel -def convert_region_column_to_categorical(table: AnnData) -> AnnData: +def convert_region_column_to_categorical(table: AnnData) -> None: from spatialdata.models.models import TableModel if TableModel.ATTRS_KEY in table.uns: @@ -404,7 +376,6 @@ def convert_region_column_to_categorical(table: AnnData) -> AnnData: stacklevel=2, ) table.obs[region_key] = pd.Categorical(table.obs[region_key]) - return table def set_channel_names(element: DataArray | DataTree, channel_names: str | list[str]) -> DataArray | DataTree: diff --git a/src/spatialdata/models/models.py b/src/spatialdata/models/models.py index 7aeb0b2c0..60f4ee205 100644 --- a/src/spatialdata/models/models.py +++ b/src/spatialdata/models/models.py @@ -156,6 +156,12 @@ def parse( You can also pass the `rgb` argument to `kwargs` to automatically set the `c_coords` to `["r", "g", "b"]`. Please refer to :func:`to_spatial_image` for more information. Note: if you set `rgb=None` in `kwargs`, 3-4 channel images will be interpreted automatically as RGB(A) images. + + **Setting axes / dims** In case of the data being a numpy or dask array, there are no named axes yet. In this + case, we first try to use the dimensions specified by the user in the `dims` argument of `.parse`. These + dimensions are used to potentially transpose the data to match the order (c)(z)yx. See the description of the + `dims` argument above. If `dims` is not specified, the dims are set to (c)(z)yx, dependent on the number of + dimensions of the data. """ if transformations: transformations = transformations.copy() @@ -170,7 +176,6 @@ def parse( raise ValueError( f"`dims`: {dims} does not match `data.dims`: {data.dims}, please specify the dims only once." ) - logger.info("`dims` is specified redundantly: found also inside `data`.") else: dims = data.dims # but if dims don't match the model's dims, throw error @@ -183,7 +188,6 @@ def parse( data = from_array(data) if dims is None: dims = cls.dims.dims - logger.info(f"no axes information specified in the object, setting `dims` to: {dims}") else: if len(set(dims).symmetric_difference(cls.dims.dims)) > 0: raise ValueError(f"Wrong `dims`: {dims}. Expected {cls.dims.dims}.") @@ -200,7 +204,6 @@ def parse( data = data.transpose(*[_reindex(d) for d in cls.dims.dims]) else: raise ValueError(f"Unsupported data type: {type(data)}.") - logger.info(f"Transposing `data` of type: {type(data)} to {cls.dims.dims}.") except ValueError as e: raise ValueError( f"Cannot transpose arrays to match `dims`: {dims}.", @@ -225,6 +228,10 @@ def parse( parsed_transform = _get_transformations(data) # delete transforms del data.attrs["transform"] + if isinstance(chunks, tuple): + chunks = {dim: chunks[index] for index, dim in enumerate(data.dims)} + if isinstance(chunks, float): + chunks = {dim: chunks for index, dim in data.dims} data = to_multiscale( data, scale_factors=scale_factors, @@ -260,6 +267,7 @@ def validate(self, data: Any) -> None: def _(self, data: DataArray) -> None: super().validate(data) self._check_chunk_size_not_too_large(data) + self._check_transforms_present(data) @validate.register(DataTree) def _(self, data: DataTree) -> None: @@ -273,6 +281,15 @@ def _(self, data: DataTree) -> None: for d in data: super().validate(data[d][name]) self._check_chunk_size_not_too_large(data) + self._check_transforms_present(data) + + def _check_transforms_present(self, data: DataArray | DataTree) -> None: + parsed_transform = _get_transformations(data) + if parsed_transform is None: + raise ValueError( + f"No transformation found for `{data}`. At least one transformation is required for " + f"raster elements, e.g. images, labels." + ) def _check_chunk_size_not_too_large(self, data: DataArray | DataTree) -> None: if isinstance(data, DataArray): @@ -650,8 +667,8 @@ def validate(cls, data: DaskDataFrame) -> None: ) if ATTRS_KEY in data.attrs and "feature_key" in data.attrs[ATTRS_KEY]: feature_key = data.attrs[ATTRS_KEY][cls.FEATURE_KEY] - if not isinstance(data[feature_key].dtype, CategoricalDtype): - logger.info(f"Feature key `{feature_key}`could be of type `pd.Categorical`. Consider casting it.") + if feature_key not in data.columns: + warnings.warn(f"Column `{feature_key}` not found." + SUGGESTION, UserWarning, stacklevel=2) @singledispatchmethod @classmethod @@ -1049,10 +1066,33 @@ def validate( ------- The validated data. """ + if not isinstance(data, AnnData): + raise TypeError(f"`table` must be `anndata.AnnData`, was {type(data)}.") + validate_table_attr_keys(data) if ATTRS_KEY not in data.uns: return data + _, region_key, instance_key = get_table_keys(data) + if region_key is not None: + if region_key not in data.obs: + raise ValueError( + f"Region key `{region_key}` not in `adata.obs`. Please create the column and parse " + f"using TableModel.parse(adata)." + ) + if not isinstance(data.obs[region_key].dtype, CategoricalDtype): + raise ValueError( + f"`table.obs[{region_key}]` must be of type `categorical`, not `{type(data.obs[region_key])}`." + ) + if instance_key: + if instance_key not in data.obs: + raise ValueError( + f"Instance key `{instance_key}` not in `adata.obs`. Please create the column and parse" + f" using TableModel.parse(adata)." + ) + if data.obs[instance_key].isnull().values.any(): + raise ValueError("`table.obs[instance_key]` must not contain null values, but it does.") + self._validate_table_annotation_metadata(data) return data @@ -1140,8 +1180,9 @@ def parse( "instance_key": instance_key, } adata.uns[cls.ATTRS_KEY] = attr + convert_region_column_to_categorical(adata) cls().validate(adata) - return convert_region_column_to_categorical(adata) + return adata Schema_t: TypeAlias = ( diff --git a/src/spatialdata/transformations/ngff/ngff_transformations.py b/src/spatialdata/transformations/ngff/ngff_transformations.py index 4e63b91c1..9cc2602e7 100644 --- a/src/spatialdata/transformations/ngff/ngff_transformations.py +++ b/src/spatialdata/transformations/ngff/ngff_transformations.py @@ -1,10 +1,9 @@ import math from abc import ABC, abstractmethod from numbers import Number -from typing import Any +from typing import Any, Self import numpy as np -from typing_extensions import Self from spatialdata._types import ArrayLike from spatialdata.transformations.ngff.ngff_coordinate_system import NgffCoordinateSystem diff --git a/tests/conftest.py b/tests/conftest.py index cc86f9777..775721253 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -64,12 +64,16 @@ def points() -> SpatialData: @pytest.fixture() def table_single_annotation() -> SpatialData: - return SpatialData(tables={"table": _get_table(region="labels2d")}) + return SpatialData(tables={"table": _get_table(region="labels2d")}, labels=_get_labels()) @pytest.fixture() def table_multiple_annotations() -> SpatialData: - return SpatialData(tables={"table": _get_table(region=["labels2d", "poly"])}) + return SpatialData( + tables={"table": _get_table(region=["labels2d", "poly"])}, + labels=_get_labels(), + shapes=_get_shapes(), + ) @pytest.fixture() @@ -91,31 +95,20 @@ def full_sdata() -> SpatialData: labels=_get_labels(), shapes=_get_shapes(), points=_get_points(), - tables=_get_tables(region="labels2d"), + tables=_get_tables(region="labels2d", region_key="region", instance_key="instance_id"), ) -# @pytest.fixture() -# def empty_points() -> SpatialData: -# geo_df = GeoDataFrame( -# geometry=[], -# ) -# from spatialdata import NgffIdentity -# _set_transformations(geo_df, NgffIdentity()) -# -# return SpatialData(points={"empty": geo_df}) - - -# @pytest.fixture() -# def empty_table() -> SpatialData: -# adata = AnnData(shape=(0, 0), obs=pd.DataFrame(columns="region"), var=pd.DataFrame()) -# adata = TableModel.parse(adata=adata) -# return SpatialData(table=adata) - - @pytest.fixture( # params=["labels"] - params=["full", "empty"] + ["images", "labels", "points", "table_single_annotation", "table_multiple_annotations"] + params=["full", "empty"] + + [ + "images", + "labels", + "points", + "table_single_annotation", + "table_multiple_annotations", + ] # + ["empty_" + x for x in ["table"]] # TODO: empty table not supported yet ) def sdata(request) -> SpatialData: @@ -276,7 +269,7 @@ def _get_points() -> dict[str, DaskDataFrame]: def _get_tables( - region: None | str | list[str] = "sample1", + region: None | str | list[str], region_key: None | str = "region", instance_key: None | str = "instance_id", ) -> dict[str, AnnData]: @@ -284,11 +277,18 @@ def _get_tables( def _get_table( - region: None | str | list[str] = "sample1", + region: None | str | list[str], region_key: None | str = "region", instance_key: None | str = "instance_id", ) -> AnnData: - adata = AnnData(RNG.normal(size=(100, 10)), obs=pd.DataFrame(RNG.normal(size=(100, 3)), columns=["a", "b", "c"])) + adata = AnnData( + RNG.normal(size=(100, 10)), + obs=pd.DataFrame( + RNG.normal(size=(100, 3)), + columns=["a", "b", "c"], + index=[f"{i}" for i in range(100)], + ), + ) if not all(var for var in (region, region_key, instance_key)): return TableModel.parse(adata=adata) adata.obs[instance_key] = np.arange(adata.n_obs) @@ -296,6 +296,7 @@ def _get_table( adata.obs[region_key] = region elif isinstance(region, list): adata.obs[region_key] = RNG.choice(region, size=adata.n_obs) + adata.obs[region_key] = adata.obs[region_key].astype("category") return TableModel.parse(adata=adata, region=region, region_key=region_key, instance_key=instance_key) @@ -406,7 +407,10 @@ def _make_sdata_for_testing_querying_and_aggretation() -> SpatialData: s_num = pd.Series(RNG.random(20)) # workaround for https://github.com/dask/dask/issues/11147, let's recompute the dataframe (it's a small one) values_points = PointsModel.parse( - dd.from_pandas(values_points.compute().assign(categorical_in_ddf=s_cat, numerical_in_ddf=s_num), npartitions=1) + dd.from_pandas( + values_points.compute().assign(categorical_in_ddf=s_cat, numerical_in_ddf=s_num), + npartitions=1, + ) ) sdata = SpatialData( @@ -421,10 +425,10 @@ def _make_sdata_for_testing_querying_and_aggretation() -> SpatialData: # generate table x = RNG.random((21, 1)) - region = np.array(["values_circles"] * 9 + ["values_polygons"] * 12) + region = pd.Categorical(np.array(["values_circles"] * 9 + ["values_polygons"] * 12)) instance_id = np.array(list(range(9)) + list(range(12))) - categorical_obs = pd.Series(pd.Categorical(["a"] * 9 + ["b"] * 9 + ["c"] * 3)) - numerical_obs = pd.Series(RNG.random(21)) + categorical_obs = pd.Categorical(["a"] * 9 + ["b"] * 9 + ["c"] * 3) + numerical_obs = RNG.random(21) table = AnnData( x, obs=pd.DataFrame( @@ -433,14 +437,18 @@ def _make_sdata_for_testing_querying_and_aggretation() -> SpatialData: "instance_id": instance_id, "categorical_in_obs": categorical_obs, "numerical_in_obs": numerical_obs, - } + }, + index=list(map(str, range(21))), ), var=pd.DataFrame(index=["numerical_in_var"]), ) table = TableModel.parse( - table, region=["values_circles", "values_polygons"], region_key="region", instance_key="instance_id" + table, + region=["values_circles", "values_polygons"], + region_key="region", + instance_key="instance_id", ) - sdata.table = table + sdata["table"] = table return sdata @@ -451,12 +459,13 @@ def sdata_query_aggregation() -> SpatialData: def generate_adata(n_var: int, obs: pd.DataFrame, obsm: dict[Any, Any], uns: dict[Any, Any]) -> AnnData: rng = np.random.default_rng(SEED) - return AnnData( + adata = AnnData( rng.normal(size=(obs.shape[0], n_var)).astype(np.float64), obs=obs, obsm=obsm, uns=uns, ) + return TableModel().parse(adata) def _get_blobs_galaxy() -> tuple[ArrayLike, ArrayLike]: @@ -480,12 +489,16 @@ def adata_labels() -> AnnData: "categorical": pd.Categorical(rng.integers(0, 2, size=(n_obs_labels,))), "cell_id": pd.Categorical(seg), "instance_id": range(n_obs_labels), - "region": ["test"] * n_obs_labels, + "region": pd.Categorical(["test"] * n_obs_labels), }, - index=np.arange(n_obs_labels), + index=np.arange(n_obs_labels).astype(str), ) uns_labels = { - "spatialdata_attrs": {"region": "test", "region_key": "region", "instance_key": "instance_id"}, + "spatialdata_attrs": { + "region": "test", + "region_key": "region", + "instance_key": "instance_id", + }, } obsm_labels = { "tensor": rng.integers(0, blobs.shape[0], size=(n_obs_labels, 2)), diff --git a/tests/core/operations/test_aggregations.py b/tests/core/operations/test_aggregations.py index 540161c7a..eb2ed089c 100644 --- a/tests/core/operations/test_aggregations.py +++ b/tests/core/operations/test_aggregations.py @@ -287,7 +287,6 @@ def test_aggregate_shapes_by_shapes( new_var = pd.concat((sdata.tables["table"].var, pd.DataFrame(index=["another_numerical_in_var"]))) new_x = np.concatenate((sdata.tables["table"].X, np.ones_like(sdata.tables["table"].X[:, :1])), axis=1) new_table = AnnData(X=new_x, obs=sdata.tables["table"].obs, var=new_var, uns=sdata.tables["table"].uns) - del sdata.tables["table"] sdata.tables["table"] = new_table result_adata = aggregate( @@ -500,7 +499,6 @@ def test_aggregate_considering_fractions_multiple_values( new_var = pd.concat((sdata.tables["table"].var, pd.DataFrame(index=["another_numerical_in_var"]))) new_x = np.concatenate((sdata.tables["table"].X, np.ones_like(sdata.tables["table"].X[:, :1])), axis=1) new_table = AnnData(X=new_x, obs=sdata.tables["table"].obs, var=new_var, uns=sdata.tables["table"].uns) - del sdata.tables["table"] sdata.tables["table"] = new_table out = aggregate( values_sdata=sdata, diff --git a/tests/core/operations/test_map.py b/tests/core/operations/test_map.py index b3fd31651..a032d381b 100644 --- a/tests/core/operations/test_map.py +++ b/tests/core/operations/test_map.py @@ -109,7 +109,7 @@ def test_map_raster_output_chunks(sdata_blobs): func_kwargs = {"parameter": 20} output_channels = ["test"] se = map_raster( - sdata_blobs["blobs_image"].chunk((3, 100, 100)), + sdata_blobs["blobs_image"].chunk({"c": 3, "y": 100, "x": 100}), func=_multiply_alter_c, func_kwargs=func_kwargs, chunks=( @@ -160,9 +160,8 @@ def test_map_to_labels_(sdata_blobs, blockwise, chunks, drop_axis): func_kwargs = {"parameter": 20} se = sdata_blobs[img_layer] - se = map_raster( - se.chunk((3, 256, 256)), + se.chunk({"c": 3, "y": 256, "x": 256}), func=_multiply_to_labels, func_kwargs=func_kwargs, c_coords=None, @@ -183,7 +182,7 @@ def test_map_squeeze_z(full_sdata): func_kwargs = {"parameter": 20} se = map_raster( - full_sdata[img_layer].chunk((3, 2, 64, 64)), + full_sdata[img_layer].chunk({"c": 3, "z": 2, "y": 64, "x": 64}), func=_multiply_squeeze_z, func_kwargs=func_kwargs, chunks=((3,), (64,), (64,)), @@ -212,7 +211,7 @@ def test_map_squeeze_z_fails(full_sdata): ), ): map_raster( - full_sdata[img_layer].chunk((3, 2, 64, 64)), + full_sdata[img_layer].chunk({"c": 3, "z": 2, "y": 64, "x": 64}), func=_multiply_squeeze_z, func_kwargs=func_kwargs, chunks=((3,), (64,), (64,)), @@ -266,7 +265,7 @@ def test_map_raster_relabel(sdata_blobs): element_name = "blobs_labels" se = map_raster( - sdata_blobs[element_name].chunk((100, 100)), + sdata_blobs[element_name].chunk({"y": 100, "x": 100}), func=_to_constant, func_kwargs=func_kwargs, c_coords=None, @@ -301,7 +300,7 @@ def test_map_raster_relabel_fail(sdata_blobs): match=re.escape("Relabel was set to True, but"), ): se = map_raster( - sdata_blobs[element_name].chunk((100, 100)), + sdata_blobs[element_name].chunk({"y": 100, "x": 100}), func=_to_constant, func_kwargs=func_kwargs, c_coords=None, @@ -319,8 +318,8 @@ def test_map_raster_relabel_fail(sdata_blobs): ValueError, match=re.escape(f"Relabeling is only supported for arrays of type {np.integer}."), ): - se = map_raster( - sdata_blobs[element_name].astype(float).chunk((100, 100)), + map_raster( + sdata_blobs[element_name].astype(float).chunk({"y": 100, "x": 100}), func=_to_constant, func_kwargs=func_kwargs, c_coords=None, diff --git a/tests/core/operations/test_rasterize.py b/tests/core/operations/test_rasterize.py index 261768e1c..25f3c3d0f 100644 --- a/tests/core/operations/test_rasterize.py +++ b/tests/core/operations/test_rasterize.py @@ -123,10 +123,11 @@ def test_rasterize_labels_value_key_specified(): labels_indices = get_element_instances(raster) obs = pd.DataFrame( { - "region": [element_name] * len(labels_indices), + "region": pd.Categorical([element_name] * len(labels_indices)), "instance_id": labels_indices, value_key: [True] * 10 + [False] * (len(labels_indices) - 10), - } + }, + index=[f"{i}" for i in range(len(labels_indices))], ) table = TableModel.parse( AnnData(shape=(len(labels_indices), 0), obs=obs), @@ -134,7 +135,7 @@ def test_rasterize_labels_value_key_specified(): region_key="region", instance_key="instance_id", ) - sdata = SpatialData.init_from_elements({element_name: raster}, tables={table_name: table}) + sdata = SpatialData.init_from_elements({element_name: raster, table_name: table}) result = rasterize( data=element_name, sdata=sdata, @@ -157,8 +158,10 @@ def test_rasterize_points_shapes_with_string_index(points, shapes): sdata = SpatialData.init_from_elements({"points_0": points["points_0"], "circles": shapes["circles"]}) # make the indices of the points_0 and circles dataframes strings - sdata["points_0"]["str_index"] = dd.from_pandas(pd.Series([str(i) for i in sdata["points_0"].index]), npartitions=1) - sdata["points_0"] = sdata["points_0"].set_index("str_index") + points = sdata["points_0"] + points["str_index"] = dd.from_pandas(pd.Series([str(i) for i in sdata["points_0"].index]), npartitions=1) + points = points.set_index("str_index") + sdata["points_0"] = points sdata["circles"].index = [str(i) for i in sdata["circles"].index] data_extent = get_extent(sdata) @@ -194,11 +197,12 @@ def _rasterize_shapes_prepare_data() -> tuple[SpatialData, GeoDataFrame, str]: X=np.arange(len(gdf)).reshape(-1, 1), obs=pd.DataFrame( { - "region": [element_name] * len(gdf), + "region": pd.Categorical([element_name] * len(gdf)), "instance_id": gdf.index, - "values": gdf["values"], - "cat_values": gdf["cat_values"], - } + "values": gdf["values"].to_numpy(), + "cat_values": gdf["cat_values"].to_numpy(), + }, + index=[str(i) for i in range(len(gdf))], ), ) adata.obs["cat_values"] = adata.obs["cat_values"].astype("category") @@ -333,11 +337,12 @@ def test_rasterize_points(): X=np.arange(len(ddf)).reshape(-1, 1), obs=pd.DataFrame( { - "region": [element_name] * len(ddf), + "region": pd.Categorical([element_name] * len(ddf)), "instance_id": ddf.index, "gene": data["gene"], "value": data["value"], - } + }, + index=[f"{i}" for i in range(len(ddf))], ), ) adata.obs["gene"] = adata.obs["gene"].astype("category") diff --git a/tests/core/operations/test_rasterize_bins.py b/tests/core/operations/test_rasterize_bins.py index 6596325c2..b99508efa 100644 --- a/tests/core/operations/test_rasterize_bins.py +++ b/tests/core/operations/test_rasterize_bins.py @@ -2,6 +2,7 @@ import re import numpy as np +import pandas as pd import pytest from anndata import AnnData from geopandas import GeoDataFrame @@ -19,7 +20,13 @@ from spatialdata._core.spatialdata import SpatialData from spatialdata._logging import logger from spatialdata._types import ArrayLike -from spatialdata.models.models import Image2DModel, Labels2DModel, PointsModel, ShapesModel, TableModel +from spatialdata.models.models import ( + Image2DModel, + Labels2DModel, + PointsModel, + ShapesModel, + TableModel, +) from spatialdata.transformations.transformations import Scale RNG = default_rng(0) @@ -41,29 +48,43 @@ def test_rasterize_bins(geometry: str, value_key: str | list[str] | None, return n = 10 data, x, y = _get_bins_data(n) scale = Scale([2.0], axes=("x",)) + index = np.arange(1, len(data) + 1) if geometry == "points": - points = PointsModel.parse(data, transformations={"global": scale}) + points = PointsModel.parse( + data, + transformations={"global": scale}, + annotation=pd.DataFrame(index=index), + ) elif geometry == "circles": - points = ShapesModel.parse(data, geometry=0, radius=1, transformations={"global": scale}) + points = ShapesModel.parse(data, geometry=0, radius=1, transformations={"global": scale}, index=index) else: assert geometry == "squares" gdf = GeoDataFrame( - data={"geometry": [Polygon([(x, y), (x + 1, y), (x + 1, y + 1), (x, y + 1), (x, y)]) for x, y in data]} + index=index, + data={"geometry": [Polygon([(x, y), (x + 1, y), (x + 1, y + 1), (x, y + 1), (x, y)]) for x, y in data]}, ) - points = ShapesModel.parse(gdf, transformations={"global": scale}) obs = DataFrame( - data={"region": ["points"] * n * n, "instance_id": np.arange(n * n), "col_index": x, "row_index": y} + data={ + "region": pd.Categorical(["points"] * n * n), + "instance_id": index, + "col_index": x, + "row_index": y, + }, + index=[f"{i}" for i in range(n * n)], ) X = RNG.normal(size=(n * n, 2)) var = DataFrame(index=["gene0", "gene1"]) table = TableModel.parse( - AnnData(X=X, var=var, obs=obs), region="points", region_key="region", instance_key="instance_id" + AnnData(X=X, var=var, obs=obs), + region="points", + region_key="region", + instance_key="instance_id", ) - sdata = SpatialData.init_from_elements({"points": points}, tables={"table": table}) + sdata = SpatialData.init_from_elements({"points": points, "table": table}) rasterized = rasterize_bins( sdata=sdata, bins="points", @@ -122,7 +143,13 @@ def _get_sdata(n: int): data, x, y = _get_bins_data(n) points = PointsModel.parse(data) obs = DataFrame( - data={"region": ["points"] * n * n, "instance_id": np.arange(n * n), "col_index": x, "row_index": y} + data={ + "region": pd.Categorical(["points"] * n * n), + "instance_id": np.arange(n * n), + "col_index": x, + "row_index": y, + }, + index=[f"{i}" for i in range(n * n)], ) table = TableModel.parse( AnnData(X=RNG.normal(size=(n * n, 2)), obs=obs), @@ -130,7 +157,7 @@ def _get_sdata(n: int): region_key="region", instance_key="instance_id", ) - return SpatialData.init_from_elements({"points": points}, tables={"table": table}) + return SpatialData.init_from_elements({"points": points, "table": table}) # sdata with not enough bins (2*2) to estimate transformation sdata = _get_sdata(n=2) @@ -165,7 +192,7 @@ def _get_sdata(n: int): # table annotating multiple elements regions = table.obs["region"].copy() regions = regions.cat.add_categories(["shapes"]) - regions[0] = "shapes" + regions.iloc[0] = "shapes" sdata["shapes"] = sdata["points"] table.obs["region"] = regions with pytest.raises( @@ -202,7 +229,10 @@ def _get_sdata(n: int): sdata = _get_sdata(n=3) table = sdata.tables["table"] table.obs["region"] = table.obs["region"].astype(str) - with pytest.raises(ValueError, match="Please convert `table.obs.*` to a category series to improve performances"): + with pytest.raises( + ValueError, + match="Please convert `table.obs.*` to a category series to improve performances", + ): _ = rasterize_bins( sdata=sdata, bins="points", @@ -269,7 +299,8 @@ def test_relabel_labels(caplog): "instance_key1": np.arange(10), "instance_key2": [1, 2] + list(range(4, 12)), "instance_key3": [str(i) for i in range(1, 11)], - } + }, + index=[f"{i}" for i in range(10)], ) adata = AnnData(X=RNG.normal(size=(10, 2)), obs=obs) _relabel_labels(table=adata, instance_key="instance_key0") diff --git a/tests/core/operations/test_spatialdata_operations.py b/tests/core/operations/test_spatialdata_operations.py index db413af31..bbf5eb102 100644 --- a/tests/core/operations/test_spatialdata_operations.py +++ b/tests/core/operations/test_spatialdata_operations.py @@ -1,6 +1,7 @@ import math import numpy as np +import pandas as pd import pytest from anndata import AnnData from geopandas import GeoDataFrame @@ -49,19 +50,6 @@ def test_element_names_unique() -> None: tables={"table": table}, ) - # add elements with the same name - # of element of same type - with pytest.warns(UserWarning): - sdata.images["image"] = image - with pytest.warns(UserWarning): - sdata.points["points"] = points - with pytest.warns(UserWarning): - sdata.shapes["shapes"] = shapes - with pytest.warns(UserWarning): - sdata.labels["labels"] = labels - with pytest.warns(UserWarning): - sdata.tables["table"] = table - # add elements with the same name # of element of different type with pytest.raises(KeyError): @@ -99,8 +87,6 @@ def test_element_names_unique() -> None: # add elements with the same name, test only couples of elements with pytest.raises(KeyError): sdata["labels"] = image - with pytest.warns(UserWarning): - sdata["points"] = points # this should not raise warnings because it's a different (new) name sdata["image2"] = image @@ -148,7 +134,7 @@ def test_element_type_from_element_name(points: SpatialData) -> None: def test_filter_by_coordinate_system(full_sdata: SpatialData) -> None: - sdata = full_sdata.filter_by_coordinate_system(coordinate_system="global", filter_table=False) + sdata = full_sdata.filter_by_coordinate_system(coordinate_system="global", filter_tables=False) assert_spatial_data_objects_are_identical(sdata, full_sdata) scale = Scale([2.0], axes=("x",)) @@ -156,12 +142,12 @@ def test_filter_by_coordinate_system(full_sdata: SpatialData) -> None: set_transformation(full_sdata.shapes["circles"], Identity(), "my_space0") set_transformation(full_sdata.shapes["poly"], Identity(), "my_space1") - sdata_my_space = full_sdata.filter_by_coordinate_system(coordinate_system="my_space0", filter_table=False) + sdata_my_space = full_sdata.filter_by_coordinate_system(coordinate_system="my_space0", filter_tables=False) assert len(list(sdata_my_space.gen_elements())) == 3 assert_elements_dict_are_identical(sdata_my_space.tables, full_sdata.tables) sdata_my_space1 = full_sdata.filter_by_coordinate_system( - coordinate_system=["my_space0", "my_space1", "my_space2"], filter_table=False + coordinate_system=["my_space0", "my_space1", "my_space2"], filter_tables=False ) assert len(list(sdata_my_space1.gen_elements())) == 4 @@ -170,11 +156,12 @@ def test_filter_by_coordinate_system_also_table(full_sdata: SpatialData) -> None from spatialdata.models import TableModel rng = np.random.default_rng(seed=0) - full_sdata["table"].obs["annotated_shapes"] = rng.choice(["circles", "poly"], size=full_sdata["table"].shape[0]) + full_sdata["table"].obs["annotated_shapes"] = pd.Categorical( + rng.choice(["circles", "poly"], size=full_sdata["table"].shape[0]) + ) adata = full_sdata["table"] del adata.uns[TableModel.ATTRS_KEY] - del full_sdata.tables["table"] - full_sdata.table = TableModel.parse( + full_sdata["table"] = TableModel.parse( adata, region=["circles", "poly"], region_key="annotated_shapes", @@ -187,7 +174,7 @@ def test_filter_by_coordinate_system_also_table(full_sdata: SpatialData) -> None filtered_sdata0 = full_sdata.filter_by_coordinate_system(coordinate_system="my_space0") filtered_sdata1 = full_sdata.filter_by_coordinate_system(coordinate_system="my_space1") - filtered_sdata2 = full_sdata.filter_by_coordinate_system(coordinate_system="my_space0", filter_table=False) + filtered_sdata2 = full_sdata.filter_by_coordinate_system(coordinate_system="my_space0", filter_tables=False) assert len(filtered_sdata0["table"]) + len(filtered_sdata1["table"]) == len(full_sdata["table"]) assert len(filtered_sdata2["table"]) == len(full_sdata["table"]) @@ -322,13 +309,13 @@ def test_concatenate_custom_table_metadata() -> None: shapes1 = _get_shapes() n = len(shapes0["poly"]) table0 = TableModel.parse( - AnnData(obs={"my_region": ["poly0"] * n, "my_instance_id": list(range(n))}), + AnnData(obs={"my_region": pd.Categorical(["poly0"] * n), "my_instance_id": list(range(n))}), region="poly0", region_key="my_region", instance_key="my_instance_id", ) table1 = TableModel.parse( - AnnData(obs={"my_region": ["poly1"] * n, "my_instance_id": list(range(n))}), + AnnData(obs={"my_region": pd.Categorical(["poly1"] * n), "my_instance_id": list(range(n))}), region="poly1", region_key="my_region", instance_key="my_instance_id", @@ -363,15 +350,14 @@ def test_concatenate_sdatas(full_sdata: SpatialData) -> None: set_transformation(full_sdata.shapes["circles"], Identity(), "my_space0") set_transformation(full_sdata.shapes["poly"], Identity(), "my_space1") - filtered = full_sdata.filter_by_coordinate_system(coordinate_system=["my_space0", "my_space1"], filter_table=False) + filtered = full_sdata.filter_by_coordinate_system(coordinate_system=["my_space0", "my_space1"], filter_tables=False) assert len(list(filtered.gen_elements())) == 3 - filtered0 = filtered.filter_by_coordinate_system(coordinate_system="my_space0", filter_table=False) - filtered1 = filtered.filter_by_coordinate_system(coordinate_system="my_space1", filter_table=False) + filtered0 = filtered.filter_by_coordinate_system(coordinate_system="my_space0", filter_tables=False) + filtered1 = filtered.filter_by_coordinate_system(coordinate_system="my_space1", filter_tables=False) # this is needed cause we can't handle regions with same name. # TODO: fix this new_region = "sample2" table_new = filtered1["table"].copy() - del filtered1.tables["table"] filtered1["table"] = table_new filtered1["table"].uns[TableModel.ATTRS_KEY][TableModel.REGION_KEY] = new_region filtered1["table"].obs[filtered1["table"].uns[TableModel.ATTRS_KEY][TableModel.REGION_KEY_KEY]] = new_region @@ -419,7 +405,12 @@ def _get_table_and_poly(i: int) -> tuple[AnnData, GeoDataFrame]: n = len(poly) region = f"poly{i}" table = TableModel.parse( - AnnData(obs={"region": [region] * n, "instance_id": list(range(n))}), + AnnData( + obs=pd.DataFrame( + {"region": pd.Categorical([region] * n), "instance_id": list(range(n))}, + index=[f"{i}" for i in range(n)], + ) + ), region=region, region_key="region", instance_key="instance_id", @@ -484,8 +475,6 @@ def test_get_item(points: SpatialData) -> None: def test_set_item(full_sdata: SpatialData) -> None: for name in ["image2d", "labels2d", "points_0", "circles", "poly"]: full_sdata[name + "_again"] = full_sdata[name] - with pytest.warns(UserWarning): - full_sdata[name] = full_sdata[name] def test_del_item(full_sdata: SpatialData) -> None: @@ -516,11 +505,11 @@ def test_no_shared_transformations() -> None: def test_init_from_elements(full_sdata: SpatialData) -> None: # this first code block needs to be removed when the tables argument is removed from init_from_elements() all_elements = {name: el for _, name, el in full_sdata._gen_elements()} - sdata = SpatialData.init_from_elements(all_elements, tables=full_sdata["table"]) + sdata = SpatialData.init_from_elements(all_elements | {"table": full_sdata["table"]}) for element_type in ["images", "labels", "points", "shapes", "tables"]: assert set(getattr(sdata, element_type).keys()) == set(getattr(full_sdata, element_type).keys()) - all_elements = {name: el for _, name, el in full_sdata._gen_elements(include_table=True)} + all_elements = {name: el for _, name, el in full_sdata._gen_elements(include_tables=True)} sdata = SpatialData.init_from_elements(all_elements) for element_type in ["images", "labels", "points", "shapes", "tables"]: assert set(getattr(sdata, element_type).keys()) == set(getattr(full_sdata, element_type).keys()) @@ -540,11 +529,10 @@ def test_subset(full_sdata: SpatialData) -> None: adata = AnnData( shape=(10, 0), obs={ - "region": ["circles"] * 5 + ["poly"] * 5, + "region": pd.Categorical(["circles"] * 5 + ["poly"] * 5), "instance_id": [0, 1, 2, 3, 4, 0, 1, 2, 3, 4], }, ) - del full_sdata.tables["table"] sdata_table = TableModel.parse( adata, region=["circles", "poly"], @@ -653,7 +641,7 @@ def test_validate_table_in_spatialdata(full_sdata): with pytest.warns(UserWarning, match="in the SpatialData object"): full_sdata.validate_table_in_spatialdata(table) - table.obs[region_key] = "points_0" + table.obs[region_key] = pd.Categorical(["points_0"] * table.n_obs) full_sdata.set_table_annotates_spatialelement("table", region="points_0") full_sdata.validate_table_in_spatialdata(table) diff --git a/tests/core/operations/test_transform.py b/tests/core/operations/test_transform.py index e8031820e..9c1c68235 100644 --- a/tests/core/operations/test_transform.py +++ b/tests/core/operations/test_transform.py @@ -578,16 +578,16 @@ def test_transform_elements_and_entire_spatial_data_object_multi_hop( if "global" in d: remove_transformation(element, "global") - for element in full_sdata._gen_spatial_element_values(): + for _, name, element in full_sdata._gen_elements(include_tables=False): transformed_element = full_sdata.transform_element_to_coordinate_system( - element, "multi_hop_space", maintain_positioning=maintain_positioning + name, "multi_hop_space", maintain_positioning=maintain_positioning ) temp = SpatialData( images=dict(full_sdata.images), labels=dict(full_sdata.labels), points=dict(full_sdata.points), shapes=dict(full_sdata.shapes), - table=full_sdata["table"], + tables={"table": full_sdata["table"]}, ) temp["transformed_element"] = transformed_element transformation = get_transformation_between_coordinate_systems( diff --git a/tests/core/query/test_relational_query.py b/tests/core/query/test_relational_query.py index f0b4da7e0..938d6871a 100644 --- a/tests/core/query/test_relational_query.py +++ b/tests/core/query/test_relational_query.py @@ -750,7 +750,6 @@ def test_get_values_df_shapes(sdata_query_aggregation): var=pd.DataFrame(index=["numerical_in_var", "another_numerical_in_var"]), uns=adata.uns, ) - del sdata_query_aggregation.tables["table"] sdata_query_aggregation["table"] = new_adata # test v = get_values( @@ -816,7 +815,7 @@ def test_get_values_df_points(points): p = p.drop("instance_id", axis=1) p.index.compute() n = len(p) - obs = pd.DataFrame(index=p.index, data={"region": ["points_0"] * n, "instance_id": range(n)}) + obs = pd.DataFrame(index=p.index.astype(str), data={"region": ["points_0"] * n, "instance_id": range(n)}) obs["region"] = obs["region"].astype("category") table = TableModel.parse( AnnData(shape=(n, 0), obs=obs), @@ -891,8 +890,10 @@ def test_get_values_labels_bug(sdata_blobs): def test_filter_table_categorical_bug(shapes): # one bug that was triggered by: https://github.com/scverse/anndata/issues/1210 - adata = AnnData(obs={"categorical": pd.Categorical(["a", "a", "a", "b", "c"])}) - adata.obs["region"] = "circles" + adata = AnnData( + obs=pd.DataFrame({"categorical": pd.Categorical(["a", "a", "a", "b", "c"])}, index=list(map(str, range(5)))) + ) + adata.obs["region"] = pd.Categorical(["circles"] * adata.n_obs) adata.obs["cell_id"] = np.arange(len(adata)) adata = TableModel.parse(adata, region=["circles"], region_key="region", instance_key="cell_id") adata_subset = adata[adata.obs["categorical"] == "a"].copy() @@ -901,7 +902,7 @@ def test_filter_table_categorical_bug(shapes): def test_filter_table_non_annotating(full_sdata): - obs = pd.DataFrame({"test": ["a", "b", "c"]}) + obs = pd.DataFrame({"test": ["a", "b", "c"]}, index=list(map(str, range(3)))) adata = AnnData(obs=obs) table = TableModel.parse(adata) full_sdata["table"] = table diff --git a/tests/core/query/test_spatial_query.py b/tests/core/query/test_spatial_query.py index 65905a97d..fc59d0698 100644 --- a/tests/core/query/test_spatial_query.py +++ b/tests/core/query/test_spatial_query.py @@ -502,7 +502,7 @@ def test_query_filter_table(with_polygon_query: bool): circles0 = ShapesModel.parse(coords0, geometry=0, radius=1) circles1 = ShapesModel.parse(coords1, geometry=0, radius=1) table = AnnData(shape=(3, 0)) - table.obs["region"] = ["circles0", "circles0", "circles1"] + table.obs["region"] = pd.Categorical(["circles0", "circles0", "circles1"]) table.obs["instance"] = [0, 1, 0] table = TableModel.parse(table, region=["circles0", "circles1"], region_key="region", instance_key="instance") sdata = SpatialData(shapes={"circles0": circles0, "circles1": circles1}, tables={"table": table}) @@ -545,7 +545,7 @@ def test_polygon_query_with_multipolygon(sdata_query_aggregation): sdata = sdata_query_aggregation values_sdata = SpatialData( shapes={"values_polygons": sdata["values_polygons"], "values_circles": sdata["values_circles"]}, - tables=sdata["table"], + tables={"table": sdata["table"]}, ) polygon = sdata["by_polygons"].geometry.iloc[0] circle = sdata["by_circles"].geometry.iloc[0] @@ -555,20 +555,18 @@ def test_polygon_query_with_multipolygon(sdata_query_aggregation): values_sdata, polygon=polygon, target_coordinate_system="global", - shapes=True, - points=False, ) assert len(queried["values_polygons"]) == 4 assert len(queried["values_circles"]) == 4 assert len(queried["table"]) == 8 - multipolygon = GeoDataFrame(geometry=[polygon, circle_pol]).unary_union + multipolygon = GeoDataFrame(geometry=[polygon, circle_pol]).union_all() queried = polygon_query(values_sdata, polygon=multipolygon, target_coordinate_system="global") assert len(queried["values_polygons"]) == 8 assert len(queried["values_circles"]) == 8 assert len(queried["table"]) == 16 - multipolygon = GeoDataFrame(geometry=[polygon, polygon]).unary_union + multipolygon = GeoDataFrame(geometry=[polygon, polygon]).union_all() queried = polygon_query(values_sdata, polygon=multipolygon, target_coordinate_system="global") assert len(queried["values_polygons"]) == 4 assert len(queried["values_circles"]) == 4 diff --git a/tests/core/test_centroids.py b/tests/core/test_centroids.py index b8ecc441b..aa332f9da 100644 --- a/tests/core/test_centroids.py +++ b/tests/core/test_centroids.py @@ -52,7 +52,7 @@ def test_get_centroids_points(points, coordinate_system: str, is_3d: bool): assert centroids.columns.tolist() == list(axes) # check the index is preserved - assert np.array_equal(centroids.index.values, element.index.values) + assert np.array_equal(centroids.index.compute().values, element.index.compute().values) # the centroids should not contain extra columns assert "genes" in element.columns and "genes" not in centroids.columns @@ -82,7 +82,7 @@ def test_get_centroids_shapes(shapes, coordinate_system: str, shapes_name: str): set_transformation(element, transformation=affine, to_coordinate_system=coordinate_system) centroids = get_centroids(element, coordinate_system=coordinate_system) - assert np.array_equal(centroids.index.values, element.index.values) + assert np.array_equal(centroids.index.compute().values, element.index.values) if shapes_name == "circles": xy = element.geometry.get_coordinates().values @@ -154,10 +154,10 @@ def test_get_centroids_labels( centroids = get_centroids(element, coordinate_system=coordinate_system, return_background=return_background) labels_indices = get_element_instances(element, return_background=return_background) - assert np.array_equal(centroids.index.values, labels_indices) + assert np.array_equal(centroids.index.compute().values, labels_indices) if not return_background: - assert 0 not in centroids.index + assert not (centroids.index == 0).any() if coordinate_system == "global": assert np.array_equal(centroids.compute().values, expected_centroids.values) @@ -178,7 +178,7 @@ def test_get_centroids_invalid_element(images): # cannot compute centroids for tables N = 10 adata = TableModel.parse( - AnnData(X=RNG.random((N, N)), obs={"region": ["dummy" for _ in range(N)], "instance_id": np.arange(N)}), + AnnData(X=RNG.random((N, N)), obs={"region": pd.Categorical(["dummy"] * N), "instance_id": np.arange(N)}), region="dummy", region_key="region", instance_key="instance_id", diff --git a/tests/io/test_format.py b/tests/io/test_format.py index 3069e8fa6..c8d9f04c1 100644 --- a/tests/io/test_format.py +++ b/tests/io/test_format.py @@ -6,41 +6,51 @@ import pytest from shapely import GeometryType +from spatialdata import read_zarr from spatialdata._io.format import ( - CurrentPointsFormat, - CurrentShapesFormat, + PointsFormatType, + PointsFormatV01, + PointsFormatV02, RasterFormatV01, RasterFormatV02, + RasterFormatV03, + ShapesFormatType, ShapesFormatV01, - SpatialDataFormat, + ShapesFormatV02, + ShapesFormatV03, + SpatialDataContainerFormatV01, + SpatialDataContainerFormatV02, + SpatialDataFormatType, + TablesFormatV01, + TablesFormatV02, ) from spatialdata.models import PointsModel, ShapesModel - -Points_f = CurrentPointsFormat() -Shapes_f = CurrentShapesFormat() +from spatialdata.testing import assert_spatial_data_objects_are_identical class TestFormat: """Test format.""" + @pytest.mark.parametrize("element_format", [PointsFormatV01(), PointsFormatV02()]) @pytest.mark.parametrize("attrs_key", [PointsModel.ATTRS_KEY]) @pytest.mark.parametrize("feature_key", [None, PointsModel.FEATURE_KEY]) @pytest.mark.parametrize("instance_key", [None, PointsModel.INSTANCE_KEY]) - def test_format_points( + def test_format_points_v1_v2( self, + element_format: PointsFormatType, attrs_key: str | None, feature_key: str | None, instance_key: str | None, ) -> None: - metadata: dict[str, Any] = {attrs_key: {"version": Points_f.spatialdata_format_version}} + metadata: dict[str, Any] = {attrs_key: {"version": element_format.spatialdata_format_version}} format_metadata: dict[str, Any] = {attrs_key: {}} if feature_key is not None: metadata[attrs_key][feature_key] = "target" if instance_key is not None: metadata[attrs_key][instance_key] = "cell_id" - format_metadata[attrs_key] = Points_f.attrs_from_dict(metadata) + format_metadata[attrs_key] = element_format.attrs_from_dict(metadata) metadata[attrs_key].pop("version") - assert metadata[attrs_key] == Points_f.attrs_to_dict(format_metadata) + assert metadata[attrs_key] == element_format.attrs_to_dict(format_metadata) if feature_key is None and instance_key is None: assert len(format_metadata[attrs_key]) == len(metadata[attrs_key]) == 0 @@ -49,7 +59,7 @@ def test_format_points( @pytest.mark.parametrize("type_key", [ShapesModel.TYPE_KEY]) @pytest.mark.parametrize("name_key", [ShapesModel.NAME_KEY]) @pytest.mark.parametrize("shapes_type", [0, 3, 6]) - def test_format_shapes_v1( + def test_format_shape_v1( self, attrs_key: str, geos_key: str, @@ -72,26 +82,179 @@ def test_format_shapes_v1( geometry = GeometryType(metadata[attrs_key][geos_key][type_key]) assert metadata[attrs_key] == ShapesFormatV01().attrs_to_dict(geometry) + @pytest.mark.parametrize("element_format", [ShapesFormatV02(), ShapesFormatV03()]) @pytest.mark.parametrize("attrs_key", [ShapesModel.ATTRS_KEY]) - def test_format_shapes_v2( + def test_format_shapes_v2_v3( self, + element_format: ShapesFormatType, attrs_key: str, ) -> None: - # not testing anything, maybe remove - metadata: dict[str, Any] = {attrs_key: {"version": Shapes_f.spatialdata_format_version}} + metadata: dict[str, Any] = {attrs_key: {"version": element_format.spatialdata_format_version}} metadata[attrs_key].pop("version") - assert metadata[attrs_key] == Shapes_f.attrs_to_dict({}) + assert metadata[attrs_key] == element_format.attrs_to_dict({}) - @pytest.mark.parametrize("format", [RasterFormatV01, RasterFormatV02]) - def test_format_raster_v1_v2(self, images, format: type[SpatialDataFormat]) -> None: + @pytest.mark.parametrize("rformat", [RasterFormatV01, RasterFormatV02, RasterFormatV03]) + def test_format_raster_v1_v2_v3(self, images, rformat: type[SpatialDataFormatType]) -> None: with tempfile.TemporaryDirectory() as tmpdir: - images.write(Path(tmpdir) / "images.zarr", format=format()) - zattrs_file = Path(tmpdir) / "images.zarr/images/image2d/.zattrs" + sdata_container_format = ( + SpatialDataContainerFormatV01() if rformat != RasterFormatV03 else SpatialDataContainerFormatV02() + ) + images.write(Path(tmpdir) / "images.zarr", sdata_formats=[sdata_container_format, rformat()]) + + metadata_file = ".zattrs" if rformat != RasterFormatV03 else "zarr.json" + zattrs_file = Path(tmpdir) / "images.zarr/images/image2d/" / metadata_file + with open(zattrs_file) as infile: zattrs = json.load(infile) - ngff_version = zattrs["multiscales"][0]["version"] - if format == RasterFormatV01: + if rformat == RasterFormatV01: + ngff_version = zattrs["multiscales"][0]["version"] assert ngff_version == "0.4" - else: - assert format == RasterFormatV02 + elif rformat == RasterFormatV02: + ngff_version = zattrs["multiscales"][0]["version"] assert ngff_version == "0.4-dev-spatialdata" + else: + ngff_version = zattrs["attributes"]["ome"]["version"] + assert rformat == RasterFormatV03 + assert ngff_version == "0.5-dev-spatialdata" + + # TODO: add tests for TablesFormatV01 and TablesFormatV02 + + +class TestFormatConversions: + """Test format conversions between older formats and newer.""" + + def test_shapes_v1_to_v2_to_v3(self, shapes): + with tempfile.TemporaryDirectory() as tmpdir: + f1 = Path(tmpdir) / "data1.zarr" + f2 = Path(tmpdir) / "data2.zarr" + f3 = Path(tmpdir) / "data3.zarr" + + shapes.write(f1, sdata_formats=[ShapesFormatV01(), SpatialDataContainerFormatV01()]) + shapes_read_v1 = read_zarr(f1) + assert_spatial_data_objects_are_identical(shapes, shapes_read_v1) + assert shapes_read_v1.is_self_contained() + + shapes_read_v1.write(f2, sdata_formats=[ShapesFormatV02(), SpatialDataContainerFormatV01()]) + shapes_read_v2 = read_zarr(f2) + assert_spatial_data_objects_are_identical(shapes, shapes_read_v2) + assert shapes_read_v2.is_self_contained() + + shapes_read_v2.write(f3, sdata_formats=[ShapesFormatV03(), SpatialDataContainerFormatV02()]) + shapes_read_v3 = read_zarr(f3) + assert_spatial_data_objects_are_identical(shapes, shapes_read_v3) + assert shapes_read_v3.is_self_contained() + + def test_raster_images_v1_to_v2_to_v3(self, images): + with tempfile.TemporaryDirectory() as tmpdir: + f1 = Path(tmpdir) / "data1.zarr" + f2 = Path(tmpdir) / "data2.zarr" + f3 = Path(tmpdir) / "data3.zarr" + + with pytest.raises(ValueError, match="Unsupported format"): + images.write(f1, sdata_formats=RasterFormatV01()) + + images.write(f1, sdata_formats=[RasterFormatV01(), SpatialDataContainerFormatV01()]) + images_read_v1 = read_zarr(f1) + assert_spatial_data_objects_are_identical(images, images_read_v1) + assert images_read_v1.is_self_contained() + + images_read_v1.write(f2, sdata_formats=[RasterFormatV02(), SpatialDataContainerFormatV01()]) + images_read_v2 = read_zarr(f2) + assert_spatial_data_objects_are_identical(images, images_read_v2) + assert images_read_v2.is_self_contained() + + images_read_v2.write(f3, sdata_formats=[RasterFormatV03(), SpatialDataContainerFormatV02()]) + images_read_v3 = read_zarr(f3) + assert_spatial_data_objects_are_identical(images, images_read_v3) + assert images_read_v3.is_self_contained() + + def test_raster_labels_v1_to_v2_to_v3(self, labels): + with tempfile.TemporaryDirectory() as tmpdir: + f1 = Path(tmpdir) / "data1.zarr" + f2 = Path(tmpdir) / "data2.zarr" + f3 = Path(tmpdir) / "data3.zarr" + + labels.write(f1, sdata_formats=[RasterFormatV01(), SpatialDataContainerFormatV01()]) + labels_read_v1 = read_zarr(f1) + assert_spatial_data_objects_are_identical(labels, labels_read_v1) + assert labels_read_v1.is_self_contained() + + labels_read_v1.write(f2, sdata_formats=[RasterFormatV02(), SpatialDataContainerFormatV01()]) + labels_read_v2 = read_zarr(f2) + assert_spatial_data_objects_are_identical(labels, labels_read_v2) + assert labels_read_v2.is_self_contained() + + labels_read_v2.write(f3, sdata_formats=[RasterFormatV03(), SpatialDataContainerFormatV02()]) + labels_read_v3 = read_zarr(f3) + assert_spatial_data_objects_are_identical(labels, labels_read_v3) + assert labels_read_v3.is_self_contained() + + def test_points_v1_to_v2(self, points): + with tempfile.TemporaryDirectory() as tmpdir: + f1 = Path(tmpdir) / "data1.zarr" + f2 = Path(tmpdir) / "data2.zarr" + + points.write(f1, sdata_formats=[PointsFormatV01(), SpatialDataContainerFormatV01()]) + points_read_v1 = read_zarr(f1) + assert_spatial_data_objects_are_identical(points, points_read_v1) + assert points_read_v1.is_self_contained() + + points_read_v1.write(f2, sdata_formats=[PointsFormatV02(), SpatialDataContainerFormatV02()]) + points_read_v2 = read_zarr(f2) + assert_spatial_data_objects_are_identical(points, points_read_v2) + assert points_read_v2.is_self_contained() + + def test_tables_v1_to_v2(self, table_multiple_annotations): + with tempfile.TemporaryDirectory() as tmpdir: + f1 = Path(tmpdir) / "data1.zarr" + f2 = Path(tmpdir) / "data2.zarr" + + table_multiple_annotations.write(f1, sdata_formats=[TablesFormatV01(), SpatialDataContainerFormatV01()]) + table_read_v1 = read_zarr(f1) + assert_spatial_data_objects_are_identical(table_multiple_annotations, table_read_v1) + + table_read_v1.write(f2, sdata_formats=[TablesFormatV02(), SpatialDataContainerFormatV02()]) + table_read_v2 = read_zarr(f2) + assert_spatial_data_objects_are_identical(table_multiple_annotations, table_read_v2) + + def test_container_v1_to_v2(self, full_sdata): + with tempfile.TemporaryDirectory() as tmpdir: + f1 = Path(tmpdir) / "data1.zarr" + f2 = Path(tmpdir) / "data2.zarr" + + full_sdata.write(f1, sdata_formats=[SpatialDataContainerFormatV01()]) + sdata_read_v1 = read_zarr(f1) + assert_spatial_data_objects_are_identical(full_sdata, sdata_read_v1) + assert sdata_read_v1.is_self_contained() + assert sdata_read_v1.has_consolidated_metadata() + + sdata_read_v1.write(f2, sdata_formats=[SpatialDataContainerFormatV02()]) + sdata_read_v2 = read_zarr(f2) + assert_spatial_data_objects_are_identical(full_sdata, sdata_read_v2) + assert sdata_read_v2.is_self_contained() + assert sdata_read_v2.has_consolidated_metadata() + + def test_channel_names_raster_images_v1_to_v2_to_v3(self, images): + with tempfile.TemporaryDirectory() as tmpdir: + f1 = Path(tmpdir) / "data1.zarr" + f2 = Path(tmpdir) / "data2.zarr" + f3 = Path(tmpdir) / "data3.zarr" + + new_channels = ["first", "second", "third"] + + images.write(f1, sdata_formats=[RasterFormatV01(), SpatialDataContainerFormatV01()]) + images_read_v1 = read_zarr(f1) + images_read_v1.set_channel_names("image2d", new_channels, write=True) + images_read_v1.set_channel_names("image2d_multiscale", new_channels, write=True) + assert images_read_v1["image2d"].coords["c"].data.tolist() == new_channels + assert images_read_v1["image2d_multiscale"]["scale0"]["image"].coords["c"].data.tolist() == new_channels + + images_read_v1.write(f2, sdata_formats=[SpatialDataContainerFormatV01()]) + images_read_v1 = read_zarr(f2) + assert images_read_v1["image2d"].coords["c"].data.tolist() == new_channels + assert images_read_v1["image2d_multiscale"]["scale0"]["image"].coords["c"].data.tolist() == new_channels + + images_read_v1.write(f3, sdata_formats=[SpatialDataContainerFormatV02()]) + sdata_read_v2 = read_zarr(f3) + assert sdata_read_v2["image2d"].coords["c"].data.tolist() == new_channels + assert sdata_read_v2["image2d_multiscale"]["scale0"]["image"].coords["c"].data.tolist() == new_channels diff --git a/tests/io/test_multi_table.py b/tests/io/test_multi_table.py index dd43cfa8d..70974c8a8 100644 --- a/tests/io/test_multi_table.py +++ b/tests/io/test_multi_table.py @@ -24,9 +24,14 @@ def test_set_get_tables_from_spatialdata(self, full_sdata: SpatialData, tmp_path adata2 = adata0.copy() del adata2.obs["region"] # fails because either none either all three 'region', 'region_key', 'instance_key' are required - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="Region key `region` not in `adata.obs`."): full_sdata["not_added_table"] = adata2 + adata3 = adata0.copy() + del adata3.obs["instance_id"] + with pytest.raises(ValueError, match="Instance key `instance_id` not in `adata.obs`."): + full_sdata["not_added_table"] = adata3 + assert len(full_sdata.tables) == 3 assert "adata0" in full_sdata.tables and "adata1" in full_sdata.tables full_sdata.write(tmpdir) @@ -36,6 +41,14 @@ def test_set_get_tables_from_spatialdata(self, full_sdata: SpatialData, tmp_path assert_equal(adata1, full_sdata["adata1"]) assert "adata0" in full_sdata.tables and "adata1" in full_sdata.tables + def test_null_values_in_instance_key_column(self, full_sdata: SpatialData): + n_obs = full_sdata["table"].n_obs + full_sdata["table"].obs["instance_id"] = range(n_obs) + # introduce null values + full_sdata["table"].obs.loc[0, "instance_id"] = None + with pytest.raises(ValueError, match="must not contain null values, but it does."): + full_sdata.validate_table_in_spatialdata(table=full_sdata["table"]) + @pytest.mark.parametrize( "region_key, instance_key, error_msg", [ @@ -74,7 +87,7 @@ def test_change_annotation_target(self, full_sdata, region_key, instance_key, er full_sdata.set_table_annotates_spatialelement("table", "poly") del full_sdata["table"].obs["instance_id"] - full_sdata["table"].obs["region"] = ["poly"] * n_obs + full_sdata["table"].obs["region"] = pd.Categorical(["poly"] * n_obs) with pytest.raises(ValueError, match=error_msg): full_sdata.set_table_annotates_spatialelement( "table", "poly", region_key=region_key, instance_key=instance_key @@ -114,43 +127,19 @@ def test_set_table_annotates_spatialelement(self, full_sdata, tmp_path): "table", "labels2d", region_key="region", instance_key="instance_id" ) - region = ["circles"] * 50 + ["poly"] * 50 + region = pd.Categorical(["circles"] * 50 + ["poly"] * 50) full_sdata["table"].obs["region"] = region full_sdata.set_table_annotates_spatialelement( "table", pd.Series(["circles", "poly"]), region_key="region", instance_key="instance_id" ) - full_sdata["table"].obs["region"] = "circles" + full_sdata["table"].obs["region"] = pd.Categorical(["circles"] * full_sdata["table"].n_obs) full_sdata.set_table_annotates_spatialelement( "table", "circles", region_key="region", instance_key="instance_id" ) full_sdata.write(tmpdir) - def test_old_accessor_deprecation(self, full_sdata, tmp_path): - # To test self._backed - tmpdir = Path(tmp_path) / "tmp.zarr" - full_sdata.write(tmpdir) - adata0 = _get_table(region="polygon") - - with pytest.warns(DeprecationWarning): - _ = full_sdata.table - with pytest.raises(ValueError): - full_sdata.table = adata0 - with pytest.warns(DeprecationWarning): - del full_sdata.table - with pytest.raises(KeyError): - del full_sdata["table"] - with pytest.warns(DeprecationWarning): - full_sdata.table = adata0 # this gets placed in sdata['table'] - - assert_equal(adata0, full_sdata["table"]) - - del full_sdata["table"] - - full_sdata.tables["my_new_table0"] = adata0 - assert full_sdata.get("table") is None - @pytest.mark.parametrize("region", ["test_shapes", "non_existing"]) def test_single_table(self, tmp_path: str, region: str): tmpdir = Path(tmp_path) / "tmp.zarr" @@ -271,13 +260,13 @@ def test_static_set_annotation_target(): ) table = _get_table(region="test_non_shapes") table_target = table.copy() - table_target.obs["region"] = "test_shapes" + table_target.obs["region"] = pd.Categorical(["test_shapes"] * table_target.n_obs) table_target = SpatialData.update_annotated_regions_metadata(table_target) assert table_target.uns[TableModel.ATTRS_KEY][TableModel.REGION_KEY] == ["test_shapes"] test_sdata["another_table"] = table_target - table.obs["diff_region"] = "test_shapes" + table.obs["diff_region"] = pd.Categorical(["test_shapes"] * table.n_obs) table = SpatialData.update_annotated_regions_metadata(table, region_key="diff_region") assert table.uns[TableModel.ATTRS_KEY][TableModel.REGION_KEY] == ["test_shapes"] diff --git a/tests/io/test_partial_read.py b/tests/io/test_partial_read.py index 7c7cdbfa2..e200c1fa9 100644 --- a/tests/io/test_partial_read.py +++ b/tests/io/test_partial_read.py @@ -11,14 +11,14 @@ from pathlib import Path from typing import TYPE_CHECKING -import numpy as np import py import pytest import zarr from pyarrow import ArrowInvalid -from zarr.errors import ArrayNotFoundError, MetadataError +from zarr.errors import ArrayNotFoundError, ZarrUserWarning from spatialdata import SpatialData, read_zarr +from spatialdata._io.format import SpatialDataContainerFormatV01 from spatialdata.datasets import blobs if TYPE_CHECKING: @@ -30,7 +30,7 @@ def pytest_warns_multiple( expected_warning: type[Warning] | tuple[type[Warning], ...] = Warning, matches: Iterable[str] = () ) -> Generator[None, None, None]: """ - Assert that code raises a warnings matching particular patterns. + Assert that code raises warnings matching particular patterns. Like `pytest.warns`, but with multiple patterns which each must match a warning. @@ -67,7 +67,7 @@ def test_case(request: _pytest.fixtures.SubRequest): class PartialReadTestCase: path: Path expected_elements: list[str] - expected_exceptions: type[Exception] | tuple[type[Exception], ...] + expected_exceptions: type[Exception] | tuple[type[Exception] | IOError, ...] warnings_patterns: list[str] @@ -85,11 +85,12 @@ def session_tmp_path(request: _pytest.fixtures.SubRequest) -> Path: @pytest.fixture(scope="module") -def sdata_with_corrupted_elem_type_zgroup(session_tmp_path: Path) -> PartialReadTestCase: - # .zattrs is a zero-byte file, aborted during write, or contains invalid JSON syntax +def sdata_with_corrupted_elem_types_zgroup(session_tmp_path: Path) -> PartialReadTestCase: + # Zarr v2 sdata = blobs() sdata_path = session_tmp_path / "sdata_with_corrupted_top_level_zgroup.zarr" - sdata.write(sdata_path) + # Errors only when no consolidation metadata store is used as this takes precedence over group metadata when reading + sdata.write(sdata_path, sdata_formats=SpatialDataContainerFormatV01(), consolidate_metadata=False) (sdata_path / "images" / ".zgroup").unlink() # missing, not detected by reader. So it doesn't raise an exception, # but it will not be found in the read SpatialData object @@ -100,43 +101,90 @@ def sdata_with_corrupted_elem_type_zgroup(session_tmp_path: Path) -> PartialRead return PartialReadTestCase( path=sdata_path, expected_elements=not_corrupted, - expected_exceptions=(JSONDecodeError, MetadataError), - warnings_patterns=["labels: JSONDecodeError", "points: MetadataError"], + expected_exceptions=(JSONDecodeError, ZarrUserWarning), + warnings_patterns=["labels: JSONDecodeError", "Object at"], ) @pytest.fixture(scope="module") -def sdata_with_corrupted_zattrs(session_tmp_path: Path) -> PartialReadTestCase: - # .zattrs is a zero-byte file, aborted during write, or contains invalid JSON syntax +def sdata_with_corrupted_elem_types_zarr_json(session_tmp_path: Path) -> PartialReadTestCase: + # Zarr v3 + sdata = blobs() + sdata_path = session_tmp_path / "sdata_with_corrupted_top_level_zarr_json.zarr" + # Errors only when no consolidation metadata store is used as this takes precedence over group metadata when reading + sdata.write(sdata_path, consolidate_metadata=False) + + (sdata_path / "images" / "zarr.json").unlink() # missing, not detected by reader. So it doesn't raise an exception, + # but it will not be found in the read SpatialData object + (sdata_path / "labels" / "zarr.json").write_text("") # corrupted + (sdata_path / "points" / "zarr.json").write_text('"not_valid": "not_valid"}') # invalid + not_corrupted = [name for t, name, _ in sdata.gen_elements() if t not in ("images", "labels", "points")] + + return PartialReadTestCase( + path=sdata_path, + expected_elements=not_corrupted, + expected_exceptions=(JSONDecodeError), + warnings_patterns=["labels: JSONDecodeError", "Extra data"], + ) + + +@pytest.fixture(scope="module") +def sdata_with_corrupted_zarr_json_elements(session_tmp_path: Path) -> PartialReadTestCase: + # Zarr v3 + # zarr.json is a zero-byte file, aborted during write, or contains invalid JSON syntax sdata = blobs() - sdata_path = session_tmp_path / "sdata_with_corrupted_zattrs.zarr" + sdata_path = session_tmp_path / "sdata_with_corrupted_zarr_json_elements.zarr" sdata.write(sdata_path) + corrupted_elements = ["blobs_image", "blobs_labels", "blobs_points", "blobs_polygons", "table"] + warnings_patterns = [] + for corrupted_element in corrupted_elements: + elem_path = sdata.locate_element(sdata[corrupted_element])[0] + (sdata_path / elem_path / "zarr.json").write_bytes(b"") + warnings_patterns.append(rf"{elem_path}: (?:OSError|JSONDecodeError):") + not_corrupted = [name for _, name, _ in sdata.gen_elements() if name not in corrupted_elements] + + return PartialReadTestCase( + path=sdata_path, + expected_elements=not_corrupted, + expected_exceptions=(JSONDecodeError, OSError), + warnings_patterns=warnings_patterns, + ) + + +@pytest.fixture(scope="module") +def sdata_with_corrupted_zattrs_elements(session_tmp_path: Path) -> PartialReadTestCase: + # Zarr v2 + # .zattrs is a zero-byte file, aborted during write, or contains invalid JSON syntax + sdata = blobs() + sdata_path = session_tmp_path / "sdata_with_corrupted_zattrs_elements.zarr" + sdata.write(sdata_path, sdata_formats=SpatialDataContainerFormatV01()) + corrupted_elements = ["blobs_image", "blobs_labels", "blobs_points", "blobs_polygons", "table"] warnings_patterns = [] for corrupted_element in corrupted_elements: elem_path = sdata.locate_element(sdata[corrupted_element])[0] (sdata_path / elem_path / ".zattrs").write_bytes(b"") - warnings_patterns.append(f"{elem_path}: JSONDecodeError") + warnings_patterns.append(rf"{elem_path}: (?:OSError|JSONDecodeError):") not_corrupted = [name for _, name, _ in sdata.gen_elements() if name not in corrupted_elements] return PartialReadTestCase( path=sdata_path, expected_elements=not_corrupted, - expected_exceptions=JSONDecodeError, + expected_exceptions=(OSError, JSONDecodeError), warnings_patterns=warnings_patterns, ) @pytest.fixture(scope="module") -def sdata_with_corrupted_image_chunks(session_tmp_path: Path) -> PartialReadTestCase: +def sdata_with_corrupted_image_chunks_zarrv3(session_tmp_path: Path) -> PartialReadTestCase: # images/blobs_image/0 is a zero-byte file or aborted during write sdata = blobs() - sdata_path = session_tmp_path / "sdata_with_corrupted_image_chunks.zarr" + sdata_path = session_tmp_path / "sdata_with_corrupted_image_chunks_zarrv3.zarr" sdata.write(sdata_path) corrupted = "blobs_image" - os.unlink(sdata_path / "images" / corrupted / "0" / ".zarray") # it will hide the "0" array from the Zarr reader + os.unlink(sdata_path / "images" / corrupted / "0" / "zarr.json") # it will hide the "0" array from the Zarr reader os.rename(sdata_path / "images" / corrupted / "0", sdata_path / "images" / corrupted / "0_corrupted") (sdata_path / "images" / corrupted / "0").touch() @@ -145,19 +193,37 @@ def sdata_with_corrupted_image_chunks(session_tmp_path: Path) -> PartialReadTest return PartialReadTestCase( path=sdata_path, expected_elements=not_corrupted, - expected_exceptions=( - ArrayNotFoundError, - TypeError, # instead of ArrayNotFoundError, with dask>=2024.10.0 zarr<=2.18.3 - ), - warnings_patterns=[rf"images/{corrupted}: (ArrayNotFoundError|TypeError)"], + expected_exceptions=(ArrayNotFoundError,), + warnings_patterns=[rf"images/{corrupted}: ArrayNotFoundError"], + ) + + +@pytest.fixture(scope="module") +def sdata_with_corrupted_image_chunks_zarrv2(session_tmp_path: Path) -> PartialReadTestCase: + # images/blobs_image/0 is a zero-byte file or aborted during write + sdata = blobs() + sdata_path = session_tmp_path / "sdata_with_corrupted_image_chunks_zarrv2.zarr" + sdata.write(sdata_path, sdata_formats=SpatialDataContainerFormatV01()) + + corrupted = "blobs_image" + os.unlink(sdata_path / "images" / corrupted / "0" / ".zarray") # it will hide the "0" array from the Zarr reader + os.rename(sdata_path / "images" / corrupted / "0", sdata_path / "images" / corrupted / "0_corrupted") + (sdata_path / "images" / corrupted / "0").touch() + not_corrupted = [name for _, name, _ in sdata.gen_elements() if name != corrupted] + + return PartialReadTestCase( + path=sdata_path, + expected_elements=not_corrupted, + expected_exceptions=(ArrayNotFoundError,), + warnings_patterns=[rf"images/{corrupted}: ArrayNotFoundError"], ) @pytest.fixture(scope="module") -def sdata_with_corrupted_parquet(session_tmp_path: Path) -> PartialReadTestCase: +def sdata_with_corrupted_parquet_zarrv3(session_tmp_path: Path) -> PartialReadTestCase: # points/blobs_points/0 is a zero-byte file or aborted during write sdata = blobs() - sdata_path = session_tmp_path / "sdata_with_corrupted_parquet.zarr" + sdata_path = session_tmp_path / "sdata_with_corrupted_parquet_zarrv3.zarr" sdata.write(sdata_path) corrupted = "blobs_points" @@ -178,12 +244,56 @@ def sdata_with_corrupted_parquet(session_tmp_path: Path) -> PartialReadTestCase: @pytest.fixture(scope="module") -def sdata_with_missing_zattrs(session_tmp_path: Path) -> PartialReadTestCase: - # .zattrs is missing +def sdata_with_corrupted_parquet_zarrv2(session_tmp_path: Path) -> PartialReadTestCase: + # points/blobs_points/0 is a zero-byte file or aborted during write + sdata = blobs() + sdata_path = session_tmp_path / "sdata_with_corrupted_parquet_zarrv2.zarr" + sdata.write(sdata_path, sdata_formats=SpatialDataContainerFormatV01()) + + corrupted = "blobs_points" + os.rename( + sdata_path / "points" / corrupted / "points.parquet", + sdata_path / "points" / corrupted / "points_corrupted.parquet", + ) + (sdata_path / "points" / corrupted / "points.parquet").touch() + + not_corrupted = [name for _, name, _ in sdata.gen_elements() if name != corrupted] + + return PartialReadTestCase( + path=sdata_path, + expected_elements=not_corrupted, + expected_exceptions=ArrowInvalid, + warnings_patterns=[rf"points/{corrupted}: ArrowInvalid"], + ) + + +@pytest.fixture(scope="module") +def sdata_with_missing_zarr_json_element(session_tmp_path: Path) -> PartialReadTestCase: + # zarr.json is missing sdata = blobs() - sdata_path = session_tmp_path / "sdata_with_missing_zattrs.zarr" + sdata_path = session_tmp_path / "sdata_with_missing_zarr_json_element.zarr" sdata.write(sdata_path) + corrupted = "blobs_image" + (sdata_path / "images" / corrupted / "zarr.json").unlink() + not_corrupted = [name for _, name, _ in sdata.gen_elements() if name != corrupted] + + return PartialReadTestCase( + path=sdata_path, + expected_elements=not_corrupted, + expected_exceptions=OSError, + warnings_patterns=[r"images/blobs_image: OSError:"], + ) + + +@pytest.fixture(scope="module") +def sdata_with_missing_zattrs_element(session_tmp_path: Path) -> PartialReadTestCase: + # Zarrv2 + # .zattrs is missing + sdata = blobs() + sdata_path = session_tmp_path / "sdata_with_missing_zattrs_element.zarr" + sdata.write(sdata_path, sdata_formats=SpatialDataContainerFormatV01()) + corrupted = "blobs_image" (sdata_path / "images" / corrupted / ".zattrs").unlink() not_corrupted = [name for _, name, _ in sdata.gen_elements() if name != corrupted] @@ -191,20 +301,43 @@ def sdata_with_missing_zattrs(session_tmp_path: Path) -> PartialReadTestCase: return PartialReadTestCase( path=sdata_path, expected_elements=not_corrupted, - expected_exceptions=ValueError, - warnings_patterns=[rf"images/{corrupted}: .* Unable to read the NGFF file"], + expected_exceptions=OSError, + warnings_patterns=["OSError: Image location"], ) @pytest.fixture(scope="module") -def sdata_with_missing_image_chunks( +def sdata_with_missing_image_chunks_zarrv3( session_tmp_path: Path, ) -> PartialReadTestCase: - # .zattrs exists, but refers to binary array chunks that do not exist sdata = blobs() - sdata_path = session_tmp_path / "sdata_with_missing_image_chunks.zarr" + sdata_path = session_tmp_path / "sdata_with_missing_image_chunks_zarrv3.zarr" sdata.write(sdata_path) + corrupted = "blobs_image" + os.unlink(sdata_path / "images" / corrupted / "0" / "zarr.json") + os.rename(sdata_path / "images" / corrupted / "0", sdata_path / "images" / corrupted / "0_corrupted") + + not_corrupted = [name for _, name, _ in sdata.gen_elements() if name != corrupted] + + return PartialReadTestCase( + path=sdata_path, + expected_elements=not_corrupted, + expected_exceptions=(ArrayNotFoundError,), + warnings_patterns=[rf"images/{corrupted}: ArrayNotFoundError"], + ) + + +@pytest.fixture(scope="module") +def sdata_with_missing_image_chunks_zarrv2( + session_tmp_path: Path, +) -> PartialReadTestCase: + # Zarrv2 + # .zattrs exists, but refers to binary array chunks that do not exist + sdata = blobs() + sdata_path = session_tmp_path / "sdata_with_missing_image_chunks_zarrv2.zarr" + sdata.write(sdata_path, sdata_formats=SpatialDataContainerFormatV01()) + corrupted = "blobs_image" os.unlink(sdata_path / "images" / corrupted / "0" / ".zarray") os.rename(sdata_path / "images" / corrupted / "0", sdata_path / "images" / corrupted / "0_corrupted") @@ -214,21 +347,19 @@ def sdata_with_missing_image_chunks( return PartialReadTestCase( path=sdata_path, expected_elements=not_corrupted, - expected_exceptions=( - ArrayNotFoundError, - TypeError, # instead of ArrayNotFoundError, with dask>=2024.10.0 zarr<=2.18.3 - ), + expected_exceptions=(ArrayNotFoundError,), warnings_patterns=[rf"images/{corrupted}: (ArrayNotFoundError|TypeError)"], ) @pytest.fixture(scope="module") -def sdata_with_invalid_zattrs_violating_spec(session_tmp_path: Path) -> PartialReadTestCase: +def sdata_with_invalid_zattrs_element_violating_spec(session_tmp_path: Path) -> PartialReadTestCase: + # Zarr v2 # .zattrs contains readable JSON which is not valid for SpatialData/NGFF specs # for example due to a missing/misspelled/renamed key sdata = blobs() sdata_path = session_tmp_path / "sdata_with_invalid_zattrs_violating_spec.zarr" - sdata.write(sdata_path) + sdata.write(sdata_path, sdata_formats=SpatialDataContainerFormatV01()) corrupted = "blobs_image" json_dict = json.loads((sdata_path / "images" / corrupted / ".zattrs").read_text()) @@ -245,20 +376,69 @@ def sdata_with_invalid_zattrs_violating_spec(session_tmp_path: Path) -> PartialR @pytest.fixture(scope="module") -def sdata_with_invalid_zattrs_table_region_not_found(session_tmp_path: Path) -> PartialReadTestCase: +def sdata_with_invalid_zarr_json_element_violating_spec(session_tmp_path: Path) -> PartialReadTestCase: + # zarr.json contains readable JSON which is not valid for SpatialData/NGFF specs + # for example due to a missing/misspelled/renamed key + sdata = blobs() + sdata_path = session_tmp_path / "sdata_with_invalid_zarr_json_violating_spec.zarr" + sdata.write(sdata_path) + + corrupted = "blobs_image" + json_dict = json.loads((sdata_path / "images" / corrupted / "zarr.json").read_text()) + del json_dict["attributes"]["ome"]["multiscales"][0]["coordinateTransformations"] + (sdata_path / "images" / corrupted / "zarr.json").write_text(json.dumps(json_dict, indent=4)) + not_corrupted = [name for _, name, _ in sdata.gen_elements() if name != corrupted] + + return PartialReadTestCase( + path=sdata_path, + expected_elements=not_corrupted, + expected_exceptions=KeyError, + warnings_patterns=[rf"images/{corrupted}: KeyError: coordinateTransformations"], + ) + + +@pytest.fixture(scope="module") +def sdata_with_table_region_not_found_zarrv3(session_tmp_path: Path) -> PartialReadTestCase: # table/table/.zarr referring to a region that is not found # This has been emitting just a warning, but does not fail reading the table element. sdata = blobs() - sdata_path = session_tmp_path / "sdata_with_invalid_zattrs_table_region_not_found.zarr" + sdata_path = session_tmp_path / "sdata_with_invalid_table_region_not_found_zarrv3.zarr" sdata.write(sdata_path) corrupted = "blobs_labels" # The element data is missing - os.unlink(sdata_path / "labels" / corrupted / ".zgroup") - os.rename(sdata_path / "labels" / corrupted, sdata_path / "labels" / f"{corrupted}_corrupted") + sdata.delete_element_from_disk(corrupted) + # But the labels element is referenced as a region in a table + regions = zarr.open_group(sdata_path / "tables" / "table" / "obs" / "region", mode="r") + arrs = dict(regions.arrays()) + assert corrupted in arrs["categories"][arrs["codes"]] + not_corrupted = [name for _, name, _ in sdata.gen_elements() if name != corrupted] + + return PartialReadTestCase( + path=sdata_path, + expected_elements=not_corrupted, + expected_exceptions=(), + warnings_patterns=[ + rf"The table is annotating '{re.escape(corrupted)}', which is not present in the SpatialData object" + ], + ) + + +@pytest.fixture(scope="module") +def sdata_with_table_region_not_found_zarrv2(session_tmp_path: Path) -> PartialReadTestCase: + # table/table/.zarr referring to a region that is not found + # This has been emitting just a warning, but does not fail reading the table element. + sdata = blobs() + sdata_path = session_tmp_path / "sdata_with_invalid_zattrs_table_region_not_found.zarr" + sdata.write(sdata_path, sdata_formats=SpatialDataContainerFormatV01()) + + corrupted = "blobs_labels" + # The element data is missing + sdata.delete_element_from_disk(corrupted) # But the labels element is referenced as a region in a table regions = zarr.open_group(sdata_path / "tables" / "table" / "obs" / "region", mode="r") - assert corrupted in np.asarray(regions.categories)[regions.codes] + arrs = dict(regions.arrays()) + assert corrupted in arrs["categories"][arrs["codes"]] not_corrupted = [name for _, name, _ in sdata.gen_elements() if name != corrupted] return PartialReadTestCase( @@ -274,19 +454,26 @@ def sdata_with_invalid_zattrs_table_region_not_found(session_tmp_path: Path) -> @pytest.mark.parametrize( "test_case", [ - sdata_with_corrupted_zattrs, - sdata_with_corrupted_image_chunks, - sdata_with_corrupted_parquet, - sdata_with_missing_zattrs, - sdata_with_missing_image_chunks, - sdata_with_invalid_zattrs_violating_spec, - sdata_with_invalid_zattrs_table_region_not_found, - sdata_with_corrupted_elem_type_zgroup, + sdata_with_corrupted_elem_types_zgroup, # JSONDecodeError + sdata_with_corrupted_elem_types_zarr_json, # JSONDecodeError + sdata_with_corrupted_zarr_json_elements, # OSError + sdata_with_corrupted_zattrs_elements, # OSError + sdata_with_corrupted_image_chunks_zarrv3, # zarr.errors.ArrayNotFoundError + sdata_with_corrupted_image_chunks_zarrv2, # zarr.errors.ArrayNotFoundError + sdata_with_corrupted_parquet_zarrv3, # ArrowInvalid + sdata_with_corrupted_parquet_zarrv2, # ArrowInvalid + sdata_with_missing_zarr_json_element, # OSError + sdata_with_missing_zattrs_element, # OSError + sdata_with_missing_image_chunks_zarrv3, # zarr.errors.ArrayNotFoundError + sdata_with_missing_image_chunks_zarrv2, # zarr.errors.ArrayNotFoundError + sdata_with_invalid_zattrs_element_violating_spec, # KeyError + sdata_with_invalid_zarr_json_element_violating_spec, # KeyError + sdata_with_table_region_not_found_zarrv3, + sdata_with_table_region_not_found_zarrv2, ], indirect=True, ) def test_read_zarr_with_error(test_case: PartialReadTestCase): - # The specific type of exception depends on the read function for the SpatialData element if test_case.expected_exceptions: with pytest.raises(test_case.expected_exceptions): read_zarr(test_case.path, on_bad_files="error") @@ -297,14 +484,22 @@ def test_read_zarr_with_error(test_case: PartialReadTestCase): @pytest.mark.parametrize( "test_case", [ - sdata_with_corrupted_zattrs, - sdata_with_corrupted_image_chunks, - sdata_with_corrupted_parquet, - sdata_with_missing_zattrs, - sdata_with_missing_image_chunks, - sdata_with_invalid_zattrs_violating_spec, - sdata_with_invalid_zattrs_table_region_not_found, - sdata_with_corrupted_elem_type_zgroup, + sdata_with_corrupted_elem_types_zgroup, # JSONDecodeError + sdata_with_corrupted_elem_types_zarr_json, # JSONDecodeError + sdata_with_corrupted_zarr_json_elements, # JSONDecodeError for non raster, else OSError + sdata_with_corrupted_zattrs_elements, # JSONDecodeError for non raster, else OSError + sdata_with_corrupted_image_chunks_zarrv3, # zarr.errors.ArrayNotFoundError + sdata_with_corrupted_image_chunks_zarrv2, # zarr.errors.ArrayNotFoundError + sdata_with_corrupted_parquet_zarrv3, # ArrowInvalid + sdata_with_corrupted_parquet_zarrv2, # ArrowInvalid + sdata_with_missing_zarr_json_element, # OSError + sdata_with_missing_zattrs_element, # OSError + sdata_with_missing_image_chunks_zarrv3, # zarr.errors.ArrayNotFoundError + sdata_with_missing_image_chunks_zarrv2, # zarr.errors.ArrayNotFoundError + sdata_with_invalid_zattrs_element_violating_spec, # KeyError + sdata_with_invalid_zarr_json_element_violating_spec, # KeyError + sdata_with_table_region_not_found_zarrv3, + sdata_with_table_region_not_found_zarrv2, ], indirect=True, ) diff --git a/tests/io/test_pyramids_performance.py b/tests/io/test_pyramids_performance.py index f0ca31a2c..0bfcc2c43 100644 --- a/tests/io/test_pyramids_performance.py +++ b/tests/io/test_pyramids_performance.py @@ -58,17 +58,17 @@ def test_write_image_multiscale_performance(sdata_with_image: SpatialData, tmp_p # (see issue https://github.com/scverse/spatialdata/issues/577). # Instead of measuring the time (which would have high variation if not using big datasets), # we watch the number of read and write accesses and compare to the theoretical number. - zarr_chunk_write_spy = mocker.spy(zarr.core.Array, "__setitem__") - zarr_chunk_read_spy = mocker.spy(zarr.core.Array, "__getitem__") + zarr_chunk_write_spy = mocker.spy(zarr.Array, "__setitem__") + zarr_chunk_read_spy = mocker.spy(zarr.Array, "__getitem__") image_name, image = next(iter(sdata_with_image.images.items())) - element_type_group = zarr.group(store=tmp_path / "sdata.zarr", path="/images") + element_type_group = zarr.group(store=tmp_path / "image.zarr", path="/images") write_image( image=image, group=element_type_group, name=image_name, - format=CurrentRasterFormat(), + element_format=CurrentRasterFormat(), ) # The number of chunks of scale level 0 diff --git a/tests/io/test_readwrite.py b/tests/io/test_readwrite.py index ad8c66b4c..8501687ca 100644 --- a/tests/io/test_readwrite.py +++ b/tests/io/test_readwrite.py @@ -1,3 +1,4 @@ +import json import os import tempfile from collections.abc import Callable @@ -7,12 +8,20 @@ import dask.dataframe as dd import numpy as np import pytest +import zarr from anndata import AnnData from numpy.random import default_rng +from zarr.errors import GroupNotFoundError from spatialdata import SpatialData, deepcopy, read_zarr from spatialdata._core.validation import ValidationError from spatialdata._io._utils import _are_directories_identical, get_dask_backing_files +from spatialdata._io.format import ( + CurrentSpatialDataContainerFormat, + SpatialDataContainerFormats, + SpatialDataContainerFormatType, + SpatialDataContainerFormatV01, +) from spatialdata.datasets import blobs from spatialdata.models import Image2DModel from spatialdata.models._utils import get_channel_names @@ -22,134 +31,145 @@ set_transformation, ) from spatialdata.transformations.transformations import Identity, Scale -from tests.conftest import _get_images, _get_labels, _get_points, _get_shapes, _get_table, _get_tables +from tests.conftest import ( + _get_images, + _get_labels, + _get_points, + _get_shapes, + _get_table, + _get_tables, +) RNG = default_rng(0) +SDATA_FORMATS = list(SpatialDataContainerFormats.values()) +@pytest.mark.parametrize("sdata_container_format", SDATA_FORMATS) class TestReadWrite: - def test_images(self, tmp_path: str, images: SpatialData) -> None: + def test_images( + self, + tmp_path: str, + images: SpatialData, + sdata_container_format: SpatialDataContainerFormatType, + ) -> None: tmpdir = Path(tmp_path) / "tmp.zarr" # ensures that we are inplicitly testing the read and write of channel names assert get_channel_names(images["image2d"]) == ["r", "g", "b"] assert get_channel_names(images["image2d_multiscale"]) == ["r", "g", "b"] - images.write(tmpdir) + images.write(tmpdir, sdata_formats=sdata_container_format) sdata = SpatialData.read(tmpdir) assert_spatial_data_objects_are_identical(images, sdata) - def test_labels(self, tmp_path: str, labels: SpatialData) -> None: + def test_labels( + self, + tmp_path: str, + labels: SpatialData, + sdata_container_format: SpatialDataContainerFormatType, + ) -> None: tmpdir = Path(tmp_path) / "tmp.zarr" - labels.write(tmpdir) + labels.write(tmpdir, sdata_formats=sdata_container_format) sdata = SpatialData.read(tmpdir) assert_spatial_data_objects_are_identical(labels, sdata) - def test_shapes(self, tmp_path: str, shapes: SpatialData) -> None: + def test_shapes( + self, + tmp_path: str, + shapes: SpatialData, + sdata_container_format: SpatialDataContainerFormatType, + ) -> None: tmpdir = Path(tmp_path) / "tmp.zarr" # check the index is correctly written and then read shapes["circles"].index = np.arange(1, len(shapes["circles"]) + 1) - shapes.write(tmpdir) + shapes.write(tmpdir, sdata_formats=sdata_container_format) sdata = SpatialData.read(tmpdir) assert_spatial_data_objects_are_identical(shapes, sdata) - def test_points(self, tmp_path: str, points: SpatialData) -> None: + def test_points( + self, + tmp_path: str, + points: SpatialData, + sdata_container_format: SpatialDataContainerFormatType, + ) -> None: tmpdir = Path(tmp_path) / "tmp.zarr" # check the index is correctly written and then read new_index = dd.from_array(np.arange(1, len(points["points_0"]) + 1)) points["points_0"] = points["points_0"].set_index(new_index) - points.write(tmpdir) + points.write(tmpdir, sdata_formats=sdata_container_format) sdata = SpatialData.read(tmpdir) assert_spatial_data_objects_are_identical(points, sdata) - def _test_table(self, tmp_path: str, table: SpatialData) -> None: + def _test_table( + self, + tmp_path: str, + table: SpatialData, + sdata_container_format: SpatialDataContainerFormatType, + ) -> None: tmpdir = Path(tmp_path) / "tmp.zarr" - table.write(tmpdir) + table.write(tmpdir, sdata_formats=sdata_container_format) sdata = SpatialData.read(tmpdir) assert_spatial_data_objects_are_identical(table, sdata) - def test_single_table_single_annotation(self, tmp_path: str, table_single_annotation: SpatialData) -> None: - self._test_table(tmp_path, table_single_annotation) + def test_single_table_single_annotation( + self, + tmp_path: str, + table_single_annotation: SpatialData, + sdata_container_format: SpatialDataContainerFormatType, + ) -> None: + self._test_table( + tmp_path, + table_single_annotation, + sdata_container_format=sdata_container_format, + ) - def test_single_table_multiple_annotations(self, tmp_path: str, table_multiple_annotations: SpatialData) -> None: - self._test_table(tmp_path, table_multiple_annotations) + def test_single_table_multiple_annotations( + self, + tmp_path: str, + table_multiple_annotations: SpatialData, + sdata_container_format: SpatialDataContainerFormatType, + ) -> None: + self._test_table( + tmp_path, + table_multiple_annotations, + sdata_container_format=sdata_container_format, + ) - def test_multiple_tables(self, tmp_path: str, tables: list[AnnData]) -> None: + def test_multiple_tables( + self, + tmp_path: str, + tables: list[AnnData], + sdata_container_format: SpatialDataContainerFormatType, + ) -> None: sdata_tables = SpatialData(tables={str(i): tables[i] for i in range(len(tables))}) - self._test_table(tmp_path, sdata_tables) + self._test_table(tmp_path, sdata_tables, sdata_container_format=sdata_container_format) def test_roundtrip( self, tmp_path: str, sdata: SpatialData, + sdata_container_format: SpatialDataContainerFormatType, ) -> None: tmpdir = Path(tmp_path) / "tmp.zarr" - sdata.write(tmpdir) + sdata.write(tmpdir, sdata_formats=sdata_container_format) sdata2 = SpatialData.read(tmpdir) tmpdir2 = Path(tmp_path) / "tmp2.zarr" - sdata2.write(tmpdir2) + sdata2.write(tmpdir2, sdata_formats=sdata_container_format) _are_directories_identical(tmpdir, tmpdir2, exclude_regexp="[1-9][0-9]*.*") - def test_incremental_io_in_memory( + def test_incremental_io_list_of_elements( self, - full_sdata: SpatialData, + shapes: SpatialData, + sdata_container_format: SpatialDataContainerFormatType, ) -> None: - sdata = full_sdata - - for k, v in _get_images().items(): - sdata.images[f"additional_{k}"] = v - with pytest.warns(UserWarning): - sdata.images[f"additional_{k}"] = v - with pytest.warns(UserWarning): - sdata[f"additional_{k}"] = v - with pytest.raises(KeyError, match="Key `table` is not unique"): - sdata["table"] = v - - for k, v in _get_labels().items(): - sdata.labels[f"additional_{k}"] = v - with pytest.warns(UserWarning): - sdata.labels[f"additional_{k}"] = v - with pytest.warns(UserWarning): - sdata[f"additional_{k}"] = v - with pytest.raises(KeyError, match="Key `table` is not unique"): - sdata["table"] = v - - for k, v in _get_shapes().items(): - sdata.shapes[f"additional_{k}"] = v - with pytest.warns(UserWarning): - sdata.shapes[f"additional_{k}"] = v - with pytest.warns(UserWarning): - sdata[f"additional_{k}"] = v - with pytest.raises(KeyError, match="Key `table` is not unique"): - sdata["table"] = v - - for k, v in _get_points().items(): - sdata.points[f"additional_{k}"] = v - with pytest.warns(UserWarning): - sdata.points[f"additional_{k}"] = v - with pytest.warns(UserWarning): - sdata[f"additional_{k}"] = v - with pytest.raises(KeyError, match="Key `table` is not unique"): - sdata["table"] = v - - for k, v in _get_tables().items(): - sdata.tables[f"additional_{k}"] = v - with pytest.warns(UserWarning): - sdata.tables[f"additional_{k}"] = v - with pytest.warns(UserWarning): - sdata[f"additional_{k}"] = v - with pytest.raises(KeyError, match="Key `poly` is not unique"): - sdata["poly"] = v - - def test_incremental_io_list_of_elements(self, shapes: SpatialData) -> None: with tempfile.TemporaryDirectory() as tmpdir: f = os.path.join(tmpdir, "data.zarr") - shapes.write(f) + shapes.write(f, sdata_formats=sdata_container_format) new_shapes0 = deepcopy(shapes["circles"]) new_shapes1 = deepcopy(shapes["poly"]) shapes["new_shapes0"] = new_shapes0 @@ -157,7 +177,7 @@ def test_incremental_io_list_of_elements(self, shapes: SpatialData) -> None: assert "shapes/new_shapes0" not in shapes.elements_paths_on_disk() assert "shapes/new_shapes1" not in shapes.elements_paths_on_disk() - shapes.write_element(["new_shapes0", "new_shapes1"]) + shapes.write_element(["new_shapes0", "new_shapes1"], sdata_formats=sdata_container_format) assert "shapes/new_shapes0" in shapes.elements_paths_on_disk() assert "shapes/new_shapes1" in shapes.elements_paths_on_disk() @@ -165,10 +185,60 @@ def test_incremental_io_list_of_elements(self, shapes: SpatialData) -> None: assert "shapes/new_shapes0" not in shapes.elements_paths_on_disk() assert "shapes/new_shapes1" not in shapes.elements_paths_on_disk() - @pytest.mark.parametrize("dask_backed", [True, False]) + @staticmethod + def _workaround1_non_dask_backed( + sdata: SpatialData, + name: str, + new_name: str, + sdata_container_format: SpatialDataContainerFormatType = CurrentSpatialDataContainerFormat(), + ) -> None: + # a. write a backup copy of the data + sdata[new_name] = sdata[name] + sdata.write_element(new_name, sdata_formats=sdata_container_format) + # b. rewrite the original data + sdata.delete_element_from_disk(name) + sdata.write_element(name, sdata_formats=sdata_container_format) + # c. remove the backup copy + del sdata[new_name] + sdata.delete_element_from_disk(new_name) + + @staticmethod + def _workaround1_dask_backed( + sdata: SpatialData, + name: str, + new_name: str, + sdata_container_format: SpatialDataContainerFormatType = CurrentSpatialDataContainerFormat(), + ) -> None: + # a. write a backup copy of the data + sdata[new_name] = sdata[name] + sdata.write_element(new_name, sdata_formats=sdata_container_format) + # a2. remove the in-memory copy from the SpatialData object (note, + # at this point the backup copy still exists on-disk) + del sdata[new_name] + del sdata[name] + # a3 load the backup copy into memory + sdata_copy = read_zarr(sdata.path) + # b1. rewrite the original data + sdata.delete_element_from_disk(name) + sdata[name] = sdata_copy[new_name] + sdata.write_element(name, sdata_formats=sdata_container_format) + # b2. reload the new data into memory (because it has been written but in-memory it still points + # from the backup location) + sdata = read_zarr(sdata.path) + # c. remove the backup copy + del sdata[new_name] + sdata.delete_element_from_disk(new_name) + + # @pytest.mark.parametrize("dask_backed", [True, False]) + @pytest.mark.parametrize("dask_backed", [True]) @pytest.mark.parametrize("workaround", [1, 2]) def test_incremental_io_on_disk( - self, tmp_path: str, full_sdata: SpatialData, dask_backed: bool, workaround: int + self, + tmp_path: str, + full_sdata: SpatialData, + dask_backed: bool, + workaround: int, + sdata_container_format: SpatialDataContainerFormatType, ) -> None: """ This tests shows workaround on how to rewrite existing data on disk. @@ -181,7 +251,7 @@ def test_incremental_io_on_disk( """ tmpdir = Path(tmp_path) / "incremental_io.zarr" sdata = SpatialData() - sdata.write(tmpdir) + sdata.write(tmpdir, sdata_formats=sdata_container_format) for name in [ "image2d", @@ -193,19 +263,37 @@ def test_incremental_io_on_disk( "table", ]: sdata[name] = full_sdata[name] - sdata.write_element(name) + sdata.write_element(name, sdata_formats=sdata_container_format) if dask_backed: # this forces the element to write to be dask-backed from disk. In this case, overwriting the data is # more laborious because we are writing the data to the same location that defines the data! sdata = read_zarr(sdata.path) with pytest.raises( - ValueError, match="The Zarr store already exists. Use `overwrite=True` to try overwriting the store." + ValueError, + match="The Zarr store already exists. Use `overwrite=True` to try overwriting the store.", ): - sdata.write_element(name) - - with pytest.raises(ValueError, match="Cannot overwrite."): - sdata.write_element(name, overwrite=True) + sdata.write_element(name, sdata_formats=sdata_container_format) + + match = ( + "Details: the target path contains one or more files that Dask use for backing elements in the " + "SpatialData object" + if dask_backed + and name + in [ + "image2d", + "labels2d", + "image3d_multiscale_xarray", + "labels3d_multiscale_xarray", + "points_0", + ] + else "Details: the target path in which to save an element is a subfolder of the current Zarr store." + ) + with pytest.raises( + ValueError, + match=match, + ): + sdata.write_element(name, overwrite=True, sdata_formats=sdata_container_format) if workaround == 1: new_name = f"{name}_new_place" @@ -213,35 +301,19 @@ def test_incremental_io_on_disk( # setups, ...). If the scenario matches your use case, please use with caution. if not dask_backed: # easier case - # a. write a backup copy of the data - sdata[new_name] = sdata[name] - sdata.write_element(new_name) - # b. rewrite the original data - sdata.delete_element_from_disk(name) - sdata.write_element(name) - # c. remove the backup copy - del sdata[new_name] - sdata.delete_element_from_disk(new_name) + self._workaround1_non_dask_backed( + sdata=sdata, + name=name, + new_name=new_name, + sdata_container_format=sdata_container_format, + ) else: # dask-backed case, more complex - # a. write a backup copy of the data - sdata[new_name] = sdata[name] - sdata.write_element(new_name) - # a2. remove the in-memory copy from the SpatialData object (note, - # at this point the backup copy still exists on-disk) - del sdata[new_name] - del sdata[name] - # a3 load the backup copy into memory - sdata_copy = read_zarr(sdata.path) - # b1. rewrite the original data - sdata.delete_element_from_disk(name) - sdata[name] = sdata_copy[new_name] - sdata.write_element(name) - # b2. reload the new data into memory (because it has been written but in-memory it still points - # from the backup location) - sdata = read_zarr(sdata.path) - # c. remove the backup copy - del sdata[new_name] - sdata.delete_element_from_disk(new_name) + self._workaround1_dask_backed( + sdata=sdata, + name=name, + new_name=new_name, + sdata_container_format=sdata_container_format, + ) elif workaround == 2: # workaround 2, unsafe but sometimes acceptable depending on the user's workflow. @@ -250,39 +322,18 @@ def test_incremental_io_on_disk( if not dask_backed: # a. rewrite the original data (risky!) sdata.delete_element_from_disk(name) - sdata.write_element(name) - - def test_incremental_io_table_legacy(self, table_single_annotation: SpatialData) -> None: - s = table_single_annotation - t = s["table"][:10, :].copy() - with pytest.raises(ValueError): - s.table = t - del s["table"] - s.table = t + sdata.write_element(name, sdata_formats=sdata_container_format) + def test_io_and_lazy_loading_points(self, points, sdata_container_format: SpatialDataContainerFormatType): with tempfile.TemporaryDirectory() as td: f = os.path.join(td, "data.zarr") - s.write(f) - s2 = SpatialData.read(f) - assert len(s2["table"]) == len(t) - del s2["table"] - s2.table = s["table"] - assert len(s2["table"]) == len(s["table"]) - f2 = os.path.join(td, "data2.zarr") - s2.write(f2) - s3 = SpatialData.read(f2) - assert len(s3["table"]) == len(s2["table"]) - - def test_io_and_lazy_loading_points(self, points): - with tempfile.TemporaryDirectory() as td: - f = os.path.join(td, "data.zarr") - points.write(f) + points.write(f, sdata_formats=sdata_container_format) assert len(get_dask_backing_files(points)) == 0 sdata2 = SpatialData.read(f) assert len(get_dask_backing_files(sdata2)) > 0 - def test_io_and_lazy_loading_raster(self, images, labels): + def test_io_and_lazy_loading_raster(self, images, labels, sdata_container_format: SpatialDataContainerFormatType): sdatas = {"images": images, "labels": labels} for k, sdata in sdatas.items(): d = getattr(sdata, k) @@ -290,7 +341,7 @@ def test_io_and_lazy_loading_raster(self, images, labels): with tempfile.TemporaryDirectory() as td: f = os.path.join(td, "data.zarr") dask0 = sdata[elem_name].data - sdata.write(f) + sdata.write(f, sdata_formats=sdata_container_format) assert all("from-zarr" not in key for key in dask0.dask.layers) assert len(get_dask_backing_files(sdata)) == 0 @@ -299,7 +350,9 @@ def test_io_and_lazy_loading_raster(self, images, labels): assert any("from-zarr" in key for key in dask1.dask.layers) assert len(get_dask_backing_files(sdata2)) > 0 - def test_replace_transformation_on_disk_raster(self, images, labels): + def test_replace_transformation_on_disk_raster( + self, images, labels, sdata_container_format: SpatialDataContainerFormatType + ): sdatas = {"images": images, "labels": labels} for k, sdata in sdatas.items(): d = getattr(sdata, k) @@ -309,7 +362,7 @@ def test_replace_transformation_on_disk_raster(self, images, labels): single_sdata = SpatialData(**kwargs) with tempfile.TemporaryDirectory() as td: f = os.path.join(td, "data.zarr") - single_sdata.write(f) + single_sdata.write(f, sdata_formats=sdata_container_format) t0 = get_transformation(SpatialData.read(f)[elem_name]) assert isinstance(t0, Identity) set_transformation( @@ -320,37 +373,49 @@ def test_replace_transformation_on_disk_raster(self, images, labels): t1 = get_transformation(SpatialData.read(f)[elem_name]) assert isinstance(t1, Scale) - def test_replace_transformation_on_disk_non_raster(self, shapes, points): + def test_replace_transformation_on_disk_non_raster( + self, shapes, points, sdata_container_format: SpatialDataContainerFormatType + ): sdatas = {"shapes": shapes, "points": points} for k, sdata in sdatas.items(): d = sdata.__getattribute__(k) elem_name = list(d.keys())[0] with tempfile.TemporaryDirectory() as td: f = os.path.join(td, "data.zarr") - sdata.write(f) + sdata.write(f, sdata_formats=sdata_container_format) t0 = get_transformation(SpatialData.read(f).__getattribute__(k)[elem_name]) assert isinstance(t0, Identity) set_transformation(sdata[elem_name], Scale([2.0], axes=("x",)), write_to_sdata=sdata) t1 = get_transformation(SpatialData.read(f)[elem_name]) assert isinstance(t1, Scale) - def test_overwrite_works_when_no_zarr_store(self, full_sdata): + def test_write_overwrite_fails_when_no_zarr_store( + self, full_sdata, sdata_container_format: SpatialDataContainerFormatType + ): with tempfile.TemporaryDirectory() as tmpdir: - f = os.path.join(tmpdir, "data.zarr") + f = Path(tmpdir) / "data.zarr" + f.mkdir() old_data = SpatialData() - old_data.write(f) - # Since no, no risk of overwriting backing data. - # Should not raise "The file path specified is the same as the one used for backing." - full_sdata.write(f, overwrite=True) + with pytest.raises(ValueError, match="The target file path specified already exists"): + old_data.write(f, sdata_formats=sdata_container_format) + with pytest.raises(ValueError, match="The target file path specified already exists"): + full_sdata.write(f, overwrite=True, sdata_formats=sdata_container_format) - def test_overwrite_fails_when_no_zarr_store_bug_dask_backed_data(self, full_sdata, points, images, labels): + def test_overwrite_fails_when_no_zarr_store_but_dask_backed_data( + self, + full_sdata, + points, + images, + labels, + sdata_container_format: SpatialDataContainerFormatType, + ): sdatas = {"images": images, "labels": labels, "points": points} elements = {"images": "image2d", "labels": "labels2d", "points": "points_0"} for k, sdata in sdatas.items(): element = elements[k] with tempfile.TemporaryDirectory() as tmpdir: f = os.path.join(tmpdir, "data.zarr") - sdata.write(f) + sdata.write(f, sdata_formats=sdata_container_format) # now we have a sdata with dask-backed elements sdata2 = SpatialData.read(f) @@ -360,31 +425,34 @@ def test_overwrite_fails_when_no_zarr_store_bug_dask_backed_data(self, full_sdat ValueError, match="The Zarr store already exists. Use `overwrite=True` to try overwriting the store.", ): - full_sdata.write(f) + full_sdata.write(f, sdata_formats=sdata_container_format) with pytest.raises( ValueError, - match="Cannot overwrite.", + match=r"Details: the target path contains one or more files that Dask use for " + "backing elements in the SpatialData object", ): - full_sdata.write(f, overwrite=True) + full_sdata.write(f, overwrite=True, sdata_formats=sdata_container_format) - def test_overwrite_fails_when_zarr_store_present(self, full_sdata): + def test_overwrite_fails_when_zarr_store_present( + self, full_sdata, sdata_container_format: SpatialDataContainerFormatType + ): # addressing https://github.com/scverse/spatialdata/issues/137 with tempfile.TemporaryDirectory() as tmpdir: f = os.path.join(tmpdir, "data.zarr") - full_sdata.write(f) + full_sdata.write(f, sdata_formats=sdata_container_format) with pytest.raises( ValueError, match="The Zarr store already exists. Use `overwrite=True` to try overwriting the store.", ): - full_sdata.write(f) + full_sdata.write(f, sdata_formats=sdata_container_format) with pytest.raises( ValueError, - match="Cannot overwrite.", + match=r"Details: the target path either contains, coincides or is contained in the current Zarr store", ): - full_sdata.write(f, overwrite=True) + full_sdata.write(f, overwrite=True, sdata_formats=sdata_container_format) # support for overwriting backed sdata has been temporarily removed # with tempfile.TemporaryDirectory() as tmpdir: @@ -402,7 +470,9 @@ def test_overwrite_fails_when_zarr_store_present(self, full_sdata): # ) # sdata2.write(f, overwrite=True) - def test_overwrite_fails_onto_non_zarr_file(self, full_sdata): + def test_overwrite_fails_onto_non_zarr_file( + self, full_sdata, sdata_container_format: SpatialDataContainerFormatType + ): ERROR_MESSAGE = ( "The target file path specified already exists, and it has been detected to not be a Zarr store." ) @@ -413,18 +483,49 @@ def test_overwrite_fails_onto_non_zarr_file(self, full_sdata): ValueError, match=ERROR_MESSAGE, ): - full_sdata.write(f0) + full_sdata.write(f0, sdata_formats=sdata_container_format) with pytest.raises( ValueError, match=ERROR_MESSAGE, ): - full_sdata.write(f0, overwrite=True) + full_sdata.write(f0, overwrite=True, sdata_formats=sdata_container_format) f1 = os.path.join(tmpdir, "test.zarr") os.mkdir(f1) with pytest.raises(ValueError, match=ERROR_MESSAGE): - full_sdata.write(f1) + full_sdata.write(f1, sdata_formats=sdata_container_format) with pytest.raises(ValueError, match=ERROR_MESSAGE): - full_sdata.write(f1, overwrite=True) + full_sdata.write(f1, overwrite=True, sdata_formats=sdata_container_format) + + +def test_incremental_io_in_memory( + full_sdata: SpatialData, +) -> None: + sdata = full_sdata + + for k, v in _get_images().items(): + sdata.images[f"additional_{k}"] = v + with pytest.raises(KeyError, match="Key `table` is not unique"): + sdata["table"] = v + + for k, v in _get_labels().items(): + sdata.labels[f"additional_{k}"] = v + with pytest.raises(KeyError, match="Key `table` is not unique"): + sdata["table"] = v + + for k, v in _get_shapes().items(): + sdata.shapes[f"additional_{k}"] = v + with pytest.raises(KeyError, match="Key `table` is not unique"): + sdata["table"] = v + + for k, v in _get_points().items(): + sdata.points[f"additional_{k}"] = v + with pytest.raises(KeyError, match="Key `table` is not unique"): + sdata["table"] = v + + for k, v in _get_tables(region="labels2d").items(): + sdata.tables[f"additional_{k}"] = v + with pytest.raises(KeyError, match="Key `poly` is not unique"): + sdata["poly"] = v def test_bug_rechunking_after_queried_raster(): @@ -435,14 +536,18 @@ def test_bug_rechunking_after_queried_raster(): images = {"single_scale": single_scale, "multi_scale": multi_scale} sdata = SpatialData(images=images) queried = sdata.query.bounding_box( - axes=("x", "y"), min_coordinate=[2, 5], max_coordinate=[12, 12], target_coordinate_system="global" + axes=("x", "y"), + min_coordinate=[2, 5], + max_coordinate=[12, 12], + target_coordinate_system="global", ) with tempfile.TemporaryDirectory() as tmpdir: f = os.path.join(tmpdir, "data.zarr") queried.write(f) -def test_self_contained(full_sdata: SpatialData) -> None: +@pytest.mark.parametrize("sdata_container_format", SDATA_FORMATS) +def test_self_contained(full_sdata: SpatialData, sdata_container_format: SpatialDataContainerFormatType) -> None: # data only in-memory, so the SpatialData object and all its elements are self-contained assert full_sdata.is_self_contained() description = full_sdata.elements_are_self_contained() @@ -451,7 +556,7 @@ def test_self_contained(full_sdata: SpatialData) -> None: with tempfile.TemporaryDirectory() as tmpdir: # data saved to disk, it's self contained f = os.path.join(tmpdir, "data.zarr") - full_sdata.write(f) + full_sdata.write(f, sdata_formats=sdata_container_format) full_sdata.is_self_contained() # we read the data, so it's self-contained @@ -460,7 +565,7 @@ def test_self_contained(full_sdata: SpatialData) -> None: # we save the data to a new location, so it's not self-contained anymore f2 = os.path.join(tmpdir, "data2.zarr") - sdata2.write(f2) + sdata2.write(f2, sdata_formats=sdata_container_format) assert not sdata2.is_self_contained() # because of the images, labels and points @@ -502,10 +607,13 @@ def test_self_contained(full_sdata: SpatialData) -> None: assert all(description[element_name] for element_name in description if element_name != "combined") -def test_symmetric_different_with_zarr_store(full_sdata: SpatialData) -> None: +@pytest.mark.parametrize("sdata_container_format", SDATA_FORMATS) +def test_symmetric_difference_with_zarr_store( + full_sdata: SpatialData, sdata_container_format: SpatialDataContainerFormatType +) -> None: with tempfile.TemporaryDirectory() as tmpdir: f = os.path.join(tmpdir, "data.zarr") - full_sdata.write(f) + full_sdata.write(f, sdata_formats=sdata_container_format) # the list of element on-disk and in-memory is the same only_in_memory, only_on_disk = full_sdata._symmetric_difference_with_zarr_store() @@ -541,11 +649,12 @@ def test_symmetric_different_with_zarr_store(full_sdata: SpatialData) -> None: } -def test_change_path_of_subset(full_sdata: SpatialData) -> None: +@pytest.mark.parametrize("sdata_container_format", SDATA_FORMATS) +def test_change_path_of_subset(full_sdata: SpatialData, sdata_container_format: SpatialDataContainerFormatType) -> None: """A subset SpatialData object has not Zarr path associated, show that we can reassign the path""" with tempfile.TemporaryDirectory() as tmpdir: f = os.path.join(tmpdir, "data.zarr") - full_sdata.write(f) + full_sdata.write(f, sdata_formats=sdata_container_format) subset = full_sdata.subset(["image2d", "labels2d", "points_0", "circles", "table"]) @@ -558,7 +667,7 @@ def test_change_path_of_subset(full_sdata: SpatialData) -> None: assert len(only_on_disk) > 0 f2 = os.path.join(tmpdir, "data2.zarr") - subset.write(f2) + subset.write(f2, sdata_formats=sdata_container_format) assert subset.is_self_contained() only_in_memory, only_on_disk = subset._symmetric_difference_with_zarr_store() assert len(only_in_memory) == 0 @@ -589,15 +698,18 @@ def _check_valid_name(f: Callable[[str], Any]) -> None: with pytest.raises(ValueError, match="Name cannot start with '__'"): f("__a") with pytest.raises( - ValueError, match="Name must contain only alphanumeric characters, underscores, dots and hyphens." + ValueError, + match="Name must contain only alphanumeric characters, underscores, dots and hyphens.", ): f("has whitespace") with pytest.raises( - ValueError, match="Name must contain only alphanumeric characters, underscores, dots and hyphens." + ValueError, + match="Name must contain only alphanumeric characters, underscores, dots and hyphens.", ): f("this/is/not/valid") with pytest.raises( - ValueError, match="Name must contain only alphanumeric characters, underscores, dots and hyphens." + ValueError, + match="Name must contain only alphanumeric characters, underscores, dots and hyphens.", ): f("non-alnum_#$%&()*+,?@") @@ -608,12 +720,13 @@ def test_incremental_io_valid_name(full_sdata: SpatialData) -> None: _check_valid_name(full_sdata.write_transformations) -def test_incremental_io_attrs(points: SpatialData) -> None: +@pytest.mark.parametrize("sdata_container_format", SDATA_FORMATS) +def test_incremental_io_attrs(points: SpatialData, sdata_container_format: SpatialDataContainerFormatType) -> None: with tempfile.TemporaryDirectory() as tmpdir: f = os.path.join(tmpdir, "data.zarr") my_attrs = {"a": "b", "c": 1} points.attrs = my_attrs - points.write(f) + points.write(f, sdata_formats=sdata_container_format) # test that the attributes are written to disk sdata = SpatialData.read(f) @@ -621,13 +734,13 @@ def test_incremental_io_attrs(points: SpatialData) -> None: # test incremental io attrs (write_attrs()) sdata.attrs["c"] = 2 - sdata.write_attrs() + sdata.write_attrs(sdata_format=sdata_container_format) sdata2 = SpatialData.read(f) assert sdata2.attrs["c"] == 2 # test incremental io attrs (write_metadata()) sdata.attrs["c"] = 3 - sdata.write_metadata() + sdata.write_metadata(sdata_format=sdata_container_format) sdata2 = SpatialData.read(f) assert sdata2.attrs["c"] == 3 @@ -636,19 +749,24 @@ def test_incremental_io_attrs(points: SpatialData) -> None: @pytest.mark.parametrize("element_name", ["image2d", "labels2d", "points_0", "circles", "table"]) -def test_delete_element_from_disk(full_sdata, element_name: str) -> None: +@pytest.mark.parametrize("sdata_container_format", SDATA_FORMATS) +def test_delete_element_from_disk( + full_sdata, + element_name: str, + sdata_container_format: SpatialDataContainerFormatType, +) -> None: # can't delete an element for a SpatialData object without associated Zarr store with pytest.raises(ValueError, match="The SpatialData object is not backed by a Zarr store."): full_sdata.delete_element_from_disk("image2d") with tempfile.TemporaryDirectory() as tmpdir: f = os.path.join(tmpdir, "data.zarr") - full_sdata.write(f) + full_sdata.write(f, sdata_formats=sdata_container_format) # cannot delete an element which is in-memory, but not in the Zarr store subset = full_sdata.subset(["points_0_1"]) f2 = os.path.join(tmpdir, "data2.zarr") - subset.write(f2) + subset.write(f2, sdata_formats=sdata_container_format) full_sdata.path = Path(f2) with pytest.raises( ValueError, @@ -672,7 +790,7 @@ def test_delete_element_from_disk(full_sdata, element_name: str) -> None: assert element_path in only_in_memory # resave it - full_sdata.write_element(element_name) + full_sdata.write_element(element_name, sdata_formats=sdata_container_format) # now delete it from memory, and then show it can still be deleted on-disk del getattr(full_sdata, element_type)[element_name] @@ -682,14 +800,19 @@ def test_delete_element_from_disk(full_sdata, element_name: str) -> None: @pytest.mark.parametrize("element_name", ["image2d", "labels2d", "points_0", "circles", "table"]) -def test_element_already_on_disk_different_type(full_sdata, element_name: str) -> None: +@pytest.mark.parametrize("sdata_container_format", SDATA_FORMATS) +def test_element_already_on_disk_different_type( + full_sdata, + element_name: str, + sdata_container_format: SpatialDataContainerFormatType, +) -> None: # Constructing a corrupted object (element present both on disk and in-memory but with different type). # Attempting to perform and IO operation will trigger an error. # The checks assessed in this test will not be needed anymore after # https://github.com/scverse/spatialdata/issues/504 is addressed with tempfile.TemporaryDirectory() as tmpdir: f = os.path.join(tmpdir, "data.zarr") - full_sdata.write(f) + full_sdata.write(f, sdata_formats=sdata_container_format) element_type = full_sdata._element_type_from_element_name(element_name) wrong_group = "images" if element_type == "tables" else "tables" @@ -709,13 +832,13 @@ def test_element_already_on_disk_different_type(full_sdata, element_name: str) - ValueError, match=ERROR_MSG, ): - full_sdata.write_element(element_name) + full_sdata.write_element(element_name, sdata_formats=sdata_container_format) with pytest.raises( ValueError, match=ERROR_MSG, ): - full_sdata.write_metadata(element_name) + full_sdata.write_metadata(element_name, sdata_format=sdata_container_format) with pytest.raises( ValueError, @@ -731,7 +854,7 @@ def test_writing_invalid_name(tmp_path: Path): invalid_sdata.labels.data["."] = next(iter(_get_labels().values())) invalid_sdata.points.data["path/separator"] = next(iter(_get_points().values())) invalid_sdata.shapes.data["non-alnum_#$%&()*+,?@"] = next(iter(_get_shapes().values())) - invalid_sdata.tables.data["has whitespace"] = _get_table() + invalid_sdata.tables.data["has whitespace"] = _get_table(region="any") with pytest.raises(ValueError, match="Name (must|cannot)"): invalid_sdata.write(tmp_path / "data.zarr") @@ -755,7 +878,7 @@ def test_incremental_writing_invalid_name(tmp_path: Path): invalid_sdata.labels.data["."] = next(iter(_get_labels().values())) invalid_sdata.points.data["path/separator"] = next(iter(_get_points().values())) invalid_sdata.shapes.data["non-alnum_#$%&()*+,?@"] = next(iter(_get_shapes().values())) - invalid_sdata.tables.data["has whitespace"] = _get_table() + invalid_sdata.tables.data["has whitespace"] = _get_table(region="any") for element_type in ["images", "labels", "points", "shapes", "tables"]: elements = getattr(invalid_sdata, element_type) @@ -779,7 +902,7 @@ def test_reading_invalid_name(tmp_path: Path): labels_name, labels = next(iter(_get_labels().items())) points_name, points = next(iter(_get_points().items())) shapes_name, shapes = next(iter(_get_shapes().items())) - table_name, table = "table", _get_table() + table_name, table = "table", _get_table(region="labels2d") valid_sdata = SpatialData( images={image_name: image}, labels={labels_name: labels}, @@ -790,15 +913,54 @@ def test_reading_invalid_name(tmp_path: Path): valid_sdata.write(tmp_path / "data.zarr") # Circumvent validation at construction time and check validation happens again at writing time. (tmp_path / "data.zarr/points" / points_name).rename(tmp_path / "data.zarr/points" / "has whitespace") - (tmp_path / "data.zarr/shapes" / shapes_name).rename(tmp_path / "data.zarr/shapes" / "non-alnum_#$%&()*+,?@") + # This one is not allowed on windows + (tmp_path / "data.zarr/shapes" / shapes_name).rename(tmp_path / "data.zarr/shapes" / "non-alnum_#$%&()+,@") + # We do this as the key of the element is otherwise not in the consolidated metadata, leading to an error. + valid_sdata.write_consolidated_metadata() with pytest.raises(ValidationError, match="Cannot construct SpatialData") as exc_info: read_zarr(tmp_path / "data.zarr") actual_message = str(exc_info.value) assert "points/has whitespace" in actual_message - assert "shapes/non-alnum_#$%&()*+,?@" in actual_message + assert "shapes/non-alnum_#$%&()+,@" in actual_message assert ( "For renaming, please see the discussion here https://github.com/scverse/spatialdata/discussions/707" in actual_message ) + + +@pytest.mark.parametrize("sdata_container_format", SDATA_FORMATS) +def test_write_store_unconsolidated_and_read(full_sdata, sdata_container_format: SpatialDataContainerFormatType): + with tempfile.TemporaryDirectory() as tmpdir: + path = Path(tmpdir) / "data.zarr" + full_sdata.write(path, consolidate_metadata=False, sdata_formats=sdata_container_format) + + group = zarr.open_group(path, mode="r") + assert group.metadata.consolidated_metadata is None + second_read = SpatialData.read(path) + assert_spatial_data_objects_are_identical(full_sdata, second_read) + + +@pytest.mark.parametrize("sdata_container_format", SDATA_FORMATS) +def test_can_read_sdata_with_reconsolidation(full_sdata, sdata_container_format: SpatialDataContainerFormatType): + with tempfile.TemporaryDirectory() as tmpdir: + path = Path(tmpdir) / "data.zarr" + full_sdata.write(path, sdata_formats=sdata_container_format) + + if isinstance(sdata_container_format, SpatialDataContainerFormatV01): + json_path = path / ".zmetadata" + json_dict = json.loads(json_path.read_text()) + # TODO: this raises no exception! + del json_dict["metadata"]["images/image2d/.zgroup"] + else: + json_path = path / "zarr.json" + json_dict = json.loads(json_path.read_text()) + del json_dict["consolidated_metadata"]["metadata"]["images/image2d"] + json_path.write_text(json.dumps(json_dict, indent=4)) + + with pytest.raises(GroupNotFoundError): + SpatialData.read(path) + + new_sdata = SpatialData.read(path, reconsolidate_metadata=True) + assert_spatial_data_objects_are_identical(full_sdata, new_sdata) diff --git a/tests/io/test_utils.py b/tests/io/test_utils.py index f9778f5c7..0a430704f 100644 --- a/tests/io/test_utils.py +++ b/tests/io/test_utils.py @@ -4,6 +4,7 @@ import dask.dataframe as dd import pytest +from upath import UPath from spatialdata import read_zarr from spatialdata._io._utils import get_dask_backing_files, handle_read_errors @@ -37,8 +38,8 @@ def test_backing_files_images(images): computational graph """ with tempfile.TemporaryDirectory() as tmp_dir: - f0 = os.path.join(tmp_dir, "images0.zarr") - f1 = os.path.join(tmp_dir, "images1.zarr") + f0 = UPath(tmp_dir) / "images0.zarr" + f1 = UPath(tmp_dir) / "images1.zarr" images.write(f0) images.write(f1) images0 = read_zarr(f0) @@ -49,7 +50,7 @@ def test_backing_files_images(images): im1 = images1.images["image2d"] im2 = im0 + im1 files = get_dask_backing_files(im2) - expected_zarr_locations = [os.path.realpath(os.path.join(f, "images/image2d")) for f in [f0, f1]] + expected_zarr_locations = [str((f / "images" / "image2d").resolve()) for f in [f0, f1]] assert set(files) == set(expected_zarr_locations) # multiscale @@ -57,7 +58,7 @@ def test_backing_files_images(images): im4 = images1.images["image2d_multiscale"] im5 = im3 + im4 files = get_dask_backing_files(im5) - expected_zarr_locations = [os.path.realpath(os.path.join(f, "images/image2d_multiscale")) for f in [f0, f1]] + expected_zarr_locations = [str((f / "images" / "image2d_multiscale").resolve()) for f in [f0, f1]] assert set(files) == set(expected_zarr_locations) @@ -98,8 +99,8 @@ def test_backing_files_combining_points_and_images(points, images): from examining its computational graph """ with tempfile.TemporaryDirectory() as tmp_dir: - f0 = os.path.join(tmp_dir, "points0.zarr") - f1 = os.path.join(tmp_dir, "images1.zarr") + f0 = UPath(tmp_dir) / "points0.zarr" + f1 = UPath(tmp_dir) / "images1.zarr" points.write(f0) images.write(f1) points0 = read_zarr(f0) @@ -112,12 +113,12 @@ def test_backing_files_combining_points_and_images(points, images): im2 = v + im1 files = get_dask_backing_files(im2) expected_zarr_locations_old = [ - os.path.realpath(os.path.join(f0, "points/points_0/points.parquet")), - os.path.realpath(os.path.join(f1, "images/image2d")), + str((f0 / "points" / "points_0" / "points.parquet").resolve()), + str((f1 / "images" / "image2d").resolve()), ] expected_zarr_locations_new = [ - os.path.realpath(os.path.join(f0, "points/points_0/points.parquet/part.0.parquet")), - os.path.realpath(os.path.join(f1, "images/image2d")), + str((f0 / "points" / "points_0" / "points.parquet" / "part.0.parquet").resolve()), + str((f1 / "images" / "image2d").resolve()), ] assert set(files) == set(expected_zarr_locations_old) or set(files) == set(expected_zarr_locations_new) diff --git a/tests/io/test_versions.py b/tests/io/test_versions.py deleted file mode 100644 index 15b87b970..000000000 --- a/tests/io/test_versions.py +++ /dev/null @@ -1,28 +0,0 @@ -import tempfile -from pathlib import Path - -from spatialdata import read_zarr -from spatialdata._io.format import ShapesFormatV01, ShapesFormatV02 -from spatialdata.testing import assert_spatial_data_objects_are_identical - - -def test_shapes_v1_to_v2(shapes): - with tempfile.TemporaryDirectory() as tmpdir: - f0 = Path(tmpdir) / "data0.zarr" - f1 = Path(tmpdir) / "data1.zarr" - - # write shapes in version 1 - shapes.write(f0, format=ShapesFormatV01()) - - # reading from v1 works - shapes_read = read_zarr(f0) - - assert_spatial_data_objects_are_identical(shapes, shapes_read) - - # write shapes using the v2 version - shapes_read.write(f1, format=ShapesFormatV02()) - - # read again - shapes_read = read_zarr(f1) - - assert_spatial_data_objects_are_identical(shapes, shapes_read) diff --git a/tests/models/test_models.py b/tests/models/test_models.py index 5b27f3ab9..2ed108b7c 100644 --- a/tests/models/test_models.py +++ b/tests/models/test_models.py @@ -71,7 +71,6 @@ def test_validate_axis_name(): validate_axis_name("invalid") -@pytest.mark.ci_only class TestModels: def _parse_transformation_from_multiple_places(self, model: Any, element: Any, **kwargs) -> None: # This function seems convoluted but the idea is simple: sometimes the parser creates a whole new object, @@ -238,10 +237,10 @@ def test_shapes_model(self, model: ShapesModel, path: Path) -> None: assert poly.equals(other_poly) if ShapesModel.RADIUS_KEY in poly.columns: - poly[ShapesModel.RADIUS_KEY].iloc[0] = -1 + poly.loc[0, ShapesModel.RADIUS_KEY] = -1 with pytest.raises(ValueError, match="Radii of circles must be positive."): ShapesModel.validate(poly) - poly[ShapesModel.RADIUS_KEY].iloc[0] = 0 + poly.loc[0, ShapesModel.RADIUS_KEY] = 0 with pytest.raises(ValueError, match="Radii of circles must be positive."): ShapesModel.validate(poly) @@ -338,7 +337,7 @@ def test_points_model( assert "cell_id" in points.attrs["spatialdata_attrs"]["instance_key"] @pytest.mark.parametrize("model", [TableModel]) - @pytest.mark.parametrize("region", ["sample", RNG.choice([1, 2], size=10).tolist()]) + @pytest.mark.parametrize("region", [["sample"] * 10, RNG.choice([1, 2], size=10).tolist()]) def test_table_model( self, model: TableModel, @@ -348,8 +347,9 @@ def test_table_model( obs = pd.DataFrame( RNG.choice(np.arange(0, 100, dtype=float), size=(10, 3), replace=False), columns=["A", "B", "C"], + index=list(map(str, range(10))), ) - obs[region_key] = region + obs[region_key] = pd.Categorical(region) adata = AnnData(RNG.normal(size=(10, 2)), obs=obs) with pytest.raises(TypeError, match="Only int"): model.parse(adata, region=region, region_key=region_key, instance_key="A") @@ -357,8 +357,9 @@ def test_table_model( obs = pd.DataFrame( RNG.choice(np.arange(0, 100), size=(10, 3), replace=False), columns=["A", "B", "C"], + index=list(map(str, range(10))), ) - obs[region_key] = region + obs[region_key] = pd.Categorical(region) adata = AnnData(RNG.normal(size=(10, 2)), obs=obs) table = model.parse(adata, region=region, region_key=region_key, instance_key="A") assert region_key in table.obs @@ -399,7 +400,7 @@ def test_table_model( assert instance_key_ == "A" # let's fix the region_key column - table.obs["B"] = ["element"] * len(table) + table.obs["B"] = pd.Categorical(["element"] * len(table)) _ = TableModel.parse(adata, region="element", region_key="B", instance_key="C", overwrite_metadata=True) region_, region_key_, instance_key_ = get_table_keys(table) @@ -445,7 +446,7 @@ def test_model_not_unique_names(self, full_sdata, element_type: str, names: list @pytest.mark.parametrize("region", [["sample_1"] * 5 + ["sample_2"] * 5]) def test_table_instance_key_values_not_unique(self, model: TableModel, region: str | np.ndarray): region_key = "region" - obs = pd.DataFrame(RNG.integers(0, 100, size=(10, 3)), columns=["A", "B", "C"]) + obs = pd.DataFrame(RNG.integers(0, 100, size=(10, 3)), columns=["A", "B", "C"], index=list(map(str, range(10)))) obs[region_key] = region obs["A"] = [1] * 5 + list(range(5)) adata = AnnData(RNG.normal(size=(10, 2)), obs=obs) @@ -520,14 +521,9 @@ def test_table_model_invalid_names(self, key: str, attr: str, parse: bool): @pytest.mark.parametrize("parse", [True, False]) def test_table_model_not_unique_columns(self, keys: list[str], attr: str, parse: bool): invalid_key = keys[1] - key_regex = re.escape(invalid_key) df = pd.DataFrame([[None] * len(keys)], columns=keys, index=["1"]) adata = AnnData(np.array([[0]]), **{attr: df}) - with pytest.raises( - ValueError, - match=f"Table contains invalid names(.|\n)*\n {attr}/{invalid_key}: " - + f"Key `{key_regex}` is not unique, or another case-variant of it exists.", - ): + with pytest.raises(ValueError, match=f"Table contains invalid names(.|\n)*\n {attr}/{invalid_key}: "): if parse: TableModel.parse(adata) else: @@ -539,7 +535,7 @@ def test_get_schema(): labels = _get_labels() points = _get_points() shapes = _get_shapes() - table = _get_table() + table = _get_table(region="any", region_key="region", instance_key="instance_id") for k, v in images.items(): schema = get_model(v) if "2d" in k: diff --git a/tests/transformations/test_transformations.py b/tests/transformations/test_transformations.py index 180d007d8..b8571dc9a 100644 --- a/tests/transformations/test_transformations.py +++ b/tests/transformations/test_transformations.py @@ -1001,6 +1001,7 @@ def test_keep_numerical_coordinates_c(image_name): def test_keep_string_coordinates_c(image_name): c_coords = ["a", "b", "c"] # n_channels will be ignored, testing also that this works - sdata = blobs(c_coords=c_coords, n_channels=4) + with pytest.warns(UserWarning, match="Number of channels "): + sdata = blobs(c_coords=c_coords, n_channels=4) t_blobs = transform(sdata.images[image_name], to_coordinate_system=DEFAULT_COORDINATE_SYSTEM) assert np.array_equal(get_channel_names(t_blobs), c_coords) diff --git a/tests/utils/test_element_utils.py b/tests/utils/test_element_utils.py index 86e75887d..1bfd20aa4 100644 --- a/tests/utils/test_element_utils.py +++ b/tests/utils/test_element_utils.py @@ -1,7 +1,6 @@ import itertools import dask_image.ndinterp -import pytest import xarray from xarray import DataArray, DataTree @@ -27,7 +26,6 @@ def _pad_raster(data: DataArray, axes: tuple[str, ...]) -> DataArray: return dask_image.ndinterp.affine_transform(data, matrix, output_shape=new_shape) -@pytest.mark.ci_only def test_unpad_raster(images, labels) -> None: for raster in itertools.chain(images.images.values(), labels.labels.values()): schema = get_model(raster) diff --git a/tests/utils/test_sanitize.py b/tests/utils/test_sanitize.py index 082e08b17..b61f19084 100644 --- a/tests/utils/test_sanitize.py +++ b/tests/utils/test_sanitize.py @@ -16,7 +16,8 @@ def invalid_table() -> AnnData: "@invalid#": [1, 2], "valid_name": [3, 4], "__private": [5, 6], - } + }, + index=["0", "1"], ) ) @@ -29,7 +30,8 @@ def invalid_table_with_index() -> AnnData: { "invalid name": [1, 2], "_index": [3, 4], - } + }, + index=["0", "1"], ) ) @@ -103,7 +105,8 @@ def test_sanitize_table_case_insensitive_collisions(): "Column1": [1, 2], "column1": [3, 4], "COLUMN1": [5, 6], - } + }, + index=["0", "1"], ) ad = AnnData(obs=obs) sanitized = sanitize_table(ad, inplace=False) @@ -113,7 +116,7 @@ def test_sanitize_table_case_insensitive_collisions(): def test_sanitize_table_whitespace_collision(): """Ensure 'a b' → 'a_b' doesn't collide silently with existing 'a_b'.""" - obs = pd.DataFrame({"a b": [1], "a_b": [2]}) + obs = pd.DataFrame({"a b": [1], "a_b": [2]}, index=["0"]) ad = AnnData(obs=obs) sanitized = sanitize_table(ad, inplace=False) cols = list(sanitized.obs.columns) @@ -127,13 +130,13 @@ def test_sanitize_table_whitespace_collision(): def test_sanitize_table_obs_and_obs_columns(): - ad = AnnData(obs=pd.DataFrame({"@col": [1, 2]})) + ad = AnnData(obs=pd.DataFrame({"@col": [1, 2]}, index=["0", "1"])) sanitized = sanitize_table(ad, inplace=False) assert list(sanitized.obs.columns) == ["_col"] def test_sanitize_table_obsm_and_obsp(): - ad = AnnData(obs=pd.DataFrame({"@col": [1, 2]})) + ad = AnnData(obs=pd.DataFrame({"@col": [1, 2]}, index=["0", "1"])) ad.obsm["@col"] = np.array([[1, 2], [3, 4]]) ad.obsp["bad name"] = np.array([[1, 2], [3, 4]]) sanitized = sanitize_table(ad, inplace=False) @@ -142,7 +145,7 @@ def test_sanitize_table_obsm_and_obsp(): def test_sanitize_table_varm_and_varp(): - ad = AnnData(obs=pd.DataFrame({"x": [1, 2]}), var=pd.DataFrame(index=["v1", "v2"])) + ad = AnnData(obs=pd.DataFrame({"x": [1, 2]}, index=["0", "1"]), var=pd.DataFrame(index=["v1", "v2"])) ad.varm["__priv"] = np.array([[1, 2], [3, 4]]) ad.varp["_index"] = np.array([[1, 2], [3, 4]]) sanitized = sanitize_table(ad, inplace=False) @@ -151,7 +154,7 @@ def test_sanitize_table_varm_and_varp(): def test_sanitize_table_uns_and_layers(): - ad = AnnData(obs=pd.DataFrame({"x": [1, 2]}), var=pd.DataFrame(index=["v1", "v2"])) + ad = AnnData(obs=pd.DataFrame({"x": [1, 2]}, index=["0", "1"]), var=pd.DataFrame(index=["v1", "v2"])) ad.uns["bad@key"] = "val" ad.layers["bad#layer"] = np.array([[0, 1], [1, 0]]) sanitized = sanitize_table(ad, inplace=False) @@ -168,7 +171,7 @@ def test_sanitize_table_empty_returns_empty(): def test_sanitize_table_preserves_underlying_data(): - ad = AnnData(obs=pd.DataFrame({"@invalid#": [1, 2], "valid": [3, 4]})) + ad = AnnData(obs=pd.DataFrame({"@invalid#": [1, 2], "valid": [3, 4]}, index=["0", "1"])) ad.obsm["@invalid#"] = np.array([[1, 2], [3, 4]]) ad.uns["invalid@key"] = "value" sanitized = sanitize_table(ad, inplace=False) @@ -198,7 +201,7 @@ def test_sanitize_table_in_spatialdata_sanitized_fixture(invalid_table, invalid_ def test_spatialdata_retains_other_elements(full_sdata): # Add another sanitized table into an existing full_sdata - tbl = AnnData(obs=pd.DataFrame({"@foo#": [1, 2], "bar": [3, 4]})) + tbl = AnnData(obs=pd.DataFrame({"@foo#": [1, 2], "bar": [3, 4]}, index=["0", "1"])) sanitize_table(tbl) full_sdata.tables["new_table"] = tbl