Skip to content

Commit 56bac67

Browse files
authored
Support for bounds. (#68)
1. Unfortunately, we cannot return bounds variables with `ds.cf["latitude"]` for e.g. DataArrays cannot have coordinate variables with dimensions that are not on the core variable. Actually, if we go through the _to_temp_dataset, assign bounds to coords, then_from_temp_dataset route, this will work. But pretty sure that's not intended. 2. ds.cf[["air"]] will return bounds variables for dimensions if appropriate.
1 parent 61cf9e9 commit 56bac67

File tree

3 files changed

+137
-64
lines changed

3 files changed

+137
-64
lines changed

cf_xarray/accessor.py

Lines changed: 115 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from typing import (
88
Any,
99
Callable,
10+
Dict,
1011
Hashable,
1112
Iterable,
1213
List,
@@ -896,42 +897,72 @@ def get_standard_names(self) -> List[str]:
896897
]
897898
)
898899

899-
def get_associated_variable_names(self, name: Hashable) -> List[Hashable]:
900+
def get_associated_variable_names(self, name: Hashable) -> Dict[str, List[str]]:
900901
"""
901-
Returns a list of variable names referred to in the following attributes
902-
1. "coordinates"
903-
2. "cell_measures"
904-
3. "ancillary_variables"
902+
Returns a dict mapping
903+
1. "ancillary_variables"
904+
2. "bounds"
905+
3. "cell_measures"
906+
4. "coordinates"
907+
to a list of variable names referred to in the appropriate attribute
908+
909+
Parameters
910+
----------
911+
912+
name: Hashable
913+
914+
Returns
915+
------
916+
917+
Dict with keys "ancillary_variables", "cell_measures", "coordinates", "bounds"
905918
"""
906-
coords = []
919+
keys = ["ancillary_variables", "cell_measures", "coordinates", "bounds"]
920+
coords: Dict[str, List[str]] = {k: [] for k in keys}
907921
attrs_or_encoding = ChainMap(self._obj[name].attrs, self._obj[name].encoding)
908922

909923
if "coordinates" in attrs_or_encoding:
910-
coords.extend(attrs_or_encoding["coordinates"].split(" "))
924+
coords["coordinates"] = attrs_or_encoding["coordinates"].split(" ")
911925

912926
if "cell_measures" in attrs_or_encoding:
913-
measures = [
914-
_get_measure(self._obj[name], measure)
915-
for measure in _CELL_MEASURES
916-
if measure in attrs_or_encoding["cell_measures"]
917-
]
918-
coords.extend(*measures)
927+
coords["cell_measures"] = list(
928+
itertools.chain(
929+
*[
930+
_get_measure(self._obj[name], measure)
931+
for measure in _CELL_MEASURES
932+
if measure in attrs_or_encoding["cell_measures"]
933+
]
934+
)
935+
)
919936

920937
if (
921938
isinstance(self._obj, Dataset)
922939
and "ancillary_variables" in attrs_or_encoding
923940
):
924-
anames = attrs_or_encoding["ancillary_variables"].split(" ")
925-
coords.extend(anames)
941+
coords["ancillary_variables"] = attrs_or_encoding[
942+
"ancillary_variables"
943+
].split(" ")
944+
945+
if "bounds" in attrs_or_encoding:
946+
coords["bounds"] = [attrs_or_encoding["bounds"]]
926947

927-
missing = set(coords) - set(self._maybe_to_dataset().variables)
948+
for dim in self._obj[name].dims:
949+
dbounds = self._obj[dim].attrs.get("bounds", None)
950+
if dbounds:
951+
coords["bounds"].append(dbounds)
952+
953+
allvars = itertools.chain(*coords.values())
954+
missing = set(allvars) - set(self._maybe_to_dataset().variables)
928955
if missing:
929956
warnings.warn(
930957
f"Variables {missing!r} not found in object but are referred to in the CF attributes.",
931958
UserWarning,
932959
)
933-
for m in missing:
934-
coords.remove(m)
960+
for k, v in coords.items():
961+
for m in missing:
962+
if m in v:
963+
v.remove(m)
964+
coords[k] = v
965+
935966
return coords
936967

937968
def __getitem__(self, key: Union[str, List[str]]):
@@ -981,8 +1012,12 @@ def __getitem__(self, key: Union[str, List[str]]):
9811012
allnames = varnames + coords
9821013

9831014
try:
984-
for name in varnames:
985-
coords.extend(self.get_associated_variable_names(name))
1015+
for name in allnames:
1016+
extravars = self.get_associated_variable_names(name)
1017+
# we cannot return bounds variables with scalar keys
1018+
if scalar_key:
1019+
extravars.pop("bounds")
1020+
coords.extend(itertools.chain(*extravars.values()))
9861021

9871022
if isinstance(self._obj, DataArray):
9881023
ds = self._obj._to_temp_dataset()
@@ -1036,47 +1071,6 @@ def _maybe_to_dataarray(self, obj=None):
10361071
else:
10371072
return obj
10381073

1039-
def add_bounds(self, dims: Union[Hashable, Iterable[Hashable]]):
1040-
"""
1041-
Returns a new object with bounds variables. The bounds values are guessed assuming
1042-
equal spacing on either side of a coordinate label.
1043-
1044-
Parameters
1045-
----------
1046-
dims: Hashable or Iterable[Hashable]
1047-
Either a single dimension name or a list of dimension names.
1048-
1049-
Returns
1050-
-------
1051-
DataArray or Dataset with bounds variables added and appropriate "bounds" attribute set.
1052-
1053-
Notes
1054-
-----
1055-
1056-
The bounds variables are automatically named f"{dim}_bounds" where ``dim``
1057-
is a dimension name.
1058-
"""
1059-
if isinstance(dims, Hashable):
1060-
dimensions = (dims,)
1061-
else:
1062-
dimensions = dims
1063-
1064-
bad_dims: Set[Hashable] = set(dimensions) - set(self._obj.dims)
1065-
if bad_dims:
1066-
raise ValueError(
1067-
f"{bad_dims!r} are not dimensions in the underlying object."
1068-
)
1069-
1070-
obj = self._maybe_to_dataset(self._obj.copy(deep=True))
1071-
for dim in dimensions:
1072-
bname = f"{dim}_bounds"
1073-
if bname in obj.variables:
1074-
raise ValueError(f"Bounds variable name {bname!r} will conflict!")
1075-
obj.coords[bname] = _guess_bounds_dim(obj[dim].reset_coords(drop=True))
1076-
obj[dim].attrs["bounds"] = bname
1077-
1078-
return self._maybe_to_dataarray(obj)
1079-
10801074
def rename_like(
10811075
self, other: Union[DataArray, Dataset]
10821076
) -> Union[DataArray, Dataset]:
@@ -1169,7 +1163,66 @@ def guess_coord_axis(self, verbose: bool = False) -> Union[DataArray, Dataset]:
11691163

11701164
@xr.register_dataset_accessor("cf")
11711165
class CFDatasetAccessor(CFAccessor):
1172-
pass
1166+
def get_bounds(self, key: str) -> DataArray:
1167+
"""
1168+
Get bounds variable corresponding to key.
1169+
1170+
Parameters
1171+
----------
1172+
key: str
1173+
Name of variable whose bounds are desired
1174+
1175+
Returns
1176+
-------
1177+
DataArray
1178+
"""
1179+
name = apply_mapper(
1180+
_get_axis_coord_single, self._obj, key, error=False, default=[key]
1181+
)[0]
1182+
bounds = self._obj[name].attrs["bounds"]
1183+
obj = self._maybe_to_dataset()
1184+
return obj[bounds]
1185+
1186+
def add_bounds(self, dims: Union[Hashable, Iterable[Hashable]]):
1187+
"""
1188+
Returns a new object with bounds variables. The bounds values are guessed assuming
1189+
equal spacing on either side of a coordinate label.
1190+
1191+
Parameters
1192+
----------
1193+
dims: Hashable or Iterable[Hashable]
1194+
Either a single dimension name or a list of dimension names.
1195+
1196+
Returns
1197+
-------
1198+
DataArray or Dataset with bounds variables added and appropriate "bounds" attribute set.
1199+
1200+
Notes
1201+
-----
1202+
1203+
The bounds variables are automatically named f"{dim}_bounds" where ``dim``
1204+
is a dimension name.
1205+
"""
1206+
if isinstance(dims, Hashable):
1207+
dimensions = (dims,)
1208+
else:
1209+
dimensions = dims
1210+
1211+
bad_dims: Set[Hashable] = set(dimensions) - set(self._obj.dims)
1212+
if bad_dims:
1213+
raise ValueError(
1214+
f"{bad_dims!r} are not dimensions in the underlying object."
1215+
)
1216+
1217+
obj = self._maybe_to_dataset(self._obj.copy(deep=True))
1218+
for dim in dimensions:
1219+
bname = f"{dim}_bounds"
1220+
if bname in obj.variables:
1221+
raise ValueError(f"Bounds variable name {bname!r} will conflict!")
1222+
obj.coords[bname] = _guess_bounds_dim(obj[dim].reset_coords(drop=True))
1223+
obj[dim].attrs["bounds"] = bname
1224+
1225+
return self._maybe_to_dataarray(obj)
11731226

11741227

11751228
@xr.register_dataarray_accessor("cf")

cf_xarray/tests/test_accessor.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -356,7 +356,7 @@ def test_plot_xincrease_yincrease():
356356

357357

358358
@pytest.mark.parametrize("dims", ["lat", "time", ["lat", "lon"]])
359-
@pytest.mark.parametrize("obj", [airds, airds.air])
359+
@pytest.mark.parametrize("obj", [airds])
360360
def test_add_bounds(obj, dims):
361361
expected = dict()
362362
expected["lat"] = xr.concat(
@@ -399,6 +399,26 @@ def test_add_bounds(obj, dims):
399399
assert_allclose(added[name].reset_coords(drop=True), expected[dim])
400400

401401

402+
def test_bounds():
403+
ds = airds.copy(deep=True).cf.add_bounds("lat")
404+
actual = ds.cf[["lat"]]
405+
expected = ds[["lat", "lat_bounds"]]
406+
assert_identical(actual, expected)
407+
408+
actual = ds.cf[["air"]]
409+
assert "lat_bounds" in actual.coords
410+
411+
# can't associate bounds variable when providing scalar keys
412+
# i.e. when DataArrays are returned
413+
actual = ds.cf["lat"]
414+
expected = ds["lat"]
415+
assert_identical(actual, expected)
416+
417+
actual = ds.cf.get_bounds("lat")
418+
expected = ds["lat_bounds"]
419+
assert_identical(actual, expected)
420+
421+
402422
def test_docstring():
403423
assert "One of ('X'" in airds.cf.groupby.__doc__
404424
assert "One or more of ('X'" in airds.cf.mean.__doc__

doc/api.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@ DataArray
1616
:toctree: generated/
1717
:template: autosummary/accessor_method.rst
1818

19-
DataArray.cf.add_bounds
2019
DataArray.cf.describe
2120
DataArray.cf.get_standard_names
2221
DataArray.cf.get_valid_keys
@@ -36,6 +35,7 @@ Dataset
3635

3736
Dataset.cf.add_bounds
3837
Dataset.cf.describe
38+
Dataset.cf.get_bounds
3939
Dataset.cf.get_standard_names
4040
Dataset.cf.get_valid_keys
4141
Dataset.cf.guess_coord_axis

0 commit comments

Comments
 (0)