Skip to content

Commit b2794dc

Browse files
authored
Look for & prioritize "coordinates" attribute when searching. (#16)
* Look for & prioritize "coordinates" attribute when searching. * fix tests. * Apply suggestions from code review
1 parent fc1419d commit b2794dc

File tree

2 files changed

+79
-21
lines changed

2 files changed

+79
-21
lines changed

cf_xarray/accessor.py

Lines changed: 49 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import functools
22
import inspect
3-
from typing import Any, Union
3+
from typing import Any, List, Optional, Set, Union
44

55
import xarray as xr
66
from xarray import DataArray, Dataset
@@ -44,6 +44,7 @@
4444
"Y": ("latitude",),
4545
"X": ("longitude",),
4646
},
47+
"long_name": {"T": ("time",)},
4748
"_CoordinateAxisType": {
4849
"T": ("Time",),
4950
"Z": ("GeoZ", "Height", "Pressure"),
@@ -84,13 +85,29 @@
8485
}
8586

8687

87-
def _get_axis_coord(var: xr.DataArray, key, error: bool = True, default: Any = None):
88+
def _get_axis_coord_single(var, key, *args):
89+
""" Helper method for when we really want only one result per key. """
90+
results = _get_axis_coord(var, key, *args)
91+
if len(results) > 1:
92+
raise ValueError(
93+
"Multiple results for {key!r} found: {results!r}. Is this valid CF? Please open an issue."
94+
)
95+
else:
96+
return results[0]
97+
98+
99+
def _get_axis_coord(
100+
var: Union[xr.DataArray, xr.Dataset],
101+
key: str,
102+
error: bool = True,
103+
default: Optional[str] = None,
104+
) -> List[Optional[str]]:
88105
"""
89106
Translate from axis or coord name to variable name
90107
91108
Parameters
92109
----------
93-
var : `xarray.DataArray`
110+
var : DataArray, Dataset
94111
DataArray belonging to the coordinate to be checked
95112
key : str, ["X", "Y", "Z", "T", "longitude", "latitude", "vertical", "time"]
96113
key to check for.
@@ -102,7 +119,7 @@ def _get_axis_coord(var: xr.DataArray, key, error: bool = True, default: Any = N
102119
103120
Returns
104121
-------
105-
str, Variable name in parent xarray object that matches axis or coordinate `key`
122+
List[str], Variable name(s) in parent xarray object that matches axis or coordinate `key`
106123
107124
Notes
108125
-----
@@ -128,22 +145,33 @@ def _get_axis_coord(var: xr.DataArray, key, error: bool = True, default: Any = N
128145
if error:
129146
raise KeyError(f"Did not understand {key}")
130147
else:
131-
return default
148+
return [default]
132149

133150
if axis is None:
134151
raise AssertionError("Should be unreachable")
135152

136-
for coord in var.coords:
153+
if "coordinates" in var.encoding:
154+
search_in = var.encoding["coordinates"].split(" ")
155+
elif "coordinates" in var.attrs:
156+
search_in = var.attrs["coordinates"].split(" ")
157+
else:
158+
search_in = list(var.coords)
159+
160+
results: Set = set()
161+
for coord in search_in:
137162
for criterion, valid_values in coordinate_criteria.items():
138163
if axis in valid_values: # type: ignore
139164
expected = valid_values[axis] # type: ignore
140165
if var.coords[coord].attrs.get(criterion, None) in expected:
141-
return coord
166+
results.update((coord,))
142167

143-
if error:
144-
raise KeyError(f"axis name {key!r} not found!")
168+
if not results:
169+
if error:
170+
raise KeyError(f"axis name {key!r} not found!")
171+
else:
172+
return [default]
145173
else:
146-
return default
174+
return list(results)
147175

148176

149177
def _get_measure_variable(
@@ -184,7 +212,9 @@ def _get_measure(da: xr.DataArray, key: str, error: bool = True, default: Any =
184212
return measures[key]
185213

186214

187-
_DEFAULT_KEY_MAPPERS: dict = dict.fromkeys(("dim", "coord", "group"), _get_axis_coord)
215+
_DEFAULT_KEY_MAPPERS: dict = dict.fromkeys(
216+
("dim", "coord", "group"), _get_axis_coord_single
217+
)
188218
_DEFAULT_KEY_MAPPERS["weights"] = _get_measure_variable
189219

190220

@@ -261,7 +291,7 @@ def __call__(self, *args, **kwargs):
261291
obj=self._obj,
262292
attr="plot",
263293
accessor=self.accessor,
264-
key_mappers=dict.fromkeys(self._keys, _get_axis_coord),
294+
key_mappers=dict.fromkeys(self._keys, _get_axis_coord_single),
265295
)
266296
return plot(*args, **kwargs)
267297

@@ -270,7 +300,7 @@ def __getattr__(self, attr):
270300
obj=self._obj.plot,
271301
attr=attr,
272302
accessor=self.accessor,
273-
key_mappers=dict.fromkeys(self._keys, _get_axis_coord),
303+
key_mappers=dict.fromkeys(self._keys, _get_axis_coord_single),
274304
)
275305

276306

@@ -294,7 +324,6 @@ def _process_signature(self, func, args, kwargs, key_mappers):
294324
arguments = self._rewrite_values(
295325
bound.arguments, key_mappers, tuple(var_kws)
296326
)
297-
print(arguments)
298327
else:
299328
arguments = {}
300329

@@ -311,7 +340,7 @@ def _process_signature(self, func, args, kwargs, key_mappers):
311340
def _rewrite_values(self, kwargs, key_mappers: dict, var_kws):
312341
""" rewrites 'dim' for example using 'mapper' """
313342
updates: dict = {}
314-
key_mappers.update(dict.fromkeys(var_kws, _get_axis_coord))
343+
key_mappers.update(dict.fromkeys(var_kws, _get_axis_coord_single))
315344
for key, mapper in key_mappers.items():
316345
value = kwargs.get(key, None)
317346
if value is not None:
@@ -341,7 +370,7 @@ def _rewrite_values(self, kwargs, key_mappers: dict, var_kws):
341370
for vkw in var_kws:
342371
if vkw in kwargs:
343372
maybe_update = {
344-
k: _get_axis_coord(self._obj, v, False, v)
373+
k: _get_axis_coord_single(self._obj, v, False, v)
345374
for k, v in kwargs[vkw].items()
346375
if k in key_mappers
347376
}
@@ -367,22 +396,21 @@ def plot(self):
367396
class CFDatasetAccessor(CFAccessor):
368397
def __getitem__(self, key):
369398
if key in _AXIS_NAMES + _COORD_NAMES:
370-
return self._obj[_get_axis_coord(self._obj, key)]
399+
varnames = _get_axis_coord(self._obj, key)
400+
return self._obj.reset_coords()[varnames].set_coords(varnames)
371401
elif key in _CELL_MEASURES:
372402
raise NotImplementedError("measures not implemented for Dataset yet.")
373403
# return self._obj[_get_measure(self._obj)[key]]
374404
else:
375405
raise KeyError(f"DataArray.cf does not understand the key {key}")
376406

377-
# def __getitem__(self, key):
378-
# raise AttributeError("Dataset.cf does not support [] indexing or __getitem__")
379-
380407

381408
@xr.register_dataarray_accessor("cf")
382409
class CFDataArrayAccessor(CFAccessor):
383410
def __getitem__(self, key):
384411
if key in _AXIS_NAMES + _COORD_NAMES:
385-
return self._obj[_get_axis_coord(self._obj, key)]
412+
varname = _get_axis_coord_single(self._obj, key)
413+
return self._obj[varname].reset_coords(drop=True)
386414
elif key in _CELL_MEASURES:
387415
return self._obj[_get_measure(self._obj, key)]
388416
else:

cf_xarray/tests/test_accessor.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,8 @@ def test_dataset_plot(obj):
135135
)
136136
def test_getitem(obj, key, expected_key):
137137
actual = obj.cf[key]
138+
if isinstance(obj, xr.Dataset):
139+
expected_key = [expected_key]
138140
expected = obj[expected_key]
139141
assert_identical(actual, expected)
140142

@@ -146,3 +148,31 @@ def test_getitem_errors(obj,):
146148
obj.lon.attrs = {}
147149
with pytest.raises(KeyError):
148150
obj.cf["X"]
151+
152+
153+
def test_getitem_uses_coordinates():
154+
# POP-like dataset
155+
ds = xr.Dataset()
156+
ds.coords["TLONG"] = (("nlat", "nlon"), np.ones((20, 30)), {"axis": "X"})
157+
ds.coords["TLAT"] = (("nlat", "nlon"), 2 * np.ones((20, 30)), {"axis": "Y"})
158+
ds.coords["ULONG"] = (("nlat", "nlon"), 0.5 * np.ones((20, 30)), {"axis": "X"})
159+
ds.coords["ULAT"] = (("nlat", "nlon"), 2.5 * np.ones((20, 30)), {"axis": "Y"})
160+
ds["UVEL"] = (
161+
("nlat", "nlon"),
162+
np.ones((20, 30)) * 15,
163+
{"coordinates": "ULONG ULAT"},
164+
)
165+
ds["TEMP"] = (
166+
("nlat", "nlon"),
167+
np.ones((20, 30)) * 15,
168+
{"coordinates": "TLONG TLAT"},
169+
)
170+
171+
assert_identical(
172+
ds.cf["X"], ds.reset_coords()[["ULONG", "TLONG"]].set_coords(["ULONG", "TLONG"])
173+
)
174+
assert_identical(
175+
ds.cf["Y"], ds.reset_coords()[["ULAT", "TLAT"]].set_coords(["ULAT", "TLAT"])
176+
)
177+
assert_identical(ds.UVEL.cf["X"], ds["ULONG"].reset_coords(drop=True))
178+
assert_identical(ds.TEMP.cf["X"], ds["TLONG"].reset_coords(drop=True))

0 commit comments

Comments
 (0)