Skip to content

Commit f8a4374

Browse files
committed
Quantify/Dequantify working & tested on DataArrays
1 parent 92dcf5a commit f8a4374

File tree

3 files changed

+109
-33
lines changed

3 files changed

+109
-33
lines changed

pintxarray/accessors.py

Lines changed: 46 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# TODO is it possible to import pint-xarray from within xarray if pint is present?
22
from xarray import (register_dataarray_accessor, register_dataset_accessor,
3-
DataArray, Dataset)
3+
DataArray, Dataset, Variable)
44
from xarray.core.npcompat import IS_NEP18_ACTIVE
55

66
import numpy as np
@@ -69,6 +69,13 @@ def _array_attach_units(data, unit, convert_from=None):
6969
return quantity
7070

7171

72+
def _dequantify_variable(var):
73+
new_var = Variable(dims=var.dims, data=var.data.magnitude,
74+
attrs=var.attrs)
75+
new_var.attrs['units'] = str(var.data.units)
76+
return new_var
77+
78+
7279
@register_dataarray_accessor("pint")
7380
class PintDataArrayAccessor:
7481
"""
@@ -115,17 +122,18 @@ def quantify(self, units=None, unit_registry=None, registry_kwargs=None):
115122
* wavelength (wavelength) np.array 1e-4, 2e-4, 4e-4, 6e-4, 1e-3, 2e-3
116123
"""
117124

125+
# TODO should also quantify coordinates (once explicit indexes ready)
126+
118127
if isinstance(self.da.data, Quantity):
119-
raise ValueError
128+
raise ValueError(f"Cannot attach unit {units} to quantity: data "
129+
f"already has units {self.da.data.units}")
120130

121131
if unit_registry is None:
122132
if registry_kwargs is None:
123133
registry_kwargs = {}
134+
registry_kwargs.update(force_ndarray=True)
135+
# TODO should this registry object then be stored somewhere global?
124136
unit_registry = pint.UnitRegistry(**registry_kwargs)
125-
else:
126-
if registry_kwargs is not None:
127-
raise ValueError("Cannot supply registry kwargs without "
128-
"supplying a registry")
129137

130138
if units is None:
131139
# TODO option to read and decode units according to CF conventions (see MetPy)?
@@ -134,6 +142,7 @@ def quantify(self, units=None, unit_registry=None, registry_kwargs=None):
134142
elif isinstance(units, Unit):
135143
# TODO do we have to check what happens if someone passes a Unit instance
136144
# without creating a unit registry?
145+
# TODO and what happens if they pass in a Unit from a different registry
137146
pass
138147
else:
139148
units = unit_registry.Unit(units)
@@ -142,13 +151,29 @@ def quantify(self, units=None, unit_registry=None, registry_kwargs=None):
142151

143152
# TODO should we (temporarily) remove the attrs here so that they don't become inconsistent?
144153
return DataArray(dims=self.da.dims, data=quantity,
145-
coords=self.da.coords, attrs=self.da.attrs )
154+
coords=self.da.coords, attrs=self.da.attrs)
155+
156+
def dequantify(self):
157+
"""
158+
Removes units from the DataArray and it's coordinates.
159+
160+
Will replace `.attrs['units']` on each variable with a string
161+
representation of the `pint.Unit` instance.
146162
147-
def dequantify(self, encode_cf=True):
148-
da = DataArray(dim=self.da.dims, data=self.da.pint.magnitude,
149-
coords=self.da.coords, attrs=self.da.attrs,
150-
encoding=self.da.encoding)
151-
da.attrs['units'] = self.da.pint.units
163+
Returns
164+
-------
165+
dequantified - DataArray whose array data is unitless, and of the type
166+
that was previously wrapped by `pint.Quantity`.
167+
"""
168+
169+
if not isinstance(self.da.data, Quantity):
170+
raise ValueError("Cannot remove units from data that does not have"
171+
" units")
172+
173+
# TODO also dequantify coords (once explicit indexes ready)
174+
da = DataArray(dims=self.da.dims, data=self.da.pint.magnitude,
175+
coords=self.da.coords, attrs=self.da.attrs)
176+
da.attrs['units'] = str(self.da.data.units)
152177
return da
153178

154179
@property
@@ -173,6 +198,15 @@ def units(self, units):
173198
def dimensionality(self):
174199
return self.da.data.dimensionality
175200

201+
@property
202+
def registry(self):
203+
# TODO is this a bad idea? (see GH issue #1071 in pint)
204+
return self.data._REGISTRY
205+
206+
@registry.setter
207+
def registry(self, _):
208+
raise AttributeError("Don't try to change the registry once created")
209+
176210
def to(self, units):
177211
quantity = self.da.data.to(units)
178212
return DataArray(dim=self.da.dims, data=quantity,

pintxarray/tests/test_accessors.py

Lines changed: 47 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,10 @@
88

99
from pint import UnitRegistry
1010
from pint.unit import Unit
11-
from pint.errors import DimensionalityError
11+
from pint.errors import UndefinedUnitError, DimensionalityError
1212

1313
from pintxarray.accessors import PintDataArrayAccessor, PintDatasetAccessor
14-
from .utils import extract_units
14+
from .utils import extract_units, raises_regex
1515

1616

1717
# make sure scalars are converted to 0d arrays so quantities can
@@ -37,44 +37,70 @@ def example_quantity_da():
3737

3838

3939
class TestQuantifyDataArray:
40-
@pytest.mark.skip(reason="Not yet implemented")
41-
def test_attach_units_given_registry(self):
42-
...
40+
def test_attach_units_given_registry(self, example_unitless_da):
41+
orig = example_unitless_da
42+
ureg = UnitRegistry(force_ndarray=True)
43+
result = orig.pint.quantify('m', unit_registry=ureg)
44+
assert_array_equal(result.data.magnitude, orig.data)
45+
assert result.data.units == ureg.Unit('m')
4346

4447
def test_attach_units(self, example_unitless_da):
4548
orig = example_unitless_da
4649
result = orig.pint.quantify('m')
4750
assert_array_equal(result.data.magnitude, orig.data)
51+
# TODO better comparisons for when you can't access the unit_registry?
4852
assert str(result.data.units) == 'meter'
4953

50-
def test_attach_units_from_attrs(self):
51-
...
54+
def test_attach_units_from_attrs(self, example_unitless_da):
55+
orig = example_unitless_da
56+
result = orig.pint.quantify()
57+
assert_array_equal(result.data.magnitude, orig.data)
58+
assert str(result.data.units) == 'meter'
5259

53-
def test_attach_unit_class(self):
54-
...
60+
def test_attach_unit_class(self, example_unitless_da):
61+
orig = example_unitless_da
62+
ureg = UnitRegistry(force_ndarray=True)
63+
result = orig.pint.quantify(ureg.Unit('m'), unit_registry=ureg)
64+
assert_array_equal(result.data.magnitude, orig.data)
65+
assert result.data.units == ureg.Unit('m')
5566

5667
def test_error_when_already_units(self, example_quantity_da):
5768
da = example_quantity_da
58-
with pytest.raises(ValueError):
69+
with raises_regex(ValueError, "already has units"):
5970
da.pint.quantify()
6071

61-
def test_error_on_nonsense_units(self):
62-
...
72+
def test_error_on_nonsense_units(self, example_unitless_da):
73+
da = example_unitless_da
74+
with pytest.raises(UndefinedUnitError):
75+
da.pint.quantify(units='aecjhbav')
6376

64-
def test_registry_kwargs(self):
65-
...
77+
def test_registry_kwargs(self, example_unitless_da):
78+
orig = example_unitless_da
79+
result = orig.pint.quantify(registry_kwargs=
80+
{'auto_reduce_dimensions': True})
81+
assert(result.data._REGISTRY.auto_reduce_dimensions) == True
6682

6783

68-
@pytest.mark.skip(reason="Not yet implemented")
6984
class TestDequantifyDataArray:
70-
def test_strip_units(self):
71-
...
85+
def test_strip_units(self, example_quantity_da):
86+
result = example_quantity_da.pint.dequantify()
87+
assert isinstance(result.data, np.ndarray)
88+
assert isinstance(result.coords['x'].data, np.ndarray)
7289

73-
def test_error_if_no_units(self):
74-
...
90+
def test_error_if_no_units(self, example_unitless_da):
91+
with raises_regex(ValueError, "does not have units"):
92+
example_unitless_da.pint.dequantify()
7593

76-
def test_roundtrip_attrs(self):
77-
...
94+
def test_attrs_reinstated(self, example_quantity_da):
95+
da = example_quantity_da
96+
result = da.pint.dequantify()
97+
assert result.attrs['units'] == 'meter'
98+
99+
def test_roundtrip_data(self, example_unitless_da):
100+
orig = example_unitless_da
101+
quantified = orig.pint.quantify()
102+
result = quantified.pint.dequantify()
103+
assert_equal(result, orig)
78104

79105

80106
@pytest.mark.skip(reason="Not yet implemented")

pintxarray/tests/utils.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,24 @@
1+
from contextlib import contextmanager
2+
import re
3+
4+
import pytest
5+
16
import xarray as xr
27

38
from pint.quantity import Quantity
49

510

11+
@contextmanager
12+
def raises_regex(error, pattern):
13+
__tracebackhide__ = True
14+
with pytest.raises(error) as excinfo:
15+
yield
16+
message = str(excinfo.value)
17+
if not re.search(pattern, message):
18+
raise AssertionError(
19+
f"exception {excinfo.value!r} did not match pattern {pattern!r}"
20+
)
21+
622
def array_extract_units(obj):
723
if isinstance(obj, (xr.Variable, xr.DataArray, xr.Dataset)):
824
obj = obj.data

0 commit comments

Comments
 (0)