14
14
)
15
15
16
16
17
- _DEFAULT_KEYS_TO_REWRITE = ("dim" , "coord" , "group" )
18
17
_AXIS_NAMES = ("X" , "Y" , "Z" , "T" )
19
18
_COORD_NAMES = ("longitude" , "latitude" , "vertical" , "time" )
20
19
_COORD_AXIS_MAPPING = dict (zip (_COORD_NAMES , _AXIS_NAMES ))
21
20
_CELL_MEASURES = ("area" , "volume" )
22
21
23
-
24
22
# Define the criteria for coordinate matches
25
23
# Copied from metpy
26
24
# Internally we only use X, Y, Z, T
@@ -148,24 +146,54 @@ def _get_axis_coord(var: xr.DataArray, key, error: bool = True, default: Any = N
148
146
return default
149
147
150
148
151
- def _get_measure (da : xr .DataArray , key : str ):
149
+ def _get_measure_variable (
150
+ da : xr .DataArray , key : str , error : bool = True , default : Any = None
151
+ ) -> DataArray :
152
+ """ tiny wrapper since xarray does not support providing str for weights."""
153
+ return da [_get_measure (da , key , error , default )]
154
+
155
+
156
+ def _get_measure (da : xr .DataArray , key : str , error : bool = True , default : Any = None ):
152
157
"""
153
- TODO: actually interpret da.attrs to get this .
158
+ Interprets 'cell_measures' .
154
159
"""
160
+ if not isinstance (da , DataArray ):
161
+ raise NotImplementedError ("Measures not implemented for Datasets yet." )
155
162
if key not in _CELL_MEASURES :
156
- raise ValueError (
157
- f"Cell measure must be one of { _CELL_MEASURES !r} . Received { key !r} instead."
158
- )
163
+ if error :
164
+ raise ValueError (
165
+ f"Cell measure must be one of { _CELL_MEASURES !r} . Received { key !r} instead."
166
+ )
167
+ else :
168
+ return default
169
+
170
+ if "cell_measures" not in da .attrs :
171
+ if error :
172
+ raise KeyError ("'cell_measures' not present in 'attrs'." )
173
+ else :
174
+ return default
159
175
160
- return {"area" : "cell_area" , "volume" : "cell_volume" }
176
+ attr = da .attrs ["cell_measures" ]
177
+ strings = [s .strip () for s in attr .strip ().split (":" )]
178
+ if len (strings ) % 2 != 0 :
179
+ if error :
180
+ raise ValueError (f"attrs['cell_measures'] = { attr !r} is malformed." )
181
+ else :
182
+ return default
183
+ measures = dict (zip (strings [slice (0 , None , 2 )], strings [slice (1 , None , 2 )]))
184
+ return measures [key ]
185
+
186
+
187
+ _DEFAULT_KEY_MAPPERS : dict = dict .fromkeys (("dim" , "coord" , "group" ), _get_axis_coord )
188
+ _DEFAULT_KEY_MAPPERS ["weights" ] = _get_measure_variable
161
189
162
190
163
191
def _getattr (
164
192
obj : Union [DataArray , Dataset ],
165
193
attr : str ,
166
194
accessor : "CFAccessor" ,
195
+ key_mappers : dict ,
167
196
wrap_classes = False ,
168
- keys = _DEFAULT_KEYS_TO_REWRITE ,
169
197
):
170
198
"""
171
199
Common getattr functionality.
@@ -186,8 +214,7 @@ def _getattr(
186
214
187
215
@functools .wraps (func )
188
216
def wrapper (* args , ** kwargs ):
189
- arguments = accessor ._process_signature (func , args , kwargs , keys = keys )
190
-
217
+ arguments = accessor ._process_signature (func , args , kwargs , key_mappers )
191
218
result = func (** arguments )
192
219
if wrap_classes and isinstance (result , _WRAPPED_CLASSES ):
193
220
result = _CFWrappedClass (result , accessor )
@@ -203,7 +230,6 @@ def __init__(self, towrap, accessor: "CFAccessor"):
203
230
204
231
Parameters
205
232
----------
206
-
207
233
obj : DataArray, Dataset
208
234
towrap : Resample, GroupBy, Coarsen, Rolling, Weighted
209
235
Instance of xarray class that is being wrapped.
@@ -216,7 +242,12 @@ def __repr__(self):
216
242
return "--- CF-xarray wrapped \n " + repr (self .wrapped )
217
243
218
244
def __getattr__ (self , attr ):
219
- return _getattr (obj = self .wrapped , attr = attr , accessor = self .accessor )
245
+ return _getattr (
246
+ obj = self .wrapped ,
247
+ attr = attr ,
248
+ accessor = self .accessor ,
249
+ key_mappers = _DEFAULT_KEY_MAPPERS ,
250
+ )
220
251
221
252
222
253
class _CFWrappedPlotMethods :
@@ -227,21 +258,27 @@ def __init__(self, obj, accessor):
227
258
228
259
def __call__ (self , * args , ** kwargs ):
229
260
plot = _getattr (
230
- obj = self ._obj , attr = "plot" , accessor = self .accessor , keys = self ._keys
261
+ obj = self ._obj ,
262
+ attr = "plot" ,
263
+ accessor = self .accessor ,
264
+ key_mappers = dict .fromkeys (self ._keys , _get_axis_coord ),
231
265
)
232
266
return plot (* args , ** kwargs )
233
267
234
268
def __getattr__ (self , attr ):
235
269
return _getattr (
236
- obj = self ._obj .plot , attr = attr , accessor = self .accessor , keys = self ._keys
270
+ obj = self ._obj .plot ,
271
+ attr = attr ,
272
+ accessor = self .accessor ,
273
+ key_mappers = dict .fromkeys (self ._keys , _get_axis_coord ),
237
274
)
238
275
239
276
240
277
class CFAccessor :
241
278
def __init__ (self , da ):
242
279
self ._obj = da
243
280
244
- def _process_signature (self , func , args , kwargs , keys ):
281
+ def _process_signature (self , func , args , kwargs , key_mappers ):
245
282
sig = inspect .signature (func , follow_wrapped = False )
246
283
247
284
# Catch things like .isel(T=5).
@@ -254,9 +291,10 @@ def _process_signature(self, func, args, kwargs, keys):
254
291
255
292
if args or kwargs :
256
293
bound = sig .bind (* args , ** kwargs )
257
- arguments = self ._rewrite_values_with_axis_names (
258
- bound .arguments , keys , tuple (var_kws )
294
+ arguments = self ._rewrite_values (
295
+ bound .arguments , key_mappers , tuple (var_kws )
259
296
)
297
+ print (arguments )
260
298
else :
261
299
arguments = {}
262
300
@@ -270,33 +308,32 @@ def _process_signature(self, func, args, kwargs, keys):
270
308
271
309
return arguments
272
310
273
- def _rewrite_values_with_axis_names (self , kwargs , keys , var_kws ):
274
- """ rewrites 'dim' for example. """
275
- updates = {}
276
- for key in tuple (keys ) + tuple (var_kws ):
311
+ def _rewrite_values (self , kwargs , key_mappers : dict , var_kws ):
312
+ """ rewrites 'dim' for example using 'mapper' """
313
+ updates : dict = {}
314
+ key_mappers .update (dict .fromkeys (var_kws , _get_axis_coord ))
315
+ for key , mapper in key_mappers .items ():
277
316
value = kwargs .get (key , None )
278
- if value :
317
+ if value is not None :
279
318
if isinstance (value , str ):
280
319
value = [value ]
281
320
282
321
if isinstance (value , dict ):
283
322
# this for things like isel where **kwargs captures things like T=5
284
323
updates [key ] = {
285
- _get_axis_coord (self ._obj , k , False , k ): v
286
- for k , v in value .items ()
324
+ mapper (self ._obj , k , False , k ): v for k , v in value .items ()
287
325
}
288
326
elif value is Ellipsis :
289
327
pass
290
328
else :
291
329
# things like sum which have dim
292
- updates [key ] = [
293
- _get_axis_coord (self ._obj , v , False , v ) for v in value
294
- ]
330
+ updates [key ] = [mapper (self ._obj , v , False , v ) for v in value ]
295
331
if len (updates [key ]) == 1 :
296
332
updates [key ] = updates [key ][0 ]
297
333
298
334
kwargs .update (updates )
299
335
336
+ # TODO: is there a way to merge this with above?
300
337
# maybe the keys we are looking for are in kwargs.
301
338
# For example, this happens with DataArray.plot(),
302
339
# where the signature is obscured and kwargs is
@@ -306,14 +343,20 @@ def _rewrite_values_with_axis_names(self, kwargs, keys, var_kws):
306
343
maybe_update = {
307
344
k : _get_axis_coord (self ._obj , v , False , v )
308
345
for k , v in kwargs [vkw ].items ()
309
- if k in keys
346
+ if k in key_mappers
310
347
}
311
348
kwargs [vkw ].update (maybe_update )
312
349
313
350
return kwargs
314
351
315
352
def __getattr__ (self , attr ):
316
- return _getattr (obj = self ._obj , attr = attr , accessor = self , wrap_classes = True )
353
+ return _getattr (
354
+ obj = self ._obj ,
355
+ attr = attr ,
356
+ accessor = self ,
357
+ key_mappers = _DEFAULT_KEY_MAPPERS ,
358
+ wrap_classes = True ,
359
+ )
317
360
318
361
@property
319
362
def plot (self ):
@@ -326,7 +369,7 @@ def __getitem__(self, key):
326
369
if key in _AXIS_NAMES + _COORD_NAMES :
327
370
return self ._obj [_get_axis_coord (self ._obj , key )]
328
371
elif key in _CELL_MEASURES :
329
- raise NotImplementedError ("measures not implemented yet." )
372
+ raise NotImplementedError ("measures not implemented for Dataset yet." )
330
373
# return self._obj[_get_measure(self._obj)[key]]
331
374
else :
332
375
raise KeyError (f"DataArray.cf does not understand the key { key } " )
@@ -341,7 +384,6 @@ def __getitem__(self, key):
341
384
if key in _AXIS_NAMES + _COORD_NAMES :
342
385
return self ._obj [_get_axis_coord (self ._obj , key )]
343
386
elif key in _CELL_MEASURES :
344
- raise NotImplementedError ("measures not implemented yet." )
345
- # return self._obj[_get_measure(self._obj)[key]]
387
+ return self ._obj [_get_measure (self ._obj , key )]
346
388
else :
347
389
raise KeyError (f"DataArray.cf does not understand the key { key } " )
0 commit comments