Skip to content

Commit 596726d

Browse files
committed
#137 Add get_repr_type
1 parent 706a3c6 commit 596726d

File tree

2 files changed

+39
-2
lines changed

2 files changed

+39
-2
lines changed

tests/test_typing.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
# standard library
2-
from typing import Any, Tuple
2+
from typing import Any, Optional, Tuple, Union
33

44

55
# third-party packages
66
from pytest import mark
7-
from typing_extensions import Literal
7+
from typing_extensions import Annotated, Literal
88

99

1010
# submodules
@@ -17,6 +17,7 @@
1717
get_dims,
1818
get_dtype,
1919
get_field_type,
20+
get_repr_type,
2021
)
2122

2223

@@ -57,6 +58,13 @@
5758
(Name[Any], "name"),
5859
]
5960

61+
testdata_repr_type = [
62+
(int, int),
63+
(Annotated[int, "annotation"], int),
64+
(Union[int, float], int),
65+
(Optional[int], int),
66+
]
67+
6068

6169
# test functions
6270
@mark.parametrize("type_, dims", testdata_dims)
@@ -72,3 +80,8 @@ def test_get_dtype(type_: Any, dtype: Any) -> None:
7280
@mark.parametrize("type_, field_type", testdata_field_type)
7381
def test_get_field_type(type_: Any, field_type: Any) -> None:
7482
assert get_field_type(type_).value == field_type
83+
84+
85+
@mark.parametrize("type_, repr_type", testdata_repr_type)
86+
def test_get_repr_type(type_: Any, repr_type: Any) -> None:
87+
assert get_repr_type(type_) == repr_type

xarray_dataclasses/typing.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -344,6 +344,30 @@ def get_inner(hint: Any, *indexes: int) -> Any:
344344
return get_inner(get_args(hint)[index], *indexes)
345345

346346

347+
def get_repr_type(type_: Any) -> Any:
348+
"""Parse a type and return an representative type.
349+
350+
Example:
351+
All of the following expressions will be ``True``::
352+
353+
get_repr_type(A) == A
354+
get_repr_type(Annotated[A, ...]) == A
355+
get_repr_type(Union[A, B, ...]) == A
356+
get_repr_type(Optional[A]) == A
357+
358+
"""
359+
360+
class Temporary:
361+
__annotations__ = dict(type=type_)
362+
363+
unannotated = get_type_hints(Temporary)["type"]
364+
365+
if get_origin(unannotated) is Union:
366+
return get_args(unannotated)[0]
367+
368+
return unannotated
369+
370+
347371
def is_str_literal(hint: Any) -> bool:
348372
"""Check if a type hint is Literal[str]."""
349373
args: Any = get_args(hint)

0 commit comments

Comments
 (0)