Skip to content

Commit 48bb91e

Browse files
committed
Merge branch 'main' into pandas-3.0
2 parents 1768e7e + d013fe0 commit 48bb91e

File tree

10 files changed

+136
-27
lines changed

10 files changed

+136
-27
lines changed

docs/installation.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Installation
22

3-
`spatialdata` requires Python version >= 3.9 to run and the installation time requires a few minutes on a standard desktop computer.
3+
`spatialdata` requires Python (the minimum required version is specified in PyPI) to run and the installation time requires a few minutes on a standard desktop computer.
44

55
## PyPI
66

docs/tutorials/notebooks

Submodule notebooks updated 53 files

src/spatialdata/_core/query/spatial_query.py

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from functools import singledispatch
66
from typing import TYPE_CHECKING, Any
77

8-
import dask.array as da
98
import dask.dataframe as dd
109
import numpy as np
1110
from dask.dataframe import DataFrame as DaskDataFrame
@@ -385,7 +384,7 @@ def _bounding_box_mask_points(
385384
axes: tuple[str, ...],
386385
min_coordinate: list[Number] | ArrayLike,
387386
max_coordinate: list[Number] | ArrayLike,
388-
) -> da.Array:
387+
) -> list[ArrayLike]:
389388
"""Compute a mask that is true for the points inside axis-aligned bounding boxes.
390389
391390
Parameters
@@ -427,12 +426,9 @@ def _bounding_box_mask_points(
427426
continue
428427
min_value = min_coordinate[box, axis_index]
429428
max_value = max_coordinate[box, axis_index]
430-
box_masks.append(
431-
points[axis_name].gt(min_value).to_dask_array(lengths=True)
432-
& points[axis_name].lt(max_value).to_dask_array(lengths=True)
433-
)
434-
bounding_box_mask = da.stack(box_masks, axis=-1)
435-
in_bounding_box_masks.append(da.all(bounding_box_mask, axis=1))
429+
box_masks.append(points[axis_name].gt(min_value).compute() & points[axis_name].lt(max_value).compute())
430+
bounding_box_mask = np.stack(box_masks, axis=-1)
431+
in_bounding_box_masks.append(np.all(bounding_box_mask, axis=1))
436432
return in_bounding_box_masks
437433

438434

@@ -673,19 +669,20 @@ def _(
673669
)
674670

675671
if not (len_df := len(in_intrinsic_bounding_box)) == (len_bb := len(min_coordinate)):
676-
raise ValueError(f"Number of dataframes `{len_df}` is not equal to the number of bounding boxes `{len_bb}`.")
672+
raise ValueError(
673+
f"Length of list of dataframes `{len_df}` is not equal to the number of bounding boxes axes `{len_bb}`."
674+
)
677675
points_in_intrinsic_bounding_box: list[DaskDataFrame | None] = []
678676
points_pd = points.compute()
679677
attrs = points.attrs.copy()
680-
for mask in in_intrinsic_bounding_box:
681-
if mask.sum() == 0:
678+
for mask_np in in_intrinsic_bounding_box:
679+
if mask_np.sum() == 0:
682680
points_in_intrinsic_bounding_box.append(None)
683681
else:
684682
# TODO there is a problem when mixing dask dataframe graph with dask array graph. Need to compute for now.
685683
# we can't compute either mask or points as when we calculate either one of them
686684
# test_query_points_multiple_partitions will fail as the mask will be used to index each partition.
687685
# However, if we compute and then create the dask array again we get the mixed dask graph problem.
688-
mask_np = mask.compute()
689686
filtered_pd = points_pd[mask_np]
690687
points_filtered = dd.from_pandas(filtered_pd, npartitions=points.npartitions)
691688
points_filtered.attrs.update(attrs)
@@ -724,9 +721,9 @@ def _(
724721
min_coordinate=min_c, # type: ignore[arg-type]
725722
max_coordinate=max_c, # type: ignore[arg-type]
726723
)
727-
if len(bounding_box_mask) == 1:
728-
bounding_box_mask = bounding_box_mask[0]
729-
bounding_box_indices = np.where(bounding_box_mask.compute())[0]
724+
if len(bounding_box_mask) != 1:
725+
raise ValueError(f"Expected a single mask, got {len(bounding_box_mask)} masks. Please report this bug.")
726+
bounding_box_indices = np.where(bounding_box_mask[0])[0]
730727

731728
if len(bounding_box_indices) == 0:
732729
output.append(None)

src/spatialdata/_core/spatialdata.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from dask.dataframe import Scalar, read_parquet
1717
from geopandas import GeoDataFrame
1818
from shapely import MultiPolygon, Polygon
19+
from upath import UPath
1920
from xarray import DataArray, DataTree
2021
from zarr.errors import GroupNotFoundError
2122

@@ -1810,15 +1811,17 @@ def tables(self, tables: dict[str, AnnData]) -> None:
18101811

18111812
@staticmethod
18121813
def read(
1813-
file_path: Path | str, selection: tuple[str] | None = None, reconsolidate_metadata: bool = False
1814+
file_path: str | Path | UPath | zarr.Group,
1815+
selection: tuple[str] | None = None,
1816+
reconsolidate_metadata: bool = False,
18141817
) -> SpatialData:
18151818
"""
18161819
Read a SpatialData object from a Zarr storage (on-disk or remote).
18171820
18181821
Parameters
18191822
----------
18201823
file_path
1821-
The path or URL to the Zarr storage.
1824+
The path, URL, or zarr.Group to the Zarr storage.
18221825
selection
18231826
The elements to read (images, labels, points, shapes, table). If None, all elements are read.
18241827
reconsolidate_metadata

src/spatialdata/_io/_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -470,8 +470,8 @@ def _resolve_zarr_store(
470470
if isinstance(path, zarr.Group):
471471
# if the input is a zarr.Group, wrap it with a store
472472
if isinstance(path.store, LocalStore):
473-
# create a simple FSStore if the store is a LocalStore with just the path
474-
return FsspecStore(os.path.join(path.store.path, path.path), **kwargs)
473+
store_path = UPath(path.store.root) / path.path
474+
return LocalStore(store_path.path)
475475
if isinstance(path.store, FsspecStore):
476476
# if the store within the zarr.Group is an FSStore, return it
477477
# but extend the path of the store with that of the zarr.Group

src/spatialdata/_io/io_zarr.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from geopandas import GeoDataFrame
1212
from ome_zarr.format import Format
1313
from pyarrow import ArrowInvalid
14+
from upath import UPath
1415
from zarr.errors import ArrayNotFoundError
1516

1617
from spatialdata._core.spatialdata import SpatialData
@@ -120,7 +121,7 @@ def get_raster_format_for_read(
120121

121122

122123
def read_zarr(
123-
store: str | Path,
124+
store: str | Path | UPath | zarr.Group,
124125
selection: None | tuple[str] = None,
125126
on_bad_files: Literal[BadFileHandleMethod.ERROR, BadFileHandleMethod.WARN] = BadFileHandleMethod.ERROR,
126127
) -> SpatialData:
@@ -130,7 +131,7 @@ def read_zarr(
130131
Parameters
131132
----------
132133
store
133-
Path to the zarr store (on-disk or remote).
134+
Path, URL, or zarr.Group to the zarr store (on-disk or remote).
134135
135136
selection
136137
List of elements to read from the zarr store (images, labels, points, shapes, table). If None, all elements are
@@ -228,7 +229,7 @@ def read_zarr(
228229
tables=tables,
229230
attrs=attrs,
230231
)
231-
sdata.path = Path(store)
232+
sdata.path = resolved_store.root
232233
return sdata
233234

234235

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,21 @@
1-
try:
1+
from __future__ import annotations
2+
3+
from typing import TYPE_CHECKING, Any
4+
5+
import spatialdata
6+
7+
if TYPE_CHECKING:
28
from spatialdata.dataloader.datasets import ImageTilesDataset
3-
except ImportError:
4-
ImageTilesDataset = None # type: ignore[assignment, misc]
9+
10+
__all__ = [
11+
"ImageTilesDataset",
12+
]
13+
14+
15+
def __getattr__(attr_name: str) -> ImageTilesDataset | Any:
16+
if attr_name == "ImageTilesDataset":
17+
from spatialdata.dataloader.datasets import ImageTilesDataset
18+
19+
return ImageTilesDataset
20+
21+
return getattr(spatialdata.dataloader, attr_name)

src/spatialdata/models/models.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,10 @@ def parse(
239239
chunks=chunks,
240240
)
241241
_parse_transformations(data, parsed_transform)
242+
else:
243+
# Chunk single scale images
244+
if chunks is not None:
245+
data = data.chunk(chunks=chunks)
242246
cls()._check_chunk_size_not_too_large(data)
243247
# recompute coordinates for (multiscale) spatial image
244248
return compute_coordinates(data)

tests/io/test_readwrite.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import zarr
1212
from anndata import AnnData
1313
from numpy.random import default_rng
14+
from upath import UPath
1415
from zarr.errors import GroupNotFoundError
1516

1617
from spatialdata import SpatialData, deepcopy, read_zarr
@@ -963,3 +964,30 @@ def test_can_read_sdata_with_reconsolidation(full_sdata, sdata_container_format:
963964

964965
new_sdata = SpatialData.read(path, reconsolidate_metadata=True)
965966
assert_spatial_data_objects_are_identical(full_sdata, new_sdata)
967+
968+
969+
def test_read_sdata(tmp_path: Path, points: SpatialData) -> None:
970+
sdata_path = tmp_path / "sdata.zarr"
971+
points.write(sdata_path)
972+
973+
# path as Path
974+
sdata_from_path = SpatialData.read(sdata_path)
975+
assert sdata_from_path.path == sdata_path
976+
977+
# path as str
978+
sdata_from_str = SpatialData.read(str(sdata_path))
979+
assert sdata_from_str.path == sdata_path
980+
981+
# path as UPath
982+
sdata_from_upath = SpatialData.read(UPath(sdata_path))
983+
assert sdata_from_upath.path == sdata_path
984+
985+
# path as zarr Group
986+
zarr_group = zarr.open_group(sdata_path, mode="r")
987+
sdata_from_zarr_group = SpatialData.read(zarr_group)
988+
assert sdata_from_zarr_group.path == sdata_path
989+
990+
# Assert all read methods produce identical SpatialData objects
991+
assert_spatial_data_objects_are_identical(sdata_from_path, sdata_from_str)
992+
assert_spatial_data_objects_are_identical(sdata_from_path, sdata_from_upath)
993+
assert_spatial_data_objects_are_identical(sdata_from_path, sdata_from_zarr_group)

tests/models/test_models.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,65 @@ def test_raster_schema(
195195
with pytest.raises(ValueError):
196196
model.parse(image, **kwargs)
197197

198+
@pytest.mark.parametrize(
199+
"model,chunks,expected",
200+
[
201+
(Labels2DModel, None, (10, 10)),
202+
(Labels2DModel, 5, (5, 5)),
203+
(Labels2DModel, (5, 5), (5, 5)),
204+
(Labels2DModel, {"x": 5, "y": 5}, (5, 5)),
205+
(Labels3DModel, None, (2, 10, 10)),
206+
(Labels3DModel, 5, (2, 5, 5)),
207+
(Labels3DModel, (2, 5, 5), (2, 5, 5)),
208+
(Labels3DModel, {"z": 2, "x": 5, "y": 5}, (2, 5, 5)),
209+
(Image2DModel, None, (1, 10, 10)), # Image2D Models always have a c dimension
210+
(Image2DModel, 5, (1, 5, 5)),
211+
(Image2DModel, (1, 5, 5), (1, 5, 5)),
212+
(Image2DModel, {"c": 1, "x": 5, "y": 5}, (1, 5, 5)),
213+
(Image3DModel, None, (1, 2, 10, 10)), # Image3D models have z in addition, so 4 total dimensions
214+
(Image3DModel, 5, (1, 2, 5, 5)),
215+
(Image3DModel, (1, 2, 5, 5), (1, 2, 5, 5)),
216+
(
217+
Image3DModel,
218+
{"c": 1, "z": 2, "x": 5, "y": 5},
219+
(1, 2, 5, 5),
220+
),
221+
],
222+
)
223+
def test_raster_models_parse_with_chunks_parameter(self, model, chunks, expected):
224+
image: ArrayLike = np.arange(100).reshape((10, 10))
225+
if model in [Labels3DModel, Image3DModel]:
226+
image = np.stack([image] * 2)
227+
228+
if model in [Image2DModel, Image3DModel]:
229+
image = np.expand_dims(image, axis=0)
230+
231+
# parse as numpy array
232+
# single scale
233+
x_ss = model.parse(image, chunks=chunks)
234+
assert x_ss.data.chunksize == expected
235+
# multi scale
236+
x_ms = model.parse(image, chunks=chunks, scale_factors=(2,))
237+
assert x_ms["scale0"]["image"].data.chunksize == expected
238+
239+
# parse as dask array
240+
dask_image = from_array(image)
241+
# single scale
242+
y_ss = model.parse(dask_image, chunks=chunks)
243+
assert y_ss.data.chunksize == expected
244+
# multi scale
245+
y_ms = model.parse(dask_image, chunks=chunks, scale_factors=(2,))
246+
assert y_ms["scale0"]["image"].data.chunksize == expected
247+
248+
# parse as DataArray
249+
data_array = DataArray(image, dims=model.dims.dims)
250+
# single scale
251+
z_ss = model.parse(data_array, chunks=chunks)
252+
assert z_ss.data.chunksize == expected
253+
# multi scale
254+
z_ms = model.parse(data_array, chunks=chunks, scale_factors=(2,))
255+
assert z_ms["scale0"]["image"].data.chunksize == expected
256+
198257
@pytest.mark.parametrize("model", [Labels2DModel, Labels3DModel])
199258
def test_labels_model_with_multiscales(self, model):
200259
# Passing "scale_factors" should generate multiscales with a "method" appropriate for labels

0 commit comments

Comments
 (0)