7
7
from typing import (
8
8
Callable ,
9
9
Hashable ,
10
+ Iterable ,
10
11
List ,
11
12
Mapping ,
12
13
MutableMapping ,
@@ -122,6 +123,50 @@ def _strip_none_list(lst: List[Optional[str]]) -> List[str]:
122
123
return [item for item in lst if item != [None ]] # type: ignore
123
124
124
125
126
+ def mapper (valid_keys : Iterable [str ]):
127
+ """
128
+ Decorator for mapping functions that does error handling / returning defaults.
129
+ """
130
+
131
+ # This decorator Inception is sponsored by
132
+ # https://realpython.com/primer-on-python-decorators/#decorators-with-arguments
133
+ def decorator (func ):
134
+ @functools .wraps (func )
135
+ def wrapper (
136
+ obj : Union [xr .DataArray , xr .Dataset ],
137
+ key : str ,
138
+ error : bool = True ,
139
+ default : str = None ,
140
+ ) -> List [Optional [str ]]:
141
+ if key not in valid_keys :
142
+ if error :
143
+ raise KeyError (
144
+ f"cf_xarray did not understand key { key !r} . Expected one of { valid_keys !r} "
145
+ )
146
+ else :
147
+ return [default ]
148
+
149
+ try :
150
+ results = func (obj , key )
151
+ except Exception as e :
152
+ if error :
153
+ raise e
154
+ else :
155
+ results = None
156
+
157
+ if not results :
158
+ if error :
159
+ raise KeyError (f"Attributes to select { key !r} not found!" )
160
+ else :
161
+ return [default ]
162
+ else :
163
+ return list (results )
164
+
165
+ return wrapper
166
+
167
+ return decorator
168
+
169
+
125
170
def _get_axis_coord_single (
126
171
var : Union [xr .DataArray , xr .Dataset ],
127
172
key : str ,
@@ -138,11 +183,9 @@ def _get_axis_coord_single(
138
183
return results [0 ]
139
184
140
185
186
+ @mapper (valid_keys = _COORD_NAMES + _AXIS_NAMES )
141
187
def _get_axis_coord (
142
- var : Union [xr .DataArray , xr .Dataset ],
143
- key : str ,
144
- error : bool = True ,
145
- default : str = None ,
188
+ var : Union [xr .DataArray , xr .Dataset ], key : str ,
146
189
) -> List [Optional [str ]]:
147
190
"""
148
191
Translate from axis or coord name to variable name
@@ -176,12 +219,6 @@ def _get_axis_coord(
176
219
MetPy's parse_cf
177
220
"""
178
221
179
- if key not in _COORD_NAMES and key not in _AXIS_NAMES :
180
- if error :
181
- raise KeyError (f"Did not understand key { key !r} " )
182
- else :
183
- return [default ]
184
-
185
222
if "coordinates" in var .encoding :
186
223
search_in = var .encoding ["coordinates" ].split (" " )
187
224
elif "coordinates" in var .attrs :
@@ -196,26 +233,18 @@ def _get_axis_coord(
196
233
expected = valid_values [key ]
197
234
if var .coords [coord ].attrs .get (criterion , None ) in expected :
198
235
results .update ((coord ,))
199
-
200
- if not results :
201
- if error :
202
- raise KeyError (f"axis name { key !r} not found!" )
203
- else :
204
- return [default ]
205
- else :
206
- return list (results )
236
+ return list (results )
207
237
208
238
209
239
def _get_measure_variable (
210
240
da : xr .DataArray , key : str , error : bool = True , default : str = None
211
241
) -> DataArray :
212
242
""" tiny wrapper since xarray does not support providing str for weights."""
213
- return da [_get_measure (da , key , error , default )]
243
+ return da [_get_measure (da , key , error , default )[ 0 ] ]
214
244
215
245
216
- def _get_measure (
217
- da : xr .DataArray , key : str , error : bool = True , default : str = None
218
- ) -> Optional [str ]:
246
+ @mapper (valid_keys = _CELL_MEASURES )
247
+ def _get_measure (da : xr .DataArray , key : str ) -> List [Optional [str ]]:
219
248
"""
220
249
Translate from cell measures ("area" or "volume") to appropriate variable name.
221
250
This function interprets the ``cell_measures`` attribute on DataArrays.
@@ -238,36 +267,16 @@ def _get_measure(
238
267
"""
239
268
if not isinstance (da , DataArray ):
240
269
raise NotImplementedError ("Measures not implemented for Datasets yet." )
241
- if key not in _CELL_MEASURES :
242
- if error :
243
- raise ValueError (
244
- f"Cell measure must be one of { _CELL_MEASURES !r} . Received { key !r} instead."
245
- )
246
- else :
247
- return default
248
270
249
271
if "cell_measures" not in da .attrs :
250
- if error :
251
- raise KeyError ("'cell_measures' not present in 'attrs'." )
252
- else :
253
- return default
272
+ raise KeyError ("'cell_measures' not present in 'attrs'." )
254
273
255
274
attr = da .attrs ["cell_measures" ]
256
275
strings = [s .strip () for s in attr .strip ().split (":" )]
257
276
if len (strings ) % 2 != 0 :
258
- if error :
259
- raise ValueError (f"attrs['cell_measures'] = { attr !r} is malformed." )
260
- else :
261
- return default
277
+ raise ValueError (f"attrs['cell_measures'] = { attr !r} is malformed." )
262
278
measures = dict (zip (strings [slice (0 , None , 2 )], strings [slice (1 , None , 2 )]))
263
- if key not in measures :
264
- if error :
265
- raise KeyError (
266
- f"Cell measure { key !r} not found. Please use .cf.describe() to see a list of key names that can be interpreted."
267
- )
268
- else :
269
- return default
270
- return measures [key ]
279
+ return [measures .get (key , None )]
271
280
272
281
273
282
#: Default mappers for common keys.
@@ -688,10 +697,10 @@ def get_valid_keys(self) -> Set[str]:
688
697
measures = [
689
698
key
690
699
for key in _CELL_MEASURES
691
- if _get_measure (self ._obj , key , error = False ) is not None
700
+ if _get_measure (self ._obj , key , error = False , default = None ) != [ None ]
692
701
]
693
702
if measures :
694
- varnames .append ( * measures )
703
+ varnames .extend ( measures )
695
704
696
705
if not isinstance (self ._obj , xr .DataArray ):
697
706
varnames .extend (_get_list_standard_names (self ._obj ))
@@ -727,7 +736,7 @@ def __getitem__(self, key: Union[str, List[str]]):
727
736
measure = _get_measure (self ._obj , k )
728
737
successful [k ] = bool (measure )
729
738
if measure :
730
- varnames .append (measure )
739
+ varnames .extend (measure )
731
740
elif not isinstance (self ._obj , xr .DataArray ):
732
741
stdnames = _filter_by_standard_names (self ._obj , k )
733
742
successful [k ] = bool (stdnames )
@@ -740,7 +749,9 @@ def __getitem__(self, key: Union[str, List[str]]):
740
749
741
750
try :
742
751
# TODO: make this a get_auxiliary_variables function
743
- # make sure to set coordinate variables referred to in "coordinates" attribute
752
+ # 1. set coordinate variables referred to in "coordinates" attribute
753
+ # 2. set measures variables as coordinates
754
+ # 3. set ancillary variables as coordinates
744
755
for name in varnames :
745
756
attrs = self ._obj [name ].attrs
746
757
if "coordinates" in attrs :
@@ -752,7 +763,7 @@ def __getitem__(self, key: Union[str, List[str]]):
752
763
for measure in _CELL_MEASURES
753
764
if measure in attrs ["cell_measures" ]
754
765
]
755
- coords .extend (_strip_none_list (measures ))
766
+ coords .extend (* _strip_none_list (measures ))
756
767
757
768
if isinstance (self ._obj , xr .Dataset ) and "ancillary_variables" in attrs :
758
769
anames = attrs ["ancillary_variables" ].split (" " )
0 commit comments