Skip to content

Commit 2bd4930

Browse files
authored
Add getitem and actually parse attrs (#12)
* Add getitem and actually parse attrs Closes #3 Partially implements #4 * lint. * Better test coverage.
1 parent fa882fc commit 2bd4930

File tree

2 files changed

+204
-9
lines changed

2 files changed

+204
-9
lines changed

cf_xarray/accessor.py

Lines changed: 177 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import functools
22
import inspect
3-
from typing import Union
3+
from typing import Any, Union
44

55
import xarray as xr
66
from xarray import DataArray, Dataset
@@ -15,10 +15,149 @@
1515

1616

1717
_DEFAULT_KEYS_TO_REWRITE = ("dim", "coord", "group")
18+
_AXIS_NAMES = ("X", "Y", "Z", "T")
19+
_COORD_NAMES = ("longitude", "latitude", "vertical", "time")
20+
_COORD_AXIS_MAPPING = dict(zip(_COORD_NAMES, _AXIS_NAMES))
21+
_CELL_MEASURES = ("area", "volume")
22+
23+
24+
# Define the criteria for coordinate matches
25+
# Copied from metpy
26+
# Internally we only use X, Y, Z, T
27+
# TODO: Metpy adds latitude and longitude separately so we may revert to doing that too
28+
coordinate_criteria = {
29+
"standard_name": {
30+
"T": ("time",),
31+
"Z": (
32+
"air_pressure",
33+
"height",
34+
"geopotential_height",
35+
"altitude",
36+
"model_level_number",
37+
"atmosphere_ln_pressure_coordinate",
38+
"atmosphere_sigma_coordinate",
39+
"atmosphere_hybrid_sigma_pressure_coordinate",
40+
"atmosphere_hybrid_height_coordinate",
41+
"atmosphere_sleve_coordinate",
42+
"height_above_geopotential_datum",
43+
"height_above_reference_ellipsoid",
44+
"height_above_mean_sea_level",
45+
),
46+
"Y": ("latitude",),
47+
"X": ("longitude",),
48+
},
49+
"_CoordinateAxisType": {
50+
"T": ("Time",),
51+
"Z": ("GeoZ", "Height", "Pressure"),
52+
"Y": ("GeoY", "Lat"),
53+
"X": ("GeoX", "Lon"),
54+
},
55+
"axis": {"T": ("T",), "Z": ("Z",), "Y": ("Y",), "X": ("X",)},
56+
"positive": {"Z": ("up", "down")},
57+
"units": {
58+
"Y": (
59+
"degree_north",
60+
"degree_N",
61+
"degreeN",
62+
"degrees_north",
63+
"degrees_N",
64+
"degreesN",
65+
),
66+
"X": (
67+
"degree_east",
68+
"degree_E",
69+
"degreeE",
70+
"degrees_east",
71+
"degrees_E",
72+
"degreesE",
73+
),
74+
},
75+
# "regular_expression": {
76+
# "time": r"time[0-9]*",
77+
# "vertical": (
78+
# r"(lv_|bottom_top|sigma|h(ei)?ght|altitude|depth|isobaric|pres|"
79+
# r"isotherm)[a-z_]*[0-9]*"
80+
# ),
81+
# "y": r"y",
82+
# "latitude": r"x?lat[a-z0-9]*",
83+
# "x": r"x",
84+
# "longitude": r"x?lon[a-z0-9]*",
85+
# },
86+
}
87+
88+
89+
def _get_axis_coord(var: xr.DataArray, key, error: bool = True, default: Any = None):
90+
"""
91+
Translate from axis or coord name to variable name
1892
93+
Parameters
94+
----------
95+
var : `xarray.DataArray`
96+
DataArray belonging to the coordinate to be checked
97+
key : str, ["X", "Y", "Z", "T", "longitude", "latitude", "vertical", "time"]
98+
key to check for.
99+
error : bool
100+
raise errors when key is not found or interpretable. Use False and provide default
101+
to replicate dict.get(k, None).
102+
default: Any
103+
default value to return when error is False.
104+
105+
Returns
106+
-------
107+
str, Variable name in parent xarray object that matches axis or coordinate `key`
108+
109+
Notes
110+
-----
111+
This functions checks for the following attributes in order
112+
- `standard_name` (CF option)
113+
- `_CoordinateAxisType` (from THREDDS)
114+
- `axis` (CF option)
115+
- `positive` (CF standard for non-pressure vertical coordinate)
116+
117+
References
118+
----------
119+
MetPy's parse_cf
120+
"""
121+
122+
axis = None
123+
if key in _COORD_NAMES:
124+
coord = key
125+
axis = _COORD_AXIS_MAPPING[key]
126+
elif key in _AXIS_NAMES:
127+
coord = ""
128+
axis = key
129+
else:
130+
if error:
131+
raise KeyError(f"Did not understand {key}")
132+
else:
133+
return default
134+
135+
if axis is None:
136+
raise AssertionError("Should be unreachable")
137+
138+
for coord in var.coords:
139+
for criterion, valid_values in coordinate_criteria.items():
140+
if axis in valid_values: # type: ignore
141+
expected = valid_values[axis] # type: ignore
142+
if var.coords[coord].attrs.get(criterion, None) in expected:
143+
return coord
144+
145+
if error:
146+
raise KeyError(f"axis name {key!r} not found!")
147+
else:
148+
return default
149+
150+
151+
def _get_measure(da: xr.DataArray, key: str):
152+
"""
153+
TODO: actually interpret da.attrs to get this.
154+
"""
155+
if key not in _CELL_MEASURES:
156+
raise ValueError(
157+
f"Cell measure must be one of {_CELL_MEASURES!r}. Received {key!r} instead."
158+
)
19159

20-
def _get_axis_name_mapping(da: xr.DataArray):
21-
return {"X": "lon", "Y": "lat", "T": "time"}
160+
return {"area": "cell_area", "volume": "cell_volume"}
22161

23162

24163
def _getattr(
@@ -98,12 +237,9 @@ def __getattr__(self, attr):
98237
)
99238

100239

101-
@xr.register_dataarray_accessor("cf")
102-
@xr.register_dataset_accessor("cf")
103240
class CFAccessor:
104241
def __init__(self, da):
105242
self._obj = da
106-
self._coords = _get_axis_name_mapping(da)
107243

108244
def _process_signature(self, func, args, kwargs, keys):
109245
sig = inspect.signature(func, follow_wrapped=False)
@@ -145,12 +281,17 @@ def _rewrite_values_with_axis_names(self, kwargs, keys, var_kws):
145281

146282
if isinstance(value, dict):
147283
# this for things like isel where **kwargs captures things like T=5
148-
updates[key] = {self._coords.get(k, k): v for k, v in value.items()}
284+
updates[key] = {
285+
_get_axis_coord(self._obj, k, False, k): v
286+
for k, v in value.items()
287+
}
149288
elif value is Ellipsis:
150289
pass
151290
else:
152291
# things like sum which have dim
153-
updates[key] = [self._coords.get(v, v) for v in value]
292+
updates[key] = [
293+
_get_axis_coord(self._obj, v, False, v) for v in value
294+
]
154295
if len(updates[key]) == 1:
155296
updates[key] = updates[key][0]
156297

@@ -163,7 +304,7 @@ def _rewrite_values_with_axis_names(self, kwargs, keys, var_kws):
163304
for vkw in var_kws:
164305
if vkw in kwargs:
165306
maybe_update = {
166-
k: self._coords.get(v, v)
307+
k: _get_axis_coord(self._obj, v, False, v)
167308
for k, v in kwargs[vkw].items()
168309
if k in keys
169310
}
@@ -177,3 +318,30 @@ def __getattr__(self, attr):
177318
@property
178319
def plot(self):
179320
return _CFWrappedPlotMethods(self._obj, self)
321+
322+
323+
@xr.register_dataset_accessor("cf")
324+
class CFDatasetAccessor(CFAccessor):
325+
def __getitem__(self, key):
326+
if key in _AXIS_NAMES + _COORD_NAMES:
327+
return self._obj[_get_axis_coord(self._obj, key)]
328+
elif key in _CELL_MEASURES:
329+
raise NotImplementedError("measures not implemented yet.")
330+
# return self._obj[_get_measure(self._obj)[key]]
331+
else:
332+
raise KeyError(f"DataArray.cf does not understand the key {key}")
333+
334+
# def __getitem__(self, key):
335+
# raise AttributeError("Dataset.cf does not support [] indexing or __getitem__")
336+
337+
338+
@xr.register_dataarray_accessor("cf")
339+
class CFDataArrayAccessor(CFAccessor):
340+
def __getitem__(self, key):
341+
if key in _AXIS_NAMES + _COORD_NAMES:
342+
return self._obj[_get_axis_coord(self._obj, key)]
343+
elif key in _CELL_MEASURES:
344+
raise NotImplementedError("measures not implemented yet.")
345+
# return self._obj[_get_measure(self._obj)[key]]
346+
else:
347+
raise KeyError(f"DataArray.cf does not understand the key {key}")

cf_xarray/tests/test_accessor.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,3 +100,30 @@ def test_dataarray_plot(obj):
100100
@pytest.mark.parametrize("obj", datasets)
101101
def test_dataset_plot(obj):
102102
pass
103+
104+
105+
@pytest.mark.parametrize("obj", objects)
106+
@pytest.mark.parametrize(
107+
"key, expected_key",
108+
(
109+
("X", "lon"),
110+
("Y", "lat"),
111+
("T", "time"),
112+
("longitude", "lon"),
113+
("latitude", "lat"),
114+
("time", "time"),
115+
),
116+
)
117+
def test_getitem(obj, key, expected_key):
118+
actual = obj.cf[key]
119+
expected = obj[expected_key]
120+
assert_identical(actual, expected)
121+
122+
123+
@pytest.mark.parametrize("obj", objects)
124+
def test_getitem_errors(obj,):
125+
with pytest.raises(KeyError):
126+
obj.cf["XX"]
127+
obj.lon.attrs = {}
128+
with pytest.raises(KeyError):
129+
obj.cf["X"]

0 commit comments

Comments
 (0)