Skip to content

Commit 2e64014

Browse files
Better 3D and 2.5D support for raster and vector data (#366)
* better 3D and 2.5D support for raster and vector data * fix mypy; workaround qt test (command double registration) * fix deprecated API * remove dask pin * cleanup * add tests for 3D points, 2.5D shapes; improve contribution guide * bump min spatialdata
1 parent 5daf882 commit 2e64014

File tree

9 files changed

+306
-29
lines changed

9 files changed

+306
-29
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,3 +117,4 @@ tests/plots/generated
117117

118118
# Local temp script for testing user bugs (Luca)
119119
temp/
120+
uv.lock

docs/contributing.md

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,25 @@
11
# Contributing guide
22

33
Please refer to the [contribution guide from the `spatialdata` repository](https://github.com/scverse/spatialdata/blob/main/docs/contributing.md).
4+
5+
## Debugging napari GUI tests
6+
7+
To visually inspect what a test is rendering in napari:
8+
9+
1. Change `make_napari_viewer()` to `make_napari_viewer(show=True)`
10+
2. Add `napari.run()` before the end of the test (before the assertions)
11+
12+
Example:
13+
14+
```python
15+
import napari
16+
17+
18+
def test_my_visualization(make_napari_viewer):
19+
viewer = make_napari_viewer(show=True)
20+
# ... setup code ...
21+
napari.run()
22+
# assertions...
23+
```
24+
25+
Remember to revert these changes before committing.

setup.cfg

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ install_requires =
5454
scipy
5555
shapely
5656
scikit-learn
57-
spatialdata>=0.7.0dev0
57+
spatialdata>=0.7.0dev1
5858
superqt
5959
typing_extensions>=4.8.0
6060
vispy

src/napari_spatialdata/_viewer.py

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,11 @@
1818
from spatialdata import get_element_annotators, get_element_instances
1919
from spatialdata._core.query.relational_query import _left_join_spatialelement_table
2020
from spatialdata._types import ArrayLike
21-
from spatialdata.models import PointsModel, ShapesModel, TableModel, force_2d, get_channel_names
21+
from spatialdata.models import PointsModel, ShapesModel, TableModel, force_2d, get_axes_names, get_channel_names
2222
from spatialdata.transformations import Affine, Identity
2323

2424
from napari_spatialdata._model import DataModel
2525
from napari_spatialdata.constants import config
26-
from napari_spatialdata.constants.config import CIRCLES_AS_POINTS
2726
from napari_spatialdata.utils._utils import (
2827
_adjust_channels_order,
2928
_get_ellipses_from_circles,
@@ -470,7 +469,7 @@ def get_sdata_image(self, sdata: SpatialData, key: str, selected_cs: str, multi:
470469
if multi:
471470
original_name = original_name[: original_name.rfind("_")]
472471

473-
affine = _get_transform(sdata.images[original_name], selected_cs)
472+
affine = _get_transform(sdata.images[original_name], selected_cs, include_z=True)
474473
rgb_image, rgb = _adjust_channels_order(element=sdata.images[original_name])
475474

476475
channels = ("RGB(A)",) if rgb else get_channel_names(sdata.images[original_name])
@@ -517,6 +516,7 @@ def get_sdata_circles(self, sdata: SpatialData, key: str, selected_cs: str, mult
517516
df = sdata.shapes[original_name]
518517
affine = _get_transform(sdata.shapes[original_name], selected_cs)
519518

519+
# 2.5D circles not supported yet
520520
xy = np.array([df.geometry.x, df.geometry.y]).T
521521
yx = np.fliplr(xy)
522522
radii = df.radius.to_numpy()
@@ -541,10 +541,10 @@ def get_sdata_circles(self, sdata: SpatialData, key: str, selected_cs: str, mult
541541
version = get_napari_version()
542542
kwargs: dict[str, Any] = (
543543
{"edge_width": 0.0}
544-
if version <= packaging.version.parse("0.4.20") or not CIRCLES_AS_POINTS
544+
if version <= packaging.version.parse("0.4.20") or not config.CIRCLES_AS_POINTS
545545
else {"border_width": 0.0}
546546
)
547-
if CIRCLES_AS_POINTS:
547+
if config.CIRCLES_AS_POINTS:
548548
layer = Points(
549549
yx,
550550
name=key,
@@ -556,7 +556,7 @@ def get_sdata_circles(self, sdata: SpatialData, key: str, selected_cs: str, mult
556556
assert affine is not None
557557
self._adjust_radii_of_points_layer(layer=layer, affine=affine)
558558
else:
559-
if version <= packaging.version.parse("0.4.20") or not CIRCLES_AS_POINTS:
559+
if version <= packaging.version.parse("0.4.20") or not config.CIRCLES_AS_POINTS:
560560
kwargs |= {"edge_color": "white"}
561561
else:
562562
kwargs |= {"border_color": "white"}
@@ -597,7 +597,8 @@ def get_sdata_shapes(self, sdata: SpatialData, key: str, selected_cs: str, multi
597597
original_name = original_name[: original_name.rfind("_")]
598598

599599
df = sdata.shapes[original_name]
600-
affine = _get_transform(sdata.shapes[original_name], selected_cs)
600+
include_z = not config.PROJECT_2_5D_SHAPES_TO_2D
601+
affine = _get_transform(sdata.shapes[original_name], selected_cs, include_z=include_z)
601602

602603
# when mulitpolygons are present, we select the largest ones
603604
if "MultiPolygon" in np.unique(df.geometry.type):
@@ -609,7 +610,7 @@ def get_sdata_shapes(self, sdata: SpatialData, key: str, selected_cs: str, multi
609610
df = df.sort_index() # reset the index to the first order
610611

611612
simplify = len(df) > config.POLYGON_THRESHOLD
612-
polygons, indices = _get_polygons_properties(df, simplify)
613+
polygons, indices = _get_polygons_properties(df, simplify, include_z=include_z)
613614

614615
# this will only work for polygons and not for multipolygons
615616
polygons = _transform_coordinates(polygons, f=lambda x: x[::-1])
@@ -662,7 +663,7 @@ def get_sdata_labels(self, sdata: SpatialData, key: str, selected_cs: str, multi
662663
original_name = original_name[: original_name.rfind("_")]
663664

664665
indices = get_element_instances(sdata.labels[original_name])
665-
affine = _get_transform(sdata.labels[original_name], selected_cs)
666+
affine = _get_transform(sdata.labels[original_name], selected_cs, include_z=True)
666667
rgb_labels, _ = _adjust_channels_order(element=sdata.labels[original_name])
667668

668669
adata, table_name, table_names = self._get_table_data(sdata, original_name)
@@ -706,8 +707,10 @@ def get_sdata_points(self, sdata: SpatialData, key: str, selected_cs: str, multi
706707
if multi:
707708
original_name = original_name[: original_name.rfind("_")]
708709

710+
axes = get_axes_names(sdata.points[original_name])
709711
points = sdata.points[original_name].compute()
710-
affine = _get_transform(sdata.points[original_name], selected_cs)
712+
include_z = "z" in axes and not config.PROJECT_3D_POINTS_TO_2D
713+
affine = _get_transform(sdata.points[original_name], selected_cs, include_z=include_z)
711714
adata, table_name, table_names = self._get_table_data(sdata, original_name)
712715

713716
if len(points) < config.POINT_THRESHOLD:
@@ -727,14 +730,16 @@ def get_sdata_points(self, sdata: SpatialData, key: str, selected_cs: str, multi
727730
_, adata = _left_join_spatialelement_table(
728731
{"points": {original_name: subsample_points}}, sdata[table_name], match_rows="left"
729732
)
730-
xy = subsample_points[["y", "x"]].values
731-
np.fliplr(xy)
733+
axes = sorted(axes, reverse=True)
734+
if not include_z and "z" in axes:
735+
axes.remove("z")
736+
coords = subsample_points[axes].values
732737
# radii_size = _calc_default_radii(self.viewer, sdata, selected_cs)
733738
radii_size = 3
734739
version = get_napari_version()
735740
kwargs = {"edge_width": 0.0} if version <= packaging.version.parse("0.4.20") else {"border_width": 0.0}
736741
layer = Points(
737-
xy,
742+
coords,
738743
name=key,
739744
size=radii_size * 2,
740745
affine=affine,

src/napari_spatialdata/constants/config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,5 @@
44
N_SHAPES_WARNING_THRESHOLD = 10000
55
POINT_SIZE_SCATTERPLOT_WIDGET = 6
66
CIRCLES_AS_POINTS = True
7+
PROJECT_3D_POINTS_TO_2D = True
8+
PROJECT_2_5D_SHAPES_TO_2D = True

src/napari_spatialdata/utils/_utils.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -181,15 +181,20 @@ def _transform_coordinates(data: list[Any], f: Callable[..., Any]) -> list[Any]:
181181
return [[f(xy) for xy in sublist] for sublist in data]
182182

183183

184-
def _get_transform(element: SpatialElement, coordinate_system_name: str | None = None) -> None | ArrayLike:
184+
def _get_transform(
185+
element: SpatialElement, coordinate_system_name: str | None = None, include_z: bool | None = None
186+
) -> None | ArrayLike:
185187
if not isinstance(element, DataArray | DataTree | DaskDataFrame | GeoDataFrame):
186188
raise RuntimeError("Cannot get transform for {type(element)}")
187189

188190
transformations = get_transformation(element, get_all=True)
189191
cs = transformations.keys().__iter__().__next__() if coordinate_system_name is None else coordinate_system_name
190192
ct = transformations.get(cs)
191193
if ct:
192-
return ct.to_affine_matrix(input_axes=("y", "x"), output_axes=("y", "x"))
194+
axes_element = get_axes_names(element)
195+
include_z = include_z and "z" in axes_element
196+
axes_transformation = ("z", "y", "x") if include_z else ("y", "x")
197+
return ct.to_affine_matrix(input_axes=axes_transformation, output_axes=axes_transformation)
193198
return None
194199

195200

Lines changed: 31 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,37 @@
11
from geopandas import GeoDataFrame
2+
from spatialdata.models import get_axes_names
23

4+
# type aliases, only used in this module
5+
Coord2D = tuple[float, float]
6+
Coord3D = tuple[float, float, float]
7+
Polygon2D = list[Coord2D]
8+
Polygon3D = list[Coord3D]
9+
Polygon = Polygon2D | Polygon3D
310

4-
def _get_polygons_properties(df: GeoDataFrame, simplify: bool) -> tuple[list[list[tuple[float, float]]], list[int]]:
5-
indices = []
6-
polygons = []
711

8-
if simplify:
9-
for i in range(0, len(df)):
10-
indices.append(df.iloc[i].name)
11-
# This can be removed once napari is sped up in the plotting. It changes the shapes only very slightly
12-
polygons.append(list(df.geometry.iloc[i].exterior.simplify(tolerance=2).coords))
13-
else:
14-
for i in range(0, len(df)):
15-
indices.append(df.iloc[i].name)
16-
polygons.append(list(df.geometry.iloc[i].exterior.coords))
12+
def _get_polygons_properties(df: GeoDataFrame, simplify: bool, include_z: bool) -> tuple[list[Polygon], list[int]]:
13+
# assumes no "Polygon Z": z is in separate column if present
14+
indices: list[int] = []
15+
polygons: list[Polygon] = []
16+
17+
axes = get_axes_names(df)
18+
add_z = include_z and "z" in axes
19+
20+
for i in range(len(df)):
21+
indices.append(int(df.index[i]))
22+
23+
if simplify:
24+
xy = list(df.geometry.iloc[i].exterior.simplify(tolerance=2).coords)
25+
else:
26+
xy = list(df.geometry.iloc[i].exterior.coords)
27+
28+
coords: Polygon2D | Polygon3D
29+
if add_z:
30+
z_val = float(df.iloc[i].z.item() if hasattr(df.iloc[i].z, "item") else df.iloc[i].z)
31+
coords = [(x, y, z_val) for x, y in xy]
32+
else:
33+
coords = xy
34+
35+
polygons.append(coords)
1736

1837
return polygons, indices

tests/conftest.py

Lines changed: 63 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,19 +9,23 @@
99
from pathlib import Path
1010
from typing import Any
1111

12+
import geopandas as gpd
1213
import napari
1314
import numpy as np
1415
import pandas as pd
1516
import pytest
1617
from anndata import AnnData
18+
from dask.dataframe import from_pandas
1719
from loguru import logger
1820
from matplotlib.testing.compare import compare_images
1921
from scipy import ndimage as ndi
22+
from shapely import MultiPolygon, Polygon
2023
from skimage import data
2124
from spatialdata import SpatialData
2225
from spatialdata._types import ArrayLike
2326
from spatialdata.datasets import blobs
24-
from spatialdata.models import TableModel
27+
from spatialdata.models import PointsModel, ShapesModel, TableModel
28+
from spatialdata.transformations import Identity, set_transformation
2529

2630
from napari_spatialdata.utils._test_utils import export_figure, save_image
2731

@@ -259,3 +263,61 @@ def caplog(caplog):
259263
def always_sync(monkeypatch, request):
260264
if request.node.get_closest_marker("use_thread_loader") is None:
261265
monkeypatch.setattr("napari_spatialdata._sdata_widgets.PROBLEMATIC_NUMPY_MACOS", True)
266+
267+
268+
@pytest.fixture
269+
def sdata_3d_points() -> SpatialData:
270+
"""Create a SpatialData object with 3D points (x, y, z coordinates)."""
271+
n_points = 10
272+
rng = np.random.default_rng(SEED)
273+
df = pd.DataFrame(
274+
{
275+
"x": rng.uniform(0, 100, n_points),
276+
"y": rng.uniform(0, 100, n_points),
277+
"z": rng.uniform(0, 50, n_points),
278+
}
279+
)
280+
dask_df = from_pandas(df, npartitions=1)
281+
points = PointsModel.parse(dask_df)
282+
set_transformation(points, {"global": Identity()}, set_all=True)
283+
284+
return SpatialData(points={"points_3d": points})
285+
286+
287+
@pytest.fixture
288+
def sdata_2_5d_shapes() -> SpatialData:
289+
"""Create a SpatialData object with 2.5D shapes (3 layers at different z, polygons + multipolygons)."""
290+
shapes = {}
291+
292+
geometries = []
293+
z_values = []
294+
indices = []
295+
for i, z_val in enumerate([0.0, 10.0, 20.0]):
296+
# Add simple polygons (triangles and quadrilaterals)
297+
poly1 = Polygon([(10 + i * 5, 10), (20 + i * 5, 10), (15 + i * 5, 20)])
298+
poly2 = Polygon([(30 + i * 5, 30), (40 + i * 5, 30), (40 + i * 5, 40), (30 + i * 5, 40)])
299+
geometries.extend([poly1, poly2])
300+
indices.extend([0, 1])
301+
z_values.extend([z_val] * 2)
302+
303+
# Add a multipolygon (two separate polygon parts)
304+
multi_poly = MultiPolygon(
305+
[
306+
Polygon([(50 + i * 5, 10), (60 + i * 5, 10), (55 + i * 5, 20)]),
307+
Polygon([(50 + i * 5, 30), (60 + i * 5, 30), (60 + i * 5, 40), (50 + i * 5, 40)]),
308+
]
309+
)
310+
geometries.append(multi_poly)
311+
indices.append(2)
312+
z_values.append(z_val)
313+
314+
gdf = gpd.GeoDataFrame(
315+
{"z": z_values, "geometry": geometries},
316+
index=indices,
317+
)
318+
319+
shape_element = ShapesModel.parse(gdf)
320+
set_transformation(shape_element, {"global": Identity()}, set_all=True)
321+
shapes["shapes_2.5d"] = shape_element
322+
323+
return SpatialData(shapes=shapes)

0 commit comments

Comments
 (0)