Skip to content

Commit d9ca190

Browse files
authored
expects decorator (#316)
* add `separate` * function to get unique values while order-preserving * util to zip mappings * reimplement `zip_mappings` to be more robust * implement `expects` * test that expects correctly converts args * check that default args work, too * support checking for single errors, as well * check that units in kwargs work * raise an error for all parameters without unit spec * check that the return value units are attached properly * use `ureg.Quantity` instead of `unit.m_from` which should be `unit.from_`, anyways * check that return values can not have units * check that functions can not return a result * check for various errors when returning results * don't cover the version fallback * add api docs * errors → error * check the error type * change the raised error to `TypeError` * copy the docstring from #143 * styling * see also * changelog * terminology * extend the tests * raise an error if the return value is unexpectedly `None` * more explicitly select the error types * more tests * add a dev env
1 parent 3877341 commit d9ca190

File tree

10 files changed

+985
-9
lines changed

10 files changed

+985
-9
lines changed

docs/api.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,13 @@ DataArray
6565
xarray.DataArray.pint.bfill
6666
xarray.DataArray.pint.interpolate_na
6767

68+
Wrapping quantity-unaware functions
69+
-----------------------------------
70+
.. autosummary::
71+
:toctree: generated/
72+
73+
pint_xarray.expects
74+
6875
Testing
6976
-------
7077

docs/terminology.rst

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ Terminology
55

66
unit-like
77
A `pint`_ unit definition, as accepted by :py:class:`pint.Unit`.
8-
May be either a :py:class:`str` or a :py:class:`pint.Unit` instance.
8+
May be a :py:class:`str`, a :py:class:`pint.Unit` instance or
9+
:py:obj:`None`.
910

1011
.. _pint: https://pint.readthedocs.io/en/stable

docs/whats-new.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ What's new
1818
By `Justus Magin <https://github.com/keewis>`_.
1919
- Switch to using pixi for all dependency management (:pull:`314`).
2020
By `Justus Magin <https://github.com/keewis>`_.
21+
- Added the :py:func:`pint_xarray.expects` decorator to allow wrapping quantity-unaware functions (:issue:`141`, :pull:`316`).
22+
By `Justus Magin <https://github.com/keewis>`_ and `Tom Nicholas <https://github.com/TomNicholas>`_.
2123

2224
0.5.1 (10 Aug 2025)
2325
-------------------

pint_xarray/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,14 @@
33
import pint
44

55
from pint_xarray import accessors, formatting, testing # noqa: F401
6+
from pint_xarray._expects import expects
67
from pint_xarray.accessors import default_registry as unit_registry
78
from pint_xarray.accessors import setup_registry
89
from pint_xarray.index import PintIndex
910

1011
try:
1112
__version__ = version("pint-xarray")
12-
except Exception:
13+
except Exception: # pragma: no cover
1314
# Local copy or not installed with setuptools.
1415
# Disable minimum version checks on downstream libraries.
1516
__version__ = "999"
@@ -23,4 +24,5 @@
2324
"unit_registry",
2425
"setup_registry",
2526
"PintIndex",
27+
"expects",
2628
]

pint_xarray/_expects.py

Lines changed: 260 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,260 @@
1+
import functools
2+
import inspect
3+
import itertools
4+
from inspect import Parameter
5+
6+
import pint
7+
import pint.testing
8+
import xarray as xr
9+
10+
from pint_xarray.accessors import get_registry
11+
from pint_xarray.conversion import extract_units
12+
from pint_xarray.itertools import zip_mappings
13+
14+
variable_parameters = (Parameter.VAR_POSITIONAL, Parameter.VAR_KEYWORD)
15+
16+
17+
def _number_of_results(result):
18+
if isinstance(result, tuple):
19+
return len(result)
20+
elif result is None:
21+
return 0
22+
else:
23+
return 1
24+
25+
26+
def expects(*args_units, return_value=None, **kwargs_units):
27+
"""
28+
Decorator which ensures the inputs and outputs of the decorated
29+
function are expressed in the expected units.
30+
31+
Arguments to the decorated function are checked for the specified
32+
units, converting to those units if necessary, and then stripped
33+
of their units before being passed into the undecorated
34+
function. Therefore the undecorated function should expect
35+
unquantified DataArrays, Datasets, or numpy-like arrays, but with
36+
the values expressed in specific units.
37+
38+
Parameters
39+
----------
40+
func : callable
41+
Function to decorate, which accepts zero or more
42+
xarray.DataArrays or numpy-like arrays as inputs, and may
43+
optionally return one or more xarray.DataArrays or numpy-like
44+
arrays.
45+
*args_units : unit-like or mapping of hashable to unit-like, optional
46+
Units to expect for each positional argument given to func.
47+
48+
The decorator will first check that arguments passed to the
49+
decorated function possess these specific units (or will
50+
attempt to convert the argument to these units), then will
51+
strip the units before passing the magnitude to the wrapped
52+
function.
53+
54+
A value of None indicates not to check that argument for units
55+
(suitable for flags and other non-data arguments).
56+
return_value : unit-like or list of unit-like or mapping of hashable to unit-like \
57+
or list of mapping of hashable to unit-like, optional
58+
The expected units of the returned value(s), either as a
59+
single unit or as a list of units. The decorator will attach
60+
these units to the variables returned from the function.
61+
62+
A value of None indicates not to attach any units to that
63+
return value (suitable for flags and other non-data results).
64+
**kwargs_units : mapping of hashable to unit-like, optional
65+
Unit to expect for each keyword argument given to func.
66+
67+
The decorator will first check that arguments passed to the decorated
68+
function possess these specific units (or will attempt to convert the
69+
argument to these units), then will strip the units before passing the
70+
magnitude to the wrapped function.
71+
72+
A value of None indicates not to check that argument for units (suitable
73+
for flags and other non-data arguments).
74+
75+
Returns
76+
-------
77+
return_values : Any
78+
Return values of the wrapped function, either a single value or a tuple
79+
of values. These will be given units according to ``return_value``.
80+
81+
Raises
82+
------
83+
TypeError
84+
If any of the units are not a valid type.
85+
ValueError
86+
If the number of arguments or return values does not match the number of
87+
units specified. Also thrown if any parameter does not have a unit
88+
specified.
89+
90+
See Also
91+
--------
92+
pint.wraps
93+
94+
Examples
95+
--------
96+
Decorating a function which takes one quantified input, but
97+
returns a non-data value (in this case a boolean).
98+
99+
>>> @expects("deg C")
100+
... def above_freezing(temp):
101+
... return temp > 0
102+
...
103+
104+
Decorating a function which allows any dimensions for the array, but also
105+
accepts an optional `weights` keyword argument, which must be dimensionless.
106+
107+
>>> @expects(None, weights="dimensionless")
108+
... def mean(da, weights=None):
109+
... if weights:
110+
... return da.weighted(weights=weights).mean()
111+
... else:
112+
... return da.mean()
113+
...
114+
"""
115+
116+
def outer(func):
117+
signature = inspect.signature(func)
118+
119+
params_units = signature.bind(*args_units, **kwargs_units)
120+
121+
missing_params = [
122+
name
123+
for name, p in signature.parameters.items()
124+
if p.kind not in variable_parameters and name not in params_units.arguments
125+
]
126+
if missing_params:
127+
raise ValueError(
128+
"Missing units for the following parameters: "
129+
+ ", ".join(map(repr, missing_params))
130+
)
131+
132+
n_expected_results = _number_of_results(return_value)
133+
134+
@functools.wraps(func)
135+
def wrapper(*args, **kwargs):
136+
nonlocal return_value
137+
138+
params = signature.bind(*args, **kwargs)
139+
# don't apply defaults, as those can't be quantities and thus must
140+
# already be in the correct units
141+
142+
spec_units = dict(
143+
enumerate(
144+
itertools.chain.from_iterable(
145+
spec.values() if isinstance(spec, dict) else (spec,)
146+
for spec in params_units.arguments.values()
147+
if spec is not None
148+
)
149+
)
150+
)
151+
params_units_ = dict(
152+
enumerate(
153+
itertools.chain.from_iterable(
154+
(
155+
extract_units(param)
156+
if isinstance(param, (xr.DataArray, xr.Dataset))
157+
else (param.units,)
158+
)
159+
for name, param in params.arguments.items()
160+
if isinstance(param, (xr.DataArray, xr.Dataset, pint.Quantity))
161+
)
162+
)
163+
)
164+
165+
ureg = get_registry(
166+
None,
167+
dict(spec_units) if spec_units else {},
168+
dict(params_units_) if params_units else {},
169+
)
170+
171+
errors = []
172+
for name, (value, units) in zip_mappings(
173+
params.arguments, params_units.arguments
174+
):
175+
try:
176+
if units is None:
177+
if isinstance(value, pint.Quantity) or (
178+
isinstance(value, (xr.DataArray, xr.Dataset))
179+
and value.pint.units
180+
):
181+
raise TypeError(
182+
"Passed in a quantity where none was expected"
183+
)
184+
continue
185+
if isinstance(value, pint.Quantity):
186+
params.arguments[name] = value.m_as(units)
187+
elif isinstance(value, (xr.DataArray, xr.Dataset)):
188+
params.arguments[name] = value.pint.to(units).pint.dequantify()
189+
else:
190+
raise TypeError(
191+
f"Attempting to convert non-quantity {value} to {units}."
192+
)
193+
except (
194+
TypeError,
195+
pint.errors.UndefinedUnitError,
196+
pint.errors.DimensionalityError,
197+
) as e:
198+
e.add_note(
199+
f"expects: raised while trying to convert parameter {name}"
200+
)
201+
errors.append(e)
202+
203+
if errors:
204+
raise ExceptionGroup("Errors while converting parameters", errors)
205+
206+
result = func(*params.args, **params.kwargs)
207+
208+
n_results = _number_of_results(result)
209+
if return_value is not None and (
210+
(isinstance(result, tuple) ^ isinstance(return_value, tuple))
211+
or (n_results != n_expected_results)
212+
):
213+
message = "mismatched number of return values:"
214+
if n_results != n_expected_results:
215+
message += f" expected {n_expected_results} but got {n_results}."
216+
elif isinstance(result, tuple) and not isinstance(return_value, tuple):
217+
message += (
218+
" expected a single return value but got a 1-sized tuple."
219+
)
220+
else:
221+
message += (
222+
" expected a 1-sized tuple but got a single return value."
223+
)
224+
raise ValueError(message)
225+
226+
if result is None:
227+
return
228+
229+
if not isinstance(result, tuple):
230+
result = (result,)
231+
if not isinstance(return_value, tuple):
232+
return_value = (return_value,)
233+
234+
final_result = []
235+
errors = []
236+
for index, (value, units) in enumerate(zip(result, return_value)):
237+
if units is not None:
238+
try:
239+
if isinstance(value, (xr.Dataset, xr.DataArray)):
240+
value = value.pint.quantify(units)
241+
else:
242+
value = ureg.Quantity(value, units)
243+
except Exception as e:
244+
e.add_note(
245+
f"expects: raised while trying to convert return value {index}"
246+
)
247+
errors.append(e)
248+
249+
final_result.append(value)
250+
251+
if errors:
252+
raise ExceptionGroup("Errors while converting return values", errors)
253+
254+
if n_results == 1:
255+
return final_result[0]
256+
return tuple(final_result)
257+
258+
return wrapper
259+
260+
return outer

pint_xarray/itertools.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
import itertools
2+
from functools import reduce
3+
4+
5+
def separate(predicate, iterable):
6+
evaluated = ((predicate(el), el) for el in iterable)
7+
8+
key = lambda x: x[0]
9+
grouped = itertools.groupby(sorted(evaluated, key=key), key=key)
10+
11+
groups = {label: [el for _, el in group] for label, group in grouped}
12+
13+
return groups[False], groups[True]
14+
15+
16+
def unique(iterable):
17+
return list(dict.fromkeys(iterable))
18+
19+
20+
def zip_mappings(*mappings):
21+
def common_keys(a, b):
22+
all_keys = unique(itertools.chain(a.keys(), b.keys()))
23+
intersection = set(a.keys()).intersection(b.keys())
24+
25+
return [key for key in all_keys if key in intersection]
26+
27+
keys = list(reduce(common_keys, mappings))
28+
29+
for key in keys:
30+
yield key, tuple(m[key] for m in mappings)

0 commit comments

Comments
 (0)