Skip to content

Commit 7c93089

Browse files
authored
use .pint.quantify() as a identity operator (#175)
* don't raise separately for the data * don't raise if we quantify with the same unit * add more tests * fix a bug in the registry detection algorithm It failed to detect unit objects in the attributes * also ignore None values as new units * don't complain if we try to attach the same units * same for dataset * changelog * use assert_units_equal as well * properly link the docs [skip-ci]
1 parent 7a10ea4 commit 7c93089

File tree

5 files changed

+113
-48
lines changed

5 files changed

+113
-48
lines changed

docs/whats-new.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@ What's new
1212
By `Justus Magin <https://github.com/keewis>`_.
1313
- fix "quantifying" dimension coordinates (:issue:`105`, :pull:`174`).
1414
By `Justus Magin <https://github.com/keewis>`_.
15+
- allow using :py:meth:`DataArray.pint.quantify` and :py:meth:`Dataset.pint.quantify`
16+
as identity operators (:issue:`47`, :pull:`175`).
17+
By `Justus Magin <https://github.com/keewis>`_.
1518

1619
0.2.1 (26 Jul 2021)
1720
-------------------

pint_xarray/accessors.py

Lines changed: 9 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import itertools
33

44
import pint
5-
from pint import Quantity, Unit
5+
from pint import Unit
66
from xarray import register_dataarray_accessor, register_dataset_accessor
77
from xarray.core.dtypes import NA
88

@@ -71,16 +71,6 @@ def zip_mappings(*mappings, fill_value=None):
7171
return zipped
7272

7373

74-
def merge_mappings(first, *mappings):
75-
result = first.copy()
76-
for mapping in mappings:
77-
result.update(
78-
{key: value for key, value in mapping.items() if value is not None}
79-
)
80-
81-
return result
82-
83-
8474
def units_to_str_or_none(mapping, unit_format):
8575
formatter = str if not unit_format else lambda v: unit_format.format(v)
8676

@@ -109,8 +99,8 @@ def either_dict_or_kwargs(positional, keywords, method_name):
10999

110100

111101
def get_registry(unit_registry, new_units, existing_units):
112-
units = merge_mappings(existing_units, new_units)
113-
registries = {unit._REGISTRY for unit in units.values() if isinstance(unit, Unit)}
102+
units = itertools.chain(new_units.values(), existing_units.values())
103+
registries = {unit._REGISTRY for unit in units if isinstance(unit, Unit)}
114104

115105
if unit_registry is None:
116106
if not registries:
@@ -133,7 +123,7 @@ def get_registry(unit_registry, new_units, existing_units):
133123

134124

135125
def _decide_units(units, registry, unit_attribute):
136-
if units is _default and unit_attribute is _default:
126+
if units is _default and unit_attribute in (None, _default):
137127
# or warn and return None?
138128
raise ValueError("no units given")
139129
elif units in no_unit_values or isinstance(units, Unit):
@@ -321,13 +311,6 @@ def quantify(self, units=_default, unit_registry=None, **unit_kwargs):
321311
array([0.4, 0.9])
322312
Dimensions without coordinates: wavelength
323313
"""
324-
325-
if isinstance(self.da.data, Quantity):
326-
raise ValueError(
327-
f"Cannot attach unit {units} to quantity: data "
328-
f"already has units {self.da.data.units}"
329-
)
330-
331314
if units is None or isinstance(units, (str, pint.Unit)):
332315
if self.da.name in unit_kwargs:
333316
raise ValueError(
@@ -347,11 +330,11 @@ def quantify(self, units=_default, unit_registry=None, **unit_kwargs):
347330
new_units = {}
348331
invalid_units = {}
349332
for name, (unit, attr) in possible_new_units.items():
350-
if unit is not _default or attr is not _default:
333+
if unit not in (_default, None) or attr not in (_default, None):
351334
try:
352335
new_units[name] = _decide_units(unit, registry, attr)
353336
except (ValueError, pint.UndefinedUnitError) as e:
354-
if unit is not _default:
337+
if unit not in (_default, None):
355338
type = "parameter"
356339
reported_unit = unit
357340
else:
@@ -373,7 +356,7 @@ def quantify(self, units=_default, unit_registry=None, **unit_kwargs):
373356
for name, (old, new) in zip_mappings(
374357
existing_units, new_units, fill_value=_default
375358
).items()
376-
if old is not _default and new is not _default
359+
if old is not _default and new is not _default and old != new
377360
}
378361
if overwritten_units:
379362
errors = {
@@ -1062,7 +1045,7 @@ def quantify(self, units=_default, unit_registry=None, **unit_kwargs):
10621045
new_units = {}
10631046
invalid_units = {}
10641047
for name, (unit, attr) in possible_new_units.items():
1065-
if unit is not _default or attr is not _default:
1048+
if unit is not _default or attr not in (None, _default):
10661049
try:
10671050
new_units[name] = _decide_units(unit, registry, attr)
10681051
except (ValueError, pint.UndefinedUnitError) as e:
@@ -1088,7 +1071,7 @@ def quantify(self, units=_default, unit_registry=None, **unit_kwargs):
10881071
for name, (old, new) in zip_mappings(
10891072
existing_units, new_units, fill_value=_default
10901073
).items()
1091-
if old is not _default and new is not _default
1074+
if old is not _default and new is not _default and old != new
10921075
}
10931076
if overwritten_units:
10941077
errors = {

pint_xarray/conversion.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,9 @@ def array_attach_units(data, unit):
3333
raise ValueError(f"cannot use {unit!r} as a unit")
3434

3535
if isinstance(data, pint.Quantity):
36+
if data.units == unit:
37+
return data
38+
3639
raise ValueError(
3740
f"Cannot attach unit {unit!r} to quantity: data "
3841
f"already has units {data.units}"

pint_xarray/tests/test_accessors.py

Lines changed: 84 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
assert_equal,
1212
assert_identical,
1313
assert_units_equal,
14-
raises_regex,
1514
requires_bottleneck,
1615
requires_dask_array,
1716
requires_scipy,
@@ -111,10 +110,62 @@ def test_override_units(self, example_unitless_da, no_unit_value):
111110
with pytest.raises(AttributeError):
112111
result["u"].data.units
113112

114-
def test_error_when_already_units(self, example_quantity_da):
113+
def test_error_when_changing_units(self, example_quantity_da):
115114
da = example_quantity_da
116-
with raises_regex(ValueError, "already has units"):
117-
da.pint.quantify()
115+
with pytest.raises(ValueError, match="already has units"):
116+
da.pint.quantify("s")
117+
118+
def test_attach_no_units(self):
119+
arr = xr.DataArray([1, 2, 3], dims="x")
120+
quantified = arr.pint.quantify()
121+
assert_identical(quantified, arr)
122+
assert_units_equal(quantified, arr)
123+
124+
def test_attach_no_new_units(self):
125+
da = xr.DataArray(unit_registry.Quantity([1, 2, 3], "m"), dims="x")
126+
quantified = da.pint.quantify()
127+
assert_identical(quantified, da)
128+
assert_units_equal(quantified, da)
129+
130+
def test_attach_same_units(self):
131+
da = xr.DataArray(unit_registry.Quantity([1, 2, 3], "m"), dims="x")
132+
quantified = da.pint.quantify("m")
133+
assert_identical(quantified, da)
134+
assert_units_equal(quantified, da)
135+
136+
def test_error_when_changing_units_dimension_coordinates(self):
137+
arr = xr.DataArray(
138+
[1, 2, 3],
139+
dims="x",
140+
coords={"x": ("x", [-1, 0, 1], {"units": unit_registry.Unit("m")})},
141+
)
142+
with pytest.raises(ValueError, match="already has units"):
143+
arr.pint.quantify({"x": "s"})
144+
145+
def test_dimension_coordinate_array(self):
146+
ds = xr.Dataset(coords={"x": ("x", [10], {"units": "m"})})
147+
arr = ds.x
148+
149+
# does not actually quantify because `arr` wraps a IndexVariable
150+
# but we still get a `Unit` in the attrs
151+
q = arr.pint.quantify()
152+
assert isinstance(q.attrs["units"], Unit)
153+
154+
def test_dimension_coordinate_array_already_quantified(self):
155+
ds = xr.Dataset(coords={"x": ("x", [10], {"units": unit_registry.Unit("m")})})
156+
arr = ds.x
157+
158+
with pytest.raises(ValueError):
159+
arr.pint.quantify({"x": "s"})
160+
161+
def test_dimension_coordinate_array_already_quantified_same_units(self):
162+
ds = xr.Dataset(coords={"x": ("x", [10], {"units": unit_registry.Unit("m")})})
163+
arr = ds.x
164+
165+
quantified = arr.pint.quantify({"x": "m"})
166+
167+
assert_identical(quantified, arr)
168+
assert_units_equal(quantified, arr)
118169

119170
def test_error_on_nonsense_units(self, example_unitless_da):
120171
da = example_unitless_da
@@ -135,22 +186,6 @@ def test_parse_integer_inverse(self):
135186
result = da.pint.quantify()
136187
assert result.pint.units == Unit("1 / meter")
137188

138-
def test_dimension_coordinate(self):
139-
ds = xr.Dataset(coords={"x": ("x", [10], {"units": "m"})})
140-
arr = ds.x
141-
142-
# does not actually quantify because `arr` wraps a IndexVariable
143-
# but we still get a `Unit` in the attrs
144-
q = arr.pint.quantify()
145-
assert isinstance(q.attrs["units"], Unit)
146-
147-
def test_dimension_coordinate_already_quantified(self):
148-
ds = xr.Dataset(coords={"x": ("x", [10], {"units": unit_registry.Unit("m")})})
149-
arr = ds.x
150-
151-
with pytest.raises(ValueError):
152-
arr.pint.quantify({"x": "s"})
153-
154189

155190
@pytest.mark.parametrize("formatter", ("", "P", "C"))
156191
@pytest.mark.parametrize("modifier", ("", "~"))
@@ -308,8 +343,35 @@ def test_override_units(self, example_unitless_ds, no_unit_value):
308343
)
309344

310345
def test_error_when_already_units(self, example_quantity_ds):
311-
with raises_regex(ValueError, "already has units"):
312-
example_quantity_ds.pint.quantify({"funds": "pounds"})
346+
with pytest.raises(ValueError, match="already has units"):
347+
example_quantity_ds.pint.quantify({"funds": "kg"})
348+
349+
def test_attach_no_units(self):
350+
ds = xr.Dataset({"a": ("x", [1, 2, 3])})
351+
quantified = ds.pint.quantify()
352+
assert_identical(quantified, ds)
353+
assert_units_equal(quantified, ds)
354+
355+
def test_attach_no_new_units(self):
356+
ds = xr.Dataset({"a": ("x", unit_registry.Quantity([1, 2, 3], "m"))})
357+
quantified = ds.pint.quantify()
358+
359+
assert_identical(quantified, ds)
360+
assert_units_equal(quantified, ds)
361+
362+
def test_attach_same_units(self):
363+
ds = xr.Dataset({"a": ("x", unit_registry.Quantity([1, 2, 3], "m"))})
364+
quantified = ds.pint.quantify({"a": "m"})
365+
366+
assert_identical(quantified, ds)
367+
assert_units_equal(quantified, ds)
368+
369+
def test_error_when_changing_units_dimension_coordinates(self):
370+
ds = xr.Dataset(
371+
coords={"x": ("x", [-1, 0, 1], {"units": unit_registry.Unit("m")})},
372+
)
373+
with pytest.raises(ValueError, match="already has units"):
374+
ds.pint.quantify({"x": "s"})
313375

314376
def test_error_on_nonsense_units(self, example_unitless_ds):
315377
ds = example_unitless_ds

pint_xarray/tests/test_conversion.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,20 @@ class TestArrayFunctions:
7878
"already has units",
7979
id="unit object on quantity",
8080
),
81+
pytest.param(
82+
Unit("m"),
83+
Quantity(np.array([0, 1]), "m"),
84+
Quantity(np.array([0, 1]), "m"),
85+
None,
86+
id="unit object on quantity with same unit",
87+
),
88+
pytest.param(
89+
Unit("mm"),
90+
Quantity(np.array([0, 1]), "m"),
91+
None,
92+
"already has units",
93+
id="unit object on quantity with similar unit",
94+
),
8195
),
8296
)
8397
def test_array_attach_units(self, data, unit, expected, match):

0 commit comments

Comments
 (0)