Skip to content

Commit 11000bd

Browse files
committed
Avoid silly mapper decorator.
Use an apply_mapper function instead. This deals with all the error handling and seems cleaner.
1 parent 16a0060 commit 11000bd

File tree

1 file changed

+75
-90
lines changed

1 file changed

+75
-90
lines changed

cf_xarray/accessor.py

Lines changed: 75 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
from typing import (
99
Callable,
1010
Hashable,
11-
Iterable,
1211
List,
1312
Mapping,
1413
MutableMapping,
@@ -113,81 +112,54 @@
113112
coordinate_criteria["long_name"] = coordinate_criteria["standard_name"]
114113

115114
# Type for Mapper functions
116-
Mapper = Callable[
117-
[Union[xr.DataArray, xr.Dataset], str, bool, str],
118-
Union[List[Optional[str]], DataArray], # this sucks
119-
]
115+
Mapper = Callable[[Union[xr.DataArray, xr.Dataset], str], List[Optional[str]]]
120116

121117

122118
def _strip_none_list(lst: List[Optional[str]]) -> List[str]:
123119
""" The mappers can return [None]. Strip that when necessary. Keeps mypy happy."""
124120
return [item for item in lst if item != [None]] # type: ignore
125121

126122

127-
def mapper(valid_keys: Iterable[str]):
123+
def apply_mapper(
124+
mapper: Mapper,
125+
obj: Union[xr.DataArray, xr.Dataset],
126+
key: str,
127+
error: bool = True,
128+
default: str = None,
129+
) -> List[Optional[str]]:
128130
"""
129-
Decorator for mapping functions that does error handling / returning defaults.
131+
Applies a mapping function; does error handling / returning defaults.
130132
"""
131133

132-
# This decorator Inception is sponsored by
133-
# https://realpython.com/primer-on-python-decorators/#decorators-with-arguments
134-
def decorator(func):
135-
@functools.wraps(func)
136-
def wrapper(
137-
obj: Union[xr.DataArray, xr.Dataset],
138-
key: str,
139-
error: bool = True,
140-
default: str = None,
141-
) -> List[Optional[str]]:
142-
"""
143-
This decorator will add `error` and `default` kwargs to the decorated Mapper function.
144-
"""
145-
if key not in valid_keys:
146-
if error:
147-
raise KeyError(
148-
f"cf_xarray did not understand key {key!r}. Expected one of {valid_keys!r}"
149-
)
150-
else:
151-
return [default]
152-
153-
try:
154-
results = func(obj, key)
155-
except Exception as e:
156-
if error:
157-
raise e
158-
else:
159-
results = None
160-
161-
if not results:
162-
if error:
163-
raise KeyError(f"Attributes to select {key!r} not found!")
164-
else:
165-
return [default]
166-
else:
167-
return list(results)
168-
169-
return wrapper
134+
try:
135+
results = mapper(obj, key)
136+
except Exception as e:
137+
if error:
138+
raise e
139+
else:
140+
results = None # type: ignore
170141

171-
return decorator
142+
if not results:
143+
if error:
144+
raise KeyError(f"Attributes to select {key!r} not found!")
145+
else:
146+
return [default]
147+
else:
148+
return list(results)
172149

173150

174151
def _get_axis_coord_single(
175-
var: Union[xr.DataArray, xr.Dataset],
176-
key: str,
177-
error: bool = True,
178-
default: str = None,
179-
) -> Optional[str]:
152+
var: Union[xr.DataArray, xr.Dataset], key: str,
153+
) -> List[Optional[str]]:
180154
""" Helper method for when we really want only one result per key. """
181-
results = _get_axis_coord(var, key, error, default)
155+
results = _get_axis_coord(var, key)
182156
if len(results) > 1:
183157
raise ValueError(
184158
f"Multiple results for {key!r} found: {results!r}. Is this valid CF? Please open an issue."
185159
)
186-
else:
187-
return results[0]
160+
return results
188161

189162

190-
@mapper(valid_keys=_COORD_NAMES + _AXIS_NAMES)
191163
def _get_axis_coord(
192164
var: Union[xr.DataArray, xr.Dataset], key: str,
193165
) -> List[Optional[str]]:
@@ -223,6 +195,12 @@ def _get_axis_coord(
223195
MetPy's parse_cf
224196
"""
225197

198+
valid_keys = _COORD_NAMES + _AXIS_NAMES
199+
if key not in valid_keys:
200+
raise KeyError(
201+
f"cf_xarray did not understand key {key!r}. Expected one of {valid_keys!r}"
202+
)
203+
226204
if "coordinates" in var.encoding:
227205
search_in = var.encoding["coordinates"].split(" ")
228206
elif "coordinates" in var.attrs:
@@ -242,13 +220,15 @@ def _get_axis_coord(
242220

243221
def _get_measure_variable(
244222
da: xr.DataArray, key: str, error: bool = True, default: str = None
245-
) -> DataArray:
223+
) -> List[DataArray]:
246224
""" tiny wrapper since xarray does not support providing str for weights."""
247-
return da[_get_measure(da, key, error, default)[0]]
225+
varnames = _strip_none_list(apply_mapper(_get_measure, da, key, error, default))
226+
if len(varnames) > 1:
227+
raise ValueError(f"Multiple measures found for key {key!r}: {varnames!r}.")
228+
return [da[varnames[0]]]
248229

249230

250-
@mapper(valid_keys=_CELL_MEASURES)
251-
def _get_measure(da: xr.DataArray, key: str) -> List[Optional[str]]:
231+
def _get_measure(da: Union[xr.DataArray, xr.Dataset], key: str) -> List[Optional[str]]:
252232
"""
253233
Translate from cell measures ("area" or "volume") to appropriate variable name.
254234
This function interprets the ``cell_measures`` attribute on DataArrays.
@@ -275,6 +255,12 @@ def _get_measure(da: xr.DataArray, key: str) -> List[Optional[str]]:
275255
if "cell_measures" not in da.attrs:
276256
raise KeyError("'cell_measures' not present in 'attrs'.")
277257

258+
valid_keys = _CELL_MEASURES
259+
if key not in valid_keys:
260+
raise KeyError(
261+
f"cf_xarray did not understand key {key!r}. Expected one of {valid_keys!r}"
262+
)
263+
278264
attr = da.attrs["cell_measures"]
279265
strings = [s.strip() for s in attr.strip().split(":")]
280266
if len(strings) % 2 != 0:
@@ -372,7 +358,7 @@ def _getattr(
372358
newmap = dict()
373359
unused_keys = set(attribute.keys())
374360
for key in _AXIS_NAMES + _COORD_NAMES:
375-
value = _get_axis_coord(obj, key, error=False)
361+
value = apply_mapper(_get_axis_coord, obj, key, error=False)
376362
unused_keys -= set(value)
377363
if value != [None]:
378364
good_values = set(value) & set(obj.dims)
@@ -596,7 +582,9 @@ def _rewrite_values(self, kwargs, key_mappers: dict, var_kws):
596582
# where xi_* have attrs["axis"] = "X"
597583
updates[key] = ChainMap(
598584
*[
599-
dict.fromkeys(mapper(self._obj, k, False, k), v)
585+
dict.fromkeys(
586+
apply_mapper(mapper, self._obj, k, False, k), v
587+
)
600588
for k, v in value.items()
601589
]
602590
)
@@ -606,16 +594,18 @@ def _rewrite_values(self, kwargs, key_mappers: dict, var_kws):
606594

607595
else:
608596
# things like sum which have dim
609-
newvalue = [mapper(self._obj, v, False, v) for v in value]
610-
if len(newvalue) == 1:
611-
# works for groupby("time")
612-
newvalue = newvalue[0]
597+
newvalue = [
598+
apply_mapper(mapper, self._obj, v, False, v) for v in value
599+
]
600+
# Mappers return list by default
601+
# for input dim=["lat", "X"], newvalue=[["lat"], ["lon"]],
602+
# so we deal with that here.
603+
unpacked = list(itertools.chain(*newvalue))
604+
if len(unpacked) == 1:
605+
# handle 'group'
606+
updates[key] = unpacked[0]
613607
else:
614-
# Mappers return list by default
615-
# for input dim=["lat", "X"], newvalue=[["lat"], ["lon"]],
616-
# so we deal with that here.
617-
newvalue = list(itertools.chain(*newvalue))
618-
updates[key] = newvalue
608+
updates[key] = unpacked
619609

620610
kwargs.update(updates)
621611

@@ -627,7 +617,9 @@ def _rewrite_values(self, kwargs, key_mappers: dict, var_kws):
627617
for vkw in var_kws:
628618
if vkw in kwargs:
629619
maybe_update = {
630-
k: _get_axis_coord_single(self._obj, v, False, v)
620+
# TODO: this is assuming key_mappers[k] is always
621+
# _get_axis_coord_single
622+
k: apply_mapper(key_mappers[k], self._obj, v)[0]
631623
for k, v in kwargs[vkw].items()
632624
if k in key_mappers
633625
}
@@ -654,20 +646,18 @@ def describe(self):
654646
"""
655647
text = "Axes:\n"
656648
for key in _AXIS_NAMES:
657-
text += f"\t{key}: {_get_axis_coord(self._obj, key, error=False)}\n"
649+
text += f"\t{key}: {apply_mapper(_get_axis_coord, self._obj, key, error=False)}\n"
658650

659651
text += "\nCoordinates:\n"
660652
for key in _COORD_NAMES:
661-
text += f"\t{key}: {_get_axis_coord(self._obj, key, error=False)}\n"
653+
text += f"\t{key}: {apply_mapper(_get_axis_coord, self._obj, key, error=False)}\n"
662654

663655
text += "\nCell Measures:\n"
664656
for measure in _CELL_MEASURES:
665657
if isinstance(self._obj, xr.Dataset):
666658
text += f"\t{measure}: unsupported\n"
667659
else:
668-
text += (
669-
f"\t{measure}: {_get_measure(self._obj, measure, error=False)}\n"
670-
)
660+
text += f"\t{measure}: {apply_mapper(_get_measure, self._obj, measure, error=False)}\n"
671661

672662
text += "\nStandard Names:\n"
673663
if isinstance(self._obj, xr.DataArray):
@@ -694,13 +684,13 @@ def get_valid_keys(self) -> Set[str]:
694684
varnames = [
695685
key
696686
for key in _AXIS_NAMES + _COORD_NAMES
697-
if _get_axis_coord(self._obj, key, error=False) != [None]
687+
if apply_mapper(_get_axis_coord, self._obj, key, error=False) != [None]
698688
]
699689
with suppress(NotImplementedError):
700690
measures = [
701691
key
702692
for key in _CELL_MEASURES
703-
if _get_measure(self._obj, key, error=False) != [None]
693+
if apply_mapper(_get_measure, self._obj, key, error=False) != [None]
704694
]
705695
if measures:
706696
varnames.extend(measures)
@@ -727,19 +717,14 @@ def __getitem__(self, key: Union[str, List[str]]):
727717
successful = dict.fromkeys(key, False)
728718
for k in key:
729719
if k in _AXIS_NAMES + _COORD_NAMES:
730-
names = _get_axis_coord(self._obj, k)
720+
names = _strip_none_list(_get_axis_coord(self._obj, k))
731721
successful[k] = bool(names)
732-
coords.extend(_strip_none_list(names))
722+
coords.extend(names)
733723
elif k in _CELL_MEASURES:
734-
if isinstance(self._obj, xr.Dataset):
735-
raise NotImplementedError(
736-
"Invalid key {k!r}. Cell measures not implemented for Dataset yet."
737-
)
738-
else:
739-
measure = _get_measure(self._obj, k)
740-
successful[k] = bool(measure)
741-
if measure:
742-
varnames.extend(measure)
724+
measure = _strip_none_list(_get_measure(self._obj, k))
725+
successful[k] = bool(measure)
726+
if measure:
727+
varnames.extend(measure)
743728
elif not isinstance(self._obj, xr.DataArray):
744729
stdnames = _filter_by_standard_names(self._obj, k)
745730
successful[k] = bool(stdnames)
@@ -766,7 +751,7 @@ def __getitem__(self, key: Union[str, List[str]]):
766751
for measure in _CELL_MEASURES
767752
if measure in attrs["cell_measures"]
768753
]
769-
coords.extend(*_strip_none_list(measures))
754+
coords.extend(_strip_none_list(*measures))
770755

771756
if isinstance(self._obj, xr.Dataset) and "ancillary_variables" in attrs:
772757
anames = attrs["ancillary_variables"].split(" ")

0 commit comments

Comments
 (0)