Skip to content

Commit fc1419d

Browse files
authored
support 'cell_measures' for DataArrays and rewrite rewriting logic (#14)
Closes #8 To get this to work `_rewrite_values` now takes a `key_mappers` dictionary that maps key to be rewritten to a function that figures out what to rewrite the key with. For e.g. key_mappers={"dim": _get_axis_coord} will rewrite kwargs["dim"] = "X" with the result of _get_axis_coord(da, "X"). For weighted I use key_mappers={"weights": _get_measure_variable}
1 parent 2bd4930 commit fc1419d

File tree

3 files changed

+104
-33
lines changed

3 files changed

+104
-33
lines changed

README.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,4 +21,14 @@ ds.air.cf.groupby("T").var("Y")
2121
)
2222
2323
ds.air.isel(lat=[0, 1], lon=1).cf.plot.line(x="T", hue="Y")
24+
25+
ds.air.attrs["cell_measures"] = "area: cell_area"
26+
ds.coords["cell_area"] = (
27+
xr.DataArray(np.cos(ds.cf["latitude"] * np.pi / 180))
28+
* xr.ones_like(ds.cf["longitude"])
29+
* 105e3
30+
* 110e3
31+
)
32+
ds.air.cf.weighted("area").sum("latitude")
33+
2434
```

cf_xarray/accessor.py

Lines changed: 75 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,11 @@
1414
)
1515

1616

17-
_DEFAULT_KEYS_TO_REWRITE = ("dim", "coord", "group")
1817
_AXIS_NAMES = ("X", "Y", "Z", "T")
1918
_COORD_NAMES = ("longitude", "latitude", "vertical", "time")
2019
_COORD_AXIS_MAPPING = dict(zip(_COORD_NAMES, _AXIS_NAMES))
2120
_CELL_MEASURES = ("area", "volume")
2221

23-
2422
# Define the criteria for coordinate matches
2523
# Copied from metpy
2624
# Internally we only use X, Y, Z, T
@@ -148,24 +146,54 @@ def _get_axis_coord(var: xr.DataArray, key, error: bool = True, default: Any = N
148146
return default
149147

150148

151-
def _get_measure(da: xr.DataArray, key: str):
149+
def _get_measure_variable(
150+
da: xr.DataArray, key: str, error: bool = True, default: Any = None
151+
) -> DataArray:
152+
""" tiny wrapper since xarray does not support providing str for weights."""
153+
return da[_get_measure(da, key, error, default)]
154+
155+
156+
def _get_measure(da: xr.DataArray, key: str, error: bool = True, default: Any = None):
152157
"""
153-
TODO: actually interpret da.attrs to get this.
158+
Interprets 'cell_measures'.
154159
"""
160+
if not isinstance(da, DataArray):
161+
raise NotImplementedError("Measures not implemented for Datasets yet.")
155162
if key not in _CELL_MEASURES:
156-
raise ValueError(
157-
f"Cell measure must be one of {_CELL_MEASURES!r}. Received {key!r} instead."
158-
)
163+
if error:
164+
raise ValueError(
165+
f"Cell measure must be one of {_CELL_MEASURES!r}. Received {key!r} instead."
166+
)
167+
else:
168+
return default
169+
170+
if "cell_measures" not in da.attrs:
171+
if error:
172+
raise KeyError("'cell_measures' not present in 'attrs'.")
173+
else:
174+
return default
159175

160-
return {"area": "cell_area", "volume": "cell_volume"}
176+
attr = da.attrs["cell_measures"]
177+
strings = [s.strip() for s in attr.strip().split(":")]
178+
if len(strings) % 2 != 0:
179+
if error:
180+
raise ValueError(f"attrs['cell_measures'] = {attr!r} is malformed.")
181+
else:
182+
return default
183+
measures = dict(zip(strings[slice(0, None, 2)], strings[slice(1, None, 2)]))
184+
return measures[key]
185+
186+
187+
_DEFAULT_KEY_MAPPERS: dict = dict.fromkeys(("dim", "coord", "group"), _get_axis_coord)
188+
_DEFAULT_KEY_MAPPERS["weights"] = _get_measure_variable
161189

162190

163191
def _getattr(
164192
obj: Union[DataArray, Dataset],
165193
attr: str,
166194
accessor: "CFAccessor",
195+
key_mappers: dict,
167196
wrap_classes=False,
168-
keys=_DEFAULT_KEYS_TO_REWRITE,
169197
):
170198
"""
171199
Common getattr functionality.
@@ -186,8 +214,7 @@ def _getattr(
186214

187215
@functools.wraps(func)
188216
def wrapper(*args, **kwargs):
189-
arguments = accessor._process_signature(func, args, kwargs, keys=keys)
190-
217+
arguments = accessor._process_signature(func, args, kwargs, key_mappers)
191218
result = func(**arguments)
192219
if wrap_classes and isinstance(result, _WRAPPED_CLASSES):
193220
result = _CFWrappedClass(result, accessor)
@@ -203,7 +230,6 @@ def __init__(self, towrap, accessor: "CFAccessor"):
203230
204231
Parameters
205232
----------
206-
207233
obj : DataArray, Dataset
208234
towrap : Resample, GroupBy, Coarsen, Rolling, Weighted
209235
Instance of xarray class that is being wrapped.
@@ -216,7 +242,12 @@ def __repr__(self):
216242
return "--- CF-xarray wrapped \n" + repr(self.wrapped)
217243

218244
def __getattr__(self, attr):
219-
return _getattr(obj=self.wrapped, attr=attr, accessor=self.accessor)
245+
return _getattr(
246+
obj=self.wrapped,
247+
attr=attr,
248+
accessor=self.accessor,
249+
key_mappers=_DEFAULT_KEY_MAPPERS,
250+
)
220251

221252

222253
class _CFWrappedPlotMethods:
@@ -227,21 +258,27 @@ def __init__(self, obj, accessor):
227258

228259
def __call__(self, *args, **kwargs):
229260
plot = _getattr(
230-
obj=self._obj, attr="plot", accessor=self.accessor, keys=self._keys
261+
obj=self._obj,
262+
attr="plot",
263+
accessor=self.accessor,
264+
key_mappers=dict.fromkeys(self._keys, _get_axis_coord),
231265
)
232266
return plot(*args, **kwargs)
233267

234268
def __getattr__(self, attr):
235269
return _getattr(
236-
obj=self._obj.plot, attr=attr, accessor=self.accessor, keys=self._keys
270+
obj=self._obj.plot,
271+
attr=attr,
272+
accessor=self.accessor,
273+
key_mappers=dict.fromkeys(self._keys, _get_axis_coord),
237274
)
238275

239276

240277
class CFAccessor:
241278
def __init__(self, da):
242279
self._obj = da
243280

244-
def _process_signature(self, func, args, kwargs, keys):
281+
def _process_signature(self, func, args, kwargs, key_mappers):
245282
sig = inspect.signature(func, follow_wrapped=False)
246283

247284
# Catch things like .isel(T=5).
@@ -254,9 +291,10 @@ def _process_signature(self, func, args, kwargs, keys):
254291

255292
if args or kwargs:
256293
bound = sig.bind(*args, **kwargs)
257-
arguments = self._rewrite_values_with_axis_names(
258-
bound.arguments, keys, tuple(var_kws)
294+
arguments = self._rewrite_values(
295+
bound.arguments, key_mappers, tuple(var_kws)
259296
)
297+
print(arguments)
260298
else:
261299
arguments = {}
262300

@@ -270,33 +308,32 @@ def _process_signature(self, func, args, kwargs, keys):
270308

271309
return arguments
272310

273-
def _rewrite_values_with_axis_names(self, kwargs, keys, var_kws):
274-
""" rewrites 'dim' for example. """
275-
updates = {}
276-
for key in tuple(keys) + tuple(var_kws):
311+
def _rewrite_values(self, kwargs, key_mappers: dict, var_kws):
312+
""" rewrites 'dim' for example using 'mapper' """
313+
updates: dict = {}
314+
key_mappers.update(dict.fromkeys(var_kws, _get_axis_coord))
315+
for key, mapper in key_mappers.items():
277316
value = kwargs.get(key, None)
278-
if value:
317+
if value is not None:
279318
if isinstance(value, str):
280319
value = [value]
281320

282321
if isinstance(value, dict):
283322
# this for things like isel where **kwargs captures things like T=5
284323
updates[key] = {
285-
_get_axis_coord(self._obj, k, False, k): v
286-
for k, v in value.items()
324+
mapper(self._obj, k, False, k): v for k, v in value.items()
287325
}
288326
elif value is Ellipsis:
289327
pass
290328
else:
291329
# things like sum which have dim
292-
updates[key] = [
293-
_get_axis_coord(self._obj, v, False, v) for v in value
294-
]
330+
updates[key] = [mapper(self._obj, v, False, v) for v in value]
295331
if len(updates[key]) == 1:
296332
updates[key] = updates[key][0]
297333

298334
kwargs.update(updates)
299335

336+
# TODO: is there a way to merge this with above?
300337
# maybe the keys we are looking for are in kwargs.
301338
# For example, this happens with DataArray.plot(),
302339
# where the signature is obscured and kwargs is
@@ -306,14 +343,20 @@ def _rewrite_values_with_axis_names(self, kwargs, keys, var_kws):
306343
maybe_update = {
307344
k: _get_axis_coord(self._obj, v, False, v)
308345
for k, v in kwargs[vkw].items()
309-
if k in keys
346+
if k in key_mappers
310347
}
311348
kwargs[vkw].update(maybe_update)
312349

313350
return kwargs
314351

315352
def __getattr__(self, attr):
316-
return _getattr(obj=self._obj, attr=attr, accessor=self, wrap_classes=True)
353+
return _getattr(
354+
obj=self._obj,
355+
attr=attr,
356+
accessor=self,
357+
key_mappers=_DEFAULT_KEY_MAPPERS,
358+
wrap_classes=True,
359+
)
317360

318361
@property
319362
def plot(self):
@@ -326,7 +369,7 @@ def __getitem__(self, key):
326369
if key in _AXIS_NAMES + _COORD_NAMES:
327370
return self._obj[_get_axis_coord(self._obj, key)]
328371
elif key in _CELL_MEASURES:
329-
raise NotImplementedError("measures not implemented yet.")
372+
raise NotImplementedError("measures not implemented for Dataset yet.")
330373
# return self._obj[_get_measure(self._obj)[key]]
331374
else:
332375
raise KeyError(f"DataArray.cf does not understand the key {key}")
@@ -341,7 +384,6 @@ def __getitem__(self, key):
341384
if key in _AXIS_NAMES + _COORD_NAMES:
342385
return self._obj[_get_axis_coord(self._obj, key)]
343386
elif key in _CELL_MEASURES:
344-
raise NotImplementedError("measures not implemented yet.")
345-
# return self._obj[_get_measure(self._obj)[key]]
387+
return self._obj[_get_measure(self._obj, key)]
346388
else:
347389
raise KeyError(f"DataArray.cf does not understand the key {key}")

cf_xarray/tests/test_accessor.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import matplotlib as mpl
22
import matplotlib.pyplot as plt
3+
import numpy as np
34
import pytest
45
import xarray as xr
56
from xarray.testing import assert_identical
@@ -10,6 +11,10 @@
1011

1112
mpl.use("Agg")
1213
ds = xr.tutorial.open_dataset("air_temperature").isel(time=slice(4), lon=slice(50))
14+
ds.air.attrs["cell_measures"] = "area: cell_area"
15+
ds.coords["cell_area"] = (
16+
xr.DataArray(np.cos(ds.lat * np.pi / 180)) * xr.ones_like(ds.lon) * 105e3 * 110e3
17+
)
1318
datasets = [ds, ds.chunk({"lat": 5})]
1419
dataarrays = [ds.air, ds.air.chunk({"lat": 5})]
1520
objects = datasets + dataarrays
@@ -57,6 +62,15 @@ def test_wrapped_classes(obj, attr, xrkwargs, cfkwargs):
5762
assert_identical(expected, actual)
5863

5964

65+
@pytest.mark.parametrize("obj", dataarrays)
66+
def test_weighted(obj):
67+
with raise_if_dask_computes(max_computes=2):
68+
# weights are checked for nans
69+
expected = obj.weighted(obj["cell_area"]).sum("lat")
70+
actual = obj.cf.weighted("area").sum("Y")
71+
assert_identical(expected, actual)
72+
73+
6074
@pytest.mark.parametrize("obj", objects)
6175
def test_kwargs_methods(obj):
6276
with raise_if_dask_computes():
@@ -112,6 +126,11 @@ def test_dataset_plot(obj):
112126
("longitude", "lon"),
113127
("latitude", "lat"),
114128
("time", "time"),
129+
pytest.param(
130+
"area",
131+
"cell_area",
132+
marks=pytest.mark.xfail(reason="measures not implemented for dataset"),
133+
),
115134
),
116135
)
117136
def test_getitem(obj, key, expected_key):

0 commit comments

Comments
 (0)