Skip to content

Commit 4d23fbf

Browse files
committed
#137 Update evaluation of dataclass field types
1 parent 596726d commit 4d23fbf

File tree

1 file changed

+19
-7
lines changed

1 file changed

+19
-7
lines changed

xarray_dataclasses/datamodel.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ class DataModel:
145145
def from_dataclass(cls, dataclass: AnyDataClass[P]) -> "DataModel":
146146
"""Create a data model from a dataclass or its object."""
147147
model = cls()
148-
eval_field_types(dataclass)
148+
eval_dataclass(dataclass)
149149

150150
for field in dataclass.__dataclass_fields__.values():
151151
value = getattr(dataclass, field.name, field.default)
@@ -167,13 +167,25 @@ def from_dataclass(cls, dataclass: AnyDataClass[P]) -> "DataModel":
167167

168168

169169
# runtime functions
170-
def eval_field_types(dataclass: AnyDataClass[P]) -> None:
171-
"""Evaluate field types of a dataclass or its object."""
172-
hints = get_type_hints(dataclass, include_extras=True) # type: ignore
170+
def eval_dataclass(dataclass: AnyDataClass[P]) -> None:
171+
"""Evaluate field types of a dataclass."""
172+
if not is_dataclass(dataclass):
173+
raise TypeError("Not a dataclass or its object.")
173174

174-
for field in dataclass.__dataclass_fields__.values():
175-
if isinstance(field.type, str):
176-
field.type = hints[field.name]
175+
fields = dataclass.__dataclass_fields__.values()
176+
177+
# do nothing if field types are already evaluated
178+
if not any(isinstance(field.type, str) for field in fields):
179+
return
180+
181+
# otherwise, replace field types with evaluated types
182+
if not isinstance(dataclass, type):
183+
dataclass = type(dataclass)
184+
185+
types = get_type_hints(dataclass, include_extras=True)
186+
187+
for field in fields:
188+
field.type = types[field.name]
177189

178190

179191
def typedarray(

0 commit comments

Comments
 (0)