Skip to content

Commit 214ab6b

Browse files
committed
#162 Merge pull request from astropenguin/astropenguin/issue147
2 parents 8508b6e + d326869 commit 214ab6b

File tree

3 files changed

+43
-3
lines changed

3 files changed

+43
-3
lines changed

tests/test_typing.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
get_dims,
2020
get_dtype,
2121
get_ftype,
22+
get_name,
2223
)
2324

2425

@@ -71,6 +72,24 @@
7172
(Union[Ann[Any, "other"], Ann[Any, "any"]], FType.OTHER),
7273
]
7374

75+
testdata_name = [
76+
(Attr[Any], None),
77+
(Data[Any, Any], None),
78+
(Coord[Any, Any], None),
79+
(Name[Any], None),
80+
(Any, None),
81+
(Ann[Attr[Any], "attr"], "attr"),
82+
(Ann[Data[Any, Any], "data"], "data"),
83+
(Ann[Coord[Any, Any], "coord"], "coord"),
84+
(Ann[Name[Any], "name"], "name"),
85+
(Ann[Any, "other"], None),
86+
(Union[Ann[Attr[Any], "attr"], Ann[Any, "any"]], "attr"),
87+
(Union[Ann[Data[Any, Any], "data"], Ann[Any, "any"]], "data"),
88+
(Union[Ann[Coord[Any, Any], "coord"], Ann[Any, "any"]], "coord"),
89+
(Union[Ann[Name[Any], "name"], Ann[Any, "any"]], "name"),
90+
(Union[Ann[Any, "other"], Ann[Any, "any"]], None),
91+
]
92+
7493

7594
# test functions
7695
@mark.parametrize("tp, dims", testdata_dims)
@@ -86,3 +105,8 @@ def test_get_dtype(tp: Any, dtype: Any) -> None:
86105
@mark.parametrize("tp, ftype", testdata_ftype)
87106
def test_get_ftype(tp: Any, ftype: Any) -> None:
88107
assert get_ftype(tp) == ftype
108+
109+
110+
@mark.parametrize("tp, name", testdata_name)
111+
def test_get_name(tp: Any, name: Any) -> None:
112+
assert get_name(tp) == name

xarray_dataclasses/datamodel.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
get_dims,
2727
get_dtype,
2828
get_ftype,
29+
get_name,
2930
)
3031

3132

@@ -209,10 +210,11 @@ def eval_dataclass(dataclass: AnyDataClass[PInit]) -> None:
209210
def get_entry(field: AnyField, value: Any) -> Optional[AnyEntry]:
210211
"""Create an entry from a field and its value."""
211212
ftype = get_ftype(field.type)
213+
name = get_name(field.type, field.name)
212214

213215
if ftype is FType.ATTR or ftype is FType.NAME:
214216
return AttrEntry(
215-
name=field.name,
217+
name=name,
216218
tag=ftype.value,
217219
value=value,
218220
type=get_annotated(field.type),
@@ -221,14 +223,14 @@ def get_entry(field: AnyField, value: Any) -> Optional[AnyEntry]:
221223
if ftype is FType.COORD or ftype is FType.DATA:
222224
try:
223225
return DataEntry(
224-
name=field.name,
226+
name=name,
225227
tag=ftype.value,
226228
base=get_dataclass(field.type),
227229
value=value,
228230
)
229231
except TypeError:
230232
return DataEntry(
231-
name=field.name,
233+
name=name,
232234
tag=ftype.value,
233235
dims=get_dims(field.type),
234236
dtype=get_dtype(field.type),

xarray_dataclasses/typing.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -347,3 +347,17 @@ def get_ftype(tp: Any, default: FType = FType.OTHER) -> FType:
347347
return get_annotations(tp)[0]
348348
except TypeError:
349349
return default
350+
351+
352+
def get_name(tp: Any, default: Hashable = None) -> Hashable:
353+
"""Extract a name if found or return given default."""
354+
try:
355+
annotations = get_annotations(tp)[1:]
356+
except TypeError:
357+
return default
358+
359+
for annotation in annotations:
360+
if isinstance(annotation, Hashable):
361+
return annotation
362+
363+
return default

0 commit comments

Comments
 (0)