Skip to content

Commit a9a29cf

Browse files
authored
Merge pull request #118 from astropenguin/astropenguin/issue91
Use name-field value as data-variable name
2 parents 64959df + 789f9ec commit a9a29cf

File tree

5 files changed

+134
-129
lines changed

5 files changed

+134
-129
lines changed

tests/test_datamodel.py

Lines changed: 50 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
# submodules
1111
from xarray_dataclasses.datamodel import DataModel
12-
from xarray_dataclasses.typing import Attr, Coord, Coordof, Data, Dataof, Name
12+
from xarray_dataclasses.typing import Attr, Coord, Coordof, Data, Dataof
1313

1414

1515
# type hints
@@ -22,14 +22,12 @@
2222
class XAxis:
2323
data: Data[X, int]
2424
units: Attr[str] = "pixel"
25-
name: Name[str] = "x axis"
2625

2726

2827
@dataclass
2928
class YAxis:
3029
data: Data[Y, int]
3130
units: Attr[str] = "pixel"
32-
name: Name[str] = "y axis"
3331

3432

3533
@dataclass
@@ -55,69 +53,73 @@ class ColorImage:
5553

5654
# test functions
5755
def test_xaxis_attr() -> None:
58-
assert xaxis_model.attr[0].name == "units"
59-
assert xaxis_model.attr[0].value == "pixel"
60-
assert xaxis_model.attr[0].type == "builtins.str"
56+
item = next(iter(xaxis_model.attr.values()))
57+
assert item.name == "units"
58+
assert item.value == "pixel"
59+
assert item.type == "builtins.str"
6160

6261

6362
def test_xaxis_data() -> None:
64-
assert xaxis_model.data[0].name == "data"
65-
assert xaxis_model.data[0].type == {"dims": ("x",), "dtype": "int"}
66-
assert xaxis_model.data[0].factory is None
67-
68-
69-
def test_xaxis_name() -> None:
70-
assert xaxis_model.name[0].name == "name"
71-
assert xaxis_model.name[0].value == "x axis"
72-
assert xaxis_model.name[0].type == "builtins.str"
63+
item = next(iter(xaxis_model.data.values()))
64+
assert item.name == "data"
65+
assert item.type == {"dims": ("x",), "dtype": "int"}
66+
assert item.factory is None
7367

7468

7569
def test_yaxis_attr() -> None:
76-
assert yaxis_model.attr[0].name == "units"
77-
assert yaxis_model.attr[0].value == "pixel"
78-
assert yaxis_model.attr[0].type == "builtins.str"
70+
item = next(iter(yaxis_model.attr.values()))
71+
assert item.name == "units"
72+
assert item.value == "pixel"
73+
assert item.type == "builtins.str"
7974

8075

8176
def test_yaxis_data() -> None:
82-
assert yaxis_model.data[0].name == "data"
83-
assert yaxis_model.data[0].type == {"dims": ("y",), "dtype": "int"}
84-
assert yaxis_model.data[0].factory is None
77+
item = next(iter(yaxis_model.data.values()))
78+
assert item.name == "data"
79+
assert item.type == {"dims": ("y",), "dtype": "int"}
80+
assert item.factory is None
8581

8682

87-
def test_yaxis_name() -> None:
88-
assert yaxis_model.name[0].name == "name"
89-
assert yaxis_model.name[0].value == "y axis"
90-
assert yaxis_model.name[0].type == "builtins.str"
83+
def test_image_coord() -> None:
84+
items = iter(image_model.coord.values())
9185

86+
item = next(items)
87+
assert item.name == "mask"
88+
assert item.type == {"dims": ("x", "y"), "dtype": "bool"}
89+
assert item.factory is None
9290

93-
def test_matrix_coord() -> None:
94-
assert image_model.coord[0].name == "mask"
95-
assert image_model.coord[0].type == {"dims": ("x", "y"), "dtype": "bool"}
96-
assert image_model.coord[0].factory is None
91+
item = next(items)
92+
assert item.name == "x"
93+
assert item.type == {"dims": ("x",), "dtype": "int"}
94+
assert item.factory is XAxis
9795

98-
assert image_model.coord[1].name == "x"
99-
assert image_model.coord[1].type == {"dims": ("x",), "dtype": "int"}
100-
assert image_model.coord[1].factory is XAxis
96+
item = next(items)
97+
assert item.name == "y"
98+
assert item.type == {"dims": ("y",), "dtype": "int"}
99+
assert item.factory is YAxis
101100

102-
assert image_model.coord[2].name == "y"
103-
assert image_model.coord[2].type == {"dims": ("y",), "dtype": "int"}
104-
assert image_model.coord[2].factory is YAxis
105101

102+
def test_image_data() -> None:
103+
item = next(iter(image_model.data.values()))
104+
assert item.name == "data"
105+
assert item.type == {"dims": ("x", "y"), "dtype": "float"}
106+
assert item.factory is None
106107

107-
def test_matrix_data() -> None:
108-
assert image_model.data[0].name == "data"
109-
assert image_model.data[0].type == {"dims": ("x", "y"), "dtype": "float"}
110108

109+
def test_color_data() -> None:
110+
items = iter(color_model.data.values())
111111

112-
def test_image_data() -> None:
113-
assert color_model.data[0].name == "red"
114-
assert color_model.data[0].type == {"dims": ("x", "y"), "dtype": "float"}
115-
assert color_model.data[0].factory is Image
112+
item = next(items)
113+
assert item.name == "red"
114+
assert item.type == {"dims": ("x", "y"), "dtype": "float"}
115+
assert item.factory is Image
116116

117-
assert color_model.data[1].name == "green"
118-
assert color_model.data[1].type == {"dims": ("x", "y"), "dtype": "float"}
119-
assert color_model.data[1].factory is Image
117+
item = next(items)
118+
assert item.name == "green"
119+
assert item.type == {"dims": ("x", "y"), "dtype": "float"}
120+
assert item.factory is Image
120121

121-
assert color_model.data[2].name == "blue"
122-
assert color_model.data[2].type == {"dims": ("x", "y"), "dtype": "float"}
123-
assert color_model.data[2].factory is Image
122+
item = next(items)
123+
assert item.name == "blue"
124+
assert item.type == {"dims": ("x", "y"), "dtype": "float"}
125+
assert item.factory is Image

xarray_dataclasses/dataarray.py

Lines changed: 21 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -113,21 +113,23 @@ def asdataarray(
113113
pass
114114

115115
model = DataModel.from_dataclass(dataclass)
116-
dataarray = dataoptions.factory(model.data[0](reference))
116+
item = next(iter(model.data.values()))
117+
dataarray = dataoptions.factory(item(reference))
117118

118-
for coord in model.coord:
119-
if coord.name in dataarray.dims:
120-
dataarray.coords.update({coord.name: coord(dataarray)})
119+
for item in model.coord.values():
120+
if item.name in dataarray.dims:
121+
dataarray.coords[item.name] = item(dataarray)
121122

122-
for coord in model.coord:
123-
if coord.name not in dataarray.dims:
124-
dataarray.coords.update({coord.name: coord(dataarray)})
123+
for item in model.coord.values():
124+
if item.name not in dataarray.dims:
125+
dataarray.coords[item.name] = item(dataarray)
125126

126-
for attr in model.attr:
127-
dataarray.attrs.update({attr.name: attr()})
127+
for item in model.attr.values():
128+
dataarray.attrs[item.name] = item()
128129

129-
for name in model.name:
130-
dataarray.name = name()
130+
if model.name:
131+
item = next(iter(model.name.values()))
132+
dataarray.name = item()
131133

132134
return dataarray
133135

@@ -175,11 +177,10 @@ def empty(
175177
176178
"""
177179
model = DataModel.from_dataclass(cls)
178-
name = model.data[0].name
179-
dims = model.data[0].type["dims"]
180+
name, item = next(iter(model.data.items()))
180181

181182
if isinstance(shape, dict):
182-
shape = tuple(shape[dim] for dim in dims)
183+
shape = tuple(shape[dim] for dim in item.type["dims"])
183184

184185
data = np.empty(shape, order=order)
185186
return asdataarray(cls(**{name: data}, **kwargs))
@@ -204,11 +205,10 @@ def zeros(
204205
205206
"""
206207
model = DataModel.from_dataclass(cls)
207-
name = model.data[0].name
208-
dims = model.data[0].type["dims"]
208+
name, item = next(iter(model.data.items()))
209209

210210
if isinstance(shape, dict):
211-
shape = tuple(shape[dim] for dim in dims)
211+
shape = tuple(shape[dim] for dim in item.type["dims"])
212212

213213
data = np.zeros(shape, order=order)
214214
return asdataarray(cls(**{name: data}, **kwargs))
@@ -233,11 +233,10 @@ def ones(
233233
234234
"""
235235
model = DataModel.from_dataclass(cls)
236-
name = model.data[0].name
237-
dims = model.data[0].type["dims"]
236+
name, item = next(iter(model.data.items()))
238237

239238
if isinstance(shape, dict):
240-
shape = tuple(shape[dim] for dim in dims)
239+
shape = tuple(shape[dim] for dim in item.type["dims"])
241240

242241
data = np.ones(shape, order=order)
243242
return asdataarray(cls(**{name: data}, **kwargs))
@@ -264,11 +263,10 @@ def full(
264263
265264
"""
266265
model = DataModel.from_dataclass(cls)
267-
name = model.data[0].name
268-
dims = model.data[0].type["dims"]
266+
name, item = next(iter(model.data.items()))
269267

270268
if isinstance(shape, dict):
271-
shape = tuple(shape[dim] for dim in dims)
269+
shape = tuple(shape[dim] for dim in item.type["dims"])
272270

273271
data = np.full(shape, fill_value, order=order)
274272
return asdataarray(cls(**{name: data}, **kwargs))

xarray_dataclasses/datamodel.py

Lines changed: 37 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
# standard library
55
from dataclasses import Field, dataclass, field, is_dataclass
6-
from typing import Any, List, Optional, Type, cast
6+
from typing import Any, Dict, Hashable, Optional, Type, cast
77

88

99
# dependencies
@@ -37,7 +37,7 @@
3737
class Data:
3838
"""Field model for data-related fields."""
3939

40-
name: str
40+
name: Hashable
4141
"""Name of the field."""
4242

4343
value: Any
@@ -71,23 +71,26 @@ def from_field(cls, field: Field[Any], value: Any, of: bool) -> "Data":
7171
"""Create a field model from a dataclass field and a value."""
7272
hint = unannotate(field.type)
7373

74-
if of:
75-
dataclass = get_inner(hint, 0)
76-
data = DataModel.from_dataclass(dataclass).data[0]
77-
return cls(field.name, value, data.type, dataclass)
74+
if not of:
75+
type: DimsDtype = {"dims": get_dims(hint), "dtype": get_dtype(hint)}
76+
return cls(field.name, value, type)
77+
78+
dataclass = get_inner(hint, 0)
79+
model = DataModel.from_dataclass(dataclass)
80+
data_item = next(iter(model.data.values()))
81+
82+
if not model.name:
83+
return cls(field.name, value, data_item.type, dataclass)
7884
else:
79-
return cls(
80-
field.name,
81-
value,
82-
{"dims": get_dims(hint), "dtype": get_dtype(hint)},
83-
)
85+
name_item = next(iter(model.name.values()))
86+
return cls(name_item.value, value, data_item.type, dataclass)
8487

8588

8689
@dataclass(frozen=True)
8790
class General:
8891
"""Field model for general fields."""
8992

90-
name: str
93+
name: Hashable
9194
"""Name of the field."""
9295

9396
value: Any
@@ -122,16 +125,16 @@ def from_field(cls, field: Field[Any], value: Any) -> "General":
122125
class DataModel:
123126
"""Model for dataclasses or their objects."""
124127

125-
attr: List[General] = field(default_factory=list)
128+
attr: Dict[str, General] = field(default_factory=dict)
126129
"""Model of the attribute fields."""
127130

128-
coord: List[Data] = field(default_factory=list)
131+
coord: Dict[str, Data] = field(default_factory=dict)
129132
"""Model of the coordinate fields."""
130133

131-
data: List[Data] = field(default_factory=list)
134+
data: Dict[str, Data] = field(default_factory=dict)
132135
"""Model of the data fields."""
133136

134-
name: List[General] = field(default_factory=list)
137+
name: Dict[str, General] = field(default_factory=dict)
135138
"""Model of the name fields."""
136139

137140
@classmethod
@@ -140,21 +143,21 @@ def from_dataclass(cls, dataclass: DataClass) -> "DataModel":
140143
model = cls()
141144
eval_field_types(dataclass)
142145

143-
for field_ in dataclass.__dataclass_fields__.values():
144-
value = getattr(dataclass, field_.name, field_.default)
145-
146-
if FieldType.ATTR.annotates(field_.type):
147-
model.attr.append(General.from_field(field_, value))
148-
elif FieldType.COORD.annotates(field_.type):
149-
model.coord.append(Data.from_field(field_, value, False))
150-
elif FieldType.COORDOF.annotates(field_.type):
151-
model.coord.append(Data.from_field(field_, value, True))
152-
elif FieldType.DATA.annotates(field_.type):
153-
model.data.append(Data.from_field(field_, value, False))
154-
elif FieldType.DATAOF.annotates(field_.type):
155-
model.data.append(Data.from_field(field_, value, True))
156-
elif FieldType.NAME.annotates(field_.type):
157-
model.name.append(General.from_field(field_, value))
146+
for field in dataclass.__dataclass_fields__.values():
147+
value = getattr(dataclass, field.name, field.default)
148+
149+
if FieldType.ATTR.annotates(field.type):
150+
model.attr[field.name] = General.from_field(field, value)
151+
elif FieldType.COORD.annotates(field.type):
152+
model.coord[field.name] = Data.from_field(field, value, False)
153+
elif FieldType.COORDOF.annotates(field.type):
154+
model.coord[field.name] = Data.from_field(field, value, True)
155+
elif FieldType.DATA.annotates(field.type):
156+
model.data[field.name] = Data.from_field(field, value, False)
157+
elif FieldType.DATAOF.annotates(field.type):
158+
model.data[field.name] = Data.from_field(field, value, True)
159+
elif FieldType.NAME.annotates(field.type):
160+
model.name[field.name] = General.from_field(field, value)
158161

159162
return model
160163

@@ -164,9 +167,9 @@ def eval_field_types(dataclass: DataClass) -> None:
164167
"""Evaluate field types of a dataclass or its object."""
165168
hints = get_type_hints(dataclass, include_extras=True) # type: ignore
166169

167-
for field_ in dataclass.__dataclass_fields__.values():
168-
if isinstance(field_.type, str):
169-
field_.type = hints[field_.name]
170+
for field in dataclass.__dataclass_fields__.values():
171+
if isinstance(field.type, str):
172+
field.type = hints[field.name]
170173

171174

172175
def typedarray(

0 commit comments

Comments
 (0)