diff --git a/pyproject.toml b/pyproject.toml index 1f16398d9..77e6edecb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,13 +27,14 @@ dependencies = [ "dask-image", "dask>=2024.4.1,<=2024.11.2", "datashader", - "fsspec", + "fsspec[s3,http]", "geopandas>=0.14", "multiscale_spatial_image>=2.0.3", "networkx", "numba>=0.55.0", "numpy", "ome_zarr>=0.8.4", + "universal_pathlib>=0.2.6", "pandas", "pooch", "pyarrow", @@ -59,6 +60,7 @@ test = [ "pytest-cov", "pytest-mock", "torch", + "moto[s3,server]" ] docs = [ "sphinx>=4.5", diff --git a/src/spatialdata/_core/query/relational_query.py b/src/spatialdata/_core/query/relational_query.py index 0803158ca..96b774d8f 100644 --- a/src/spatialdata/_core/query/relational_query.py +++ b/src/spatialdata/_core/query/relational_query.py @@ -716,7 +716,7 @@ def _call_join( return elements_dict, table -def match_table_to_element(sdata: SpatialData, element_name: str, table_name: str = "table") -> AnnData: +def match_table_to_element(sdata: SpatialData, element_name: str, table_name: str) -> AnnData: """ Filter the table and reorders the rows to match the instances (rows/labels) of the specified SpatialElement. @@ -738,14 +738,6 @@ def match_table_to_element(sdata: SpatialData, element_name: str, table_name: st match_element_to_table : Function to match a spatial element to a table. join_spatialelement_table : General function, to join spatial elements with a table with more control. """ - if table_name is None: - warnings.warn( - "Assumption of table with name `table` being present is being deprecated in SpatialData v0.1. " - "Please provide the name of the table as argument to table_name.", - DeprecationWarning, - stacklevel=2, - ) - table_name = "table" _, table = join_spatialelement_table( sdata=sdata, spatial_element_names=element_name, table_name=table_name, how="left", match_rows="left" ) diff --git a/src/spatialdata/_core/spatialdata.py b/src/spatialdata/_core/spatialdata.py index 48f6386ca..bcd993d70 100644 --- a/src/spatialdata/_core/spatialdata.py +++ b/src/spatialdata/_core/spatialdata.py @@ -17,7 +17,6 @@ from dask.delayed import Delayed from geopandas import GeoDataFrame from ome_zarr.io import parse_url -from ome_zarr.types import JSONDict from shapely import MultiPolygon, Polygon from xarray import DataArray, DataTree @@ -30,11 +29,8 @@ validate_table_attr_keys, ) from spatialdata._logging import logger -from spatialdata._types import ArrayLike, Raster_T -from spatialdata._utils import ( - _deprecation_alias, - _error_message_add_element, -) +from spatialdata._types import ArrayLike, Raster_T, StoreLike +from spatialdata._utils import _deprecation_alias from spatialdata.models import ( Image2DModel, Image3DModel, @@ -601,7 +597,7 @@ def path(self, value: Path | None) -> None: ) def _get_groups_for_element( - self, zarr_path: Path, element_type: str, element_name: str + self, zarr_path: StoreLike, element_type: str, element_name: str ) -> tuple[zarr.Group, zarr.Group, zarr.Group]: """ Get the Zarr groups for the root, element_type and element for a specific element. @@ -621,9 +617,9 @@ def _get_groups_for_element( ------- either the existing Zarr subgroup or a new one. """ - if not isinstance(zarr_path, Path): - raise ValueError("zarr_path should be a Path object") - store = parse_url(zarr_path, mode="r+").store + from spatialdata._io._utils import _open_zarr_store + + store = _open_zarr_store(zarr_path, mode="r+") root = zarr.group(store=store) if element_type not in ["images", "labels", "points", "polygons", "shapes", "tables"]: raise ValueError(f"Unknown element type {element_type}") @@ -1068,9 +1064,12 @@ def elements_paths_on_disk(self) -> list[str]: ------- A list of paths of the elements saved in the Zarr store. """ + from spatialdata._io._utils import _open_zarr_store + if self.path is None: raise ValueError("The SpatialData object is not backed by a Zarr store.") - store = parse_url(self.path, mode="r").store + + store = _open_zarr_store(self.path) root = zarr.group(store=store) elements_in_zarr = [] @@ -1205,12 +1204,16 @@ def write( :class:`~spatialdata._io.format.CurrentRasterFormat`, :class:`~spatialdata._io.format.CurrentShapesFormat`, :class:`~spatialdata._io.format.CurrentPointsFormat`, :class:`~spatialdata._io.format.CurrentTablesFormat`. """ + from spatialdata._io._utils import _open_zarr_store + if isinstance(file_path, str): file_path = Path(file_path) - self._validate_can_safely_write_to_path(file_path, overwrite=overwrite) - self._validate_all_elements() + if isinstance(file_path, Path): + # TODO: also validate remote paths + self._validate_can_safely_write_to_path(file_path, overwrite=overwrite) + self._validate_all_elements() - store = parse_url(file_path, mode="w").store + store = _open_zarr_store(file_path, mode="w") zarr_group = zarr.group(store=store, overwrite=overwrite) self.write_attrs(zarr_group=zarr_group) store.close() @@ -1236,20 +1239,22 @@ def write( def _write_element( self, element: SpatialElement | AnnData, - zarr_container_path: Path, + zarr_container_path: StoreLike, element_type: str, element_name: str, overwrite: bool, format: SpatialDataFormat | list[SpatialDataFormat] | None = None, ) -> None: - if not isinstance(zarr_container_path, Path): + if not isinstance(zarr_container_path, StoreLike): raise ValueError( - f"zarr_container_path must be a Path object, type(zarr_container_path) = {type(zarr_container_path)}." + f"zarr_container_path must be a 'StoreLike' object " + f"(str | Path | UPath | zarr.storage.StoreLike | zarr.Group), got: {type(zarr_container_path)}." + ) + if isinstance(zarr_container_path, Path): + file_path_of_element = zarr_container_path / element_type / element_name + self._validate_can_safely_write_to_path( + file_path=file_path_of_element, overwrite=overwrite, saving_an_element=True ) - file_path_of_element = zarr_container_path / element_type / element_name - self._validate_can_safely_write_to_path( - file_path=file_path_of_element, overwrite=overwrite, saving_an_element=True - ) root_group, element_type_group, _ = self._get_groups_for_element( zarr_path=zarr_container_path, element_type=element_type, element_name=element_name @@ -1259,14 +1264,27 @@ def _write_element( parsed = _parse_formats(formats=format) + # We pass on zarr_container_path to ensure proper paths when writing to remote system even when on windows. if element_type == "images": write_image(image=element, group=element_type_group, name=element_name, format=parsed["raster"]) elif element_type == "labels": write_labels(labels=element, group=root_group, name=element_name, format=parsed["raster"]) elif element_type == "points": - write_points(points=element, group=element_type_group, name=element_name, format=parsed["points"]) + write_points( + points=element, + group=element_type_group, + name=element_name, + zarr_container_path=zarr_container_path, + format=parsed["points"], + ) elif element_type == "shapes": - write_shapes(shapes=element, group=element_type_group, name=element_name, format=parsed["shapes"]) + write_shapes( + shapes=element, + group=element_type_group, + name=element_name, + zarr_container_path=zarr_container_path, + format=parsed["shapes"], + ) elif element_type == "tables": write_table(table=element, group=element_type_group, name=element_name, format=parsed["tables"]) else: @@ -1376,7 +1394,7 @@ def delete_element_from_disk(self, element_name: str | list[str]) -> None: self.delete_element_from_disk(name) return - from spatialdata._io._utils import _backed_elements_contained_in_path + from spatialdata._io._utils import _backed_elements_contained_in_path, _open_zarr_store if self.path is None: raise ValueError("The SpatialData object is not backed by a Zarr store.") @@ -1417,7 +1435,7 @@ def delete_element_from_disk(self, element_name: str | list[str]) -> None: ) # delete the element - store = parse_url(self.path, mode="r+").store + store = _open_zarr_store(self.path) root = zarr.group(store=store) root[element_type].pop(element_name) store.close() @@ -1438,15 +1456,24 @@ def _check_element_not_on_disk_with_different_type(self, element_type: str, elem ) def write_consolidated_metadata(self) -> None: - store = parse_url(self.path, mode="r+").store - # consolidate metadata to more easily support remote reading bug in zarr. In reality, 'zmetadata' is written - # instead of '.zmetadata' see discussion https://github.com/zarr-developers/zarr-python/issues/1121 - zarr.consolidate_metadata(store, metadata_key=".zmetadata") + from spatialdata._io._utils import _open_zarr_store + + store = _open_zarr_store(self.path) + # Note that the store can be local (which does not have the zmetadata bug) + # or a remote FSStore (which has the bug). + # Consolidate metadata to more easily support remote reading bug in zarr. + # We write 'zmetadata' instead of the standard '.zmetadata' to avoid the FSStore bug. + # See discussion https://github.com/zarr-developers/zarr-python/issues/1121 + zarr.consolidate_metadata(store, metadata_key="zmetadata") store.close() def has_consolidated_metadata(self) -> bool: + from spatialdata._io._utils import _open_zarr_store + return_value = False - store = parse_url(self.path, mode="r").store + store = _open_zarr_store(self.path) + # Note that the store can be local (which does not have the zmetadata bug) + # or a remote FSStore (which has the bug). if "zmetadata" in store: return_value = True store.close() @@ -1575,15 +1602,11 @@ def write_transformations(self, element_name: str | None = None) -> None: ) axes = get_axes_names(element) if isinstance(element, DataArray | DataTree): - from spatialdata._io._utils import ( - overwrite_coordinate_transformations_raster, - ) + from spatialdata._io._utils import overwrite_coordinate_transformations_raster overwrite_coordinate_transformations_raster(group=element_group, axes=axes, transformations=transformations) elif isinstance(element, DaskDataFrame | GeoDataFrame | AnnData): - from spatialdata._io._utils import ( - overwrite_coordinate_transformations_non_raster, - ) + from spatialdata._io._utils import overwrite_coordinate_transformations_non_raster overwrite_coordinate_transformations_non_raster( group=element_group, axes=axes, transformations=transformations @@ -1792,41 +1815,16 @@ def table(self) -> None | AnnData: ------- The table. """ - warnings.warn( - "Table accessor will be deprecated with SpatialData version 0.1, use sdata.tables instead.", - DeprecationWarning, - stacklevel=2, - ) - # Isinstance will still return table if anndata has 0 rows. - if isinstance(self.tables.get("table"), AnnData): - return self.tables["table"] - return None + raise AttributeError("The property 'table' is deprecated. use '.tables' instead.") @table.setter def table(self, table: AnnData) -> None: - warnings.warn( - "Table setter will be deprecated with SpatialData version 0.1, use tables instead.", - DeprecationWarning, - stacklevel=2, - ) - TableModel().validate(table) - if self.tables.get("table") is not None: - raise ValueError("The table already exists. Use del sdata.tables['table'] to remove it first.") - self.tables["table"] = table + raise AttributeError("The property 'table' is deprecated. use '.tables' instead.") @table.deleter def table(self) -> None: """Delete the table.""" - warnings.warn( - "del sdata.table will be deprecated with SpatialData version 0.1, use del sdata.tables['table'] instead.", - DeprecationWarning, - stacklevel=2, - ) - if self.tables.get("table"): - del self.tables["table"] - else: - # More informative than the error in the zarr library. - raise KeyError("table with name 'table' not present in the SpatialData object.") + raise AttributeError("The property 'table' is deprecated. use '.tables' instead.") @staticmethod def read(file_path: Path | str, selection: tuple[str] | None = None) -> SpatialData: @@ -1848,44 +1846,6 @@ def read(file_path: Path | str, selection: tuple[str] | None = None) -> SpatialD return read_zarr(file_path, selection=selection) - def add_image( - self, - name: str, - image: DataArray | DataTree, - storage_options: JSONDict | list[JSONDict] | None = None, - overwrite: bool = False, - ) -> None: - """Deprecated. Use `sdata[name] = image` instead.""" # noqa: D401 - _error_message_add_element() - - def add_labels( - self, - name: str, - labels: DataArray | DataTree, - storage_options: JSONDict | list[JSONDict] | None = None, - overwrite: bool = False, - ) -> None: - """Deprecated. Use `sdata[name] = labels` instead.""" # noqa: D401 - _error_message_add_element() - - def add_points( - self, - name: str, - points: DaskDataFrame, - overwrite: bool = False, - ) -> None: - """Deprecated. Use `sdata[name] = points` instead.""" # noqa: D401 - _error_message_add_element() - - def add_shapes( - self, - name: str, - shapes: GeoDataFrame, - overwrite: bool = False, - ) -> None: - """Deprecated. Use `sdata[name] = shapes` instead.""" # noqa: D401 - _error_message_add_element() - @property def images(self) -> Images: """Return images as a Dict of name to image data.""" diff --git a/src/spatialdata/_io/_utils.py b/src/spatialdata/_io/_utils.py index 5e8eb832d..e1756d6bd 100644 --- a/src/spatialdata/_io/_utils.py +++ b/src/spatialdata/_io/_utils.py @@ -13,14 +13,18 @@ from pathlib import Path from typing import Any, Literal -import zarr +import zarr.storage from anndata import AnnData from dask.array import Array as DaskArray from dask.dataframe import DataFrame as DaskDataFrame from geopandas import GeoDataFrame +from upath import UPath +from upath.implementations.local import PosixUPath, WindowsUPath from xarray import DataArray, DataTree +from zarr.storage import FSStore from spatialdata._core.spatialdata import SpatialData +from spatialdata._types import StoreLike from spatialdata._utils import get_pyramid_levels from spatialdata.models._utils import ( MappingToCoordinateSystem_t, @@ -383,6 +387,48 @@ def save_transformations(sdata: SpatialData) -> None: sdata.write_transformations() +def _open_zarr_store(path: StoreLike, **kwargs: Any) -> zarr.storage.BaseStore: + # TODO: ensure kwargs like mode are enforced everywhere and passed correctly to the store + if isinstance(path, str | Path): + # if the input is str or Path, map it to UPath + path = UPath(path) + if isinstance(path, PosixUPath | WindowsUPath): + # if the input is a local path, use DirectoryStore + return zarr.storage.DirectoryStore(path.path, dimension_separator="/") + if isinstance(path, zarr.Group): + # if the input is a zarr.Group, wrap it with a store + if isinstance(path.store, zarr.storage.DirectoryStore): + # create a simple FSStore if the store is a DirectoryStore with just the path + return FSStore(os.path.join(path.store.path, path.path), **kwargs) + if isinstance(path.store, FSStore): + # if the store within the zarr.Group is an FSStore, return it + # but extend the path of the store with that of the zarr.Group + return FSStore(path.store.path + "/" + path.path, fs=path.store.fs, **kwargs) + if isinstance(path.store, zarr.storage.ConsolidatedMetadataStore): + # if the store is a ConsolidatedMetadataStore, just return the underlying FSSpec store + return path.store.store + raise ValueError(f"Unsupported store type or zarr.Group: {type(path.store)}") + if isinstance(path, zarr.storage.StoreLike): + # if the input already a store, wrap it in an FSStore + return FSStore(path, **kwargs) + if isinstance(path, UPath): + # if input is a remote UPath, map it to an FSStore + return FSStore(path.path, fs=path.fs, **kwargs) + raise TypeError(f"Unsupported type: {type(path)}") + + +def _create_upath(path: StoreLike) -> UPath | None: + # try to create a UPath from the input + if isinstance(path, zarr.storage.ConsolidatedMetadataStore): + path = path.store # get the fsstore from the consolidated store + if isinstance(path, FSStore): + protocol = path.fs.protocol if isinstance(path.fs.protocol, str) else path.fs.protocol[0] + return UPath(path.path, protocol=protocol, **path.fs.storage_options) + if isinstance(path, zarr.storage.DirectoryStore): + return UPath(path.path) + return None + + class BadFileHandleMethod(Enum): ERROR = "error" WARN = "warn" diff --git a/src/spatialdata/_io/format.py b/src/spatialdata/_io/format.py index 5ee675be6..910cc7c8b 100644 --- a/src/spatialdata/_io/format.py +++ b/src/spatialdata/_io/format.py @@ -253,7 +253,7 @@ def validate_table( def format_implementations() -> Iterator[Format]: """Return an instance of each format implementation, newest to oldest.""" yield RasterFormatV02() - # yield RasterFormatV01() # same format string as FormatV04 + yield RasterFormatV01() # same format string as FormatV04 yield FormatV04() yield FormatV03() yield FormatV02() diff --git a/src/spatialdata/_io/io_points.py b/src/spatialdata/_io/io_points.py index 3106c8470..feed13e3e 100644 --- a/src/spatialdata/_io/io_points.py +++ b/src/spatialdata/_io/io_points.py @@ -1,4 +1,3 @@ -import os from collections.abc import MutableMapping from pathlib import Path @@ -13,6 +12,7 @@ overwrite_coordinate_transformations_non_raster, ) from spatialdata._io.format import CurrentPointsFormat, PointsFormats, _parse_version +from spatialdata._types import StoreLike from spatialdata.models import get_axes_names from spatialdata.transformations._utils import ( _get_transformations, @@ -24,17 +24,12 @@ def _read_points( store: str | Path | MutableMapping | zarr.Group, # type: ignore[type-arg] ) -> DaskDataFrame: """Read points from a zarr store.""" - assert isinstance(store, str | Path) - f = zarr.open(store, mode="r") - + f = zarr.open(store, mode="r") if isinstance(store, str | Path | MutableMapping) else store version = _parse_version(f, expect_attrs_key=True) assert version is not None format = PointsFormats[version] - path = os.path.join(f._store.path, f.path, "points.parquet") - # cache on remote file needed for parquet reader to work - # TODO: allow reading in the metadata without caching all the data - points = read_parquet("simplecache::" + path if path.startswith("http") else path) + points = read_parquet(f.store.path, filesystem=f.store.fs) assert isinstance(points, DaskDataFrame) transformations = _get_transformations_from_ngff_dict(f.attrs.asdict()["coordinateTransformations"]) @@ -50,6 +45,7 @@ def write_points( points: DaskDataFrame, group: zarr.Group, name: str, + zarr_container_path: StoreLike, group_type: str = "ngff:points", format: Format = CurrentPointsFormat(), ) -> None: @@ -57,7 +53,8 @@ def write_points( t = _get_transformations(points) points_groups = group.require_group(name) - path = Path(points_groups._store.path) / points_groups.path / "points.parquet" + store = points_groups._store + path = zarr_container_path / points_groups.path / "points.parquet" # The following code iterates through all columns in the 'points' DataFrame. If the column's datatype is # 'category', it checks whether the categories of this column are known. If not, it explicitly converts the @@ -70,7 +67,7 @@ def write_points( c = c.cat.as_known() points[column_name] = c - points.to_parquet(path) + points.to_parquet(path, filesystem=getattr(store, "fs", None)) attrs = format.attrs_to_dict(points.attrs) attrs["version"] = format.spatialdata_format_version diff --git a/src/spatialdata/_io/io_raster.py b/src/spatialdata/_io/io_raster.py index 541be3ead..824b17ff0 100644 --- a/src/spatialdata/_io/io_raster.py +++ b/src/spatialdata/_io/io_raster.py @@ -1,9 +1,9 @@ -from pathlib import Path from typing import Any, Literal import dask.array as da import numpy as np import zarr +import zarr.storage from ome_zarr.format import Format from ome_zarr.io import ZarrLocation from ome_zarr.reader import Label, Multiscales, Node, Reader @@ -13,18 +13,14 @@ from ome_zarr.writer import write_labels as write_labels_ngff from ome_zarr.writer import write_multiscale as write_multiscale_ngff from ome_zarr.writer import write_multiscale_labels as write_multiscale_labels_ngff +from upath import UPath from xarray import DataArray, Dataset, DataTree from spatialdata._io._utils import ( _get_transformations_from_ngff_dict, overwrite_coordinate_transformations_raster, ) -from spatialdata._io.format import ( - CurrentRasterFormat, - RasterFormats, - RasterFormatV01, - _parse_version, -) +from spatialdata._io.format import CurrentRasterFormat, RasterFormats, RasterFormatV01, _parse_version from spatialdata._utils import get_pyramid_levels from spatialdata.models._utils import get_channel_names from spatialdata.models.models import ATTRS_KEY @@ -36,19 +32,19 @@ ) -def _read_multiscale(store: str | Path, raster_type: Literal["image", "labels"]) -> DataArray | DataTree: - assert isinstance(store, str | Path) +def _read_multiscale(store: zarr.storage.BaseStore, raster_type: Literal["image", "labels"]) -> DataArray | DataTree: assert raster_type in ["image", "labels"] - - f = zarr.open(store, mode="r") - version = _parse_version(f, expect_attrs_key=True) + if isinstance(store, str | UPath): + raise NotImplementedError("removed in this PR") + group = zarr.group(store=store) + version = _parse_version(group, expect_attrs_key=True) # old spatialdata datasets don't have format metadata for raster elements; this line ensure backwards compatibility, # interpreting the lack of such information as the presence of the format v01 format = RasterFormatV01() if version is None else RasterFormats[version] - f.store.close() + store.close() nodes: list[Node] = [] - image_loc = ZarrLocation(store) + image_loc = ZarrLocation(store, fmt=format) if image_loc.exists(): image_reader = Reader(image_loc)() image_nodes = list(image_reader) diff --git a/src/spatialdata/_io/io_shapes.py b/src/spatialdata/_io/io_shapes.py index c32ce1f34..514912c78 100644 --- a/src/spatialdata/_io/io_shapes.py +++ b/src/spatialdata/_io/io_shapes.py @@ -12,26 +12,17 @@ _write_metadata, overwrite_coordinate_transformations_non_raster, ) -from spatialdata._io.format import ( - CurrentShapesFormat, - ShapesFormats, - ShapesFormatV01, - ShapesFormatV02, - _parse_version, -) +from spatialdata._io.format import CurrentShapesFormat, ShapesFormats, ShapesFormatV01, ShapesFormatV02, _parse_version +from spatialdata._types import StoreLike from spatialdata.models import ShapesModel, get_axes_names -from spatialdata.transformations._utils import ( - _get_transformations, - _set_transformations, -) +from spatialdata.transformations._utils import _get_transformations, _set_transformations def _read_shapes( store: str | Path | MutableMapping | zarr.Group, # type: ignore[type-arg] ) -> GeoDataFrame: """Read shapes from a zarr store.""" - assert isinstance(store, str | Path) - f = zarr.open(store, mode="r") + f = zarr.open(store, mode="r") if isinstance(store, str | Path | MutableMapping) else store version = _parse_version(f, expect_attrs_key=True) assert version is not None format = ShapesFormats[version] @@ -50,8 +41,7 @@ def _read_shapes( geometry = from_ragged_array(typ, coords, offsets) geo_df = GeoDataFrame({"geometry": geometry}, index=index) elif isinstance(format, ShapesFormatV02): - path = Path(f._store.path) / f.path / "shapes.parquet" - geo_df = read_parquet(path) + geo_df = read_parquet(f.store.path, filesystem=f.store.fs) else: raise ValueError( f"Unsupported shapes format {format} from version {version}. Please update the spatialdata library." @@ -66,6 +56,7 @@ def write_shapes( shapes: GeoDataFrame, group: zarr.Group, name: str, + zarr_container_path: StoreLike, group_type: str = "ngff:shapes", format: Format = CurrentShapesFormat(), ) -> None: @@ -93,8 +84,11 @@ def write_shapes( attrs = format.attrs_to_dict(geometry) attrs["version"] = format.spatialdata_format_version elif isinstance(format, ShapesFormatV02): - path = Path(shapes_group._store.path) / shapes_group.path / "shapes.parquet" - shapes.to_parquet(path) + store = shapes_group._store + path = zarr_container_path / shapes_group.path / "shapes.parquet" + + # Geopandas only allows path-like objects for local filesystems and not remote ones. + shapes.to_parquet(str(path), filesystem=getattr(store, "fs", None)) attrs = format.attrs_to_dict(shapes.attrs) attrs["version"] = format.spatialdata_format_version diff --git a/src/spatialdata/_io/io_table.py b/src/spatialdata/_io/io_table.py index 92ff64b94..904eeba08 100644 --- a/src/spatialdata/_io/io_table.py +++ b/src/spatialdata/_io/io_table.py @@ -1,4 +1,5 @@ -import os +from __future__ import annotations + from json import JSONDecodeError from typing import Literal @@ -46,22 +47,19 @@ def _read_table( count = 0 for table_name in subgroup: f_elem = subgroup[table_name] - f_elem_store = os.path.join(zarr_store_path, f_elem.path) with handle_read_errors( on_bad_files=on_bad_files, location=f"{subgroup.path}/{table_name}", exc_types=(JSONDecodeError, KeyError, ValueError, ArrayNotFoundError), ): - tables[table_name] = read_anndata_zarr(f_elem_store) + tables[table_name] = read_anndata_zarr(f_elem) - f = zarr.open(f_elem_store, mode="r") - version = _parse_version(f, expect_attrs_key=False) + version = _parse_version(f_elem, expect_attrs_key=False) assert version is not None # since have just one table format, we currently read it but do not use it; if we ever change the format # we can rename the two _ to format and implement the per-format read logic (as we do for shapes) _ = TablesFormats[version] - f.store.close() # # replace with format from above # version = "0.1" diff --git a/src/spatialdata/_io/io_zarr.py b/src/spatialdata/_io/io_zarr.py index 224ef1129..e94f69eb6 100644 --- a/src/spatialdata/_io/io_zarr.py +++ b/src/spatialdata/_io/io_zarr.py @@ -1,47 +1,100 @@ import logging -import os import warnings from json import JSONDecodeError -from pathlib import Path from typing import Literal import zarr from anndata import AnnData +from dask.dataframe import DataFrame as DaskDataFrame +from geopandas import GeoDataFrame from pyarrow import ArrowInvalid +from xarray import DataArray, DataTree from zarr.errors import ArrayNotFoundError, MetadataError from spatialdata._core.spatialdata import SpatialData -from spatialdata._io._utils import BadFileHandleMethod, handle_read_errors, ome_zarr_logger +from spatialdata._io._utils import ( + BadFileHandleMethod, + _create_upath, + _open_zarr_store, + handle_read_errors, + ome_zarr_logger, +) from spatialdata._io.io_points import _read_points from spatialdata._io.io_raster import _read_multiscale from spatialdata._io.io_shapes import _read_shapes from spatialdata._io.io_table import _read_table from spatialdata._logging import logger +from spatialdata._types import StoreLike -def _open_zarr_store(store: str | Path | zarr.Group) -> tuple[zarr.Group, str]: +def is_hidden_zarr_entry(name: str) -> bool: + """Skip hidden files like '.zgroup' or '.zmetadata'.""" + return name.rpartition("/")[2].startswith(".") + + +def read_image_element(path: StoreLike) -> DataArray | DataTree: + """Read a single image element from a store location. + + Parameters + ---------- + path + Path to the zarr store. + + Returns + ------- + A DataArray or DataTree object. """ - Open a zarr store (on-disk or remote) and return the zarr.Group object and the path to the store. + # stay in sync with ome v4 format spec: + # https://github.com/ome/ome-zarr-py/blob/7d1ae35c97/ome_zarr/format.py#L189-L192 + store = _open_zarr_store(path, normalize_keys=False) + return _read_multiscale(store, raster_type="image") + + +def read_labels_element(path: StoreLike) -> DataArray | DataTree: + """Read a single image element from a store location. Parameters ---------- - store - Path to the zarr store (on-disk or remote) or a zarr.Group object. + path + Path to the zarr store. Returns ------- - A tuple of the zarr.Group object and the path to the store. + A DataArray or DataTree object. """ - f = store if isinstance(store, zarr.Group) else zarr.open(store, mode="r") - # workaround: .zmetadata is being written as zmetadata (https://github.com/zarr-developers/zarr-python/issues/1121) - if isinstance(store, str | Path) and str(store).startswith("http") and len(f) == 0: - f = zarr.open_consolidated(store, mode="r", metadata_key="zmetadata") - f_store_path = f.store.store.path if isinstance(f.store, zarr.storage.ConsolidatedMetadataStore) else f.store.path - return f, f_store_path + store = _open_zarr_store(path) + return _read_multiscale(store, raster_type="labels") + + +def read_points_element(path: StoreLike) -> DaskDataFrame: + store = _open_zarr_store(path) + return _read_points(store) + + +def read_shapes_element(path: StoreLike) -> GeoDataFrame: + store = _open_zarr_store(path) + return _read_shapes(store) + + +def read_tables_element( + zarr_store_path: StoreLike, + group: zarr.Group, + subgroup: zarr.Group, + tables: dict[str, AnnData], + on_bad_files: Literal[BadFileHandleMethod.ERROR, BadFileHandleMethod.WARN] = BadFileHandleMethod.ERROR, +) -> dict[str, AnnData]: + store = _open_zarr_store(zarr_store_path) + return _read_table( + store, + group, + subgroup, + tables, + on_bad_files, + ) def read_zarr( - store: str | Path | zarr.Group, + store: StoreLike, selection: None | tuple[str] = None, on_bad_files: Literal[BadFileHandleMethod.ERROR, BadFileHandleMethod.WARN] = BadFileHandleMethod.ERROR, ) -> SpatialData: @@ -71,7 +124,8 @@ def read_zarr( ------- A SpatialData object. """ - f, f_store_path = _open_zarr_store(store) + _store = _open_zarr_store(store) + f = zarr.group(_store) images = {} labels = {} @@ -93,11 +147,9 @@ def read_zarr( group = f["images"] count = 0 for subgroup_name in group: - if Path(subgroup_name).name.startswith("."): - # skip hidden files like .zgroup or .zmetadata + if is_hidden_zarr_entry(subgroup_name): continue f_elem = group[subgroup_name] - f_elem_store = os.path.join(f_store_path, f_elem.path) with handle_read_errors( on_bad_files, location=f"{group.path}/{subgroup_name}", @@ -109,7 +161,7 @@ def read_zarr( TypeError, # instead of ArrayNotFoundError, with dask>=2024.10.0 zarr<=2.18.3 ), ): - element = _read_multiscale(f_elem_store, raster_type="image") + element = read_image_element(f_elem) images[subgroup_name] = element count += 1 logger.debug(f"Found {count} elements in {group}") @@ -125,17 +177,15 @@ def read_zarr( group = f["labels"] count = 0 for subgroup_name in group: - if Path(subgroup_name).name.startswith("."): - # skip hidden files like .zgroup or .zmetadata + if is_hidden_zarr_entry(subgroup_name): continue f_elem = group[subgroup_name] - f_elem_store = os.path.join(f_store_path, f_elem.path) with handle_read_errors( on_bad_files, location=f"{group.path}/{subgroup_name}", exc_types=(JSONDecodeError, KeyError, ValueError, ArrayNotFoundError, TypeError), ): - labels[subgroup_name] = _read_multiscale(f_elem_store, raster_type="labels") + labels[subgroup_name] = read_labels_element(f_elem) count += 1 logger.debug(f"Found {count} elements in {group}") @@ -150,16 +200,14 @@ def read_zarr( count = 0 for subgroup_name in group: f_elem = group[subgroup_name] - if Path(subgroup_name).name.startswith("."): - # skip hidden files like .zgroup or .zmetadata + if is_hidden_zarr_entry(subgroup_name): continue - f_elem_store = os.path.join(f_store_path, f_elem.path) with handle_read_errors( on_bad_files, location=f"{group.path}/{subgroup_name}", exc_types=(JSONDecodeError, KeyError, ArrowInvalid), ): - points[subgroup_name] = _read_points(f_elem_store) + points[subgroup_name] = read_points_element(f_elem) count += 1 logger.debug(f"Found {count} elements in {group}") @@ -172,11 +220,9 @@ def read_zarr( group = f["shapes"] count = 0 for subgroup_name in group: - if Path(subgroup_name).name.startswith("."): - # skip hidden files like .zgroup or .zmetadata + if is_hidden_zarr_entry(subgroup_name): continue f_elem = group[subgroup_name] - f_elem_store = os.path.join(f_store_path, f_elem.path) with handle_read_errors( on_bad_files, location=f"{group.path}/{subgroup_name}", @@ -187,7 +233,7 @@ def read_zarr( ArrayNotFoundError, ), ): - shapes[subgroup_name] = _read_shapes(f_elem_store) + shapes[subgroup_name] = read_shapes_element(f_elem) count += 1 logger.debug(f"Found {count} elements in {group}") if "tables" in selector and "tables" in f: @@ -197,12 +243,11 @@ def read_zarr( exc_types=(JSONDecodeError, MetadataError), ): group = f["tables"] - tables = _read_table(f_store_path, f, group, tables, on_bad_files=on_bad_files) + tables = read_tables_element(_store, f, group, tables, on_bad_files=on_bad_files) if "table" in selector and "table" in f: warnings.warn( - f"Table group found in zarr store at location {f_store_path}. Please update the zarr store to use tables " - f"instead.", + f"Table group found in zarr store at location {store}. Please update the zarr store to use tables instead.", DeprecationWarning, stacklevel=2, ) @@ -213,7 +258,7 @@ def read_zarr( exc_types=(JSONDecodeError, MetadataError), ): group = f[subgroup_name] - tables = _read_table(f_store_path, f, group, tables, on_bad_files=on_bad_files) + tables = read_tables_element(store, f, group, tables, on_bad_files=on_bad_files) logger.debug(f"Found {count} elements in {group}") @@ -234,5 +279,5 @@ def read_zarr( tables=tables, attrs=attrs, ) - sdata.path = Path(store) + sdata.path = _create_upath(_store) return sdata diff --git a/src/spatialdata/_types.py b/src/spatialdata/_types.py index 30d623a57..4b6b351d3 100644 --- a/src/spatialdata/_types.py +++ b/src/spatialdata/_types.py @@ -1,9 +1,12 @@ -from typing import Any +from pathlib import Path +from typing import Any, TypeAlias import numpy as np +import zarr +from upath import UPath from xarray import DataArray, DataTree -__all__ = ["ArrayLike", "ColorLike", "DTypeLike", "Raster_T"] +__all__ = ["ArrayLike", "ColorLike", "DTypeLike", "Raster_T", "StoreLike"] from numpy.typing import DTypeLike, NDArray @@ -12,3 +15,5 @@ Raster_T = DataArray | DataTree ColorLike = tuple[float, ...] | str + +StoreLike: TypeAlias = str | Path | UPath | zarr.storage.StoreLike | zarr.Group diff --git a/tests/conftest.py b/tests/conftest.py index cc86f9777..38b2fbb1a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -440,7 +440,7 @@ def _make_sdata_for_testing_querying_and_aggretation() -> SpatialData: table = TableModel.parse( table, region=["values_circles", "values_polygons"], region_key="region", instance_key="instance_id" ) - sdata.table = table + sdata["table"] = table return sdata diff --git a/tests/core/operations/test_spatialdata_operations.py b/tests/core/operations/test_spatialdata_operations.py index db413af31..3c225dbe1 100644 --- a/tests/core/operations/test_spatialdata_operations.py +++ b/tests/core/operations/test_spatialdata_operations.py @@ -174,7 +174,7 @@ def test_filter_by_coordinate_system_also_table(full_sdata: SpatialData) -> None adata = full_sdata["table"] del adata.uns[TableModel.ATTRS_KEY] del full_sdata.tables["table"] - full_sdata.table = TableModel.parse( + full_sdata["table"] = TableModel.parse( adata, region=["circles", "poly"], region_key="annotated_shapes", diff --git a/tests/core/query/test_relational_query.py b/tests/core/query/test_relational_query.py index f0b4da7e0..1b47606b3 100644 --- a/tests/core/query/test_relational_query.py +++ b/tests/core/query/test_relational_query.py @@ -15,10 +15,14 @@ def test_match_table_to_element(sdata_query_aggregation): - matched_table = match_table_to_element(sdata=sdata_query_aggregation, element_name="values_circles") + matched_table = match_table_to_element( + sdata=sdata_query_aggregation, element_name="values_circles", table_name="table" + ) arr = np.array(list(reversed(sdata_query_aggregation["values_circles"].index))) sdata_query_aggregation["values_circles"].index = arr - matched_table_reversed = match_table_to_element(sdata=sdata_query_aggregation, element_name="values_circles") + matched_table_reversed = match_table_to_element( + sdata=sdata_query_aggregation, element_name="values_circles", table_name="table" + ) assert matched_table.obs.index.tolist() == list(reversed(matched_table_reversed.obs.index.tolist())) # TODO: add tests for labels diff --git a/tests/io/test_multi_table.py b/tests/io/test_multi_table.py index dd43cfa8d..7191f2f86 100644 --- a/tests/io/test_multi_table.py +++ b/tests/io/test_multi_table.py @@ -127,30 +127,6 @@ def test_set_table_annotates_spatialelement(self, full_sdata, tmp_path): ) full_sdata.write(tmpdir) - def test_old_accessor_deprecation(self, full_sdata, tmp_path): - # To test self._backed - tmpdir = Path(tmp_path) / "tmp.zarr" - full_sdata.write(tmpdir) - adata0 = _get_table(region="polygon") - - with pytest.warns(DeprecationWarning): - _ = full_sdata.table - with pytest.raises(ValueError): - full_sdata.table = adata0 - with pytest.warns(DeprecationWarning): - del full_sdata.table - with pytest.raises(KeyError): - del full_sdata["table"] - with pytest.warns(DeprecationWarning): - full_sdata.table = adata0 # this gets placed in sdata['table'] - - assert_equal(adata0, full_sdata["table"]) - - del full_sdata["table"] - - full_sdata.tables["my_new_table0"] = adata0 - assert full_sdata.get("table") is None - @pytest.mark.parametrize("region", ["test_shapes", "non_existing"]) def test_single_table(self, tmp_path: str, region: str): tmpdir = Path(tmp_path) / "tmp.zarr" diff --git a/tests/io/test_readwrite.py b/tests/io/test_readwrite.py index ad8c66b4c..9778b529d 100644 --- a/tests/io/test_readwrite.py +++ b/tests/io/test_readwrite.py @@ -1,4 +1,5 @@ import os +import sys import tempfile from collections.abc import Callable from pathlib import Path @@ -252,27 +253,6 @@ def test_incremental_io_on_disk( sdata.delete_element_from_disk(name) sdata.write_element(name) - def test_incremental_io_table_legacy(self, table_single_annotation: SpatialData) -> None: - s = table_single_annotation - t = s["table"][:10, :].copy() - with pytest.raises(ValueError): - s.table = t - del s["table"] - s.table = t - - with tempfile.TemporaryDirectory() as td: - f = os.path.join(td, "data.zarr") - s.write(f) - s2 = SpatialData.read(f) - assert len(s2["table"]) == len(t) - del s2["table"] - s2.table = s["table"] - assert len(s2["table"]) == len(s["table"]) - f2 = os.path.join(td, "data2.zarr") - s2.write(f2) - s3 = SpatialData.read(f2) - assert len(s3["table"]) == len(s2["table"]) - def test_io_and_lazy_loading_points(self, points): with tempfile.TemporaryDirectory() as td: f = os.path.join(td, "data.zarr") @@ -774,6 +754,7 @@ def test_incremental_writing_valid_table_name_invalid_table(tmp_path: Path): invalid_sdata.write_element("valid_name") +@pytest.mark.skipif(sys.platform == "win32", reason="Renaming fails as windows path already sees the name as invalid.") def test_reading_invalid_name(tmp_path: Path): image_name, image = next(iter(_get_images().items())) labels_name, labels = next(iter(_get_labels().items())) @@ -789,8 +770,10 @@ def test_reading_invalid_name(tmp_path: Path): ) valid_sdata.write(tmp_path / "data.zarr") # Circumvent validation at construction time and check validation happens again at writing time. - (tmp_path / "data.zarr/points" / points_name).rename(tmp_path / "data.zarr/points" / "has whitespace") - (tmp_path / "data.zarr/shapes" / shapes_name).rename(tmp_path / "data.zarr/shapes" / "non-alnum_#$%&()*+,?@") + (tmp_path / "data.zarr" / "points" / points_name).rename(tmp_path / "data.zarr" / "points" / "has whitespace") + (tmp_path / "data.zarr" / "shapes" / shapes_name).rename( + tmp_path / "data.zarr" / "shapes" / "non-alnum_#$%&()*+,?@" + ) with pytest.raises(ValidationError, match="Cannot construct SpatialData") as exc_info: read_zarr(tmp_path / "data.zarr") diff --git a/tests/io/test_remote.py b/tests/io/test_remote.py new file mode 100644 index 000000000..c625f5eb4 --- /dev/null +++ b/tests/io/test_remote.py @@ -0,0 +1,33 @@ +import pytest +import zarr +from upath import UPath + +from spatialdata import SpatialData + + +class TestRemote: + # Test actual remote datasets from https://spatialdata.scverse.org/en/latest/tutorials/notebooks/datasets/README.html + + @pytest.fixture(params=["merfish", "mibitof", "mibitof_alt"]) + def s3_address(self, request): + urls = { + "merfish": UPath( + "s3://spatialdata/spatialdata-sandbox/merfish.zarr", endpoint_url="https://s3.embl.de", anon=True + ), + "mibitof": UPath( + "s3://spatialdata/spatialdata-sandbox/mibitof.zarr", endpoint_url="https://s3.embl.de", anon=True + ), + "mibitof_alt": "https://dl01.irc.ugent.be/spatial/mibitof/data.zarr/", + } + return urls[request.param] + + def test_remote(self, s3_address): + # TODO: remove selection once support for points, shapes and tables is added + sdata = SpatialData.read(s3_address, selection=("images", "labels")) + assert len(list(sdata.gen_elements())) > 0 + + def test_remote_consolidated(self, s3_address): + urlpath, storage_options = str(s3_address), getattr(s3_address, "storage_options", {}) + root = zarr.open_consolidated(urlpath, mode="r", metadata_key="zmetadata", storage_options=storage_options) + sdata = SpatialData.read(root, selection=("images", "labels")) + assert len(list(sdata.gen_elements())) > 0 diff --git a/tests/io/test_remote_mock.py b/tests/io/test_remote_mock.py new file mode 100644 index 000000000..d83eff9a5 --- /dev/null +++ b/tests/io/test_remote_mock.py @@ -0,0 +1,176 @@ +import os +import shlex +import subprocess +import tempfile +import time +import uuid +from pathlib import Path + +import fsspec +import pytest +from upath import UPath +from upath.implementations.cloud import S3Path + +from spatialdata import SpatialData +from spatialdata.testing import assert_spatial_data_objects_are_identical + +# This mock setup was inspired by https://github.com/fsspec/universal_pathlib/blob/main/upath/tests/conftest.py + + +@pytest.fixture(scope="session") +def s3_server(): + # create a writable local S3 system via moto + if "BOTO_CONFIG" not in os.environ: # pragma: no cover + os.environ["BOTO_CONFIG"] = "/dev/null" + if "AWS_ACCESS_KEY_ID" not in os.environ: # pragma: no cover + os.environ["AWS_ACCESS_KEY_ID"] = "testing" + if "AWS_SECRET_ACCESS_KEY" not in os.environ: # pragma: no cover + os.environ["AWS_SECRET_ACCESS_KEY"] = "testing" + if "AWS_SECURITY_TOKEN" not in os.environ: # pragma: no cover + os.environ["AWS_SECURITY_TOKEN"] = "testing" + if "AWS_SESSION_TOKEN" not in os.environ: # pragma: no cover + os.environ["AWS_SESSION_TOKEN"] = "testing" + if "AWS_DEFAULT_REGION" not in os.environ: # pragma: no cover + os.environ["AWS_DEFAULT_REGION"] = "us-east-1" + requests = pytest.importorskip("requests") + + pytest.importorskip("moto") + + port = 5555 + endpoint_uri = f"http://127.0.0.1:{port}/" + proc = subprocess.Popen( + shlex.split(f"moto_server -p {port}"), + stderr=subprocess.DEVNULL, + stdout=subprocess.DEVNULL, + ) + try: + timeout = 5 + while timeout > 0: + try: + r = requests.get(endpoint_uri, timeout=10) + if r.ok: + break + except requests.exceptions.RequestException: # pragma: no cover + pass + timeout -= 0.1 # pragma: no cover + time.sleep(0.1) # pragma: no cover + anon = False + s3so = { + "client_kwargs": {"endpoint_url": endpoint_uri}, + "use_listings_cache": True, + } + yield anon, s3so + finally: + proc.terminate() + proc.wait() + + +def clear_s3(s3_server, location=None): + # clear an s3 bucket of all contents + anon, s3so = s3_server + s3 = fsspec.filesystem("s3", anon=anon, **s3so) + if location and s3.exists(location): + for d, _, keys in s3.walk(location): + for key in keys: + s3.rm(f"{d}/{key}") + s3.invalidate_cache() + + +def upload_to_upath(upath, sdata): + # write the object to disk via a regular path, then copy it to the UPath byte-by-byte + # useful for testing the read and write functionality separately + with tempfile.TemporaryDirectory() as tempdir: + sdata_path = Path(tempdir) / "temp.zarr" + sdata.write(sdata_path) + # for every file in the sdata_path, copy it to the upath + for x in sdata_path.glob("**/*"): + if x.is_file(): + data = x.read_bytes() + destination = upath / x.relative_to(sdata_path) + destination.write_bytes(data) + + +@pytest.fixture(scope="function") +def upath(s3_server): + # make a mock bucket available for testing + pytest.importorskip("s3fs") + anon, s3so = s3_server + s3 = fsspec.filesystem("s3", anon=anon, **s3so) + random_name = uuid.uuid4().hex + bucket_name = f"test_{random_name}" + clear_s3(s3_server, bucket_name) + s3.mkdir(bucket_name) + # here you could write existing test files to s3.upload if needed + s3.invalidate_cache() + + upath = UPath(f"s3://{bucket_name}", anon=anon, **s3so) + yield upath + + +class TestRemoteMock: + def test_is_S3Path(self, upath): + assert isinstance(upath, S3Path) + + def test_upload_sdata(self, upath, full_sdata): + tmpdir = upath / "tmp.zarr" + upload_to_upath(tmpdir, full_sdata) + assert tmpdir.exists() + assert len(list(tmpdir.glob("*"))) == 8 + + def test_creating_file(self, upath) -> None: + file_name = "file1" + p1 = upath / file_name + p1.touch() + contents = [p.name for p in upath.iterdir()] + assert file_name in contents + + @pytest.mark.parametrize( + "sdata_type", + [ + "images", + "labels", + "table_single_annotation", + "table_multiple_annotations", + "points", + "shapes", + ], + ) + def test_reading_mocked_elements(self, upath: UPath, sdata_type: str, request) -> None: + sdata = request.getfixturevalue(sdata_type) + with tempfile.TemporaryDirectory() as tmpdir: + local_path = Path(tmpdir) / "tmp.zarr" + sdata.write(local_path) + local_sdata = SpatialData.read(local_path) + local_len = len(list(local_sdata.gen_elements())) + assert local_len > 0 + remote_path = upath / "tmp.zarr" + upload_to_upath(remote_path, sdata) + remote_sdata = SpatialData.read(remote_path) + assert len(list(remote_sdata.gen_elements())) == local_len + assert_spatial_data_objects_are_identical(local_sdata, remote_sdata) + + @pytest.mark.parametrize( + "sdata_type", + [ + "images", + "labels", + "table_single_annotation", + "table_multiple_annotations", + "points", + "shapes", + ], + ) + def test_writing_mocked_elements(self, upath: UPath, sdata_type: str, request) -> None: + sdata = request.getfixturevalue(sdata_type) + n_elements = len(list(sdata.gen_elements())) + # test writing to a remote path + remote_path = upath / "tmp.zarr" + sdata.write(remote_path) + if not sdata_type.startswith("table"): + assert len(list((remote_path / sdata_type).glob("[a-zA-Z]*"))) == n_elements + else: + assert len(list((remote_path / "tables").glob("[a-zA-Z]*"))) == n_elements + + # test reading the remotely written object + remote_sdata = SpatialData.read(remote_path) + assert_spatial_data_objects_are_identical(sdata, remote_sdata)