Skip to content

Commit 71d0ba7

Browse files
committed
#135 Make typing.DataClass generic
1 parent d2a5b8e commit 71d0ba7

File tree

4 files changed

+58
-63
lines changed

4 files changed

+58
-63
lines changed

xarray_dataclasses/dataarray.py

Lines changed: 22 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,18 @@
22

33

44
# standard library
5-
from dataclasses import Field
65
from functools import partial, wraps
76
from types import MethodType
8-
from typing import Any, Callable, Dict, Optional, Type, TypeVar, Union, overload
7+
from typing import (
8+
Any,
9+
Callable,
10+
ClassVar,
11+
Optional,
12+
Type,
13+
TypeVar,
14+
Union,
15+
overload,
16+
)
917

1018

1119
# dependencies
@@ -18,7 +26,7 @@
1826
# submodules
1927
from .datamodel import DataModel
2028
from .dataoptions import DataOptions
21-
from .typing import DataType, Order, Shape, Sizes
29+
from .typing import DataClass, DataClassFields, DataType, Order, Shape, Sizes
2230

2331

2432
# constants
@@ -30,29 +38,20 @@
3038
TDataArray = TypeVar("TDataArray", bound=xr.DataArray)
3139

3240

33-
class DataClass(Protocol[P]):
34-
"""Type hint for a dataclass object."""
41+
class OptionedClass(Protocol[P, TDataArray]):
42+
"""Type hint for dataclass objects with options."""
3543

3644
def __init__(self, *args: P.args, **kwargs: P.kwargs) -> None:
3745
...
3846

39-
__dataclass_fields__: Dict[str, Field[Any]]
40-
41-
42-
class DataArrayClass(Protocol[P, TDataArray]):
43-
"""Type hint for a dataclass object with a DataArray factory."""
44-
45-
def __init__(self, *args: P.args, **kwargs: P.kwargs) -> None:
46-
...
47-
48-
__dataclass_fields__: Dict[str, Field[Any]]
47+
__dataclass_fields__: ClassVar[DataClassFields]
4948
__dataoptions__: DataOptions[TDataArray]
5049

5150

5251
# runtime functions
5352
@overload
5453
def asdataarray(
55-
dataclass: DataArrayClass[Any, TDataArray],
54+
dataclass: OptionedClass[P, TDataArray],
5655
reference: Optional[DataType] = None,
5756
dataoptions: DataOptions[Any] = DEFAULT_OPTIONS,
5857
) -> TDataArray:
@@ -61,7 +60,7 @@ def asdataarray(
6160

6261
@overload
6362
def asdataarray(
64-
dataclass: DataClass[Any],
63+
dataclass: DataClass[P],
6564
reference: Optional[DataType] = None,
6665
dataoptions: DataOptions[TDataArray] = DEFAULT_OPTIONS,
6766
) -> TDataArray:
@@ -127,7 +126,7 @@ def __init__(self, func: Any) -> None:
127126
def __get__(
128127
self,
129128
obj: Any,
130-
cls: Type[DataArrayClass[P, TDataArray]],
129+
cls: Type[OptionedClass[P, TDataArray]],
131130
) -> Callable[P, TDataArray]:
132131
...
133132

@@ -163,7 +162,7 @@ def new(cls: Any, *args: Any, **kwargs: Any) -> Any:
163162
@overload
164163
@classmethod
165164
def shaped(
166-
cls: Type[DataArrayClass[P, TDataArray]],
165+
cls: Type[OptionedClass[P, TDataArray]],
167166
func: Callable[[Shape], np.ndarray],
168167
shape: Union[Shape, Sizes],
169168
**kwargs: Any,
@@ -209,7 +208,7 @@ def shaped(
209208
@overload
210209
@classmethod
211210
def empty(
212-
cls: Type[DataArrayClass[P, TDataArray]],
211+
cls: Type[OptionedClass[P, TDataArray]],
213212
shape: Union[Shape, Sizes],
214213
order: Order = "C",
215214
**kwargs: Any,
@@ -251,7 +250,7 @@ def empty(
251250
@overload
252251
@classmethod
253252
def zeros(
254-
cls: Type[DataArrayClass[P, TDataArray]],
253+
cls: Type[OptionedClass[P, TDataArray]],
255254
shape: Union[Shape, Sizes],
256255
order: Order = "C",
257256
**kwargs: Any,
@@ -293,7 +292,7 @@ def zeros(
293292
@overload
294293
@classmethod
295294
def ones(
296-
cls: Type[DataArrayClass[P, TDataArray]],
295+
cls: Type[OptionedClass[P, TDataArray]],
297296
shape: Union[Shape, Sizes],
298297
order: Order = "C",
299298
**kwargs: Any,
@@ -335,7 +334,7 @@ def ones(
335334
@overload
336335
@classmethod
337336
def full(
338-
cls: Type[DataArrayClass[P, TDataArray]],
337+
cls: Type[OptionedClass[P, TDataArray]],
339338
shape: Union[Shape, Sizes],
340339
fill_value: Any,
341340
order: Order = "C",

xarray_dataclasses/datamodel.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,13 @@
33

44
# standard library
55
from dataclasses import Field, dataclass, field, is_dataclass
6-
from typing import Any, Dict, Hashable, Optional, Type, cast
6+
from typing import Any, Dict, Hashable, Optional, Type, Union, cast
77

88

99
# dependencies
1010
import numpy as np
1111
import xarray as xr
12-
from typing_extensions import TypedDict, get_type_hints
12+
from typing_extensions import ParamSpec, TypedDict, get_type_hints
1313

1414

1515
# submodules
@@ -28,6 +28,8 @@
2828

2929

3030
# type hints
31+
P = ParamSpec("P")
32+
AnyDataClass = Union[Type[DataClass[P]], DataClass[P]]
3133
DimsDtype = TypedDict("DimsDtype", dims=Dims, dtype=Dtype)
3234

3335

@@ -45,7 +47,7 @@ class Data:
4547
type: DimsDtype
4648
"""Type (dims and dtype) of the field."""
4749

48-
factory: Optional[Type[DataClass]] = None
50+
factory: Any = None
4951
"""Factory dataclass to create a DataArray object."""
5052

5153
def __call__(self, reference: Optional[DataType] = None) -> xr.DataArray:
@@ -137,7 +139,7 @@ class DataModel:
137139
"""Model of the name fields."""
138140

139141
@classmethod
140-
def from_dataclass(cls, dataclass: DataClass) -> "DataModel":
142+
def from_dataclass(cls, dataclass: AnyDataClass[P]) -> "DataModel":
141143
"""Create a data model from a dataclass or its object."""
142144
model = cls()
143145
eval_field_types(dataclass)
@@ -162,7 +164,7 @@ def from_dataclass(cls, dataclass: DataClass) -> "DataModel":
162164

163165

164166
# runtime functions
165-
def eval_field_types(dataclass: DataClass) -> None:
167+
def eval_field_types(dataclass: AnyDataClass[P]) -> None:
166168
"""Evaluate field types of a dataclass or its object."""
167169
hints = get_type_hints(dataclass, include_extras=True) # type: ignore
168170

xarray_dataclasses/dataset.py

Lines changed: 13 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,9 @@
22

33

44
# standard library
5-
from dataclasses import Field
65
from functools import partial, wraps
76
from types import MethodType
8-
from typing import Any, Callable, Dict, Optional, Type, TypeVar, overload
7+
from typing import Any, Callable, ClassVar, Dict, Optional, Type, TypeVar, overload
98

109

1110
# dependencies
@@ -18,7 +17,7 @@
1817
# submodules
1918
from .datamodel import DataModel
2019
from .dataoptions import DataOptions
21-
from .typing import DataType, Order, Shape, Sizes
20+
from .typing import DataClass, DataClassFields, DataType, Order, Shape, Sizes
2221

2322

2423
# constants
@@ -30,29 +29,20 @@
3029
TDataset = TypeVar("TDataset", bound=xr.Dataset)
3130

3231

33-
class DataClass(Protocol[P]):
34-
"""Type hint for a dataclass object."""
32+
class OptionedClass(Protocol[P, TDataset]):
33+
"""Type hint for dataclass objects with options."""
3534

3635
def __init__(self, *args: P.args, **kwargs: P.kwargs) -> None:
3736
...
3837

39-
__dataclass_fields__: Dict[str, Field[Any]]
40-
41-
42-
class DatasetClass(Protocol[P, TDataset]):
43-
"""Type hint for a dataclass object with a Dataset factory."""
44-
45-
def __init__(self, *args: P.args, **kwargs: P.kwargs) -> None:
46-
...
47-
48-
__dataclass_fields__: Dict[str, Field[Any]]
38+
__dataclass_fields__: ClassVar[DataClassFields]
4939
__dataoptions__: DataOptions[TDataset]
5040

5141

5242
# runtime functions and classes
5343
@overload
5444
def asdataset(
55-
dataclass: DatasetClass[Any, TDataset],
45+
dataclass: OptionedClass[P, TDataset],
5646
reference: Optional[DataType] = None,
5747
dataoptions: DataOptions[Any] = DEFAULT_OPTIONS,
5848
) -> TDataset:
@@ -61,7 +51,7 @@ def asdataset(
6151

6252
@overload
6353
def asdataset(
64-
dataclass: DataClass[Any],
54+
dataclass: DataClass[P],
6555
reference: Optional[DataType] = None,
6656
dataoptions: DataOptions[TDataset] = DEFAULT_OPTIONS,
6757
) -> TDataset:
@@ -125,7 +115,7 @@ def __init__(self, func: Callable[..., Any]) -> None:
125115
def __get__(
126116
self,
127117
obj: Any,
128-
cls: Type[DatasetClass[P, TDataset]],
118+
cls: Type[OptionedClass[P, TDataset]],
129119
) -> Callable[P, TDataset]:
130120
...
131121

@@ -161,7 +151,7 @@ def new(cls: Any, *args: Any, **kwargs: Any) -> Any:
161151
@overload
162152
@classmethod
163153
def shaped(
164-
cls: Type[DatasetClass[P, TDataset]],
154+
cls: Type[OptionedClass[P, TDataset]],
165155
func: Callable[[Shape], np.ndarray],
166156
sizes: Sizes,
167157
**kwargs: Any,
@@ -208,7 +198,7 @@ def shaped(
208198
@overload
209199
@classmethod
210200
def empty(
211-
cls: Type[DatasetClass[P, TDataset]],
201+
cls: Type[OptionedClass[P, TDataset]],
212202
sizes: Sizes,
213203
order: Order = "C",
214204
**kwargs: Any,
@@ -250,7 +240,7 @@ def empty(
250240
@overload
251241
@classmethod
252242
def zeros(
253-
cls: Type[DatasetClass[P, TDataset]],
243+
cls: Type[OptionedClass[P, TDataset]],
254244
sizes: Sizes,
255245
order: Order = "C",
256246
**kwargs: Any,
@@ -292,7 +282,7 @@ def zeros(
292282
@overload
293283
@classmethod
294284
def ones(
295-
cls: Type[DatasetClass[P, TDataset]],
285+
cls: Type[OptionedClass[P, TDataset]],
296286
sizes: Sizes,
297287
order: Order = "C",
298288
**kwargs: Any,
@@ -334,7 +324,7 @@ def ones(
334324
@overload
335325
@classmethod
336326
def full(
337-
cls: Type[DatasetClass[P, TDataset]],
327+
cls: Type[OptionedClass[P, TDataset]],
338328
sizes: Sizes,
339329
fill_value: Any,
340330
order: Order = "C",

xarray_dataclasses/typing.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from enum import auto, Enum
2323
from typing import (
2424
Any,
25-
Callable,
25+
ClassVar,
2626
Dict,
2727
Hashable,
2828
Optional,
@@ -38,6 +38,7 @@
3838
from typing_extensions import (
3939
Annotated,
4040
Literal,
41+
ParamSpec,
4142
Protocol,
4243
get_args,
4344
get_origin,
@@ -74,21 +75,25 @@ def annotates(self, hint: Any) -> bool:
7475

7576

7677
# type hints
78+
P = ParamSpec("P")
79+
T = TypeVar("T")
80+
TDataClass = TypeVar("TDataClass", bound="DataClass[Any]")
81+
TDims = TypeVar("TDims", covariant=True)
82+
TDtype = TypeVar("TDtype", covariant=True)
83+
TName = TypeVar("TName", bound=Hashable)
84+
85+
DataClassFields = Dict[str, Field[Any]]
7786
DataType = Union[xr.DataArray, xr.Dataset]
7887
Dims = Tuple[str, ...]
7988
Dtype = Optional[str]
8089
Order = Literal["C", "F"]
8190
Shape = Union[Sequence[int], int]
8291
Sizes = Dict[str, int]
83-
T = TypeVar("T")
84-
TDims = TypeVar("TDims", covariant=True)
85-
TDtype = TypeVar("TDtype", covariant=True)
86-
TName = TypeVar("TName", bound=Hashable)
8792

8893

8994
@runtime_checkable
9095
class ArrayLike(Protocol[TDims, TDtype]):
91-
"""Type hint of array-like objects."""
96+
"""Type hint for array-like objects."""
9297

9398
def astype(self: T, dtype: Any) -> T:
9499
"""Method to convert data type of the object."""
@@ -105,14 +110,13 @@ def shape(self) -> Tuple[int, ...]:
105110
...
106111

107112

108-
class DataClass(Protocol):
109-
"""Type hint of dataclasses or their objects."""
110-
111-
__init__: Callable[..., None]
112-
__dataclass_fields__: Dict[str, Field[Any]]
113+
class DataClass(Protocol[P]):
114+
"""Type hint for dataclass objects."""
113115

116+
def __init__(self, *args: P.args, **kwargs: P.kwargs) -> None:
117+
...
114118

115-
TDataClass = TypeVar("TDataClass", bound=DataClass)
119+
__dataclass_fields__: ClassVar[DataClassFields]
116120

117121

118122
Attr = Annotated[T, FieldType.ATTR]

0 commit comments

Comments
 (0)