Skip to content

Commit 23267b0

Browse files
authored
#142 Merge pull request from astropenguin/astropenguin/issue141
Update protocol for public type hints
2 parents 364216e + 761e76c commit 23267b0

File tree

3 files changed

+35
-43
lines changed

3 files changed

+35
-43
lines changed

tests/test_typing.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@
99

1010
# submodules
1111
from xarray_dataclasses.typing import (
12-
ArrayLike,
1312
Attr,
13+
Collection,
1414
Coord,
1515
Data,
1616
Name,
@@ -34,21 +34,21 @@
3434
(Tuple[()], ()),
3535
(Tuple[X], ("x",)),
3636
(Tuple[X, Y], ("x", "y")),
37-
(ArrayLike[X, Any], ("x",)),
38-
(ArrayLike[Tuple[()], Any], ()),
39-
(ArrayLike[Tuple[X], Any], ("x",)),
40-
(ArrayLike[Tuple[X, Y], Any], ("x", "y")),
37+
(Collection[X, Any], ("x",)),
38+
(Collection[Tuple[()], Any], ()),
39+
(Collection[Tuple[X], Any], ("x",)),
40+
(Collection[Tuple[X, Y], Any], ("x", "y")),
4141
]
4242

4343
testdata_dtype = [
4444
(Any, None),
4545
(NoneType, None),
4646
(Int64, "int64"),
4747
(int, "int"),
48-
(ArrayLike[Any, Any], None),
49-
(ArrayLike[Any, NoneType], None),
50-
(ArrayLike[Any, Int64], "int64"),
51-
(ArrayLike[Any, int], "int"),
48+
(Collection[Any, Any], None),
49+
(Collection[Any, NoneType], None),
50+
(Collection[Any, Int64], "int64"),
51+
(Collection[Any, int], "int"),
5252
]
5353

5454
testdata_field_type = [

xarray_dataclasses/datamodel.py

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414

1515
# submodules
1616
from .typing import (
17-
ArrayLike,
1817
DataClass,
1918
DataType,
2019
Dims,
@@ -254,27 +253,27 @@ def get_typedarray(
254253
DataArray object with given dims and dtype.
255254
256255
"""
257-
if isinstance(data, ArrayLike):
258-
array = cast(np.ndarray, data)
259-
else:
260-
array = np.asarray(data)
256+
try:
257+
data.__array__
258+
except AttributeError:
259+
data = np.asarray(data)
261260

262261
if dtype is not None:
263-
array = array.astype(dtype, copy=False)
262+
data = data.astype(dtype, copy=False)
264263

265-
if array.ndim == len(dims):
266-
dataarray = xr.DataArray(array, dims=dims)
267-
elif array.ndim == 0 and reference is not None:
268-
dataarray = xr.DataArray(array)
264+
if data.ndim == len(dims):
265+
dataarray = xr.DataArray(data, dims=dims)
266+
elif data.ndim == 0 and reference is not None:
267+
dataarray = xr.DataArray(data)
269268
else:
270269
raise ValueError(
271270
"Could not create a DataArray object from data. "
272-
f"Mismatch between shape {array.shape} and dims {dims}."
271+
f"Mismatch between shape {data.shape} and dims {dims}."
273272
)
274273

275274
if reference is None:
276275
return dataarray
277-
278-
diff_dims = set(reference.dims) - set(dims)
279-
subspace = reference.isel({dim: 0 for dim in diff_dims})
280-
return dataarray.broadcast_like(subspace)
276+
else:
277+
ddims = set(reference.dims) - set(dims)
278+
reference = reference.isel({dim: 0 for dim in ddims})
279+
return dataarray.broadcast_like(reference)

xarray_dataclasses/typing.py

Lines changed: 12 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from typing import (
2424
Any,
2525
ClassVar,
26+
Collection,
2627
Dict,
2728
Hashable,
2829
Optional,
@@ -44,7 +45,6 @@
4445
get_args,
4546
get_origin,
4647
get_type_hints,
47-
runtime_checkable,
4848
)
4949

5050

@@ -65,23 +65,16 @@
6565
Sizes = Dict[str, int]
6666

6767

68-
@runtime_checkable
69-
class ArrayLike(Protocol[TDims, TDtype]):
70-
"""Type hint for array-like objects."""
68+
class Labeled(Protocol[TDims]):
69+
"""Type hint for labeled objects."""
7170

72-
def astype(self: T, dtype: Any) -> T:
73-
"""Method to convert data type of the object."""
74-
...
71+
pass
7572

76-
@property
77-
def ndim(self) -> int:
78-
"""Number of dimensions of the object."""
79-
...
8073

81-
@property
82-
def shape(self) -> Tuple[int, ...]:
83-
"""Shape of the object."""
84-
...
74+
class Collection(Labeled[TDims], Collection[TDtype], Protocol):
75+
"""Type hint for labeled collection objects."""
76+
77+
pass
8578

8679

8780
class DataClass(Protocol[P]):
@@ -138,7 +131,7 @@ class Image(AsDataArray):
138131
139132
"""
140133

141-
Coord = Annotated[Union[ArrayLike[TDims, TDtype], TDtype], FieldType.COORD]
134+
Coord = Annotated[Union[Collection[TDims, TDtype], TDtype], FieldType.COORD]
142135
"""Type hint to define coordinate fields (``Coord[TDims, TDtype]``).
143136
144137
Example:
@@ -189,7 +182,7 @@ class Image(AsDataArray):
189182
190183
"""
191184

192-
Data = Annotated[Union[ArrayLike[TDims, TDtype], TDtype], FieldType.DATA]
185+
Data = Annotated[Union[Collection[TDims, TDtype], TDtype], FieldType.DATA]
193186
"""Type hint to define data fields (``Coordof[TDims, TDtype]``).
194187
195188
Examples:
@@ -267,7 +260,7 @@ def get_dims(type_: Any) -> Dims:
267260
args = get_args(type_)
268261
origin = get_origin(type_)
269262

270-
if origin is ArrayLike:
263+
if origin is Collection:
271264
return get_dims(args[0])
272265

273266
if origin is tuple or origin is Tuple:
@@ -298,7 +291,7 @@ def get_dtype(type_: Any) -> Dtype:
298291
args = get_args(type_)
299292
origin = get_origin(type_)
300293

301-
if origin is ArrayLike:
294+
if origin is Collection:
302295
return get_dtype(args[1])
303296

304297
if origin is Literal:

0 commit comments

Comments
 (0)