Skip to content

Commit 706a3c6

Browse files
committed
#137 Add get_field_type
1 parent 22a221f commit 706a3c6

File tree

2 files changed

+46
-8
lines changed

2 files changed

+46
-8
lines changed

tests/test_typing.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,16 @@
88

99

1010
# submodules
11-
from xarray_dataclasses.typing import ArrayLike, get_dims, get_dtype
11+
from xarray_dataclasses.typing import (
12+
ArrayLike,
13+
Attr,
14+
Coord,
15+
Data,
16+
Name,
17+
get_dims,
18+
get_dtype,
19+
get_field_type,
20+
)
1221

1322

1423
# type hints
@@ -41,6 +50,13 @@
4150
(ArrayLike[Any, int], "int"),
4251
]
4352

53+
testdata_field_type = [
54+
(Attr[Any], "attr"),
55+
(Coord[Any, Any], "coord"),
56+
(Data[Any, Any], "data"),
57+
(Name[Any], "name"),
58+
]
59+
4460

4561
# test functions
4662
@mark.parametrize("type_, dims", testdata_dims)
@@ -51,3 +67,8 @@ def test_get_dims(type_: Any, dims: Any) -> None:
5167
@mark.parametrize("type_, dtype", testdata_dtype)
5268
def test_get_dtype(type_: Any, dtype: Any) -> None:
5369
assert get_dtype(type_) == dtype
70+
71+
72+
@mark.parametrize("type_, field_type", testdata_field_type)
73+
def test_get_field_type(type_: Any, field_type: Any) -> None:
74+
assert get_field_type(type_).value == field_type

xarray_dataclasses/typing.py

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
# standard library
2121
from dataclasses import Field
22-
from enum import auto, Enum
22+
from enum import Enum
2323
from typing import (
2424
Any,
2525
ClassVar,
@@ -52,22 +52,22 @@
5252
class FieldType(Enum):
5353
"""Annotation of xarray-related field hints."""
5454

55-
ATTR = auto()
55+
ATTR = "attr"
5656
"""Annotation of attribute field hints."""
5757

58-
COORD = auto()
58+
COORD = "coord"
5959
"""Annotation of coordinate field hints."""
6060

61-
COORDOF = auto()
61+
COORDOF = "coordof"
6262
"""Annotation of coordinate field hints."""
6363

64-
DATA = auto()
64+
DATA = "data"
6565
"""Annotation of data (variable) field hints."""
6666

67-
DATAOF = auto()
67+
DATAOF = "dataof"
6868
"""Annotation of data (variable) field hints."""
6969

70-
NAME = auto()
70+
NAME = "name"
7171
"""Annotation of name field hints."""
7272

7373
def annotates(self, hint: Any) -> bool:
@@ -318,6 +318,23 @@ def get_dtype(type_: Any) -> Dtype:
318318
raise ValueError(f"Could not convert {type_!r} to dtype.")
319319

320320

321+
def get_field_type(type_: Any) -> FieldType:
322+
"""Parse a type and return a field type if it exists."""
323+
if FieldType.ATTR.annotates(type_):
324+
return FieldType.ATTR
325+
326+
if FieldType.COORD.annotates(type_):
327+
return FieldType.COORD
328+
329+
if FieldType.DATA.annotates(type_):
330+
return FieldType.DATA
331+
332+
if FieldType.NAME.annotates(type_):
333+
return FieldType.NAME
334+
335+
raise TypeError(f"Could not find any field type in {type_!r}.")
336+
337+
321338
def get_inner(hint: Any, *indexes: int) -> Any:
322339
"""Return an inner type hint by indexes."""
323340
if not indexes:

0 commit comments

Comments
 (0)