diff --git a/cf_xarray/__init__.py b/cf_xarray/__init__.py index 6ae0a916..cc038830 100644 --- a/cf_xarray/__init__.py +++ b/cf_xarray/__init__.py @@ -9,6 +9,6 @@ from .options import set_options # noqa from .utils import _get_version -from . import geometry # noqa +from . import geometry, groupers # noqa __version__ = _get_version() diff --git a/cf_xarray/groupers.py b/cf_xarray/groupers.py new file mode 100644 index 00000000..353d7829 --- /dev/null +++ b/cf_xarray/groupers.py @@ -0,0 +1,34 @@ +from dataclasses import dataclass + +import numpy as np +import pandas as pd +from xarray.groupers import EncodedGroups, UniqueGrouper + + +@dataclass +class FlagGrouper(UniqueGrouper): + def factorize(self, group) -> EncodedGroups: + if "flag_values" not in group.attrs or "flag_meanings" not in group.attrs: + raise ValueError( + "FlagGrouper can only be used with flag variables that have" + "`flag_values` and `flag_meanings` specified in attrs." + ) + + values = np.array(group.attrs["flag_values"]) + full_index = pd.Index(group.attrs["flag_meanings"].split(" ")) + + self.labels = values + + # TODO: we could optimize here, since `group` is already factorized, + # but there are subtleties. For example, the attrs must be up to date, + # any value that is not in flag_values will cause an error, etc. + ret = super().factorize(group) + + ret.codes.attrs.pop("flag_values") + ret.codes.attrs.pop("flag_meanings") + + return EncodedGroups( + codes=ret.codes, + full_index=full_index, + group_indices=ret.group_indices, + ) diff --git a/cf_xarray/tests/test_groupers.py b/cf_xarray/tests/test_groupers.py new file mode 100644 index 00000000..8c87d695 --- /dev/null +++ b/cf_xarray/tests/test_groupers.py @@ -0,0 +1,45 @@ +import numpy as np +import pytest +import xarray as xr +from xarray.testing import assert_identical + +from cf_xarray.datasets import flag_excl +from cf_xarray.groupers import FlagGrouper + + +def test_flag_grouper(): + ds = flag_excl.to_dataset().set_coords("flag_var").copy(deep=True) + ds["foo"] = ("time", np.arange(8)) + actual = ds.groupby(flag_var=FlagGrouper()).mean() + expected = ds.groupby("flag_var").mean() + expected["flag_var"] = ["flag_1", "flag_2", "flag_3"] + expected["flag_var"].attrs["standard_name"] = "flag_mutual_exclusive" + assert_identical(actual, expected) + + del ds.flag_var.attrs["flag_values"] + with pytest.raises(ValueError): + ds.groupby(flag_var=FlagGrouper()) + + ds.flag_var.attrs["flag_values"] = [0, 1, 2] + del ds.flag_var.attrs["flag_meanings"] + with pytest.raises(ValueError): + ds.groupby(flag_var=FlagGrouper()) + + +@pytest.mark.parametrize( + "values", + [ + [1, 2], + [1, 2, 3], # value out of range of flag_values + ], +) +def test_flag_grouper_optimized(values): + ds = xr.Dataset( + {"foo": ("x", values, {"flag_values": [0, 1, 2], "flag_meanings": "a b c"})} + ) + ret = FlagGrouper().factorize(ds.foo) + expected = ds.foo + expected.data[ds.foo.data > 2] = -1 + del ds.foo.attrs["flag_meanings"] + del ds.foo.attrs["flag_values"] + assert_identical(ret.codes, ds.foo) diff --git a/ci/doc.yml b/ci/doc.yml index 22640931..8c7187f2 100644 --- a/ci/doc.yml +++ b/ci/doc.yml @@ -8,7 +8,7 @@ dependencies: - netcdf4 - pooch - xarray - - sphinx + - sphinx<8 - sphinx-copybutton - numpydoc - sphinx-autosummary-accessors diff --git a/doc/api.rst b/doc/api.rst index 493a01a9..bc88fc86 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -21,13 +21,20 @@ Geometries ---------- .. autosummary:: :toctree: generated/ - geometry.decode_geometries + geometry.encode_geometries geometry.shapely_to_cf geometry.cf_to_shapely geometry.GeometryNames + +Groupers +-------- +.. autosummary:: + :toctree: generated/ + groupers.FlagGrouper + .. currentmodule:: xarray DataArray diff --git a/doc/flags.md b/doc/flags.md index 5e37df36..2e95b6e4 100644 --- a/doc/flags.md +++ b/doc/flags.md @@ -60,6 +60,38 @@ You can also check whether a DataArray has the appropriate attributes to be reco da.cf.is_flag_variable ``` +## GroupBy + +Flag variables, such as that above, are naturally used for GroupBy operations. +cf-xarray provides a `FlagGrouper` that understands the `flag_meanings` and `flag_values` attributes. + +Let's load an example dataset where the `flag_var` array has the needed attributes. + +```{code-cell} +import cf_xarray as cfxr +import numpy as np + +from cf_xarray.datasets import flag_excl + +ds = flag_excl.to_dataset().set_coords('flag_var') +ds["foo"] = ("time", np.arange(8)) +ds.flag_var +``` + +Now use the :py:class:`~cf_xarray.groupers.FlagGrouper` to group by this flag variable: + +```{code-cell} +from cf_xarray.groupers import FlagGrouper + +ds.groupby(flag_var=FlagGrouper()).mean() +``` + +Note how the output coordinate has the values from `flag_meanings`! + +```{seealso} +See the Xarray docs on using [Grouper objects](https://docs.xarray.dev/en/stable/user-guide/groupby.html#grouper-objects). +``` + ## Flag Masks ```{warning}