Skip to content

Commit 5902933

Browse files
authored
simplify the conversion functions using call_on_dataset (#110)
* vendor a short helper function * only attempt to convert to DataArray if the result is a Dataset * also try attaching a empty dict * refactor the conversion functions to use call_on_dataset * catch the ImportError
1 parent 451b639 commit 5902933

File tree

3 files changed

+133
-126
lines changed

3 files changed

+133
-126
lines changed

pint_xarray/compat.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
import xarray as xr
2+
3+
try:
4+
from xarray import call_on_dataset
5+
except ImportError:
6+
7+
def call_on_dataset(func, obj, name, *args, **kwargs):
8+
if isinstance(obj, xr.DataArray):
9+
ds = obj.to_dataset(name=name)
10+
else:
11+
ds = obj
12+
13+
result = func(ds, *args, **kwargs)
14+
15+
if isinstance(obj, xr.DataArray) and isinstance(result, xr.Dataset):
16+
result = result.get(name).rename(obj.name)
17+
18+
return result

pint_xarray/conversion.py

Lines changed: 114 additions & 126 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,13 @@
33
import pint
44
from xarray import DataArray, Dataset, IndexVariable, Variable
55

6+
from .compat import call_on_dataset
67
from .errors import format_error_message
78

89
no_unit_values = ("none", None)
910
unit_attribute_name = "units"
1011
slice_attributes = ("start", "stop", "step")
12+
temporary_name = "<this-array>"
1113

1214

1315
def array_attach_units(data, unit):
@@ -107,40 +109,49 @@ def attach_units_variable(variable, units):
107109
return new_obj
108110

109111

112+
def dataset_from_variables(variables, coords, attrs):
113+
data_vars = {name: var for name, var in variables.items() if name not in coords}
114+
coords = {name: var for name, var in variables.items() if name in coords}
115+
116+
return Dataset(data_vars=data_vars, coords=coords, attrs=attrs)
117+
118+
119+
def attach_units_dataset(obj, units):
120+
attached = {}
121+
rejected_vars = {}
122+
for name, var in obj.variables.items():
123+
unit = units.get(name)
124+
try:
125+
converted = attach_units_variable(var, unit)
126+
attached[name] = converted
127+
except ValueError as e:
128+
rejected_vars[name] = (unit, e)
129+
130+
if rejected_vars:
131+
raise ValueError(rejected_vars)
132+
133+
return dataset_from_variables(attached, obj._coord_names, obj.attrs)
134+
135+
110136
def attach_units(obj, units):
137+
if not isinstance(obj, (DataArray, Dataset)):
138+
raise ValueError(f"cannot attach units to {obj!r}: unknown type")
139+
111140
if isinstance(obj, DataArray):
112-
old_name = obj.name
113-
new_name = old_name if old_name is not None else "<this-array>"
114-
ds = obj.rename(new_name).to_dataset()
115141
units = units.copy()
116-
units[new_name] = units.get(old_name)
142+
if obj.name in units:
143+
units[temporary_name] = units.get(obj.name)
117144

118-
new_ds = attach_units(ds, units)
119-
new_obj = new_ds.get(new_name).rename(old_name)
120-
elif isinstance(obj, Dataset):
121-
attached = {}
122-
rejected_vars = {}
123-
for name, var in obj.variables.items():
124-
unit = units.get(name)
125-
try:
126-
converted = attach_units_variable(var, unit)
127-
attached[name] = converted
128-
except ValueError as e:
129-
rejected_vars[name] = (unit, e)
130-
131-
if rejected_vars:
132-
raise ValueError(format_error_message(rejected_vars, "attach"))
133-
134-
data_vars = {
135-
name: var for name, var in attached.items() if name not in obj._coord_names
136-
}
137-
coords = {
138-
name: var for name, var in attached.items() if name in obj._coord_names
139-
}
140-
141-
new_obj = Dataset(data_vars=data_vars, coords=coords, attrs=obj.attrs)
142-
else:
143-
raise ValueError(f"cannot attach units to {obj!r}: unknown type")
145+
try:
146+
new_obj = call_on_dataset(
147+
attach_units_dataset, obj, name=temporary_name, units=units
148+
)
149+
except ValueError as e:
150+
(rejected_vars,) = e.args
151+
if temporary_name in rejected_vars:
152+
rejected_vars[obj.name] = rejected_vars.pop(temporary_name)
153+
154+
raise ValueError(format_error_message(rejected_vars, "attach")) from e
144155

145156
return new_obj
146157

@@ -192,87 +203,81 @@ def convert_units_variable(variable, units):
192203
return new_obj
193204

194205

206+
def convert_units_dataset(obj, units):
207+
converted = {}
208+
failed = {}
209+
for name, var in obj.variables.items():
210+
unit = units.get(name)
211+
try:
212+
converted[name] = convert_units_variable(var, unit)
213+
except (ValueError, pint.errors.PintTypeError) as e:
214+
failed[name] = e
215+
216+
if failed:
217+
raise ValueError(failed)
218+
219+
return dataset_from_variables(converted, obj._coord_names, obj.attrs)
220+
221+
195222
def convert_units(obj, units):
196-
if isinstance(obj, DataArray):
197-
original_name = obj.name
198-
name = obj.name if obj.name is not None else "<this-array>"
223+
if not isinstance(obj, (DataArray, Dataset)):
224+
raise ValueError(f"cannot convert object: {obj!r}: unknown type")
199225

200-
units_ = units.copy()
201-
if obj.name in units_:
202-
units_[name] = units_[obj.name]
226+
if isinstance(obj, DataArray):
227+
units = units.copy()
228+
if obj.name in units:
229+
units[temporary_name] = units.pop(obj.name)
203230

204-
ds = obj.rename(name).to_dataset()
205-
converted = convert_units(ds, units_)
231+
try:
232+
new_obj = call_on_dataset(
233+
convert_units_dataset, obj, name=temporary_name, units=units
234+
)
235+
except ValueError as e:
236+
(failed,) = e.args
237+
if temporary_name in failed:
238+
failed[obj.name] = failed.pop(temporary_name)
206239

207-
new_obj = converted[name].rename(original_name)
208-
elif isinstance(obj, Dataset):
209-
converted = {}
210-
failed = {}
211-
for name, var in obj.variables.items():
212-
unit = units.get(name)
213-
try:
214-
converted[name] = convert_units_variable(var, unit)
215-
except (ValueError, pint.errors.PintTypeError) as e:
216-
failed[name] = e
217-
218-
if failed:
219-
raise ValueError(format_error_message(failed, "convert"))
220-
221-
coords = {
222-
name: var for name, var in converted.items() if name in obj._coord_names
223-
}
224-
data_vars = {
225-
name: var for name, var in converted.items() if name not in obj._coord_names
226-
}
227-
228-
new_obj = Dataset(data_vars=data_vars, coords=coords, attrs=obj.attrs)
229-
else:
230-
raise ValueError(f"cannot convert object: {obj}")
240+
raise ValueError(format_error_message(failed, "convert")) from e
231241

232242
return new_obj
233243

234244

235-
def extract_units(obj):
236-
if isinstance(obj, Dataset):
237-
units = extract_unit_attributes(obj)
238-
dims = obj.dims
239-
units.update(
240-
{
241-
name: array_extract_units(value.data)
242-
for name, value in obj.variables.items()
243-
if name not in dims
244-
}
245-
)
246-
elif isinstance(obj, DataArray):
247-
original_name = obj.name
248-
name = obj.name if obj.name is not None else "<this-array>"
245+
def extract_units_dataset(obj):
246+
return {name: array_extract_units(var.data) for name, var in obj.variables.items()}
249247

250-
ds = obj.rename(name).to_dataset()
251248

252-
units = extract_units(ds)
253-
units[original_name] = units.pop(name)
254-
else:
249+
def extract_units(obj):
250+
if not isinstance(obj, (DataArray, Dataset)):
255251
raise ValueError(f"unknown type: {type(obj)}")
256252

257-
return units
253+
unit_attributes = extract_unit_attributes(obj)
258254

255+
units = call_on_dataset(extract_units_dataset, obj, name=temporary_name)
256+
if temporary_name in units:
257+
units[obj.name] = units.pop(temporary_name)
259258

260-
def extract_unit_attributes(obj, attr="units"):
261-
if isinstance(obj, DataArray):
262-
original_name = obj.name
263-
name = obj.name if obj.name is not None else "<this-array>"
259+
units_ = unit_attributes.copy()
260+
units_.update({k: v for k, v in units.items() if v is not None})
264261

265-
ds = obj.rename(name).to_dataset()
262+
return units_
266263

267-
units = extract_unit_attributes(ds)
268-
units[original_name] = units.pop(name)
269-
elif isinstance(obj, Dataset):
270-
units = {name: var.attrs.get(attr, None) for name, var in obj.variables.items()}
271-
else:
264+
265+
def extract_unit_attributes_dataset(obj, attr="units"):
266+
return {name: var.attrs.get(attr, None) for name, var in obj.variables.items()}
267+
268+
269+
def extract_unit_attributes(obj, attr="units"):
270+
if not isinstance(obj, (DataArray, Dataset)):
272271
raise ValueError(
273272
f"cannot retrieve unit attributes from unknown type: {type(obj)}"
274273
)
275274

275+
units = call_on_dataset(
276+
extract_unit_attributes_dataset, obj, name=temporary_name, attr=attr
277+
)
278+
if temporary_name in units:
279+
units[obj.name] = units.pop(temporary_name)
280+
276281
return units
277282

278283

@@ -281,51 +286,34 @@ def strip_units_variable(var):
281286
return var.copy(data=data)
282287

283288

284-
def strip_units(obj):
285-
if isinstance(obj, DataArray):
286-
original_name = obj.name
287-
name = obj.name if obj.name is not None else "<this-array>"
288-
ds = obj.rename(name).to_dataset()
289-
stripped = strip_units(ds)
289+
def strip_units_dataset(obj):
290+
variables = {name: strip_units_variable(var) for name, var in obj.variables.items()}
290291

291-
new_obj = stripped[name].rename(original_name)
292-
elif isinstance(obj, Dataset):
293-
data_vars = {
294-
name: strip_units_variable(variable)
295-
for name, variable in obj.variables.items()
296-
if name not in obj._coord_names
297-
}
298-
coords = {
299-
name: strip_units_variable(variable)
300-
for name, variable in obj.variables.items()
301-
if name in obj._coord_names
302-
}
303-
304-
new_obj = Dataset(data_vars=data_vars, coords=coords, attrs=obj.attrs)
305-
else:
292+
return dataset_from_variables(variables, obj._coord_names, obj.attrs)
293+
294+
295+
def strip_units(obj):
296+
if not isinstance(obj, (DataArray, Dataset)):
306297
raise ValueError("cannot strip units from {obj!r}: unknown type")
307298

308-
return new_obj
299+
return call_on_dataset(strip_units_dataset, obj, name=temporary_name)
309300

310301

311-
def strip_unit_attributes(obj, attr="units"):
312-
if isinstance(obj, DataArray):
313-
original_name = obj.name
314-
name = obj.name if obj.name is not None else "<this-array>"
302+
def strip_unit_attributes_dataset(obj, attr="units"):
303+
new_obj = obj.copy()
304+
for var in new_obj.variables.values():
305+
var.attrs.pop(attr, None)
315306

316-
ds = obj.rename(name).to_dataset()
307+
return new_obj
317308

318-
stripped = strip_unit_attributes(ds)
319309

320-
new_obj = stripped[name].rename(original_name)
321-
elif isinstance(obj, Dataset):
322-
new_obj = obj.copy()
323-
for var in new_obj.variables.values():
324-
var.attrs.pop(attr, None)
325-
else:
310+
def strip_unit_attributes(obj, attr="units"):
311+
if not isinstance(obj, (DataArray, Dataset)):
326312
raise ValueError(f"cannot strip unit attributes from unknown type: {type(obj)}")
327313

328-
return new_obj
314+
return call_on_dataset(
315+
strip_unit_attributes_dataset, obj, name=temporary_name, attr=attr
316+
)
329317

330318

331319
def slice_extract_units(indexer):

pint_xarray/tests/test_conversion.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,7 @@ class TestXarrayFunctions:
208208
@pytest.mark.parametrize(
209209
"units",
210210
(
211+
pytest.param({}, id="empty units"),
211212
pytest.param({"a": None, "b": None, "u": None, "x": None}, id="no units"),
212213
pytest.param(
213214
{"a": unit_registry.m, "b": unit_registry.m, "u": None, "x": None},

0 commit comments

Comments
 (0)