Skip to content

Commit 3c1ca8f

Browse files
committed
Cleanup mappers by adding a mapper decorator
I am not sure this is a good change. The advantage is that the mapper decorator consolidates a lot of the error handling. The actual mapper functions are a lot cleaner and easier to reason about. The not-so-nice bit is that the decorator adds a couple of kwargs (error, default) to the mapper. This is confusing since these kwargs are not in the mapper's signature. As a side effect, _get_measure was cleaned up.
1 parent 0f27fb2 commit 3c1ca8f

File tree

1 file changed

+61
-50
lines changed

1 file changed

+61
-50
lines changed

cf_xarray/accessor.py

Lines changed: 61 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from typing import (
88
Callable,
99
Hashable,
10+
Iterable,
1011
List,
1112
Mapping,
1213
MutableMapping,
@@ -122,6 +123,50 @@ def _strip_none_list(lst: List[Optional[str]]) -> List[str]:
122123
return [item for item in lst if item != [None]] # type: ignore
123124

124125

126+
def mapper(valid_keys: Iterable[str]):
127+
"""
128+
Decorator for mapping functions that does error handling / returning defaults.
129+
"""
130+
131+
# This decorator Inception is sponsored by
132+
# https://realpython.com/primer-on-python-decorators/#decorators-with-arguments
133+
def decorator(func):
134+
@functools.wraps(func)
135+
def wrapper(
136+
obj: Union[xr.DataArray, xr.Dataset],
137+
key: str,
138+
error: bool = True,
139+
default: str = None,
140+
) -> List[Optional[str]]:
141+
if key not in valid_keys:
142+
if error:
143+
raise KeyError(
144+
f"cf_xarray did not understand key {key!r}. Expected one of {valid_keys!r}"
145+
)
146+
else:
147+
return [default]
148+
149+
try:
150+
results = func(obj, key)
151+
except Exception as e:
152+
if error:
153+
raise e
154+
else:
155+
results = None
156+
157+
if not results:
158+
if error:
159+
raise KeyError(f"Attributes to select {key!r} not found!")
160+
else:
161+
return [default]
162+
else:
163+
return list(results)
164+
165+
return wrapper
166+
167+
return decorator
168+
169+
125170
def _get_axis_coord_single(
126171
var: Union[xr.DataArray, xr.Dataset],
127172
key: str,
@@ -138,11 +183,9 @@ def _get_axis_coord_single(
138183
return results[0]
139184

140185

186+
@mapper(valid_keys=_COORD_NAMES + _AXIS_NAMES)
141187
def _get_axis_coord(
142-
var: Union[xr.DataArray, xr.Dataset],
143-
key: str,
144-
error: bool = True,
145-
default: str = None,
188+
var: Union[xr.DataArray, xr.Dataset], key: str,
146189
) -> List[Optional[str]]:
147190
"""
148191
Translate from axis or coord name to variable name
@@ -176,12 +219,6 @@ def _get_axis_coord(
176219
MetPy's parse_cf
177220
"""
178221

179-
if key not in _COORD_NAMES and key not in _AXIS_NAMES:
180-
if error:
181-
raise KeyError(f"Did not understand key {key!r}")
182-
else:
183-
return [default]
184-
185222
if "coordinates" in var.encoding:
186223
search_in = var.encoding["coordinates"].split(" ")
187224
elif "coordinates" in var.attrs:
@@ -196,26 +233,18 @@ def _get_axis_coord(
196233
expected = valid_values[key]
197234
if var.coords[coord].attrs.get(criterion, None) in expected:
198235
results.update((coord,))
199-
200-
if not results:
201-
if error:
202-
raise KeyError(f"axis name {key!r} not found!")
203-
else:
204-
return [default]
205-
else:
206-
return list(results)
236+
return list(results)
207237

208238

209239
def _get_measure_variable(
210240
da: xr.DataArray, key: str, error: bool = True, default: str = None
211241
) -> DataArray:
212242
""" tiny wrapper since xarray does not support providing str for weights."""
213-
return da[_get_measure(da, key, error, default)]
243+
return da[_get_measure(da, key, error, default)[0]]
214244

215245

216-
def _get_measure(
217-
da: xr.DataArray, key: str, error: bool = True, default: str = None
218-
) -> Optional[str]:
246+
@mapper(valid_keys=_CELL_MEASURES)
247+
def _get_measure(da: xr.DataArray, key: str) -> List[Optional[str]]:
219248
"""
220249
Translate from cell measures ("area" or "volume") to appropriate variable name.
221250
This function interprets the ``cell_measures`` attribute on DataArrays.
@@ -238,36 +267,16 @@ def _get_measure(
238267
"""
239268
if not isinstance(da, DataArray):
240269
raise NotImplementedError("Measures not implemented for Datasets yet.")
241-
if key not in _CELL_MEASURES:
242-
if error:
243-
raise ValueError(
244-
f"Cell measure must be one of {_CELL_MEASURES!r}. Received {key!r} instead."
245-
)
246-
else:
247-
return default
248270

249271
if "cell_measures" not in da.attrs:
250-
if error:
251-
raise KeyError("'cell_measures' not present in 'attrs'.")
252-
else:
253-
return default
272+
raise KeyError("'cell_measures' not present in 'attrs'.")
254273

255274
attr = da.attrs["cell_measures"]
256275
strings = [s.strip() for s in attr.strip().split(":")]
257276
if len(strings) % 2 != 0:
258-
if error:
259-
raise ValueError(f"attrs['cell_measures'] = {attr!r} is malformed.")
260-
else:
261-
return default
277+
raise ValueError(f"attrs['cell_measures'] = {attr!r} is malformed.")
262278
measures = dict(zip(strings[slice(0, None, 2)], strings[slice(1, None, 2)]))
263-
if key not in measures:
264-
if error:
265-
raise KeyError(
266-
f"Cell measure {key!r} not found. Please use .cf.describe() to see a list of key names that can be interpreted."
267-
)
268-
else:
269-
return default
270-
return measures[key]
279+
return [measures.get(key, None)]
271280

272281

273282
#: Default mappers for common keys.
@@ -688,10 +697,10 @@ def get_valid_keys(self) -> Set[str]:
688697
measures = [
689698
key
690699
for key in _CELL_MEASURES
691-
if _get_measure(self._obj, key, error=False) is not None
700+
if _get_measure(self._obj, key, error=False, default=None) != [None]
692701
]
693702
if measures:
694-
varnames.append(*measures)
703+
varnames.extend(measures)
695704

696705
if not isinstance(self._obj, xr.DataArray):
697706
varnames.extend(_get_list_standard_names(self._obj))
@@ -727,7 +736,7 @@ def __getitem__(self, key: Union[str, List[str]]):
727736
measure = _get_measure(self._obj, k)
728737
successful[k] = bool(measure)
729738
if measure:
730-
varnames.append(measure)
739+
varnames.extend(measure)
731740
elif not isinstance(self._obj, xr.DataArray):
732741
stdnames = _filter_by_standard_names(self._obj, k)
733742
successful[k] = bool(stdnames)
@@ -740,7 +749,9 @@ def __getitem__(self, key: Union[str, List[str]]):
740749

741750
try:
742751
# TODO: make this a get_auxiliary_variables function
743-
# make sure to set coordinate variables referred to in "coordinates" attribute
752+
# 1. set coordinate variables referred to in "coordinates" attribute
753+
# 2. set measures variables as coordinates
754+
# 3. set ancillary variables as coordinates
744755
for name in varnames:
745756
attrs = self._obj[name].attrs
746757
if "coordinates" in attrs:
@@ -752,7 +763,7 @@ def __getitem__(self, key: Union[str, List[str]]):
752763
for measure in _CELL_MEASURES
753764
if measure in attrs["cell_measures"]
754765
]
755-
coords.extend(_strip_none_list(measures))
766+
coords.extend(*_strip_none_list(measures))
756767

757768
if isinstance(self._obj, xr.Dataset) and "ancillary_variables" in attrs:
758769
anames = attrs["ancillary_variables"].split(" ")

0 commit comments

Comments
 (0)