@@ -265,15 +265,15 @@ def _get_measure(da: Union[DataArray, Dataset], key: str) -> List[str]:
265
265
# TODO: Make the values of this a tuple,
266
266
# so that multiple mappers can be used for a single key
267
267
# We need this for groupby("T.month") and groupby("latitude") for example.
268
- _DEFAULT_KEY_MAPPERS : Mapping [str , Mapper ] = {
269
- "dim" : _get_axis_coord ,
270
- "dims" : _get_axis_coord , # is this necessary?
271
- "coords" : _get_axis_coord , # interp
272
- "indexers" : _get_axis_coord , # sel, isel
273
- "dims_or_levels" : _get_axis_coord , # reset_index
274
- "coord" : _get_axis_coord_single ,
275
- "group" : _get_axis_coord_single ,
276
- "weights" : _get_measure_variable , # type: ignore
268
+ _DEFAULT_KEY_MAPPERS : Mapping [str , Tuple [ Mapper , ...] ] = {
269
+ "dim" : ( _get_axis_coord ,) ,
270
+ "dims" : ( _get_axis_coord ,) , # is this necessary?
271
+ "coords" : ( _get_axis_coord ,) , # interp
272
+ "indexers" : ( _get_axis_coord ,) , # sel, isel
273
+ "dims_or_levels" : ( _get_axis_coord ,) , # reset_index
274
+ "coord" : ( _get_axis_coord_single ,) ,
275
+ "group" : ( _get_axis_coord_single ,) ,
276
+ "weights" : ( _get_measure_variable ,) , # type: ignore
277
277
}
278
278
279
279
@@ -498,7 +498,7 @@ def __call__(self, *args, **kwargs):
498
498
obj = self ._obj ,
499
499
attr = "plot" ,
500
500
accessor = self .accessor ,
501
- key_mappers = dict .fromkeys (self ._keys , _get_axis_coord_single ),
501
+ key_mappers = dict .fromkeys (self ._keys , ( _get_axis_coord_single ,) ),
502
502
)
503
503
return self ._plot_decorator (plot )(* args , ** kwargs )
504
504
@@ -510,7 +510,7 @@ def __getattr__(self, attr):
510
510
obj = self ._obj .plot ,
511
511
attr = attr ,
512
512
accessor = self .accessor ,
513
- key_mappers = dict .fromkeys (self ._keys , _get_axis_coord_single ),
513
+ key_mappers = dict .fromkeys (self ._keys , ( _get_axis_coord_single ,) ),
514
514
# TODO: "extra_decorator" is more complex than I would like it to be.
515
515
# Not sure if there is a better way though
516
516
extra_decorator = self ._plot_decorator ,
@@ -525,7 +525,13 @@ class CFAccessor:
525
525
def __init__ (self , da ):
526
526
self ._obj = da
527
527
528
- def _process_signature (self , func : Callable , args , kwargs , key_mappers ):
528
+ def _process_signature (
529
+ self ,
530
+ func : Callable ,
531
+ args ,
532
+ kwargs ,
533
+ key_mappers : MutableMapping [str , Tuple [Mapper , ...]],
534
+ ):
529
535
"""
530
536
Processes a function's signature, args, kwargs:
531
537
1. Binds *args so that everthing is a Mapping from kwarg name to values
@@ -559,7 +565,12 @@ def _process_signature(self, func: Callable, args, kwargs, key_mappers):
559
565
560
566
return arguments
561
567
562
- def _rewrite_values (self , kwargs , key_mappers : dict , var_kws ):
568
+ def _rewrite_values (
569
+ self ,
570
+ kwargs ,
571
+ key_mappers : MutableMapping [str , Tuple [Mapper , ...]],
572
+ var_kws : Tuple [str , ...],
573
+ ):
563
574
"""
564
575
Rewrites the values in a Mapping from kwarg to value.
565
576
@@ -582,11 +593,11 @@ def _rewrite_values(self, kwargs, key_mappers: dict, var_kws):
582
593
583
594
# allow multiple return values here.
584
595
# these are valid for .sel, .isel, .coarsen
585
- key_mappers .update (dict .fromkeys (var_kws , _get_axis_coord ))
596
+ key_mappers .update (dict .fromkeys (var_kws , ( _get_axis_coord ,) ))
586
597
587
598
for key in set (key_mappers ) & set (kwargs ):
588
599
value = kwargs [key ]
589
- mapper = key_mappers [key ]
600
+ mappers = key_mappers [key ]
590
601
591
602
if isinstance (value , str ):
592
603
value = [value ]
@@ -601,6 +612,7 @@ def _rewrite_values(self, kwargs, key_mappers: dict, var_kws):
601
612
* [
602
613
dict .fromkeys (apply_mapper (mapper , self ._obj , k , False , k ), v )
603
614
for k , v in value .items ()
615
+ for mapper in mappers
604
616
]
605
617
)
606
618
@@ -609,7 +621,11 @@ def _rewrite_values(self, kwargs, key_mappers: dict, var_kws):
609
621
610
622
else :
611
623
# things like sum which have dim
612
- newvalue = [apply_mapper (mapper , self ._obj , v , False , v ) for v in value ]
624
+ newvalue = [
625
+ apply_mapper (mapper , self ._obj , v , False , v )
626
+ for v in value
627
+ for mapper in mappers
628
+ ]
613
629
# Mappers return list by default
614
630
# for input dim=["lat", "X"], newvalue=[["lat"], ["lon"]],
615
631
# so we deal with that here.
@@ -632,9 +648,10 @@ def _rewrite_values(self, kwargs, key_mappers: dict, var_kws):
632
648
maybe_update = {
633
649
# TODO: this is assuming key_mappers[k] is always
634
650
# _get_axis_coord_single
635
- k : apply_mapper (key_mappers [ k ] , self ._obj , v )[0 ]
651
+ k : apply_mapper (mapper , self ._obj , v )[0 ]
636
652
for k , v in kwargs [vkw ].items ()
637
653
if k in key_mappers
654
+ for mapper in key_mappers [k ]
638
655
}
639
656
kwargs [vkw ].update (maybe_update )
640
657
0 commit comments