Skip to content

Commit 70306c6

Browse files
Add type hint when get a SpatialElement item (#939)
* add type hint when get a SpatialElement item * getitem can return an AnnData table * fix precommit --------- Co-authored-by: Luca Marconato <[email protected]>
1 parent eb5a202 commit 70306c6

File tree

3 files changed

+32
-15
lines changed

3 files changed

+32
-15
lines changed

src/spatialdata/_core/_elements.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,13 @@
44

55
from collections import UserDict
66
from collections.abc import Iterable, KeysView, ValuesView
7-
from typing import Any
7+
from typing import TypeVar
88
from warnings import warn
99

1010
from anndata import AnnData
1111
from dask.dataframe import DataFrame as DaskDataFrame
1212
from geopandas import GeoDataFrame
13+
from xarray import DataArray, DataTree
1314

1415
from spatialdata._core.validation import check_key_is_case_insensitively_unique, check_valid_name
1516
from spatialdata._types import Raster_T
@@ -25,8 +26,10 @@
2526
get_model,
2627
)
2728

29+
T = TypeVar("T")
2830

29-
class Elements(UserDict[str, Any]):
31+
32+
class Elements(UserDict[str, T]):
3033
def __init__(self, shared_keys: set[str | None]) -> None:
3134
self._shared_keys = shared_keys
3235
super().__init__()
@@ -49,7 +52,7 @@ def _check_key(key: str, element_keys: Iterable[str], shared_keys: set[str | Non
4952
# Validation raises ValueError, but inappropriate mapping key must raise KeyError.
5053
raise KeyError(*e.args) from e
5154

52-
def __setitem__(self, key: str, value: Any) -> None:
55+
def __setitem__(self, key: str, value: T) -> None:
5356
self._add_shared_key(key)
5457
super().__setitem__(key, value)
5558

@@ -61,12 +64,12 @@ def keys(self) -> KeysView[str]:
6164
"""Return the keys of the Elements."""
6265
return self.data.keys()
6366

64-
def values(self) -> ValuesView[Any]:
67+
def values(self) -> ValuesView[T]:
6568
"""Return the values of the Elements."""
6669
return self.data.values()
6770

6871

69-
class Images(Elements):
72+
class Images(Elements[DataArray | DataTree]):
7073
def __setitem__(self, key: str, value: Raster_T) -> None:
7174
self._check_key(key, self.keys(), self._shared_keys)
7275
schema = get_model(value)
@@ -83,7 +86,7 @@ def __setitem__(self, key: str, value: Raster_T) -> None:
8386
NotImplementedError("TODO: implement for ndim > 4.")
8487

8588

86-
class Labels(Elements):
89+
class Labels(Elements[DataArray | DataTree]):
8790
def __setitem__(self, key: str, value: Raster_T) -> None:
8891
self._check_key(key, self.keys(), self._shared_keys)
8992
schema = get_model(value)
@@ -100,7 +103,7 @@ def __setitem__(self, key: str, value: Raster_T) -> None:
100103
NotImplementedError("TODO: implement for ndim > 3.")
101104

102105

103-
class Shapes(Elements):
106+
class Shapes(Elements[GeoDataFrame]):
104107
def __setitem__(self, key: str, value: GeoDataFrame) -> None:
105108
self._check_key(key, self.keys(), self._shared_keys)
106109
schema = get_model(value)
@@ -110,7 +113,7 @@ def __setitem__(self, key: str, value: GeoDataFrame) -> None:
110113
super().__setitem__(key, value)
111114

112115

113-
class Points(Elements):
116+
class Points(Elements[DaskDataFrame]):
114117
def __setitem__(self, key: str, value: DaskDataFrame) -> None:
115118
self._check_key(key, self.keys(), self._shared_keys)
116119
schema = get_model(value)
@@ -120,7 +123,7 @@ def __setitem__(self, key: str, value: DaskDataFrame) -> None:
120123
super().__setitem__(key, value)
121124

122125

123-
class Tables(Elements):
126+
class Tables(Elements[AnnData]):
124127
def __setitem__(self, key: str, value: AnnData) -> None:
125128
self._check_key(key, self.keys(), self._shared_keys)
126129
schema = get_model(value)

src/spatialdata/_core/spatialdata.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2336,7 +2336,7 @@ def subset(
23362336
)
23372337
return SpatialData(**elements_dict, tables=tables, attrs=self.attrs)
23382338

2339-
def __getitem__(self, item: str) -> SpatialElement:
2339+
def __getitem__(self, item: str) -> SpatialElement | AnnData:
23402340
"""
23412341
Return the element with the given name.
23422342

src/spatialdata/testing.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,19 @@
1010
from xarray.testing import assert_equal
1111

1212
from spatialdata import SpatialData
13-
from spatialdata._core._elements import Elements
13+
from spatialdata._core._elements import Images, Labels, Points, Shapes, Tables
1414
from spatialdata.models import PointsModel
1515
from spatialdata.models._utils import SpatialElement
1616
from spatialdata.transformations.operations import get_transformation
1717

18+
_Elements = Images | Labels | Shapes | Points | Tables
19+
1820

1921
def assert_elements_dict_are_identical(
20-
elements0: Elements, elements1: Elements, check_transformations: bool = True, check_metadata: bool = True
22+
elements0: _Elements,
23+
elements1: _Elements,
24+
check_transformations: bool = True,
25+
check_metadata: bool = True,
2126
) -> None:
2227
"""
2328
Compare two dictionaries of elements and assert that they are identical (except for the order of the keys).
@@ -55,7 +60,10 @@ def assert_elements_dict_are_identical(
5560
element0 = elements0[k]
5661
element1 = elements1[k]
5762
assert_elements_are_identical(
58-
element0, element1, check_transformations=check_transformations, check_metadata=check_metadata
63+
element0,
64+
element1,
65+
check_transformations=check_transformations,
66+
check_metadata=check_metadata,
5967
)
6068

6169

@@ -125,7 +133,10 @@ def assert_elements_are_identical(
125133

126134

127135
def assert_spatial_data_objects_are_identical(
128-
sdata0: SpatialData, sdata1: SpatialData, check_transformations: bool = True, check_metadata: bool = True
136+
sdata0: SpatialData,
137+
sdata1: SpatialData,
138+
check_transformations: bool = True,
139+
check_metadata: bool = True,
129140
) -> None:
130141
"""
131142
Compare two SpatialData objects and assert that they are identical.
@@ -169,5 +180,8 @@ def assert_spatial_data_objects_are_identical(
169180
element0 = sdata0[element_name]
170181
element1 = sdata1[element_name]
171182
assert_elements_are_identical(
172-
element0, element1, check_transformations=check_transformations, check_metadata=check_metadata
183+
element0,
184+
element1,
185+
check_transformations=check_transformations,
186+
check_metadata=check_metadata,
173187
)

0 commit comments

Comments
 (0)