Skip to content

Commit 52f714c

Browse files
committed
#91 Use dict in DataModel
1 parent b2d18d2 commit 52f714c

File tree

4 files changed

+123
-107
lines changed

4 files changed

+123
-107
lines changed

tests/test_datamodel.py

Lines changed: 51 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -53,57 +53,73 @@ class ColorImage:
5353

5454
# test functions
5555
def test_xaxis_attr() -> None:
56-
assert xaxis_model.attr[0].name == "units"
57-
assert xaxis_model.attr[0].value == "pixel"
58-
assert xaxis_model.attr[0].type == "builtins.str"
56+
item = next(iter(xaxis_model.attr.values()))
57+
assert item.key == "units"
58+
assert item.value == "pixel"
59+
assert item.type == "builtins.str"
5960

6061

6162
def test_xaxis_data() -> None:
62-
assert xaxis_model.data[0].name == "data"
63-
assert xaxis_model.data[0].type == {"dims": ("x",), "dtype": "int"}
64-
assert xaxis_model.data[0].factory is None
63+
item = next(iter(xaxis_model.data.values()))
64+
assert item.key == "data"
65+
assert item.type == {"dims": ("x",), "dtype": "int"}
66+
assert item.factory is None
6567

6668

6769
def test_yaxis_attr() -> None:
68-
assert yaxis_model.attr[0].name == "units"
69-
assert yaxis_model.attr[0].value == "pixel"
70-
assert yaxis_model.attr[0].type == "builtins.str"
70+
item = next(iter(yaxis_model.attr.values()))
71+
assert item.key == "units"
72+
assert item.value == "pixel"
73+
assert item.type == "builtins.str"
7174

7275

7376
def test_yaxis_data() -> None:
74-
assert yaxis_model.data[0].name == "data"
75-
assert yaxis_model.data[0].type == {"dims": ("y",), "dtype": "int"}
76-
assert yaxis_model.data[0].factory is None
77+
item = next(iter(yaxis_model.data.values()))
78+
assert item.key == "data"
79+
assert item.type == {"dims": ("y",), "dtype": "int"}
80+
assert item.factory is None
7781

7882

79-
def test_matrix_coord() -> None:
80-
assert image_model.coord[0].name == "mask"
81-
assert image_model.coord[0].type == {"dims": ("x", "y"), "dtype": "bool"}
82-
assert image_model.coord[0].factory is None
83+
def test_image_coord() -> None:
84+
items = iter(image_model.coord.values())
8385

84-
assert image_model.coord[1].name == "x"
85-
assert image_model.coord[1].type == {"dims": ("x",), "dtype": "int"}
86-
assert image_model.coord[1].factory is XAxis
86+
item = next(items)
87+
assert item.key == "mask"
88+
assert item.type == {"dims": ("x", "y"), "dtype": "bool"}
89+
assert item.factory is None
8790

88-
assert image_model.coord[2].name == "y"
89-
assert image_model.coord[2].type == {"dims": ("y",), "dtype": "int"}
90-
assert image_model.coord[2].factory is YAxis
91+
item = next(items)
92+
assert item.key == "x"
93+
assert item.type == {"dims": ("x",), "dtype": "int"}
94+
assert item.factory is XAxis
9195

92-
93-
def test_matrix_data() -> None:
94-
assert image_model.data[0].name == "data"
95-
assert image_model.data[0].type == {"dims": ("x", "y"), "dtype": "float"}
96+
item = next(items)
97+
assert item.key == "y"
98+
assert item.type == {"dims": ("y",), "dtype": "int"}
99+
assert item.factory is YAxis
96100

97101

98102
def test_image_data() -> None:
99-
assert color_model.data[0].name == "red"
100-
assert color_model.data[0].type == {"dims": ("x", "y"), "dtype": "float"}
101-
assert color_model.data[0].factory is Image
103+
item = next(iter(image_model.data.values()))
104+
assert item.key == "data"
105+
assert item.type == {"dims": ("x", "y"), "dtype": "float"}
106+
assert item.factory is None
107+
108+
109+
def test_color_data() -> None:
110+
items = iter(color_model.data.values())
111+
112+
item = next(items)
113+
assert item.key == "red"
114+
assert item.type == {"dims": ("x", "y"), "dtype": "float"}
115+
assert item.factory is Image
102116

103-
assert color_model.data[1].name == "green"
104-
assert color_model.data[1].type == {"dims": ("x", "y"), "dtype": "float"}
105-
assert color_model.data[1].factory is Image
117+
item = next(items)
118+
assert item.key == "green"
119+
assert item.type == {"dims": ("x", "y"), "dtype": "float"}
120+
assert item.factory is Image
106121

107-
assert color_model.data[2].name == "blue"
108-
assert color_model.data[2].type == {"dims": ("x", "y"), "dtype": "float"}
109-
assert color_model.data[2].factory is Image
122+
item = next(items)
123+
assert item.key == "blue"
124+
assert item.type == {"dims": ("x", "y"), "dtype": "float"}
125+
assert item.factory is Image

xarray_dataclasses/dataarray.py

Lines changed: 21 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -113,21 +113,23 @@ def asdataarray(
113113
pass
114114

115115
model = DataModel.from_dataclass(dataclass)
116-
dataarray = dataoptions.factory(model.data[0](reference))
116+
item = next(iter(model.data.values()))
117+
dataarray = dataoptions.factory(item(reference))
117118

118-
for coord in model.coord:
119-
if coord.name in dataarray.dims:
120-
dataarray.coords.update({coord.name: coord(dataarray)})
119+
for item in model.coord.values():
120+
if item.key in dataarray.dims:
121+
dataarray.coords[item.key] = item(dataarray)
121122

122-
for coord in model.coord:
123-
if coord.name not in dataarray.dims:
124-
dataarray.coords.update({coord.name: coord(dataarray)})
123+
for item in model.coord.values():
124+
if item.key not in dataarray.dims:
125+
dataarray.coords[item.key] = item(dataarray)
125126

126-
for attr in model.attr:
127-
dataarray.attrs.update({attr.name: attr()})
127+
for item in model.attr.values():
128+
dataarray.attrs[item.key] = item()
128129

129-
for name in model.name:
130-
dataarray.name = name()
130+
if model.name:
131+
item = next(iter(model.name.values()))
132+
dataarray.name = item()
131133

132134
return dataarray
133135

@@ -175,11 +177,10 @@ def empty(
175177
176178
"""
177179
model = DataModel.from_dataclass(cls)
178-
name = model.data[0].name
179-
dims = model.data[0].type["dims"]
180+
name, item = next(iter(model.data.items()))
180181

181182
if isinstance(shape, dict):
182-
shape = tuple(shape[dim] for dim in dims)
183+
shape = tuple(shape[dim] for dim in item.type["dims"])
183184

184185
data = np.empty(shape, order=order)
185186
return asdataarray(cls(**{name: data}, **kwargs))
@@ -204,11 +205,10 @@ def zeros(
204205
205206
"""
206207
model = DataModel.from_dataclass(cls)
207-
name = model.data[0].name
208-
dims = model.data[0].type["dims"]
208+
name, item = next(iter(model.data.items()))
209209

210210
if isinstance(shape, dict):
211-
shape = tuple(shape[dim] for dim in dims)
211+
shape = tuple(shape[dim] for dim in item.type["dims"])
212212

213213
data = np.zeros(shape, order=order)
214214
return asdataarray(cls(**{name: data}, **kwargs))
@@ -233,11 +233,10 @@ def ones(
233233
234234
"""
235235
model = DataModel.from_dataclass(cls)
236-
name = model.data[0].name
237-
dims = model.data[0].type["dims"]
236+
name, item = next(iter(model.data.items()))
238237

239238
if isinstance(shape, dict):
240-
shape = tuple(shape[dim] for dim in dims)
239+
shape = tuple(shape[dim] for dim in item.type["dims"])
241240

242241
data = np.ones(shape, order=order)
243242
return asdataarray(cls(**{name: data}, **kwargs))
@@ -264,11 +263,10 @@ def full(
264263
265264
"""
266265
model = DataModel.from_dataclass(cls)
267-
name = model.data[0].name
268-
dims = model.data[0].type["dims"]
266+
name, item = next(iter(model.data.items()))
269267

270268
if isinstance(shape, dict):
271-
shape = tuple(shape[dim] for dim in dims)
269+
shape = tuple(shape[dim] for dim in item.type["dims"])
272270

273271
data = np.full(shape, fill_value, order=order)
274272
return asdataarray(cls(**{name: data}, **kwargs))

xarray_dataclasses/datamodel.py

Lines changed: 29 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
# standard library
55
from dataclasses import Field, dataclass, field, is_dataclass
6-
from typing import Any, Hashable, List, Optional, Type, cast
6+
from typing import Any, Dict, Hashable, Optional, Type, cast
77

88

99
# dependencies
@@ -37,7 +37,7 @@
3737
class Data:
3838
"""Field model for data-related fields."""
3939

40-
name: Hashable
40+
key: Hashable
4141
"""Name of the field."""
4242

4343
value: Any
@@ -77,18 +77,20 @@ def from_field(cls, field: Field[Any], value: Any, of: bool) -> "Data":
7777

7878
dataclass = get_inner(hint, 0)
7979
model = DataModel.from_dataclass(dataclass)
80+
data_item = next(iter(model.data.values()))
8081

8182
if not model.name:
82-
return cls(field.name, value, model.data[0].type, dataclass)
83+
return cls(field.name, value, data_item.type, dataclass)
8384
else:
84-
return cls(model.name[0].value, value, model.data[0].type, dataclass)
85+
name_item = next(iter(model.name.values()))
86+
return cls(name_item.value, value, data_item.type, dataclass)
8587

8688

8789
@dataclass(frozen=True)
8890
class General:
8991
"""Field model for general fields."""
9092

91-
name: Hashable
93+
key: Hashable
9294
"""Name of the field."""
9395

9496
value: Any
@@ -123,16 +125,16 @@ def from_field(cls, field: Field[Any], value: Any) -> "General":
123125
class DataModel:
124126
"""Model for dataclasses or their objects."""
125127

126-
attr: List[General] = field(default_factory=list)
128+
attr: Dict[str, General] = field(default_factory=dict)
127129
"""Model of the attribute fields."""
128130

129-
coord: List[Data] = field(default_factory=list)
131+
coord: Dict[str, Data] = field(default_factory=dict)
130132
"""Model of the coordinate fields."""
131133

132-
data: List[Data] = field(default_factory=list)
134+
data: Dict[str, Data] = field(default_factory=dict)
133135
"""Model of the data fields."""
134136

135-
name: List[General] = field(default_factory=list)
137+
name: Dict[str, General] = field(default_factory=dict)
136138
"""Model of the name fields."""
137139

138140
@classmethod
@@ -141,21 +143,21 @@ def from_dataclass(cls, dataclass: DataClass) -> "DataModel":
141143
model = cls()
142144
eval_field_types(dataclass)
143145

144-
for field_ in dataclass.__dataclass_fields__.values():
145-
value = getattr(dataclass, field_.name, field_.default)
146-
147-
if FieldType.ATTR.annotates(field_.type):
148-
model.attr.append(General.from_field(field_, value))
149-
elif FieldType.COORD.annotates(field_.type):
150-
model.coord.append(Data.from_field(field_, value, False))
151-
elif FieldType.COORDOF.annotates(field_.type):
152-
model.coord.append(Data.from_field(field_, value, True))
153-
elif FieldType.DATA.annotates(field_.type):
154-
model.data.append(Data.from_field(field_, value, False))
155-
elif FieldType.DATAOF.annotates(field_.type):
156-
model.data.append(Data.from_field(field_, value, True))
157-
elif FieldType.NAME.annotates(field_.type):
158-
model.name.append(General.from_field(field_, value))
146+
for field in dataclass.__dataclass_fields__.values():
147+
value = getattr(dataclass, field.name, field.default)
148+
149+
if FieldType.ATTR.annotates(field.type):
150+
model.attr[field.name] = General.from_field(field, value)
151+
elif FieldType.COORD.annotates(field.type):
152+
model.coord[field.name] = Data.from_field(field, value, False)
153+
elif FieldType.COORDOF.annotates(field.type):
154+
model.coord[field.name] = Data.from_field(field, value, True)
155+
elif FieldType.DATA.annotates(field.type):
156+
model.data[field.name] = Data.from_field(field, value, False)
157+
elif FieldType.DATAOF.annotates(field.type):
158+
model.data[field.name] = Data.from_field(field, value, True)
159+
elif FieldType.NAME.annotates(field.type):
160+
model.name[field.name] = General.from_field(field, value)
159161

160162
return model
161163

@@ -165,9 +167,9 @@ def eval_field_types(dataclass: DataClass) -> None:
165167
"""Evaluate field types of a dataclass or its object."""
166168
hints = get_type_hints(dataclass, include_extras=True) # type: ignore
167169

168-
for field_ in dataclass.__dataclass_fields__.values():
169-
if isinstance(field_.type, str):
170-
field_.type = hints[field_.name]
170+
for field in dataclass.__dataclass_fields__.values():
171+
if isinstance(field.type, str):
172+
field.type = hints[field.name]
171173

172174

173175
def typedarray(

xarray_dataclasses/dataset.py

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -115,19 +115,19 @@ def asdataset(
115115
model = DataModel.from_dataclass(dataclass)
116116
dataset = dataoptions.factory()
117117

118-
for data in model.data:
119-
dataset.update({data.name: data(reference)})
118+
for item in model.data.values():
119+
dataset[item.key] = item(reference)
120120

121-
for coord in model.coord:
122-
if coord.name in dataset.dims:
123-
dataset.coords.update({coord.name: coord(dataset)})
121+
for item in model.coord.values():
122+
if item.key in dataset.dims:
123+
dataset.coords[item.key] = item(dataset)
124124

125-
for coord in model.coord:
126-
if coord.name not in dataset.dims:
127-
dataset.coords.update({coord.name: coord(dataset)})
125+
for item in model.coord.values():
126+
if item.key not in dataset.dims:
127+
dataset.coords[item.key] = item(dataset)
128128

129-
for attr in model.attr:
130-
dataset.attrs.update({attr.name: attr()})
129+
for item in model.attr.values():
130+
dataset.attrs[item.key] = item()
131131

132132
return dataset
133133

@@ -177,9 +177,9 @@ def empty(
177177
model = DataModel.from_dataclass(cls)
178178
data_vars: Dict[str, Any] = {}
179179

180-
for data in model.data:
181-
shape = tuple(sizes[dim] for dim in data.type["dims"])
182-
data_vars[data.name] = np.empty(shape, order=order)
180+
for name, item in model.data.items():
181+
shape = tuple(sizes[dim] for dim in item.type["dims"])
182+
data_vars[name] = np.empty(shape, order=order)
183183

184184
return asdataset(cls(**data_vars, **kwargs))
185185

@@ -205,9 +205,9 @@ def zeros(
205205
model = DataModel.from_dataclass(cls)
206206
data_vars: Dict[str, Any] = {}
207207

208-
for data in model.data:
209-
shape = tuple(sizes[dim] for dim in data.type["dims"])
210-
data_vars[data.name] = np.zeros(shape, order=order)
208+
for name, item in model.data.items():
209+
shape = tuple(sizes[dim] for dim in item.type["dims"])
210+
data_vars[name] = np.zeros(shape, order=order)
211211

212212
return asdataset(cls(**data_vars, **kwargs))
213213

@@ -233,9 +233,9 @@ def ones(
233233
model = DataModel.from_dataclass(cls)
234234
data_vars: Dict[str, Any] = {}
235235

236-
for data in model.data:
237-
shape = tuple(sizes[dim] for dim in data.type["dims"])
238-
data_vars[data.name] = np.ones(shape, order=order)
236+
for name, item in model.data.items():
237+
shape = tuple(sizes[dim] for dim in item.type["dims"])
238+
data_vars[name] = np.ones(shape, order=order)
239239

240240
return asdataset(cls(**data_vars, **kwargs))
241241

@@ -263,8 +263,8 @@ def full(
263263
model = DataModel.from_dataclass(cls)
264264
data_vars: Dict[str, Any] = {}
265265

266-
for data in model.data:
267-
shape = tuple(sizes[dim] for dim in data.type["dims"])
268-
data_vars[data.name] = np.full(shape, fill_value, order=order)
266+
for name, item in model.data.items():
267+
shape = tuple(sizes[dim] for dim in item.type["dims"])
268+
data_vars[name] = np.full(shape, fill_value, order=order)
269269

270270
return asdataset(cls(**data_vars, **kwargs))

0 commit comments

Comments
 (0)