1
1
import functools
2
2
import inspect
3
- from typing import Union
3
+ from typing import Any , Union
4
4
5
5
import xarray as xr
6
6
from xarray import DataArray , Dataset
15
15
16
16
17
17
_DEFAULT_KEYS_TO_REWRITE = ("dim" , "coord" , "group" )
18
+ _AXIS_NAMES = ("X" , "Y" , "Z" , "T" )
19
+ _COORD_NAMES = ("longitude" , "latitude" , "vertical" , "time" )
20
+ _COORD_AXIS_MAPPING = dict (zip (_COORD_NAMES , _AXIS_NAMES ))
21
+ _CELL_MEASURES = ("area" , "volume" )
22
+
23
+
24
+ # Define the criteria for coordinate matches
25
+ # Copied from metpy
26
+ # Internally we only use X, Y, Z, T
27
+ # TODO: Metpy adds latitude and longitude separately so we may revert to doing that too
28
+ coordinate_criteria = {
29
+ "standard_name" : {
30
+ "T" : ("time" ,),
31
+ "Z" : (
32
+ "air_pressure" ,
33
+ "height" ,
34
+ "geopotential_height" ,
35
+ "altitude" ,
36
+ "model_level_number" ,
37
+ "atmosphere_ln_pressure_coordinate" ,
38
+ "atmosphere_sigma_coordinate" ,
39
+ "atmosphere_hybrid_sigma_pressure_coordinate" ,
40
+ "atmosphere_hybrid_height_coordinate" ,
41
+ "atmosphere_sleve_coordinate" ,
42
+ "height_above_geopotential_datum" ,
43
+ "height_above_reference_ellipsoid" ,
44
+ "height_above_mean_sea_level" ,
45
+ ),
46
+ "Y" : ("latitude" ,),
47
+ "X" : ("longitude" ,),
48
+ },
49
+ "_CoordinateAxisType" : {
50
+ "T" : ("Time" ,),
51
+ "Z" : ("GeoZ" , "Height" , "Pressure" ),
52
+ "Y" : ("GeoY" , "Lat" ),
53
+ "X" : ("GeoX" , "Lon" ),
54
+ },
55
+ "axis" : {"T" : ("T" ,), "Z" : ("Z" ,), "Y" : ("Y" ,), "X" : ("X" ,)},
56
+ "positive" : {"Z" : ("up" , "down" )},
57
+ "units" : {
58
+ "Y" : (
59
+ "degree_north" ,
60
+ "degree_N" ,
61
+ "degreeN" ,
62
+ "degrees_north" ,
63
+ "degrees_N" ,
64
+ "degreesN" ,
65
+ ),
66
+ "X" : (
67
+ "degree_east" ,
68
+ "degree_E" ,
69
+ "degreeE" ,
70
+ "degrees_east" ,
71
+ "degrees_E" ,
72
+ "degreesE" ,
73
+ ),
74
+ },
75
+ # "regular_expression": {
76
+ # "time": r"time[0-9]*",
77
+ # "vertical": (
78
+ # r"(lv_|bottom_top|sigma|h(ei)?ght|altitude|depth|isobaric|pres|"
79
+ # r"isotherm)[a-z_]*[0-9]*"
80
+ # ),
81
+ # "y": r"y",
82
+ # "latitude": r"x?lat[a-z0-9]*",
83
+ # "x": r"x",
84
+ # "longitude": r"x?lon[a-z0-9]*",
85
+ # },
86
+ }
87
+
88
+
89
+ def _get_axis_coord (var : xr .DataArray , key , error : bool = True , default : Any = None ):
90
+ """
91
+ Translate from axis or coord name to variable name
18
92
93
+ Parameters
94
+ ----------
95
+ var : `xarray.DataArray`
96
+ DataArray belonging to the coordinate to be checked
97
+ key : str, ["X", "Y", "Z", "T", "longitude", "latitude", "vertical", "time"]
98
+ key to check for.
99
+ error : bool
100
+ raise errors when key is not found or interpretable. Use False and provide default
101
+ to replicate dict.get(k, None).
102
+ default: Any
103
+ default value to return when error is False.
104
+
105
+ Returns
106
+ -------
107
+ str, Variable name in parent xarray object that matches axis or coordinate `key`
108
+
109
+ Notes
110
+ -----
111
+ This functions checks for the following attributes in order
112
+ - `standard_name` (CF option)
113
+ - `_CoordinateAxisType` (from THREDDS)
114
+ - `axis` (CF option)
115
+ - `positive` (CF standard for non-pressure vertical coordinate)
116
+
117
+ References
118
+ ----------
119
+ MetPy's parse_cf
120
+ """
121
+
122
+ axis = None
123
+ if key in _COORD_NAMES :
124
+ coord = key
125
+ axis = _COORD_AXIS_MAPPING [key ]
126
+ elif key in _AXIS_NAMES :
127
+ coord = ""
128
+ axis = key
129
+ else :
130
+ if error :
131
+ raise KeyError (f"Did not understand { key } " )
132
+ else :
133
+ return default
134
+
135
+ if axis is None :
136
+ raise AssertionError ("Should be unreachable" )
137
+
138
+ for coord in var .coords :
139
+ for criterion , valid_values in coordinate_criteria .items ():
140
+ if axis in valid_values : # type: ignore
141
+ expected = valid_values [axis ] # type: ignore
142
+ if var .coords [coord ].attrs .get (criterion , None ) in expected :
143
+ return coord
144
+
145
+ if error :
146
+ raise KeyError (f"axis name { key !r} not found!" )
147
+ else :
148
+ return default
149
+
150
+
151
+ def _get_measure (da : xr .DataArray , key : str ):
152
+ """
153
+ TODO: actually interpret da.attrs to get this.
154
+ """
155
+ if key not in _CELL_MEASURES :
156
+ raise ValueError (
157
+ f"Cell measure must be one of { _CELL_MEASURES !r} . Received { key !r} instead."
158
+ )
19
159
20
- def _get_axis_name_mapping (da : xr .DataArray ):
21
- return {"X" : "lon" , "Y" : "lat" , "T" : "time" }
160
+ return {"area" : "cell_area" , "volume" : "cell_volume" }
22
161
23
162
24
163
def _getattr (
@@ -98,12 +237,9 @@ def __getattr__(self, attr):
98
237
)
99
238
100
239
101
- @xr .register_dataarray_accessor ("cf" )
102
- @xr .register_dataset_accessor ("cf" )
103
240
class CFAccessor :
104
241
def __init__ (self , da ):
105
242
self ._obj = da
106
- self ._coords = _get_axis_name_mapping (da )
107
243
108
244
def _process_signature (self , func , args , kwargs , keys ):
109
245
sig = inspect .signature (func , follow_wrapped = False )
@@ -145,12 +281,17 @@ def _rewrite_values_with_axis_names(self, kwargs, keys, var_kws):
145
281
146
282
if isinstance (value , dict ):
147
283
# this for things like isel where **kwargs captures things like T=5
148
- updates [key ] = {self ._coords .get (k , k ): v for k , v in value .items ()}
284
+ updates [key ] = {
285
+ _get_axis_coord (self ._obj , k , False , k ): v
286
+ for k , v in value .items ()
287
+ }
149
288
elif value is Ellipsis :
150
289
pass
151
290
else :
152
291
# things like sum which have dim
153
- updates [key ] = [self ._coords .get (v , v ) for v in value ]
292
+ updates [key ] = [
293
+ _get_axis_coord (self ._obj , v , False , v ) for v in value
294
+ ]
154
295
if len (updates [key ]) == 1 :
155
296
updates [key ] = updates [key ][0 ]
156
297
@@ -163,7 +304,7 @@ def _rewrite_values_with_axis_names(self, kwargs, keys, var_kws):
163
304
for vkw in var_kws :
164
305
if vkw in kwargs :
165
306
maybe_update = {
166
- k : self ._coords . get ( v , v )
307
+ k : _get_axis_coord ( self ._obj , v , False , v )
167
308
for k , v in kwargs [vkw ].items ()
168
309
if k in keys
169
310
}
@@ -177,3 +318,30 @@ def __getattr__(self, attr):
177
318
@property
178
319
def plot (self ):
179
320
return _CFWrappedPlotMethods (self ._obj , self )
321
+
322
+
323
+ @xr .register_dataset_accessor ("cf" )
324
+ class CFDatasetAccessor (CFAccessor ):
325
+ def __getitem__ (self , key ):
326
+ if key in _AXIS_NAMES + _COORD_NAMES :
327
+ return self ._obj [_get_axis_coord (self ._obj , key )]
328
+ elif key in _CELL_MEASURES :
329
+ raise NotImplementedError ("measures not implemented yet." )
330
+ # return self._obj[_get_measure(self._obj)[key]]
331
+ else :
332
+ raise KeyError (f"DataArray.cf does not understand the key { key } " )
333
+
334
+ # def __getitem__(self, key):
335
+ # raise AttributeError("Dataset.cf does not support [] indexing or __getitem__")
336
+
337
+
338
+ @xr .register_dataarray_accessor ("cf" )
339
+ class CFDataArrayAccessor (CFAccessor ):
340
+ def __getitem__ (self , key ):
341
+ if key in _AXIS_NAMES + _COORD_NAMES :
342
+ return self ._obj [_get_axis_coord (self ._obj , key )]
343
+ elif key in _CELL_MEASURES :
344
+ raise NotImplementedError ("measures not implemented yet." )
345
+ # return self._obj[_get_measure(self._obj)[key]]
346
+ else :
347
+ raise KeyError (f"DataArray.cf does not understand the key { key } " )
0 commit comments