Skip to content

Commit ede3a9f

Browse files
authored
Make _guess_coord_axis work with pint arrays (#247)
1 parent 770f40c commit ede3a9f

File tree

4 files changed

+57
-14
lines changed

4 files changed

+57
-14
lines changed

cf_xarray/accessor.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -220,13 +220,13 @@ def _get_custom_criteria(
220220
return list(results)
221221

222222

223-
def _get_axis_coord(var: Union[DataArray, Dataset], key: str) -> List[str]:
223+
def _get_axis_coord(obj: Union[DataArray, Dataset], key: str) -> List[str]:
224224
"""
225225
Translate from axis or coord name to variable name
226226
227227
Parameters
228228
----------
229-
var : DataArray, Dataset
229+
obj : DataArray, Dataset
230230
DataArray belonging to the coordinate to be checked
231231
key : str, ["X", "Y", "Z", "T", "longitude", "latitude", "vertical", "time"]
232232
key to check for.
@@ -260,25 +260,29 @@ def _get_axis_coord(var: Union[DataArray, Dataset], key: str) -> List[str]:
260260
)
261261

262262
search_in = set()
263-
if "coordinates" in var.encoding:
264-
search_in.update(var.encoding["coordinates"].split(" "))
265-
if "coordinates" in var.attrs:
266-
search_in.update(var.attrs["coordinates"].split(" "))
263+
if "coordinates" in obj.encoding:
264+
search_in.update(obj.encoding["coordinates"].split(" "))
265+
if "coordinates" in obj.attrs:
266+
search_in.update(obj.attrs["coordinates"].split(" "))
267267
if not search_in:
268-
search_in = set(var.coords)
268+
search_in = set(obj.coords)
269269

270270
# maybe only do this for key in _AXIS_NAMES?
271-
search_in.update(var.indexes)
271+
search_in.update(obj.indexes)
272272

273+
search_in = search_in & set(obj.coords)
273274
results: Set = set()
274275
for coord in search_in:
276+
var = obj.coords[coord]
275277
if key in coordinate_criteria:
276278
for criterion, expected in coordinate_criteria[key].items():
277-
if (
278-
coord in var.coords
279-
and var.coords[coord].attrs.get(criterion, None) in expected
280-
):
279+
if var.attrs.get(criterion, None) in expected:
281280
results.update((coord,))
281+
if criterion == "units":
282+
# deal with pint-backed objects
283+
units = getattr(var.data, "units", None)
284+
if units in expected:
285+
results.update((coord,))
282286
return list(results)
283287

284288

cf_xarray/tests/__init__.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1+
import importlib
12
import re
23
from contextlib import contextmanager
4+
from distutils import version
35

46
import dask
57
import pytest
@@ -39,3 +41,26 @@ def __call__(self, dsk, keys, **kwargs):
3941
def raise_if_dask_computes(max_computes=0):
4042
scheduler = CountingScheduler(max_computes)
4143
return dask.config.set(scheduler=scheduler)
44+
45+
46+
def _importorskip(modname, minversion=None):
47+
try:
48+
mod = importlib.import_module(modname)
49+
has = True
50+
if minversion is not None:
51+
if LooseVersion(mod.__version__) < LooseVersion(minversion):
52+
raise ImportError("Minimum version not satisfied")
53+
except ImportError:
54+
has = False
55+
func = pytest.mark.skipif(not has, reason=f"requires {modname}")
56+
return has, func
57+
58+
59+
def LooseVersion(vstring):
60+
# Our development version is something like '0.10.9+aac7bfc'
61+
# This function just ignored the git commit id.
62+
vstring = vstring.split("+")[0]
63+
return version.LooseVersion(vstring)
64+
65+
66+
has_pint, requires_pint = _importorskip("pint")

cf_xarray/tests/test_accessor.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
popds,
2525
romsds,
2626
)
27-
from . import raise_if_dask_computes
27+
from . import raise_if_dask_computes, requires_pint
2828

2929
mpl.use("Agg")
3030

@@ -149,6 +149,19 @@ def test_coordinates():
149149
assert actual == expected
150150

151151

152+
@requires_pint
153+
def test_coordinates_quantified():
154+
# note: import order is important
155+
from .. import units # noqa
156+
157+
pytest.importorskip("pint_xarray")
158+
159+
quantified = popds.pint.quantify()
160+
assert_identical(
161+
quantified.cf[["latitude"]].pint.dequantify(), popds.cf[["latitude"]]
162+
)
163+
164+
152165
def test_cell_measures():
153166
ds = airds.copy(deep=True)
154167
ds["foo"] = xr.DataArray(ds["cell_area"], attrs=dict(standard_name="foo_std_name"))

doc/whats-new.rst

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,9 @@
33
What's New
44
----------
55

6-
v0.6.0 (unreleased)
6+
v0.6.1 (unreleased)
77
===================
8+
- Support detecting pint-backed Variables with units-based criteria. By `Deepak Cherian`_.
89

910
v0.6.0 (June 29, 2021)
1011
======================

0 commit comments

Comments
 (0)