Skip to content

Commit cbfbcfd

Browse files
authored
Allow adding bounds to general coordinate vars (#293)
1 parent 8dfa801 commit cbfbcfd

File tree

2 files changed

+23
-17
lines changed

2 files changed

+23
-17
lines changed

cf_xarray/accessor.py

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2057,7 +2057,7 @@ def add_bounds(self, keys: str | Iterable[str]):
20572057
Parameters
20582058
----------
20592059
keys : str or Iterable[str]
2060-
Either a single key or a list of keys corresponding to dimensions.
2060+
Either a single variable name or a list of variable names.
20612061
20622062
Returns
20632063
-------
@@ -2070,8 +2070,8 @@ def add_bounds(self, keys: str | Iterable[str]):
20702070
20712071
Notes
20722072
-----
2073-
The bounds variables are automatically named ``f"{dim}_bounds"`` where ``dim``
2074-
is a dimension name.
2073+
The bounds variables are automatically named ``f"{var}_bounds"`` where ``var``
2074+
is a variable name.
20752075
20762076
Examples
20772077
--------
@@ -2085,25 +2085,26 @@ def add_bounds(self, keys: str | Iterable[str]):
20852085
if isinstance(keys, str):
20862086
keys = [keys]
20872087

2088-
dimensions = set()
2088+
variables = set()
20892089
for key in keys:
2090-
dimensions.update(
2091-
apply_mapper(_get_dims, self._obj, key, error=False, default=[key])
2090+
variables.update(
2091+
apply_mapper(_get_all, self._obj, key, error=False, default=[key])
20922092
)
20932093

2094-
bad_dims: set[str] = dimensions - set(self._obj.dims)
2095-
if bad_dims:
2094+
obj = self._maybe_to_dataset(self._obj.copy(deep=True))
2095+
2096+
bad_vars: set[str] = variables - set(obj.variables)
2097+
if bad_vars:
20962098
raise ValueError(
2097-
f"{bad_dims!r} are not dimensions in the underlying object."
2099+
f"{bad_vars!r} are not variables in the underlying object."
20982100
)
20992101

2100-
obj = self._maybe_to_dataset(self._obj.copy(deep=True))
2101-
for dim in dimensions:
2102-
bname = f"{dim}_bounds"
2102+
for var in variables:
2103+
bname = f"{var}_bounds"
21032104
if bname in obj.variables:
21042105
raise ValueError(f"Bounds variable name {bname!r} will conflict!")
2105-
obj.coords[bname] = _guess_bounds_dim(obj[dim].reset_coords(drop=True))
2106-
obj[dim].attrs["bounds"] = bname
2106+
obj.coords[bname] = _guess_bounds_dim(obj[var].reset_coords(drop=True))
2107+
obj[var].attrs["bounds"] = bname
21072108

21082109
return self._maybe_to_dataarray(obj)
21092110

cf_xarray/tests/test_accessor.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -699,9 +699,10 @@ def test_plot_xincrease_yincrease():
699699
assert lim[0] > lim[1]
700700

701701

702-
@pytest.mark.parametrize("dims", ["lat", "time", ["lat", "lon"]])
703-
@pytest.mark.parametrize("obj", [airds])
704-
def test_add_bounds(obj, dims):
702+
@pytest.mark.parametrize("dims", ["time2", "lat", "time", ["lat", "lon"]])
703+
def test_add_bounds(dims):
704+
obj = airds.copy(deep=True)
705+
705706
expected = {}
706707
expected["lat"] = xr.concat(
707708
[
@@ -728,10 +729,12 @@ def test_add_bounds(obj, dims):
728729
],
729730
dim="bounds",
730731
)
732+
expected["time2"] = expected["time"]
731733
expected["lat"].attrs.clear()
732734
expected["lon"].attrs.clear()
733735
expected["time"].attrs.clear()
734736

737+
obj.coords["time2"] = obj.time
735738
added = obj.cf.add_bounds(dims)
736739
if isinstance(dims, str):
737740
dims = (dims,)
@@ -742,6 +745,8 @@ def test_add_bounds(obj, dims):
742745
assert added[dim].attrs["bounds"] == name
743746
assert_allclose(added[name].reset_coords(drop=True), expected[dim])
744747

748+
749+
def test_add_bounds_multiple():
745750
# Test multiple dimensions
746751
assert not {"x1_bounds", "x2_bounds"} <= set(multiple.variables)
747752
assert {"x1_bounds", "x2_bounds"} <= set(multiple.cf.add_bounds("X").variables)

0 commit comments

Comments
 (0)