1
1
import functools
2
2
import inspect
3
- from typing import Any , Union
3
+ from typing import Any , List , Optional , Set , Union
4
4
5
5
import xarray as xr
6
6
from xarray import DataArray , Dataset
44
44
"Y" : ("latitude" ,),
45
45
"X" : ("longitude" ,),
46
46
},
47
+ "long_name" : {"T" : ("time" ,)},
47
48
"_CoordinateAxisType" : {
48
49
"T" : ("Time" ,),
49
50
"Z" : ("GeoZ" , "Height" , "Pressure" ),
84
85
}
85
86
86
87
87
- def _get_axis_coord (var : xr .DataArray , key , error : bool = True , default : Any = None ):
88
+ def _get_axis_coord_single (var , key , * args ):
89
+ """ Helper method for when we really want only one result per key. """
90
+ results = _get_axis_coord (var , key , * args )
91
+ if len (results ) > 1 :
92
+ raise ValueError (
93
+ "Multiple results for {key!r} found: {results!r}. Is this valid CF? Please open an issue."
94
+ )
95
+ else :
96
+ return results [0 ]
97
+
98
+
99
+ def _get_axis_coord (
100
+ var : Union [xr .DataArray , xr .Dataset ],
101
+ key : str ,
102
+ error : bool = True ,
103
+ default : Optional [str ] = None ,
104
+ ) -> List [Optional [str ]]:
88
105
"""
89
106
Translate from axis or coord name to variable name
90
107
91
108
Parameters
92
109
----------
93
- var : `xarray. DataArray`
110
+ var : DataArray, Dataset
94
111
DataArray belonging to the coordinate to be checked
95
112
key : str, ["X", "Y", "Z", "T", "longitude", "latitude", "vertical", "time"]
96
113
key to check for.
@@ -102,7 +119,7 @@ def _get_axis_coord(var: xr.DataArray, key, error: bool = True, default: Any = N
102
119
103
120
Returns
104
121
-------
105
- str, Variable name in parent xarray object that matches axis or coordinate `key`
122
+ List[ str] , Variable name(s) in parent xarray object that matches axis or coordinate `key`
106
123
107
124
Notes
108
125
-----
@@ -128,22 +145,33 @@ def _get_axis_coord(var: xr.DataArray, key, error: bool = True, default: Any = N
128
145
if error :
129
146
raise KeyError (f"Did not understand { key } " )
130
147
else :
131
- return default
148
+ return [ default ]
132
149
133
150
if axis is None :
134
151
raise AssertionError ("Should be unreachable" )
135
152
136
- for coord in var .coords :
153
+ if "coordinates" in var .encoding :
154
+ search_in = var .encoding ["coordinates" ].split (" " )
155
+ elif "coordinates" in var .attrs :
156
+ search_in = var .attrs ["coordinates" ].split (" " )
157
+ else :
158
+ search_in = list (var .coords )
159
+
160
+ results : Set = set ()
161
+ for coord in search_in :
137
162
for criterion , valid_values in coordinate_criteria .items ():
138
163
if axis in valid_values : # type: ignore
139
164
expected = valid_values [axis ] # type: ignore
140
165
if var .coords [coord ].attrs .get (criterion , None ) in expected :
141
- return coord
166
+ results . update (( coord ,))
142
167
143
- if error :
144
- raise KeyError (f"axis name { key !r} not found!" )
168
+ if not results :
169
+ if error :
170
+ raise KeyError (f"axis name { key !r} not found!" )
171
+ else :
172
+ return [default ]
145
173
else :
146
- return default
174
+ return list ( results )
147
175
148
176
149
177
def _get_measure_variable (
@@ -184,7 +212,9 @@ def _get_measure(da: xr.DataArray, key: str, error: bool = True, default: Any =
184
212
return measures [key ]
185
213
186
214
187
- _DEFAULT_KEY_MAPPERS : dict = dict .fromkeys (("dim" , "coord" , "group" ), _get_axis_coord )
215
+ _DEFAULT_KEY_MAPPERS : dict = dict .fromkeys (
216
+ ("dim" , "coord" , "group" ), _get_axis_coord_single
217
+ )
188
218
_DEFAULT_KEY_MAPPERS ["weights" ] = _get_measure_variable
189
219
190
220
@@ -261,7 +291,7 @@ def __call__(self, *args, **kwargs):
261
291
obj = self ._obj ,
262
292
attr = "plot" ,
263
293
accessor = self .accessor ,
264
- key_mappers = dict .fromkeys (self ._keys , _get_axis_coord ),
294
+ key_mappers = dict .fromkeys (self ._keys , _get_axis_coord_single ),
265
295
)
266
296
return plot (* args , ** kwargs )
267
297
@@ -270,7 +300,7 @@ def __getattr__(self, attr):
270
300
obj = self ._obj .plot ,
271
301
attr = attr ,
272
302
accessor = self .accessor ,
273
- key_mappers = dict .fromkeys (self ._keys , _get_axis_coord ),
303
+ key_mappers = dict .fromkeys (self ._keys , _get_axis_coord_single ),
274
304
)
275
305
276
306
@@ -294,7 +324,6 @@ def _process_signature(self, func, args, kwargs, key_mappers):
294
324
arguments = self ._rewrite_values (
295
325
bound .arguments , key_mappers , tuple (var_kws )
296
326
)
297
- print (arguments )
298
327
else :
299
328
arguments = {}
300
329
@@ -311,7 +340,7 @@ def _process_signature(self, func, args, kwargs, key_mappers):
311
340
def _rewrite_values (self , kwargs , key_mappers : dict , var_kws ):
312
341
""" rewrites 'dim' for example using 'mapper' """
313
342
updates : dict = {}
314
- key_mappers .update (dict .fromkeys (var_kws , _get_axis_coord ))
343
+ key_mappers .update (dict .fromkeys (var_kws , _get_axis_coord_single ))
315
344
for key , mapper in key_mappers .items ():
316
345
value = kwargs .get (key , None )
317
346
if value is not None :
@@ -341,7 +370,7 @@ def _rewrite_values(self, kwargs, key_mappers: dict, var_kws):
341
370
for vkw in var_kws :
342
371
if vkw in kwargs :
343
372
maybe_update = {
344
- k : _get_axis_coord (self ._obj , v , False , v )
373
+ k : _get_axis_coord_single (self ._obj , v , False , v )
345
374
for k , v in kwargs [vkw ].items ()
346
375
if k in key_mappers
347
376
}
@@ -367,22 +396,21 @@ def plot(self):
367
396
class CFDatasetAccessor (CFAccessor ):
368
397
def __getitem__ (self , key ):
369
398
if key in _AXIS_NAMES + _COORD_NAMES :
370
- return self ._obj [_get_axis_coord (self ._obj , key )]
399
+ varnames = _get_axis_coord (self ._obj , key )
400
+ return self ._obj .reset_coords ()[varnames ].set_coords (varnames )
371
401
elif key in _CELL_MEASURES :
372
402
raise NotImplementedError ("measures not implemented for Dataset yet." )
373
403
# return self._obj[_get_measure(self._obj)[key]]
374
404
else :
375
405
raise KeyError (f"DataArray.cf does not understand the key { key } " )
376
406
377
- # def __getitem__(self, key):
378
- # raise AttributeError("Dataset.cf does not support [] indexing or __getitem__")
379
-
380
407
381
408
@xr .register_dataarray_accessor ("cf" )
382
409
class CFDataArrayAccessor (CFAccessor ):
383
410
def __getitem__ (self , key ):
384
411
if key in _AXIS_NAMES + _COORD_NAMES :
385
- return self ._obj [_get_axis_coord (self ._obj , key )]
412
+ varname = _get_axis_coord_single (self ._obj , key )
413
+ return self ._obj [varname ].reset_coords (drop = True )
386
414
elif key in _CELL_MEASURES :
387
415
return self ._obj [_get_measure (self ._obj , key )]
388
416
else :
0 commit comments