Skip to content

Commit 31ceff2

Browse files
sol1105malmans2
andauthored
Bugfix for issue #191 part 2 (overlooking attributes in xarray.DataArray.encoding) (#264)
* Bugfix for issue #191 part 2 - CFAccessor no longer overlooks cell_measures, formula_terms and bounds when associated attributes are stored in xarray.DataArray.encoding rather than xarray.DataArray.attrs - Added test and test data * Replaced netCDF dataset with dummy dataset. * Execute drop_bounds only for xarray.Datasets * Added test incl. dataset for drop_bounds * Removed unnecessary attributes from the new datasets 'ambig' and 'vert'. * use chainmap rather than nested get Co-authored-by: Mattia Almansi <[email protected]>
1 parent a5aa1d6 commit 31ceff2

File tree

3 files changed

+215
-11
lines changed

3 files changed

+215
-11
lines changed

cf_xarray/accessor.py

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -309,8 +309,9 @@ def _get_measure(obj: Union[DataArray, Dataset], key: str) -> List[str]:
309309
results = set()
310310
for var in obj.variables:
311311
da = obj[var]
312-
if "cell_measures" in da.attrs:
313-
attr = da.attrs["cell_measures"]
312+
attrs_or_encoding = ChainMap(da.attrs, da.encoding)
313+
if "cell_measures" in attrs_or_encoding:
314+
attr = attrs_or_encoding["cell_measures"]
314315
measures = parse_cell_methods_attr(attr)
315316
if key in measures:
316317
results.update([measures[key]])
@@ -339,8 +340,9 @@ def _get_bounds(obj: Union[DataArray, Dataset], key: str) -> List[str]:
339340

340341
results = set()
341342
for var in apply_mapper(_get_all, obj, key, error=False, default=[key]):
342-
if "bounds" in obj[var].attrs:
343-
results |= {obj[var].attrs["bounds"]}
343+
attrs_or_encoding = ChainMap(obj[var].attrs, obj[var].encoding)
344+
if "bounds" in attrs_or_encoding:
345+
results |= {attrs_or_encoding["bounds"]}
344346

345347
return list(results)
346348

@@ -627,8 +629,10 @@ def drop_bounds(names):
627629
# actual variable. It seems practical to ignore them when indexing
628630
# with a scalar key. Hopefully these will soon get decoded to IntervalIndex
629631
# and we can move on...
630-
if scalar_key:
631-
bounds = {obj[k].attrs.get("bounds", None) for k in names}
632+
if not isinstance(obj, DataArray) and scalar_key:
633+
bounds = set()
634+
for name in names:
635+
bounds.update(obj.cf.bounds.get(name, []))
632636
names = set(names) - bounds
633637
return names
634638

@@ -1364,12 +1368,16 @@ def cell_measures(self) -> Dict[str, List[str]]:
13641368
"""
13651369

13661370
obj = self._obj
1367-
all_attrs = [da.attrs.get("cell_measures", "") for da in obj.coords.values()]
1371+
all_attrs = [
1372+
ChainMap(da.attrs, da.encoding).get("cell_measures", "")
1373+
for da in obj.coords.values()
1374+
]
13681375
if isinstance(obj, DataArray):
1369-
all_attrs += [obj.attrs.get("cell_measures", "")]
1376+
all_attrs += [ChainMap(obj.attrs, obj.encoding).get("cell_measures", "")]
13701377
elif isinstance(obj, Dataset):
13711378
all_attrs += [
1372-
da.attrs.get("cell_measures", "") for da in obj.data_vars.values()
1379+
ChainMap(da.attrs, da.encoding).get("cell_measures", "")
1380+
for da in obj.data_vars.values()
13731381
]
13741382

13751383
keys = {}
@@ -2144,12 +2152,13 @@ def formula_terms(self) -> Dict[str, str]:
21442152
{parametric_coord_name: {standard_term_name: variable_name}}
21452153
"""
21462154
da = self._obj
2147-
if "formula_terms" not in da.attrs:
2155+
if "formula_terms" not in ChainMap(da.attrs, da.encoding):
21482156
var = da[_single(_get_dims)(da, "Z")[0]]
21492157
else:
21502158
var = da
2159+
21512160
terms = {}
2152-
formula_terms = var.attrs.get("formula_terms", "")
2161+
formula_terms = ChainMap(var.attrs, var.encoding).get("formula_terms", "")
21532162
for mapping in re.sub(r"\s*:\s*", ":", formula_terms).split():
21542163
key, value = mapping.split(":")
21552164
terms[key] = value

cf_xarray/datasets.py

Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -291,3 +291,172 @@
291291
},
292292
name="basin",
293293
)
294+
295+
296+
ambig = xr.Dataset(
297+
data_vars={},
298+
coords={
299+
"lat": ("lat", np.zeros(5)),
300+
"lon": ("lon", np.zeros(5)),
301+
"vertices_latitude": (["lat", "bnds"], np.zeros((5, 2))),
302+
"vertices_longitude": (["lon", "bnds"], np.zeros((5, 2))),
303+
},
304+
)
305+
ambig["lat"].attrs = {
306+
"bounds": "vertices_latitude",
307+
"units": "degrees_north",
308+
"standard_name": "latitude",
309+
"axis": "Y",
310+
}
311+
ambig["lon"].attrs = {
312+
"bounds": "vertices_longitude",
313+
"units": "degrees_east",
314+
"standard_name": "longitude",
315+
"axis": "X",
316+
}
317+
ambig["vertices_latitude"].attrs = {
318+
"units": "degrees_north",
319+
}
320+
ambig["vertices_longitude"].attrs = {
321+
"units": "degrees_east",
322+
}
323+
324+
325+
vert = xr.Dataset.from_dict(
326+
{
327+
"coords": {
328+
"lat": {
329+
"dims": ("lat",),
330+
"attrs": {
331+
"standard_name": "latitude",
332+
"axis": "Y",
333+
"bounds": "lat_bnds",
334+
"units": "degrees_north",
335+
},
336+
"data": [0.0, 1.0],
337+
},
338+
"lon": {
339+
"dims": ("lon",),
340+
"attrs": {
341+
"standard_name": "longitude",
342+
"axis": "X",
343+
"bounds": "lon_bnds",
344+
"units": "degrees_east",
345+
},
346+
"data": [0.0, 1.0],
347+
},
348+
"lev": {
349+
"dims": ("lev",),
350+
"attrs": {
351+
"standard_name": "atmosphere_hybrid_sigma_pressure_coordinate",
352+
"formula": "p = ap + b*ps",
353+
"formula_terms": "ap: ap b: b ps: ps",
354+
"postitive": "down",
355+
"axis": "Z",
356+
"bounds": "lev_bnds",
357+
},
358+
"data": [0.0, 1.0],
359+
},
360+
"time": {
361+
"dims": ("time",),
362+
"attrs": {
363+
"standard_name": "time",
364+
"axis:": "T",
365+
"bounds": "time_bnds",
366+
"units": "days since 1850-01-01",
367+
"calendar": "proleptic_gregorian",
368+
},
369+
"data": [0.5],
370+
},
371+
"lat_bnds": {
372+
"dims": (
373+
"lat",
374+
"bnds",
375+
),
376+
"attrs": {
377+
"units": "degrees_north",
378+
},
379+
"data": [[0.0, 0.5], [0.5, 1.0]],
380+
},
381+
"lon_bnds": {
382+
"dims": (
383+
"lon",
384+
"bnds",
385+
),
386+
"attrs": {
387+
"units": "degrees_east",
388+
},
389+
"data": [[0.0, 0.5], [0.5, 1.0]],
390+
},
391+
"lev_bnds": {
392+
"dims": (
393+
"lev",
394+
"bnds",
395+
),
396+
"attrs": {
397+
"standard_name": "atmosphere_hybrid_sigma_pressure_coordinate",
398+
"formula": "p = ap + b*ps",
399+
"formula_terms": "ap: ap b: b ps: ps",
400+
},
401+
"data": [[0.0, 0.5], [0.5, 1.0]],
402+
},
403+
"time_bnds": {
404+
"dims": ("time", "bnds"),
405+
"attrs": {
406+
"units": "days since 1850-01-01",
407+
"calendar": "proleptic_gregorian",
408+
},
409+
"data": [[0.0, 1.0]],
410+
},
411+
"ap": {
412+
"dims": ("lev",),
413+
"data": [0.0, 0.0],
414+
},
415+
"b": {
416+
"dims": ("lev",),
417+
"data": [1.0, 0.9],
418+
},
419+
"ap_bnds": {
420+
"dims": (
421+
"lev",
422+
"bnds",
423+
),
424+
"data": [[0.0, 0.0], [0.0, 0.0]],
425+
},
426+
"b_bnds": {
427+
"dims": (
428+
"lev",
429+
"bnds",
430+
),
431+
"data": [[1.0, 0.95], [0.95, 0.9]],
432+
},
433+
},
434+
"dims": {"time": 1, "lev": 2, "lat": 2, "lon": 2, "bnds": 2},
435+
"data_vars": {
436+
"o3": {
437+
"dims": ("time", "lev", "lat", "lon"),
438+
"attrs": {
439+
"cell_methods": "area: time: mean",
440+
"cell_measures": "area: areacella",
441+
"missing_value": 1e20,
442+
"_FillValue": 1e20,
443+
},
444+
"data": np.ones(8, dtype=np.float32).reshape((1, 2, 2, 2)),
445+
},
446+
"areacella": {
447+
"dims": ("lat", "lon"),
448+
"attrs": {
449+
"standard_name": "cell_area",
450+
"cell_methods": "area: sum",
451+
"missing_value": 1e20,
452+
"_FillValue": 1e20,
453+
},
454+
"data": np.ones(4, dtype=np.float32).reshape((2, 2)),
455+
},
456+
"ps": {
457+
"dims": ("time", "lat", "lon"),
458+
"data": np.ones(4, dtype=np.float32).reshape((1, 2, 2)),
459+
},
460+
},
461+
}
462+
)

cf_xarray/tests/test_accessor.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
from ..datasets import (
1818
airds,
19+
ambig,
1920
anc,
2021
basin,
2122
ds_no_attrs,
@@ -24,6 +25,7 @@
2425
multiple,
2526
popds,
2627
romsds,
28+
vert,
2729
)
2830
from . import raise_if_dask_computes, requires_pint
2931

@@ -211,6 +213,30 @@ def test_standard_names():
211213
assert dsnew.cf.standard_names == dict(a=["a", "b"])
212214

213215

216+
def test_drop_bounds():
217+
assert ambig.cf["latitude"].name == "lat"
218+
assert ambig.cf["longitude"].name == "lon"
219+
assert ambig.cf.bounds["latitude"] == ["vertices_latitude"]
220+
assert ambig.cf.bounds["longitude"] == ["vertices_longitude"]
221+
222+
223+
def test_accessor_getattr_and_describe():
224+
ds_verta = vert.set_coords(
225+
(
226+
"ps",
227+
"areacella",
228+
)
229+
)
230+
ds_vertb = xr.decode_cf(vert, decode_coords="all")
231+
232+
assert ds_verta.cf.cell_measures == ds_vertb.cf.cell_measures
233+
assert ds_verta.o3.cf.cell_measures == ds_vertb.o3.cf.cell_measures
234+
assert ds_verta.cf.formula_terms == ds_vertb.cf.formula_terms
235+
assert ds_verta.o3.cf.formula_terms == ds_vertb.o3.cf.formula_terms
236+
assert ds_verta.cf.bounds == ds_vertb.cf.bounds
237+
assert str(ds_verta.cf) == str(ds_vertb.cf)
238+
239+
214240
def test_getitem_standard_name():
215241
actual = airds.cf["air_temperature"]
216242
expected = airds["air"]

0 commit comments

Comments
 (0)