8
8
from typing import (
9
9
Callable ,
10
10
Hashable ,
11
- Iterable ,
12
11
List ,
13
12
Mapping ,
14
13
MutableMapping ,
113
112
coordinate_criteria ["long_name" ] = coordinate_criteria ["standard_name" ]
114
113
115
114
# Type for Mapper functions
116
- Mapper = Callable [
117
- [Union [xr .DataArray , xr .Dataset ], str , bool , str ],
118
- Union [List [Optional [str ]], DataArray ], # this sucks
119
- ]
115
+ Mapper = Callable [[Union [xr .DataArray , xr .Dataset ], str ], List [Optional [str ]]]
120
116
121
117
122
118
def _strip_none_list (lst : List [Optional [str ]]) -> List [str ]:
123
119
""" The mappers can return [None]. Strip that when necessary. Keeps mypy happy."""
124
120
return [item for item in lst if item != [None ]] # type: ignore
125
121
126
122
127
- def mapper (valid_keys : Iterable [str ]):
123
+ def apply_mapper (
124
+ mapper : Mapper ,
125
+ obj : Union [xr .DataArray , xr .Dataset ],
126
+ key : str ,
127
+ error : bool = True ,
128
+ default : str = None ,
129
+ ) -> List [Optional [str ]]:
128
130
"""
129
- Decorator for mapping functions that does error handling / returning defaults.
131
+ Applies a mapping function; does error handling / returning defaults.
130
132
"""
131
133
132
- # This decorator Inception is sponsored by
133
- # https://realpython.com/primer-on-python-decorators/#decorators-with-arguments
134
- def decorator (func ):
135
- @functools .wraps (func )
136
- def wrapper (
137
- obj : Union [xr .DataArray , xr .Dataset ],
138
- key : str ,
139
- error : bool = True ,
140
- default : str = None ,
141
- ) -> List [Optional [str ]]:
142
- """
143
- This decorator will add `error` and `default` kwargs to the decorated Mapper function.
144
- """
145
- if key not in valid_keys :
146
- if error :
147
- raise KeyError (
148
- f"cf_xarray did not understand key { key !r} . Expected one of { valid_keys !r} "
149
- )
150
- else :
151
- return [default ]
152
-
153
- try :
154
- results = func (obj , key )
155
- except Exception as e :
156
- if error :
157
- raise e
158
- else :
159
- results = None
160
-
161
- if not results :
162
- if error :
163
- raise KeyError (f"Attributes to select { key !r} not found!" )
164
- else :
165
- return [default ]
166
- else :
167
- return list (results )
168
-
169
- return wrapper
134
+ try :
135
+ results = mapper (obj , key )
136
+ except Exception as e :
137
+ if error :
138
+ raise e
139
+ else :
140
+ results = None # type: ignore
170
141
171
- return decorator
142
+ if not results :
143
+ if error :
144
+ raise KeyError (f"Attributes to select { key !r} not found!" )
145
+ else :
146
+ return [default ]
147
+ else :
148
+ return list (results )
172
149
173
150
174
151
def _get_axis_coord_single (
175
- var : Union [xr .DataArray , xr .Dataset ],
176
- key : str ,
177
- error : bool = True ,
178
- default : str = None ,
179
- ) -> Optional [str ]:
152
+ var : Union [xr .DataArray , xr .Dataset ], key : str ,
153
+ ) -> List [Optional [str ]]:
180
154
""" Helper method for when we really want only one result per key. """
181
- results = _get_axis_coord (var , key , error , default )
155
+ results = _get_axis_coord (var , key )
182
156
if len (results ) > 1 :
183
157
raise ValueError (
184
158
f"Multiple results for { key !r} found: { results !r} . Is this valid CF? Please open an issue."
185
159
)
186
- else :
187
- return results [0 ]
160
+ return results
188
161
189
162
190
- @mapper (valid_keys = _COORD_NAMES + _AXIS_NAMES )
191
163
def _get_axis_coord (
192
164
var : Union [xr .DataArray , xr .Dataset ], key : str ,
193
165
) -> List [Optional [str ]]:
@@ -223,6 +195,12 @@ def _get_axis_coord(
223
195
MetPy's parse_cf
224
196
"""
225
197
198
+ valid_keys = _COORD_NAMES + _AXIS_NAMES
199
+ if key not in valid_keys :
200
+ raise KeyError (
201
+ f"cf_xarray did not understand key { key !r} . Expected one of { valid_keys !r} "
202
+ )
203
+
226
204
if "coordinates" in var .encoding :
227
205
search_in = var .encoding ["coordinates" ].split (" " )
228
206
elif "coordinates" in var .attrs :
@@ -242,13 +220,15 @@ def _get_axis_coord(
242
220
243
221
def _get_measure_variable (
244
222
da : xr .DataArray , key : str , error : bool = True , default : str = None
245
- ) -> DataArray :
223
+ ) -> List [ DataArray ] :
246
224
""" tiny wrapper since xarray does not support providing str for weights."""
247
- return da [_get_measure (da , key , error , default )[0 ]]
225
+ varnames = _strip_none_list (apply_mapper (_get_measure , da , key , error , default ))
226
+ if len (varnames ) > 1 :
227
+ raise ValueError (f"Multiple measures found for key { key !r} : { varnames !r} ." )
228
+ return [da [varnames [0 ]]]
248
229
249
230
250
- @mapper (valid_keys = _CELL_MEASURES )
251
- def _get_measure (da : xr .DataArray , key : str ) -> List [Optional [str ]]:
231
+ def _get_measure (da : Union [xr .DataArray , xr .Dataset ], key : str ) -> List [Optional [str ]]:
252
232
"""
253
233
Translate from cell measures ("area" or "volume") to appropriate variable name.
254
234
This function interprets the ``cell_measures`` attribute on DataArrays.
@@ -275,6 +255,12 @@ def _get_measure(da: xr.DataArray, key: str) -> List[Optional[str]]:
275
255
if "cell_measures" not in da .attrs :
276
256
raise KeyError ("'cell_measures' not present in 'attrs'." )
277
257
258
+ valid_keys = _CELL_MEASURES
259
+ if key not in valid_keys :
260
+ raise KeyError (
261
+ f"cf_xarray did not understand key { key !r} . Expected one of { valid_keys !r} "
262
+ )
263
+
278
264
attr = da .attrs ["cell_measures" ]
279
265
strings = [s .strip () for s in attr .strip ().split (":" )]
280
266
if len (strings ) % 2 != 0 :
@@ -372,7 +358,7 @@ def _getattr(
372
358
newmap = dict ()
373
359
unused_keys = set (attribute .keys ())
374
360
for key in _AXIS_NAMES + _COORD_NAMES :
375
- value = _get_axis_coord ( obj , key , error = False )
361
+ value = apply_mapper ( _get_axis_coord , obj , key , error = False )
376
362
unused_keys -= set (value )
377
363
if value != [None ]:
378
364
good_values = set (value ) & set (obj .dims )
@@ -596,7 +582,9 @@ def _rewrite_values(self, kwargs, key_mappers: dict, var_kws):
596
582
# where xi_* have attrs["axis"] = "X"
597
583
updates [key ] = ChainMap (
598
584
* [
599
- dict .fromkeys (mapper (self ._obj , k , False , k ), v )
585
+ dict .fromkeys (
586
+ apply_mapper (mapper , self ._obj , k , False , k ), v
587
+ )
600
588
for k , v in value .items ()
601
589
]
602
590
)
@@ -606,16 +594,18 @@ def _rewrite_values(self, kwargs, key_mappers: dict, var_kws):
606
594
607
595
else :
608
596
# things like sum which have dim
609
- newvalue = [mapper (self ._obj , v , False , v ) for v in value ]
610
- if len (newvalue ) == 1 :
611
- # works for groupby("time")
612
- newvalue = newvalue [0 ]
597
+ newvalue = [
598
+ apply_mapper (mapper , self ._obj , v , False , v ) for v in value
599
+ ]
600
+ # Mappers return list by default
601
+ # for input dim=["lat", "X"], newvalue=[["lat"], ["lon"]],
602
+ # so we deal with that here.
603
+ unpacked = list (itertools .chain (* newvalue ))
604
+ if len (unpacked ) == 1 :
605
+ # handle 'group'
606
+ updates [key ] = unpacked [0 ]
613
607
else :
614
- # Mappers return list by default
615
- # for input dim=["lat", "X"], newvalue=[["lat"], ["lon"]],
616
- # so we deal with that here.
617
- newvalue = list (itertools .chain (* newvalue ))
618
- updates [key ] = newvalue
608
+ updates [key ] = unpacked
619
609
620
610
kwargs .update (updates )
621
611
@@ -627,7 +617,9 @@ def _rewrite_values(self, kwargs, key_mappers: dict, var_kws):
627
617
for vkw in var_kws :
628
618
if vkw in kwargs :
629
619
maybe_update = {
630
- k : _get_axis_coord_single (self ._obj , v , False , v )
620
+ # TODO: this is assuming key_mappers[k] is always
621
+ # _get_axis_coord_single
622
+ k : apply_mapper (key_mappers [k ], self ._obj , v )[0 ]
631
623
for k , v in kwargs [vkw ].items ()
632
624
if k in key_mappers
633
625
}
@@ -654,20 +646,18 @@ def describe(self):
654
646
"""
655
647
text = "Axes:\n "
656
648
for key in _AXIS_NAMES :
657
- text += f"\t { key } : { _get_axis_coord ( self ._obj , key , error = False )} \n "
649
+ text += f"\t { key } : { apply_mapper ( _get_axis_coord , self ._obj , key , error = False )} \n "
658
650
659
651
text += "\n Coordinates:\n "
660
652
for key in _COORD_NAMES :
661
- text += f"\t { key } : { _get_axis_coord ( self ._obj , key , error = False )} \n "
653
+ text += f"\t { key } : { apply_mapper ( _get_axis_coord , self ._obj , key , error = False )} \n "
662
654
663
655
text += "\n Cell Measures:\n "
664
656
for measure in _CELL_MEASURES :
665
657
if isinstance (self ._obj , xr .Dataset ):
666
658
text += f"\t { measure } : unsupported\n "
667
659
else :
668
- text += (
669
- f"\t { measure } : { _get_measure (self ._obj , measure , error = False )} \n "
670
- )
660
+ text += f"\t { measure } : { apply_mapper (_get_measure , self ._obj , measure , error = False )} \n "
671
661
672
662
text += "\n Standard Names:\n "
673
663
if isinstance (self ._obj , xr .DataArray ):
@@ -694,13 +684,13 @@ def get_valid_keys(self) -> Set[str]:
694
684
varnames = [
695
685
key
696
686
for key in _AXIS_NAMES + _COORD_NAMES
697
- if _get_axis_coord ( self ._obj , key , error = False ) != [None ]
687
+ if apply_mapper ( _get_axis_coord , self ._obj , key , error = False ) != [None ]
698
688
]
699
689
with suppress (NotImplementedError ):
700
690
measures = [
701
691
key
702
692
for key in _CELL_MEASURES
703
- if _get_measure ( self ._obj , key , error = False ) != [None ]
693
+ if apply_mapper ( _get_measure , self ._obj , key , error = False ) != [None ]
704
694
]
705
695
if measures :
706
696
varnames .extend (measures )
@@ -727,19 +717,14 @@ def __getitem__(self, key: Union[str, List[str]]):
727
717
successful = dict .fromkeys (key , False )
728
718
for k in key :
729
719
if k in _AXIS_NAMES + _COORD_NAMES :
730
- names = _get_axis_coord (self ._obj , k )
720
+ names = _strip_none_list ( _get_axis_coord (self ._obj , k ) )
731
721
successful [k ] = bool (names )
732
- coords .extend (_strip_none_list ( names ) )
722
+ coords .extend (names )
733
723
elif k in _CELL_MEASURES :
734
- if isinstance (self ._obj , xr .Dataset ):
735
- raise NotImplementedError (
736
- "Invalid key {k!r}. Cell measures not implemented for Dataset yet."
737
- )
738
- else :
739
- measure = _get_measure (self ._obj , k )
740
- successful [k ] = bool (measure )
741
- if measure :
742
- varnames .extend (measure )
724
+ measure = _strip_none_list (_get_measure (self ._obj , k ))
725
+ successful [k ] = bool (measure )
726
+ if measure :
727
+ varnames .extend (measure )
743
728
elif not isinstance (self ._obj , xr .DataArray ):
744
729
stdnames = _filter_by_standard_names (self ._obj , k )
745
730
successful [k ] = bool (stdnames )
@@ -766,7 +751,7 @@ def __getitem__(self, key: Union[str, List[str]]):
766
751
for measure in _CELL_MEASURES
767
752
if measure in attrs ["cell_measures" ]
768
753
]
769
- coords .extend (* _strip_none_list (measures ))
754
+ coords .extend (_strip_none_list (* measures ))
770
755
771
756
if isinstance (self ._obj , xr .Dataset ) and "ancillary_variables" in attrs :
772
757
anames = attrs ["ancillary_variables" ].split (" " )
0 commit comments