Skip to content

Commit 8054b36

Browse files
committed
basic implementation of da.quantify
1 parent 6d498c2 commit 8054b36

File tree

1 file changed

+119
-80
lines changed

1 file changed

+119
-80
lines changed

pintxarray/accessors.py

Lines changed: 119 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,22 @@
1-
from importlib import import_module
1+
# TODO is it possible to import pint-xarray from within xarray if pint is present?
2+
from xarray import (register_dataarray_accessor, register_dataset_accessor,
3+
DataArray, Dataset)
4+
from xarray.core.npcompat import IS_NEP18_ACTIVE
5+
6+
import numpy as np
27

38
import pint
49
from pint.quantity import Quantity
510
from pint.unit import Unit
6-
# TODO is it possible to import pint-xarray from within xarray if pint is present?
7-
import xarray as xr
8-
import numpy as np
9-
from xarray.core.npcompat import IS_NEP18_ACTIVE
1011

1112

1213
if not hasattr(Quantity, "__array_function__"):
1314
raise ImportError("Imported version of pint does not implement "
14-
"__array_function__ yet")
15+
"__array_function__")
1516

1617
if not IS_NEP18_ACTIVE:
1718
raise ImportError("NUMPY_EXPERIMENTAL_ARRAY_FUNCTION is not enabled")
1819

19-
try:
20-
mpl = import_module("matplotlib")
21-
has_mpl = True
22-
except ImportError:
23-
has_mpl = False
24-
else:
25-
# TODO can we do this without initialising a Unit Registry?
26-
unit_registry = pint.UnitRegistry()
27-
unit_registry.setup_matplotlib(True)
28-
2920

3021
# TODO could/should we overwrite xr.open_dataset and xr.open_mfdataset to make
3122
# them apply units upon loading???
@@ -35,32 +26,28 @@
3526
# TODO type hints
3627
# TODO f-strings
3728

38-
def array_attach_units(data, unit, convert_from=None):
29+
30+
def _array_attach_units(data, unit, convert_from=None):
31+
"""
32+
Internal utility function for attaching units to a numpy-like array,
33+
converting them, or throwing the correct error.
34+
"""
35+
3936
if isinstance(data, Quantity):
4037
if not convert_from:
41-
raise ValueError(
42-
"cannot attach unit {unit} to quantity ({data.units})".format(
43-
unit=unit, data=data
44-
)
45-
)
38+
raise ValueError(f"Cannot attach unit {unit} to quantity: data "
39+
f"already has units {data.units}")
4640
elif isinstance(convert_from, Unit):
4741
data = data.magnitude
4842
elif convert_from is True: # intentionally accept exactly true
4943
if data.check(unit):
5044
convert_from = data.units
5145
data = data.magnitude
5246
else:
53-
raise ValueError(
54-
"cannot convert quantity ({data.units}) to {unit}".format(
55-
unit=unit, data=data
56-
)
57-
)
47+
raise ValueError("Cannot convert quantity from {data.units} "
48+
"to {unit}")
5849
else:
59-
raise ValueError(
60-
"cannot convert from invalid unit {convert_from}".format(
61-
convert_from=convert_from
62-
)
63-
)
50+
raise ValueError("Cannot convert from invalid unit {convert_from}")
6451

6552
# to make sure we also encounter the case of "equal if converted"
6653
if convert_from is not None:
@@ -73,6 +60,7 @@ def array_attach_units(data, unit, convert_from=None):
7360
try:
7461
quantity = data * unit
7562
except np.core._exceptions.UFuncTypeError:
63+
# from @keewis in xarray.tests.test_units - unsure what this checks?
7664
if unit != 1:
7765
raise
7866

@@ -81,77 +69,128 @@ def array_attach_units(data, unit, convert_from=None):
8169
return quantity
8270

8371

84-
# TODO Error checking (that data is actually a quantity etc)
85-
86-
# TODO refactor with an apply_to(da, data_method) function?
72+
@register_dataarray_accessor("pint")
73+
class PintDataArrayAccessor:
74+
"""
75+
Access methods for DataArrays with units using Pint.
8776
77+
Methods and attributes can be accessed through the `.pint` attribute.
78+
"""
8879

89-
@xr.register_dataarray_accessor("pint")
90-
class PintDataArrayAccessor:
9180
def __init__(self, da):
9281
self.da = da
9382

94-
def quantify(self, units=None, unit_registry=None, decode_cf=False):
95-
# TODO read and decode units according to CF conventions (see MetPy)
96-
if not units:
97-
if decode_cf:
98-
# TODO unit = Unit(_decode_cf(self.da.attrs['units']))
99-
raise NotImplementedError
100-
else:
101-
units = Unit(self.da.attrs['units'])
83+
def quantify(self, units=None, unit_registry=None, registry_kwargs=None):
84+
"""
85+
Attaches units to the DataArray.
86+
87+
Units can be specified as a pint.Unit or as a string, which will will
88+
be parsed by the given unit registry. If no units are specified then
89+
the units will be parsed from the `'units'` entry of the DataArray's
90+
`.attrs`. Will raise a ValueError if the DataArray already contains a
91+
unit-aware array.
92+
93+
Parameters
94+
----------
95+
units : pint.Unit or str, optional
96+
Physical units to use for this DataArray. If not provided, will try
97+
to read them from `DataArray.attrs['units']` using pint's parser.
98+
unit_registry : `pint.UnitRegistry`, optional
99+
Unit registry to be used for the units attached to this DataArray.
100+
If not given then a default registry will be created.
101+
registry_kwargs : dict, optional
102+
Keyword arguments to be passed to `pint.UnitRegistry`.
103+
104+
Returns
105+
-------
106+
quantified - DataArray whose wrapped array data will now be a Quantity
107+
array with the specified units.
108+
109+
Examples
110+
--------
111+
>>> da.pint.quantify(units='Hz')
112+
<xarray.DataArray (frequency: 6)>
113+
Quantity([ 0.4, 0.9, 1.7, 4.8, 3.2, 9.1], 'Hz')
114+
Coordinates:
115+
* wavelength (wavelength) np.array 1e-4, 2e-4, 4e-4, 6e-4, 1e-3, 2e-3
116+
"""
117+
118+
if isinstance(self.da.data, Quantity):
119+
raise ValueError
120+
121+
if unit_registry is None:
122+
if registry_kwargs is None:
123+
registry_kwargs = {}
124+
unit_registry = pint.UnitRegistry(**registry_kwargs)
125+
else:
126+
if registry_kwargs is not None:
127+
raise ValueError("Cannot supply registry kwargs without "
128+
"supplying a registry")
129+
130+
if units is None:
131+
# TODO option to read and decode units according to CF conventions (see MetPy)?
132+
attr_units = self.da.attrs['units']
133+
units = unit_registry.parse_expression(attr_units)
134+
elif isinstance(units, Unit):
135+
# TODO do we have to check what happens if someone passes a Unit instance
136+
# without creating a unit registry?
137+
pass
138+
else:
139+
units = unit_registry.Unit(units)
102140

103-
quantity = array_attach_units(self.da.data, units)
104-
# TODO should we (temporarily) remove the attrs here?
105-
return xr.DataArray(dim=self.da.dims, data=quantity,
106-
coords=self.da.coords, attrs=self.da.attrs,
107-
encoding=self.da.encoding)
141+
quantity = _array_attach_units(self.da.data, units, convert_from=None)
142+
143+
# TODO should we (temporarily) remove the attrs here so that they don't become inconsistent?
144+
return DataArray(dims=self.da.dims, data=quantity,
145+
coords=self.da.coords, attrs=self.da.attrs )
108146

109147
def dequantify(self, encode_cf=True):
110-
da = xr.DataArray(dim=self.da.dims, data=self.da.pint.magnitude,
111-
coords=self.da.coords, attrs=self.da.attrs,
112-
encoding=self.da.encoding)
148+
da = DataArray(dim=self.da.dims, data=self.da.pint.magnitude,
149+
coords=self.da.coords, attrs=self.da.attrs,
150+
encoding=self.da.encoding)
113151
da.attrs['units'] = self.da.pint.units
114152
return da
115153

116154
@property
155+
def magnitude(self):
156+
return self.da.data.magnitude
157+
158+
@magnitude.setter
159+
def magnitude(self, da):
160+
self.da = DataArray(dim=self.da.dims, data=da.data,
161+
coords=self.da.coords, attrs=self.da.attrs)
162+
@property
117163
def units(self):
118164
return self.da.data.units
119165

120166
@units.setter
121167
def units(self, units):
122-
quantity = array_attach_units(self.da.data, units)
123-
self.da = xr.DataArray(dim=self.da.dims, data=quantity,
124-
coords=self.da.coords, attrs=self.da.attrs,
125-
encoding=self.da.encoding)
168+
quantity = _array_attach_units(self.da.data, units)
169+
self.da = DataArray(dim=self.da.dims, data=quantity,
170+
coords=self.da.coords, attrs=self.da.attrs)
126171

127172
@property
128-
def magnitude(self):
129-
return self.da.data.magnitude
130-
131-
@magnitude.setter
132-
def magnitude(self, da):
133-
self.da = xr.DataArray(dim=self.da.dims, data=da.data,
134-
coords=self.da.coords, attrs=self.da.attrs,
135-
encoding=self.da.encoding)
173+
def dimensionality(self):
174+
return self.da.data.dimensionality
136175

137176
def to(self, units):
138177
quantity = self.da.data.to(units)
139-
return xr.DataArray(dim=self.da.dims, data=quantity,
140-
coords=self.da.coords, attrs=self.da.attrs,
141-
encoding=self.da.encoding)
178+
return DataArray(dim=self.da.dims, data=quantity,
179+
coords=self.da.coords, attrs=self.da.attrs,
180+
encoding=self.da.encoding)
142181

143182
def to_base_units(self):
144183
quantity = self.da.data.to_base_units()
145-
return xr.DataArray(dim=self.da.dims, data=quantity,
146-
coords=self.da.coords, attrs=self.da.attrs,
147-
encoding=self.da.encoding)
184+
return DataArray(dim=self.da.dims, data=quantity,
185+
coords=self.da.coords, attrs=self.da.attrs,
186+
encoding=self.da.encoding)
148187

149188
# TODO integrate with the uncertainties package here...?
150189
def plus_minus(self, value, error, relative=False):
151190
quantity = self.da.data.plus_minus(value, error, relative)
152-
return xr.DataArray(dim=self.da.dims, data=quantity,
153-
coords=self.da.coords, attrs=self.da.attrs,
154-
encoding=self.da.encoding)
191+
return DataArray(dim=self.da.dims, data=quantity,
192+
coords=self.da.coords, attrs=self.da.attrs,
193+
encoding=self.da.encoding)
155194

156195
def sel(self, indexers=None, method=None, tolerance=None, drop=False,
157196
**indexers_kwargs):
@@ -162,7 +201,7 @@ def loc(self):
162201
...
163202

164203

165-
@xr.register_dataset_accessor("pint")
204+
@register_dataset_accessor("pint")
166205
class PintDatasetAccessor:
167206
def __init__(self, ds):
168207
self.ds = ds
@@ -171,19 +210,19 @@ def quantify(self, unit_registry=None, decode_cf=False):
171210
quantified_vars = {name: da.pint.quantify(unit_registry=unit_registry,
172211
decode_cf=decode_cf)
173212
for name, da in self.ds.items()}
174-
return xr.Dataset(quantified_vars, attrs=self.ds.attrs,
175-
encoding=self.ds.encoding)
213+
return Dataset(quantified_vars, attrs=self.ds.attrs,
214+
encoding=self.ds.encoding)
176215

177216
def dequantify(self):
178217
dequantified_vars = {name: da.pint.to_base_units()
179218
for name, da in self.ds.items()}
180-
return xr.Dataset(dequantified_vars, attrs=self.ds.attrs,
181-
encoding=self.ds.encoding)
219+
return Dataset(dequantified_vars, attrs=self.ds.attrs,
220+
encoding=self.ds.encoding)
182221

183222
def to_base_units(self):
184223
base_vars = {name: da.pint.to_base_units()
185224
for name, da in self.ds.items()}
186-
return xr.Dataset(base_vars, attrs=self.ds.attrs, encoding=self.ds.encoding)
225+
return Dataset(base_vars, attrs=self.ds.attrs, encoding=self.ds.encoding)
187226

188227
# TODO way to change every variable in ds to be expressed in a new units system?
189228

0 commit comments

Comments
 (0)