Skip to content

Commit 6d81630

Browse files
committed
Use tuple of mappers.
This lets us use multiple mappers per kwarg. This will be needed for groupby.
1 parent c54d9b3 commit 6d81630

File tree

1 file changed

+34
-17
lines changed

1 file changed

+34
-17
lines changed

cf_xarray/accessor.py

Lines changed: 34 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -265,15 +265,15 @@ def _get_measure(da: Union[DataArray, Dataset], key: str) -> List[str]:
265265
# TODO: Make the values of this a tuple,
266266
# so that multiple mappers can be used for a single key
267267
# We need this for groupby("T.month") and groupby("latitude") for example.
268-
_DEFAULT_KEY_MAPPERS: Mapping[str, Mapper] = {
269-
"dim": _get_axis_coord,
270-
"dims": _get_axis_coord, # is this necessary?
271-
"coords": _get_axis_coord, # interp
272-
"indexers": _get_axis_coord, # sel, isel
273-
"dims_or_levels": _get_axis_coord, # reset_index
274-
"coord": _get_axis_coord_single,
275-
"group": _get_axis_coord_single,
276-
"weights": _get_measure_variable, # type: ignore
268+
_DEFAULT_KEY_MAPPERS: Mapping[str, Tuple[Mapper, ...]] = {
269+
"dim": (_get_axis_coord,),
270+
"dims": (_get_axis_coord,), # is this necessary?
271+
"coords": (_get_axis_coord,), # interp
272+
"indexers": (_get_axis_coord,), # sel, isel
273+
"dims_or_levels": (_get_axis_coord,), # reset_index
274+
"coord": (_get_axis_coord_single,),
275+
"group": (_get_axis_coord_single,),
276+
"weights": (_get_measure_variable,), # type: ignore
277277
}
278278

279279

@@ -498,7 +498,7 @@ def __call__(self, *args, **kwargs):
498498
obj=self._obj,
499499
attr="plot",
500500
accessor=self.accessor,
501-
key_mappers=dict.fromkeys(self._keys, _get_axis_coord_single),
501+
key_mappers=dict.fromkeys(self._keys, (_get_axis_coord_single,)),
502502
)
503503
return self._plot_decorator(plot)(*args, **kwargs)
504504

@@ -510,7 +510,7 @@ def __getattr__(self, attr):
510510
obj=self._obj.plot,
511511
attr=attr,
512512
accessor=self.accessor,
513-
key_mappers=dict.fromkeys(self._keys, _get_axis_coord_single),
513+
key_mappers=dict.fromkeys(self._keys, (_get_axis_coord_single,)),
514514
# TODO: "extra_decorator" is more complex than I would like it to be.
515515
# Not sure if there is a better way though
516516
extra_decorator=self._plot_decorator,
@@ -525,7 +525,13 @@ class CFAccessor:
525525
def __init__(self, da):
526526
self._obj = da
527527

528-
def _process_signature(self, func: Callable, args, kwargs, key_mappers):
528+
def _process_signature(
529+
self,
530+
func: Callable,
531+
args,
532+
kwargs,
533+
key_mappers: MutableMapping[str, Tuple[Mapper, ...]],
534+
):
529535
"""
530536
Processes a function's signature, args, kwargs:
531537
1. Binds *args so that everthing is a Mapping from kwarg name to values
@@ -559,7 +565,12 @@ def _process_signature(self, func: Callable, args, kwargs, key_mappers):
559565

560566
return arguments
561567

562-
def _rewrite_values(self, kwargs, key_mappers: dict, var_kws):
568+
def _rewrite_values(
569+
self,
570+
kwargs,
571+
key_mappers: MutableMapping[str, Tuple[Mapper, ...]],
572+
var_kws: Tuple[str, ...],
573+
):
563574
"""
564575
Rewrites the values in a Mapping from kwarg to value.
565576
@@ -582,11 +593,11 @@ def _rewrite_values(self, kwargs, key_mappers: dict, var_kws):
582593

583594
# allow multiple return values here.
584595
# these are valid for .sel, .isel, .coarsen
585-
key_mappers.update(dict.fromkeys(var_kws, _get_axis_coord))
596+
key_mappers.update(dict.fromkeys(var_kws, (_get_axis_coord,)))
586597

587598
for key in set(key_mappers) & set(kwargs):
588599
value = kwargs[key]
589-
mapper = key_mappers[key]
600+
mappers = key_mappers[key]
590601

591602
if isinstance(value, str):
592603
value = [value]
@@ -601,6 +612,7 @@ def _rewrite_values(self, kwargs, key_mappers: dict, var_kws):
601612
*[
602613
dict.fromkeys(apply_mapper(mapper, self._obj, k, False, k), v)
603614
for k, v in value.items()
615+
for mapper in mappers
604616
]
605617
)
606618

@@ -609,7 +621,11 @@ def _rewrite_values(self, kwargs, key_mappers: dict, var_kws):
609621

610622
else:
611623
# things like sum which have dim
612-
newvalue = [apply_mapper(mapper, self._obj, v, False, v) for v in value]
624+
newvalue = [
625+
apply_mapper(mapper, self._obj, v, False, v)
626+
for v in value
627+
for mapper in mappers
628+
]
613629
# Mappers return list by default
614630
# for input dim=["lat", "X"], newvalue=[["lat"], ["lon"]],
615631
# so we deal with that here.
@@ -632,9 +648,10 @@ def _rewrite_values(self, kwargs, key_mappers: dict, var_kws):
632648
maybe_update = {
633649
# TODO: this is assuming key_mappers[k] is always
634650
# _get_axis_coord_single
635-
k: apply_mapper(key_mappers[k], self._obj, v)[0]
651+
k: apply_mapper(mapper, self._obj, v)[0]
636652
for k, v in kwargs[vkw].items()
637653
if k in key_mappers
654+
for mapper in key_mappers[k]
638655
}
639656
kwargs[vkw].update(maybe_update)
640657

0 commit comments

Comments
 (0)