Skip to content

Commit c656abd

Browse files
authored
Add FlagGrouper (#556)
* Add FlagGrouper Closes #472 * fix * fix doc build * update * cleanup * try again
1 parent 59167f4 commit c656abd

File tree

6 files changed

+121
-3
lines changed

6 files changed

+121
-3
lines changed

cf_xarray/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,6 @@
99
from .options import set_options # noqa
1010
from .utils import _get_version
1111

12-
from . import geometry # noqa
12+
from . import geometry, groupers # noqa
1313

1414
__version__ = _get_version()

cf_xarray/groupers.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
from dataclasses import dataclass
2+
3+
import numpy as np
4+
import pandas as pd
5+
from xarray.groupers import EncodedGroups, UniqueGrouper
6+
7+
8+
@dataclass
9+
class FlagGrouper(UniqueGrouper):
10+
def factorize(self, group) -> EncodedGroups:
11+
if "flag_values" not in group.attrs or "flag_meanings" not in group.attrs:
12+
raise ValueError(
13+
"FlagGrouper can only be used with flag variables that have"
14+
"`flag_values` and `flag_meanings` specified in attrs."
15+
)
16+
17+
values = np.array(group.attrs["flag_values"])
18+
full_index = pd.Index(group.attrs["flag_meanings"].split(" "))
19+
20+
self.labels = values
21+
22+
# TODO: we could optimize here, since `group` is already factorized,
23+
# but there are subtleties. For example, the attrs must be up to date,
24+
# any value that is not in flag_values will cause an error, etc.
25+
ret = super().factorize(group)
26+
27+
ret.codes.attrs.pop("flag_values")
28+
ret.codes.attrs.pop("flag_meanings")
29+
30+
return EncodedGroups(
31+
codes=ret.codes,
32+
full_index=full_index,
33+
group_indices=ret.group_indices,
34+
)

cf_xarray/tests/test_groupers.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
import numpy as np
2+
import pytest
3+
import xarray as xr
4+
from xarray.testing import assert_identical
5+
6+
from cf_xarray.datasets import flag_excl
7+
from cf_xarray.groupers import FlagGrouper
8+
9+
10+
def test_flag_grouper():
11+
ds = flag_excl.to_dataset().set_coords("flag_var").copy(deep=True)
12+
ds["foo"] = ("time", np.arange(8))
13+
actual = ds.groupby(flag_var=FlagGrouper()).mean()
14+
expected = ds.groupby("flag_var").mean()
15+
expected["flag_var"] = ["flag_1", "flag_2", "flag_3"]
16+
expected["flag_var"].attrs["standard_name"] = "flag_mutual_exclusive"
17+
assert_identical(actual, expected)
18+
19+
del ds.flag_var.attrs["flag_values"]
20+
with pytest.raises(ValueError):
21+
ds.groupby(flag_var=FlagGrouper())
22+
23+
ds.flag_var.attrs["flag_values"] = [0, 1, 2]
24+
del ds.flag_var.attrs["flag_meanings"]
25+
with pytest.raises(ValueError):
26+
ds.groupby(flag_var=FlagGrouper())
27+
28+
29+
@pytest.mark.parametrize(
30+
"values",
31+
[
32+
[1, 2],
33+
[1, 2, 3], # value out of range of flag_values
34+
],
35+
)
36+
def test_flag_grouper_optimized(values):
37+
ds = xr.Dataset(
38+
{"foo": ("x", values, {"flag_values": [0, 1, 2], "flag_meanings": "a b c"})}
39+
)
40+
ret = FlagGrouper().factorize(ds.foo)
41+
expected = ds.foo
42+
expected.data[ds.foo.data > 2] = -1
43+
del ds.foo.attrs["flag_meanings"]
44+
del ds.foo.attrs["flag_values"]
45+
assert_identical(ret.codes, ds.foo)

ci/doc.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ dependencies:
88
- netcdf4
99
- pooch
1010
- xarray
11-
- sphinx
11+
- sphinx<8
1212
- sphinx-copybutton
1313
- numpydoc
1414
- sphinx-autosummary-accessors

doc/api.rst

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,20 @@ Geometries
2121
----------
2222
.. autosummary::
2323
:toctree: generated/
24-
2524
geometry.decode_geometries
25+
2626
geometry.encode_geometries
2727
geometry.shapely_to_cf
2828
geometry.cf_to_shapely
2929
geometry.GeometryNames
3030

31+
32+
Groupers
33+
--------
34+
.. autosummary::
35+
:toctree: generated/
36+
groupers.FlagGrouper
37+
3138
.. currentmodule:: xarray
3239

3340
DataArray

doc/flags.md

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,38 @@ You can also check whether a DataArray has the appropriate attributes to be reco
6060
da.cf.is_flag_variable
6161
```
6262

63+
## GroupBy
64+
65+
Flag variables, such as that above, are naturally used for GroupBy operations.
66+
cf-xarray provides a `FlagGrouper` that understands the `flag_meanings` and `flag_values` attributes.
67+
68+
Let's load an example dataset where the `flag_var` array has the needed attributes.
69+
70+
```{code-cell}
71+
import cf_xarray as cfxr
72+
import numpy as np
73+
74+
from cf_xarray.datasets import flag_excl
75+
76+
ds = flag_excl.to_dataset().set_coords('flag_var')
77+
ds["foo"] = ("time", np.arange(8))
78+
ds.flag_var
79+
```
80+
81+
Now use the :py:class:`~cf_xarray.groupers.FlagGrouper` to group by this flag variable:
82+
83+
```{code-cell}
84+
from cf_xarray.groupers import FlagGrouper
85+
86+
ds.groupby(flag_var=FlagGrouper()).mean()
87+
```
88+
89+
Note how the output coordinate has the values from `flag_meanings`!
90+
91+
```{seealso}
92+
See the Xarray docs on using [Grouper objects](https://docs.xarray.dev/en/stable/user-guide/groupby.html#grouper-objects).
93+
```
94+
6395
## Flag Masks
6496

6597
```{warning}

0 commit comments

Comments
 (0)