Skip to content

Commit 92dcf5a

Browse files
committed
first tests
1 parent c7a7c7f commit 92dcf5a

File tree

3 files changed

+192
-0
lines changed

3 files changed

+192
-0
lines changed

pintxarray/tests/__init__.py

Whitespace-only changes.

pintxarray/tests/test_accessors.py

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
import pytest
2+
3+
import xarray as xr
4+
from xarray.testing import assert_equal
5+
6+
import numpy as np
7+
from numpy.testing import assert_array_equal
8+
9+
from pint import UnitRegistry
10+
from pint.unit import Unit
11+
from pint.errors import DimensionalityError
12+
13+
from pintxarray.accessors import PintDataArrayAccessor, PintDatasetAccessor
14+
from .utils import extract_units
15+
16+
17+
# make sure scalars are converted to 0d arrays so quantities can
18+
# always be treated like ndarrays
19+
unit_registry = UnitRegistry(force_ndarray=True)
20+
Quantity = unit_registry.Quantity
21+
22+
23+
@pytest.fixture()
24+
def example_unitless_da():
25+
array = np.linspace(0, 10, 20)
26+
x = np.arange(20)
27+
da = xr.DataArray(data=array, dims="x", coords={"x": x})
28+
da.attrs['units'] = 'm'
29+
da.coords['x'].attrs['units'] = 's'
30+
return da
31+
32+
@pytest.fixture()
33+
def example_quantity_da():
34+
array = np.linspace(0, 10, 20) * unit_registry.m
35+
x = np.arange(20) * unit_registry.s
36+
return xr.DataArray(data=array, dims="x", coords={"x": x})
37+
38+
39+
class TestQuantifyDataArray:
40+
@pytest.mark.skip(reason="Not yet implemented")
41+
def test_attach_units_given_registry(self):
42+
...
43+
44+
def test_attach_units(self, example_unitless_da):
45+
orig = example_unitless_da
46+
result = orig.pint.quantify('m')
47+
assert_array_equal(result.data.magnitude, orig.data)
48+
assert str(result.data.units) == 'meter'
49+
50+
def test_attach_units_from_attrs(self):
51+
...
52+
53+
def test_attach_unit_class(self):
54+
...
55+
56+
def test_error_when_already_units(self, example_quantity_da):
57+
da = example_quantity_da
58+
with pytest.raises(ValueError):
59+
da.pint.quantify()
60+
61+
def test_error_on_nonsense_units(self):
62+
...
63+
64+
def test_registry_kwargs(self):
65+
...
66+
67+
68+
@pytest.mark.skip(reason="Not yet implemented")
69+
class TestDequantifyDataArray:
70+
def test_strip_units(self):
71+
...
72+
73+
def test_error_if_no_units(self):
74+
...
75+
76+
def test_roundtrip_attrs(self):
77+
...
78+
79+
80+
@pytest.mark.skip(reason="Not yet implemented")
81+
class TestPropertiesDataArray:
82+
def test_units(self):
83+
...
84+
85+
86+
@pytest.mark.skip(reason="Not yet implemented")
87+
class TestConversionDataArray:
88+
def test_units(self):
89+
...
90+
91+
92+
@pytest.mark.skip(reason="Not yet implemented")
93+
class TestUncertainties:
94+
def test_units(self):
95+
...
96+
97+
98+
@pytest.mark.skip(reason="Not yet implemented")
99+
class TestQuantifyDataSet:
100+
def test_attach_units(self):
101+
...
102+
103+
def test_attach_str_units(self):
104+
...
105+
106+
def test_attach_units_from_registry(self):
107+
...
108+
109+
110+
@pytest.mark.skip(reason="Not yet implemented")
111+
class TestIndexing:
112+
...

pintxarray/tests/utils.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
import xarray as xr
2+
3+
from pint.quantity import Quantity
4+
5+
6+
def array_extract_units(obj):
7+
if isinstance(obj, (xr.Variable, xr.DataArray, xr.Dataset)):
8+
obj = obj.data
9+
10+
try:
11+
return obj.units
12+
except AttributeError:
13+
return None
14+
15+
def extract_units(obj):
16+
if isinstance(obj, xr.Dataset):
17+
vars_units = {
18+
name: array_extract_units(value) for name, value in obj.data_vars.items()
19+
}
20+
coords_units = {
21+
name: array_extract_units(value) for name, value in obj.coords.items()
22+
}
23+
24+
units = {**vars_units, **coords_units}
25+
elif isinstance(obj, xr.DataArray):
26+
vars_units = {obj.name: array_extract_units(obj)}
27+
coords_units = {
28+
name: array_extract_units(value) for name, value in obj.coords.items()
29+
}
30+
31+
units = {**vars_units, **coords_units}
32+
elif isinstance(obj, xr.Variable):
33+
vars_units = {None: array_extract_units(obj.data)}
34+
35+
units = {**vars_units}
36+
elif isinstance(obj, Quantity):
37+
vars_units = {None: array_extract_units(obj)}
38+
39+
units = {**vars_units}
40+
else:
41+
units = {}
42+
43+
return units
44+
45+
def assert_units_equal(a, b):
46+
__tracebackhide__ = True
47+
assert extract_units(a) == extract_units(b)
48+
49+
def assert_equal_with_units(a, b):
50+
# works like xr.testing.assert_equal, but also explicitly checks units
51+
# so, it is more like assert_identical
52+
__tracebackhide__ = True
53+
54+
if isinstance(a, xr.Dataset) or isinstance(b, xr.Dataset):
55+
a_units = extract_units(a)
56+
b_units = extract_units(b)
57+
58+
a_without_units = strip_units(a)
59+
b_without_units = strip_units(b)
60+
61+
assert a_without_units.equals(b_without_units), formatting.diff_dataset_repr(
62+
a, b, "equals"
63+
)
64+
assert a_units == b_units
65+
else:
66+
a = a if not isinstance(a, (xr.DataArray, xr.Variable)) else a.data
67+
b = b if not isinstance(b, (xr.DataArray, xr.Variable)) else b.data
68+
69+
assert type(a) == type(b) or (
70+
isinstance(a, Quantity) and isinstance(b, Quantity)
71+
)
72+
73+
# workaround until pint implements allclose in __array_function__
74+
if isinstance(a, Quantity) or isinstance(b, Quantity):
75+
assert (
76+
hasattr(a, "magnitude") and hasattr(b, "magnitude")
77+
) and np.allclose(a.magnitude, b.magnitude, equal_nan=True)
78+
assert (hasattr(a, "units") and hasattr(b, "units")) and a.units == b.units
79+
else:
80+
assert np.allclose(a, b, equal_nan=True)

0 commit comments

Comments
 (0)