44
55from collections import UserDict
66from collections .abc import Iterable , KeysView , ValuesView
7- from typing import Any
7+ from typing import TypeVar
88from warnings import warn
99
1010from anndata import AnnData
1111from dask .dataframe import DataFrame as DaskDataFrame
1212from geopandas import GeoDataFrame
13+ from xarray import DataArray , DataTree
1314
1415from spatialdata ._core .validation import check_key_is_case_insensitively_unique , check_valid_name
1516from spatialdata ._types import Raster_T
2526 get_model ,
2627)
2728
29+ T = TypeVar ("T" )
2830
29- class Elements (UserDict [str , Any ]):
31+
32+ class Elements (UserDict [str , T ]):
3033 def __init__ (self , shared_keys : set [str | None ]) -> None :
3134 self ._shared_keys = shared_keys
3235 super ().__init__ ()
@@ -49,7 +52,7 @@ def _check_key(key: str, element_keys: Iterable[str], shared_keys: set[str | Non
4952 # Validation raises ValueError, but inappropriate mapping key must raise KeyError.
5053 raise KeyError (* e .args ) from e
5154
52- def __setitem__ (self , key : str , value : Any ) -> None :
55+ def __setitem__ (self , key : str , value : T ) -> None :
5356 self ._add_shared_key (key )
5457 super ().__setitem__ (key , value )
5558
@@ -61,12 +64,12 @@ def keys(self) -> KeysView[str]:
6164 """Return the keys of the Elements."""
6265 return self .data .keys ()
6366
64- def values (self ) -> ValuesView [Any ]:
67+ def values (self ) -> ValuesView [T ]:
6568 """Return the values of the Elements."""
6669 return self .data .values ()
6770
6871
69- class Images (Elements ):
72+ class Images (Elements [ DataArray | DataTree ] ):
7073 def __setitem__ (self , key : str , value : Raster_T ) -> None :
7174 self ._check_key (key , self .keys (), self ._shared_keys )
7275 schema = get_model (value )
@@ -83,7 +86,7 @@ def __setitem__(self, key: str, value: Raster_T) -> None:
8386 NotImplementedError ("TODO: implement for ndim > 4." )
8487
8588
86- class Labels (Elements ):
89+ class Labels (Elements [ DataArray | DataTree ] ):
8790 def __setitem__ (self , key : str , value : Raster_T ) -> None :
8891 self ._check_key (key , self .keys (), self ._shared_keys )
8992 schema = get_model (value )
@@ -100,7 +103,7 @@ def __setitem__(self, key: str, value: Raster_T) -> None:
100103 NotImplementedError ("TODO: implement for ndim > 3." )
101104
102105
103- class Shapes (Elements ):
106+ class Shapes (Elements [ GeoDataFrame ] ):
104107 def __setitem__ (self , key : str , value : GeoDataFrame ) -> None :
105108 self ._check_key (key , self .keys (), self ._shared_keys )
106109 schema = get_model (value )
@@ -110,7 +113,7 @@ def __setitem__(self, key: str, value: GeoDataFrame) -> None:
110113 super ().__setitem__ (key , value )
111114
112115
113- class Points (Elements ):
116+ class Points (Elements [ DaskDataFrame ] ):
114117 def __setitem__ (self , key : str , value : DaskDataFrame ) -> None :
115118 self ._check_key (key , self .keys (), self ._shared_keys )
116119 schema = get_model (value )
@@ -120,7 +123,7 @@ def __setitem__(self, key: str, value: DaskDataFrame) -> None:
120123 super ().__setitem__ (key , value )
121124
122125
123- class Tables (Elements ):
126+ class Tables (Elements [ AnnData ] ):
124127 def __setitem__ (self , key : str , value : AnnData ) -> None :
125128 self ._check_key (key , self .keys (), self ._shared_keys )
126129 schema = get_model (value )
0 commit comments