File tree Expand file tree Collapse file tree 2 files changed +39
-2
lines changed Expand file tree Collapse file tree 2 files changed +39
-2
lines changed Original file line number Diff line number Diff line change 1
1
# standard library
2
- from typing import Any , Tuple
2
+ from typing import Any , Optional , Tuple , Union
3
3
4
4
5
5
# third-party packages
6
6
from pytest import mark
7
- from typing_extensions import Literal
7
+ from typing_extensions import Annotated , Literal
8
8
9
9
10
10
# submodules
17
17
get_dims ,
18
18
get_dtype ,
19
19
get_field_type ,
20
+ get_repr_type ,
20
21
)
21
22
22
23
57
58
(Name [Any ], "name" ),
58
59
]
59
60
61
+ testdata_repr_type = [
62
+ (int , int ),
63
+ (Annotated [int , "annotation" ], int ),
64
+ (Union [int , float ], int ),
65
+ (Optional [int ], int ),
66
+ ]
67
+
60
68
61
69
# test functions
62
70
@mark .parametrize ("type_, dims" , testdata_dims )
@@ -72,3 +80,8 @@ def test_get_dtype(type_: Any, dtype: Any) -> None:
72
80
@mark .parametrize ("type_, field_type" , testdata_field_type )
73
81
def test_get_field_type (type_ : Any , field_type : Any ) -> None :
74
82
assert get_field_type (type_ ).value == field_type
83
+
84
+
85
+ @mark .parametrize ("type_, repr_type" , testdata_repr_type )
86
+ def test_get_repr_type (type_ : Any , repr_type : Any ) -> None :
87
+ assert get_repr_type (type_ ) == repr_type
Original file line number Diff line number Diff line change @@ -344,6 +344,30 @@ def get_inner(hint: Any, *indexes: int) -> Any:
344
344
return get_inner (get_args (hint )[index ], * indexes )
345
345
346
346
347
+ def get_repr_type (type_ : Any ) -> Any :
348
+ """Parse a type and return an representative type.
349
+
350
+ Example:
351
+ All of the following expressions will be ``True``::
352
+
353
+ get_repr_type(A) == A
354
+ get_repr_type(Annotated[A, ...]) == A
355
+ get_repr_type(Union[A, B, ...]) == A
356
+ get_repr_type(Optional[A]) == A
357
+
358
+ """
359
+
360
+ class Temporary :
361
+ __annotations__ = dict (type = type_ )
362
+
363
+ unannotated = get_type_hints (Temporary )["type" ]
364
+
365
+ if get_origin (unannotated ) is Union :
366
+ return get_args (unannotated )[0 ]
367
+
368
+ return unannotated
369
+
370
+
347
371
def is_str_literal (hint : Any ) -> bool :
348
372
"""Check if a type hint is Literal[str]."""
349
373
args : Any = get_args (hint )
You can’t perform that action at this time.
0 commit comments