diff --git a/src/spatialdata/_core/spatialdata.py b/src/spatialdata/_core/spatialdata.py index 48f6386c..77844fcd 100644 --- a/src/spatialdata/_core/spatialdata.py +++ b/src/spatialdata/_core/spatialdata.py @@ -1179,6 +1179,7 @@ def write( overwrite: bool = False, consolidate_metadata: bool = True, format: SpatialDataFormat | list[SpatialDataFormat] | None = None, + shapes_geometry_encoding: Literal["WKB", "geoarrow"] = "WKB", ) -> None: """ Write the `SpatialData` object to a Zarr store. @@ -1223,6 +1224,7 @@ def write( element_name=element_name, overwrite=False, format=format, + shapes_geometry_encoding=shapes_geometry_encoding, ) if self.path != file_path: @@ -1241,6 +1243,7 @@ def _write_element( element_name: str, overwrite: bool, format: SpatialDataFormat | list[SpatialDataFormat] | None = None, + shapes_geometry_encoding: Literal["WKB", "geoarrow"] = "WKB", ) -> None: if not isinstance(zarr_container_path, Path): raise ValueError( @@ -1266,7 +1269,13 @@ def _write_element( elif element_type == "points": write_points(points=element, group=element_type_group, name=element_name, format=parsed["points"]) elif element_type == "shapes": - write_shapes(shapes=element, group=element_type_group, name=element_name, format=parsed["shapes"]) + write_shapes( + shapes=element, + group=element_type_group, + name=element_name, + format=parsed["shapes"], + geometry_encoding=shapes_geometry_encoding, + ) elif element_type == "tables": write_table(table=element, group=element_type_group, name=element_name, format=parsed["tables"]) else: diff --git a/src/spatialdata/_io/io_shapes.py b/src/spatialdata/_io/io_shapes.py index c32ce1f3..fc57b436 100644 --- a/src/spatialdata/_io/io_shapes.py +++ b/src/spatialdata/_io/io_shapes.py @@ -1,5 +1,6 @@ from collections.abc import MutableMapping from pathlib import Path +from typing import Literal import numpy as np import zarr @@ -68,6 +69,7 @@ def write_shapes( name: str, group_type: str = "ngff:shapes", format: Format = CurrentShapesFormat(), + geometry_encoding: Literal["WKB", "geoarrow"] = "WKB", ) -> None: import numcodecs @@ -94,7 +96,7 @@ def write_shapes( attrs["version"] = format.spatialdata_format_version elif isinstance(format, ShapesFormatV02): path = Path(shapes_group._store.path) / shapes_group.path / "shapes.parquet" - shapes.to_parquet(path) + shapes.to_parquet(path, geometry_encoding=geometry_encoding) attrs = format.attrs_to_dict(shapes.attrs) attrs["version"] = format.spatialdata_format_version