Skip to content

Commit c9d66f5

Browse files
authored
Fix application of multiple mappers. (#58)
* Fix application of multiple mappers. apply_mapper can now take a tuple of mappers. It will apply them sequentially. Mappers are expected to return [] if no mapping is possible. A list of results e.g. [[], ["lon"]] is then unpacked to ["lon"] and returned. If none of the mappers apply or if more than one mapper can apply, an error is raised. * Update cf_xarray/accessor.py
1 parent be4b4a5 commit c9d66f5

File tree

1 file changed

+44
-30
lines changed

1 file changed

+44
-30
lines changed

cf_xarray/accessor.py

Lines changed: 44 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@
116116

117117

118118
def apply_mapper(
119-
mapper: Mapper,
119+
mappers: Union[Mapper, Tuple[Mapper, ...]],
120120
obj: Union[DataArray, Dataset],
121121
key: str,
122122
error: bool = True,
@@ -129,34 +129,45 @@ def apply_mapper(
129129
It should return a list in all other cases including when there are no
130130
results for a good key.
131131
"""
132+
if default is None:
133+
default = []
132134

133-
def _maybe_return_default():
134-
"""
135-
Used when mapper raises an error or returns empty list.
136-
Sets a default if possible else sets []
137-
"""
135+
def _apply_single_mapper(mapper):
136+
137+
try:
138+
results = mapper(obj, key)
139+
except Exception as e:
140+
if error:
141+
raise e
142+
else:
143+
results = []
144+
return results
145+
146+
if not isinstance(mappers, Iterable):
147+
mappers = (mappers,)
148+
149+
# apply a sequence of mappers
150+
# if the mapper fails, it *should* return an empty list
151+
# if the mapper raises an error, that is processed based on `error`
152+
results = []
153+
for mapper in mappers:
154+
results.append(_apply_single_mapper(mapper))
155+
156+
nresults = sum([bool(v) for v in results])
157+
if nresults > 1:
158+
raise KeyError(
159+
f"Multiple mappers succeeded with key {key!r}.\nI was using mappers: {mappers!r}."
160+
f"I received results: {results!r}.\nPlease open an issue."
161+
)
162+
if nresults == 0:
138163
if error:
139164
raise KeyError(
140165
f"cf-xarray cannot interpret key {key!r}. Perhaps some needed attributes are missing."
141166
)
142-
if default:
143-
results = [default]
144167
else:
145-
results = []
146-
return results
147-
148-
try:
149-
results = mapper(obj, key)
150-
except Exception as e:
151-
if error:
152-
raise e
153-
else:
154-
results = _maybe_return_default()
155-
156-
if not results:
157-
results = _maybe_return_default()
158-
159-
return results
168+
# none of the mappers worked. Return the default
169+
return default
170+
return list(itertools.chain(*results))
160171

161172

162173
def _get_axis_coord_single(var: Union[DataArray, Dataset], key: str,) -> List[str]:
@@ -629,9 +640,10 @@ def _rewrite_values(
629640
# where xi_* have attrs["axis"] = "X"
630641
updates[key] = ChainMap(
631642
*[
632-
dict.fromkeys(apply_mapper(mapper, self._obj, k, False, k), v)
643+
dict.fromkeys(
644+
apply_mapper(mappers, self._obj, k, False, [k]), v
645+
)
633646
for k, v in value.items()
634-
for mapper in mappers
635647
]
636648
)
637649

@@ -641,9 +653,8 @@ def _rewrite_values(
641653
else:
642654
# things like sum which have dim
643655
newvalue = [
644-
apply_mapper(mapper, self._obj, v, False, v)
656+
apply_mapper(mappers, self._obj, v, error=False, default=[v])
645657
for v in value
646-
for mapper in mappers
647658
]
648659
# Mappers return list by default
649660
# for input dim=["lat", "X"], newvalue=[["lat"], ["lon"]],
@@ -695,18 +706,21 @@ def describe(self):
695706
"""
696707
text = "Axes:\n"
697708
for key in _AXIS_NAMES:
698-
text += f"\t{key}: {apply_mapper(_get_axis_coord, self._obj, key, error=False)}\n"
709+
axes = apply_mapper(_get_axis_coord, self._obj, key, error=False)
710+
text += f"\t{key}: {axes}\n"
699711

700712
text += "\nCoordinates:\n"
701713
for key in _COORD_NAMES:
702-
text += f"\t{key}: {apply_mapper(_get_axis_coord, self._obj, key, error=False)}\n"
714+
coords = apply_mapper(_get_axis_coord, self._obj, key, error=False)
715+
text += f"\t{key}: {coords}\n"
703716

704717
text += "\nCell Measures:\n"
705718
for measure in _CELL_MEASURES:
706719
if isinstance(self._obj, Dataset):
707720
text += f"\t{measure}: unsupported\n"
708721
else:
709-
text += f"\t{measure}: {apply_mapper(_get_measure, self._obj, measure, error=False)}\n"
722+
measures = apply_mapper(_get_measure, self._obj, measure, error=False)
723+
text += f"\t{measure}: {measures}\n"
710724

711725
text += "\nStandard Names:\n"
712726
if isinstance(self._obj, DataArray):

0 commit comments

Comments
 (0)