diff --git a/docs/api.rst b/docs/api.rst index bb7cb2f9..f6b7658c 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -65,6 +65,13 @@ DataArray xarray.DataArray.pint.bfill xarray.DataArray.pint.interpolate_na +Wrapping quantity-unaware functions +----------------------------------- +.. autosummary:: + :toctree: generated/ + + pint_xarray.expects + Testing ------- diff --git a/docs/terminology.rst b/docs/terminology.rst index 4fe6534d..df784fae 100644 --- a/docs/terminology.rst +++ b/docs/terminology.rst @@ -5,6 +5,7 @@ Terminology unit-like A `pint`_ unit definition, as accepted by :py:class:`pint.Unit`. - May be either a :py:class:`str` or a :py:class:`pint.Unit` instance. + May be a :py:class:`str`, a :py:class:`pint.Unit` instance or + :py:obj:`None`. .. _pint: https://pint.readthedocs.io/en/stable diff --git a/docs/whats-new.rst b/docs/whats-new.rst index 1b9dcd2e..08f80293 100644 --- a/docs/whats-new.rst +++ b/docs/whats-new.rst @@ -18,6 +18,8 @@ What's new By `Justus Magin `_. - Switch to using pixi for all dependency management (:pull:`314`). By `Justus Magin `_. +- Added the :py:func:`pint_xarray.expects` decorator to allow wrapping quantity-unaware functions (:issue:`141`, :pull:`316`). + By `Justus Magin `_ and `Tom Nicholas `_. 0.5.1 (10 Aug 2025) ------------------- diff --git a/pint_xarray/__init__.py b/pint_xarray/__init__.py index 7fb5b327..d7890a54 100644 --- a/pint_xarray/__init__.py +++ b/pint_xarray/__init__.py @@ -3,13 +3,14 @@ import pint from pint_xarray import accessors, formatting, testing # noqa: F401 +from pint_xarray._expects import expects from pint_xarray.accessors import default_registry as unit_registry from pint_xarray.accessors import setup_registry from pint_xarray.index import PintIndex try: __version__ = version("pint-xarray") -except Exception: +except Exception: # pragma: no cover # Local copy or not installed with setuptools. # Disable minimum version checks on downstream libraries. __version__ = "999" @@ -23,4 +24,5 @@ "unit_registry", "setup_registry", "PintIndex", + "expects", ] diff --git a/pint_xarray/_expects.py b/pint_xarray/_expects.py new file mode 100644 index 00000000..37c4593a --- /dev/null +++ b/pint_xarray/_expects.py @@ -0,0 +1,234 @@ +import functools +import inspect +import itertools +from inspect import Parameter + +import pint +import pint.testing +import xarray as xr + +from pint_xarray.accessors import get_registry +from pint_xarray.conversion import extract_units +from pint_xarray.itertools import zip_mappings + +variable_parameters = (Parameter.VAR_POSITIONAL, Parameter.VAR_KEYWORD) + + +def expects(*args_units, return_value=None, **kwargs_units): + """ + Decorator which ensures the inputs and outputs of the decorated + function are expressed in the expected units. + + Arguments to the decorated function are checked for the specified + units, converting to those units if necessary, and then stripped + of their units before being passed into the undecorated + function. Therefore the undecorated function should expect + unquantified DataArrays, Datasets, or numpy-like arrays, but with + the values expressed in specific units. + + Parameters + ---------- + func : callable + Function to decorate, which accepts zero or more + xarray.DataArrays or numpy-like arrays as inputs, and may + optionally return one or more xarray.DataArrays or numpy-like + arrays. + *args_units : unit-like or mapping of hashable to unit-like, optional + Units to expect for each positional argument given to func. + + The decorator will first check that arguments passed to the + decorated function possess these specific units (or will + attempt to convert the argument to these units), then will + strip the units before passing the magnitude to the wrapped + function. + + A value of None indicates not to check that argument for units + (suitable for flags and other non-data arguments). + return_value : unit-like or list of unit-like or mapping of hashable to unit-like \ + or list of mapping of hashable to unit-like, optional + The expected units of the returned value(s), either as a + single unit or as a list of units. The decorator will attach + these units to the variables returned from the function. + + A value of None indicates not to attach any units to that + return value (suitable for flags and other non-data results). + **kwargs_units : mapping of hashable to unit-like, optional + Unit to expect for each keyword argument given to func. + + The decorator will first check that arguments passed to the decorated + function possess these specific units (or will attempt to convert the + argument to these units), then will strip the units before passing the + magnitude to the wrapped function. + + A value of None indicates not to check that argument for units (suitable + for flags and other non-data arguments). + + Returns + ------- + return_values : Any + Return values of the wrapped function, either a single value or a tuple + of values. These will be given units according to ``return_value``. + + Raises + ------ + TypeError + If any of the units are not a valid type. + ValueError + If the number of arguments or return values does not match the number of + units specified. Also thrown if any parameter does not have a unit + specified. + + See Also + -------- + pint.wraps + + Examples + -------- + Decorating a function which takes one quantified input, but + returns a non-data value (in this case a boolean). + + >>> @expects("deg C") + ... def above_freezing(temp): + ... return temp > 0 + ... + + Decorating a function which allows any dimensions for the array, but also + accepts an optional `weights` keyword argument, which must be dimensionless. + + >>> @expects(None, weights="dimensionless") + ... def mean(da, weights=None): + ... if weights: + ... return da.weighted(weights=weights).mean() + ... else: + ... return da.mean() + ... + """ + + def outer(func): + signature = inspect.signature(func) + + params_units = signature.bind(*args_units, **kwargs_units) + + missing_params = [ + name + for name, p in signature.parameters.items() + if p.kind not in variable_parameters and name not in params_units.arguments + ] + if missing_params: + raise ValueError( + "Missing units for the following parameters: " + + ", ".join(map(repr, missing_params)) + ) + + @functools.wraps(func) + def wrapper(*args, **kwargs): + nonlocal return_value + + params = signature.bind(*args, **kwargs) + # don't apply defaults, as those can't be quantities and thus must + # already be in the correct units + + spec_units = dict( + enumerate( + itertools.chain.from_iterable( + spec.values() if isinstance(spec, dict) else (spec,) + for spec in params_units.arguments.values() + if spec is not None + ) + ) + ) + params_units_ = dict( + enumerate( + itertools.chain.from_iterable( + ( + extract_units(param) + if isinstance(param, (xr.DataArray, xr.Dataset)) + else (param.units,) + ) + for name, param in params.arguments.items() + if isinstance(param, (xr.DataArray, xr.Dataset, pint.Quantity)) + ) + ) + ) + + ureg = get_registry( + None, + dict(spec_units) if spec_units else {}, + dict(params_units_) if params_units else {}, + ) + + errors = [] + for name, (value, units) in zip_mappings( + params.arguments, params_units.arguments + ): + try: + if units is None: + if isinstance(value, pint.Quantity) or ( + isinstance(value, (xr.DataArray, xr.Dataset)) + and value.pint.units + ): + raise TypeError( + "Passed in a quantity where none was expected" + ) + continue + if isinstance(value, pint.Quantity): + params.arguments[name] = value.m_as(units) + elif isinstance(value, (xr.DataArray, xr.Dataset)): + params.arguments[name] = value.pint.to(units).pint.dequantify() + else: + raise TypeError( + f"Attempting to convert non-quantity {value} to {units}." + ) + except Exception as e: + e.add_note( + f"expects: raised while trying to convert parameter {name}" + ) + errors.append(e) + + if errors: + raise ExceptionGroup("Errors while converting parameters", errors) + + result = func(*params.args, **params.kwargs) + + if (isinstance(result, tuple) ^ isinstance(return_value, tuple)) or ( + isinstance(result, tuple) and len(result) != len(return_value) + ): + raise ValueError("mismatched number of return values") + + if result is None: + return + + n_results = len(result) if isinstance(result, tuple) else 1 + + if not isinstance(result, tuple): + result = (result,) + if not isinstance(return_value, tuple): + return_value = (return_value,) + + final_result = [] + errors = [] + for index, (value, units) in enumerate(zip(result, return_value)): + if units is not None: + try: + if isinstance(value, (xr.Dataset, xr.DataArray)): + value = value.pint.quantify(units) + else: + value = ureg.Quantity(value, units) + except Exception as e: + e.add_note( + f"expects: raised while trying to convert return value {index}" + ) + errors.append(e) + + final_result.append(value) + + if errors: + raise ExceptionGroup("Errors while converting return values", errors) + + if n_results == 1: + return final_result[0] + return tuple(final_result) + + return wrapper + + return outer diff --git a/pint_xarray/itertools.py b/pint_xarray/itertools.py new file mode 100644 index 00000000..f03634d0 --- /dev/null +++ b/pint_xarray/itertools.py @@ -0,0 +1,30 @@ +import itertools +from functools import reduce + + +def separate(predicate, iterable): + evaluated = ((predicate(el), el) for el in iterable) + + key = lambda x: x[0] + grouped = itertools.groupby(sorted(evaluated, key=key), key=key) + + groups = {label: [el for _, el in group] for label, group in grouped} + + return groups[False], groups[True] + + +def unique(iterable): + return list(dict.fromkeys(iterable)) + + +def zip_mappings(*mappings): + def common_keys(a, b): + all_keys = unique(itertools.chain(a.keys(), b.keys())) + intersection = set(a.keys()).intersection(b.keys()) + + return [key for key in all_keys if key in intersection] + + keys = list(reduce(common_keys, mappings)) + + for key in keys: + yield key, tuple(m[key] for m in mappings) diff --git a/pint_xarray/tests/test_expects.py b/pint_xarray/tests/test_expects.py new file mode 100644 index 00000000..8d379fe6 --- /dev/null +++ b/pint_xarray/tests/test_expects.py @@ -0,0 +1,243 @@ +import re + +import pint +import pytest +import xarray as xr + +import pint_xarray + +ureg = pint_xarray.unit_registry + + +class TestExpects: + @pytest.mark.parametrize( + ["values", "units", "expected"], + ( + ((ureg.Quantity(1, "m"), 2), ("mm", None, None), 500), + ((ureg.Quantity(1, "m"), ureg.Quantity(0.5, "s")), ("mm", "ms", None), 2), + ( + (xr.DataArray(4).pint.quantify("km"), 2), + ("m", None, None), + xr.DataArray(2000), + ), + ( + ( + xr.DataArray([4, 2, 0]).pint.quantify("cm"), + xr.DataArray([4, 2, 1]).pint.quantify("mg"), + ), + ("m", "g", None), + xr.DataArray([10, 10, 0]), + ), + ( + (ureg.Quantity(16, "m"), 2, ureg.Quantity(4, "s")), + ("mm", None, "ms"), + 2, + ), + ), + ) + def test_args(self, values, units, expected): + @pint_xarray.expects(*units) + def func(a, b, c=1): + return a / (b * c) + + actual = func(*values) + + if isinstance(actual, xr.DataArray): + xr.testing.assert_identical(actual, expected) + elif isinstance(actual, pint.Quantity): + pint.testing.assert_equal(actual, expected) + else: + assert actual == expected + + @pytest.mark.parametrize( + ["value", "units", "error", "message", "multiple"], + ( + ( + ureg.Quantity(1, "m"), + (None, None), + TypeError, + "Passed in a quantity where none was expected", + True, + ), + (1, ("m", None), TypeError, "Attempting to convert non-quantity", True), + ( + 1, + (None,), + ValueError, + "Missing units for the following parameters: 'b'", + False, + ), + ), + ) + def test_args_error(self, value, units, error, message, multiple): + if multiple: + root_error = ExceptionGroup + root_message = "Errors while converting parameters" + else: + root_error = error + root_message = message + + with pytest.raises(root_error, match=root_message) as excinfo: + + @pint_xarray.expects(*units) + def func(a, b=1): + return a * b + + func(value) + + if not multiple: + return + + group = excinfo.value + assert len(group.exceptions) == 1, f"Found {len(group.exceptions)} exceptions" + exc = group.exceptions[0] + assert isinstance( + exc, error + ), f"Unexpected exception type: {type(exc)}, expected {error}" + if not re.search(message, str(exc)): + raise AssertionError(f"exception {exc!r} did not match pattern {message!r}") + + @pytest.mark.parametrize( + ["values", "units", "expected"], + ( + ( + {"a": ureg.Quantity(1, "m"), "b": 2}, + {"a": "mm", "b": None, "c": None}, + 1000, + ), + ( + {"a": 2, "b": ureg.Quantity(100, "cm")}, + {"a": None, "b": "m", "c": None}, + 4, + ), + ( + {"a": ureg.Quantity(1, "m"), "b": ureg.Quantity(0.5, "s")}, + {"a": "mm", "b": "ms", "c": None}, + 4, + ), + ( + {"a": xr.DataArray(4).pint.quantify("km"), "b": 2}, + {"a": "m", "b": None, "c": None}, + xr.DataArray(4000), + ), + ( + { + "a": xr.DataArray([4, 2, 0]).pint.quantify("cm"), + "b": xr.DataArray([4, 2, 1]).pint.quantify("mg"), + }, + {"a": "m", "b": "g", "c": None}, + xr.DataArray([20, 20, 0]), + ), + ), + ) + def test_kwargs(self, values, units, expected): + @pint_xarray.expects(**units) + def func(a, b, c=2): + return a / b * c + + actual = func(**values) + + if isinstance(actual, xr.DataArray): + xr.testing.assert_identical(actual, expected) + elif isinstance(actual, pint.Quantity): + pint.testing.assert_equal(actual, expected) + else: + assert actual == expected + + @pytest.mark.parametrize( + ["values", "return_value_units", "expected"], + ( + ((1, 2), ("m", "s"), (ureg.Quantity(1, "m"), ureg.Quantity(2, "s"))), + ((1, 2), "m / s", ureg.Quantity(0.5, "m / s")), + ((1, 2), None, 0.5), + ( + (xr.DataArray(2), 2), + ("m", "s"), + (xr.DataArray(2).pint.quantify("m"), ureg.Quantity(2, "s")), + ), + ( + (xr.DataArray(2), 2), + "kg / m^2", + xr.DataArray(1).pint.quantify("kg / m^2"), + ), + ), + ) + def test_return_value(self, values, return_value_units, expected): + multiple = isinstance(return_value_units, tuple) + + @pint_xarray.expects(a=None, b=None, return_value=return_value_units) + def func(a, b): + if multiple: + return a, b + else: + return a / b + + actual = func(*values) + if isinstance(actual, xr.DataArray): + xr.testing.assert_identical(actual, expected) + elif isinstance(actual, pint.Quantity): + pint.testing.assert_equal(actual, expected) + else: + assert actual == expected + + def test_return_value_none(self): + @pint_xarray.expects(None) + def func(a): + return None + + actual = func(1) + assert actual is None + + @pytest.mark.parametrize( + [ + "return_value_units", + "multiple_units", + "error", + "multiple_errors", + "message", + ], + ( + ( + ("m", "s"), + False, + ValueError, + False, + "mismatched number of return values", + ), + ("m", True, ValueError, False, "mismatched number of return values"), + (("m",), True, ValueError, False, "mismatched number of return values"), + (1, False, TypeError, True, "units must be of type"), + ), + ) + def test_return_value_error( + self, return_value_units, multiple_units, error, multiple_errors, message + ): + if multiple_errors: + root_error = ExceptionGroup + root_message = "Errors while converting return values" + else: + root_error = error + root_message = message + + with pytest.raises(root_error, match=root_message) as excinfo: + + @pint_xarray.expects(a=None, b=None, return_value=return_value_units) + def func(a, b): + if multiple_units: + return a, b + else: + return a / b + + func(1, 2) + + if not multiple_errors: + return + + group = excinfo.value + assert len(group.exceptions) == 1, f"Found {len(group.exceptions)} exceptions" + exc = group.exceptions[0] + assert isinstance( + exc, error + ), f"Unexpected exception type: {type(exc)}, expected {error}" + if not re.search(message, str(exc)): + raise AssertionError(f"exception {exc!r} did not match pattern {message!r}") diff --git a/pint_xarray/tests/test_itertools.py b/pint_xarray/tests/test_itertools.py new file mode 100644 index 00000000..61c2169b --- /dev/null +++ b/pint_xarray/tests/test_itertools.py @@ -0,0 +1,48 @@ +import pytest + +from pint_xarray.itertools import separate, unique, zip_mappings + + +@pytest.mark.parametrize( + ["predicate", "iterable"], + ( + (lambda x: x % 2 == 0, range(10)), + (lambda x: x in [0, 2, 3, 5], range(10)), + (lambda x: "s" in x, ["ab", "de", "sf", "fs"]), + ), +) +def test_separate(predicate, iterable): + actual_false, actual_true = separate(predicate, iterable) + + expected_true = [el for el in iterable if predicate(el)] + expected_false = [el for el in iterable if not predicate(el)] + + assert actual_true == expected_true + assert actual_false == expected_false + + +@pytest.mark.parametrize( + ["iterable", "expected"], + ( + ([5, 4, 4, 1, 2, 3, 2, 1], [5, 4, 1, 2, 3]), + (list("dadgafffgaefed"), list("dagfe")), + ), +) +def test_unique(iterable, expected): + actual = unique(iterable) + + assert actual == expected + + +@pytest.mark.parametrize( + ["mappings", "expected"], + ( + (({"a": 1, "b": 2}, {"a": 2, "b": 3}), [("a", (1, 2)), ("b", (2, 3))]), + (({"a": 1, "c": 2}, {"a": 2, "b": 0}), [("a", (1, 2))]), + (({"a": 1, "c": 2}, {"c": 2, "b": 0}), [("c", (2, 2))]), + (({"a": 1}, {"c": 2, "b": 0}), []), + ), +) +def test_zip_mappings(mappings, expected): + actual = list(zip_mappings(*mappings)) + assert actual == expected