Skip to content

Commit 1bcc6c7

Browse files
committed
test that expects correctly converts args
1 parent cfcab93 commit 1bcc6c7

File tree

1 file changed

+69
-0
lines changed

1 file changed

+69
-0
lines changed

pint_xarray/tests/test_expects.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
import re
2+
3+
import pint
4+
import pytest
5+
import xarray as xr
6+
7+
import pint_xarray
8+
9+
ureg = pint_xarray.unit_registry
10+
11+
12+
class TestExpects:
13+
@pytest.mark.parametrize(
14+
["values", "units", "expected"],
15+
(
16+
((ureg.Quantity(1, "m"), 2), ("mm", None), 500),
17+
((ureg.Quantity(1, "m"), ureg.Quantity(0.5, "s")), ("mm", "ms"), 2),
18+
((xr.DataArray(4).pint.quantify("km"), 2), ("m", None), xr.DataArray(2000)),
19+
(
20+
(
21+
xr.DataArray([4, 2, 0]).pint.quantify("cm"),
22+
xr.DataArray([4, 2, 1]).pint.quantify("mg"),
23+
),
24+
("m", "g"),
25+
xr.DataArray([10, 10, 0]),
26+
),
27+
),
28+
)
29+
def test_args(self, values, units, expected):
30+
@pint_xarray.expects(*units)
31+
def func(a, b):
32+
return a / b
33+
34+
actual = func(*values)
35+
36+
if isinstance(actual, xr.DataArray):
37+
xr.testing.assert_identical(actual, expected)
38+
elif isinstance(actual, pint.Quantity):
39+
pint.testing.assert_equal(actual, expected)
40+
else:
41+
assert actual == expected
42+
43+
@pytest.mark.parametrize(
44+
["value", "units", "errors", "message"],
45+
(
46+
(
47+
ureg.Quantity(1, "m"),
48+
None,
49+
ValueError,
50+
"quantity where none was expected",
51+
),
52+
(1, "m", ValueError, "Attempting to convert non-quantity"),
53+
),
54+
)
55+
def test_args_error(self, value, units, errors, message):
56+
with pytest.raises(
57+
ExceptionGroup, match="Errors while converting parameters"
58+
) as excinfo:
59+
60+
@pint_xarray.expects(units)
61+
def func(a):
62+
return a
63+
64+
func(value)
65+
group = excinfo.value
66+
assert len(group.exceptions) == 1, f"Found {len(group.exceptions)} exceptions"
67+
exc = group.exceptions[0]
68+
if not re.search(message, str(exc)):
69+
raise AssertionError(f"exception {exc!r} did not match pattern {message!r}")

0 commit comments

Comments
 (0)