Skip to content

Commit fef87df

Browse files
Descanongepre-commit-ci[bot]dcherian
authored
Add support for 'flag_masks' (#354)
* Implement for 'flag_masks' Change create_flag_dict() to account for flag_masks. Add extract_flags() to retrieve boolean masks of any type of flag. Change implementation of __eq__, __ne__, isin Change __repr__ to list flags depending on their type (mutually exclusive or independent). Do not list the corresponding value anymore (it becomes complicated when mixing both types...). * make isin name same as original variable * Fix extract_flag for only mutually excl flags * Use original error message * Simplify casting to integers * Fix typo * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Remove math import * Make flags a property * Streamline _extract_flags * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix isin Fix if no test_elements is in flag_meanings Would return the opposite * Add unit tests * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix argument typing * Remove list conversion Not useful, and does not respect typing * Readd basin tests * Minor edits * Fix types * Minor edits * Fix tests. * Revert to old eq, ne * Update rich repr * minimize diff * Fix test * Fix test * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add pint to mypy ignore * ignore typing errors * Use NamedTuple * comment * Add docs. * more doc * even more doc --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Deepak Cherian <[email protected]> Co-authored-by: dcherian <[email protected]>
1 parent fa09644 commit fef87df

File tree

6 files changed

+377
-64
lines changed

6 files changed

+377
-64
lines changed

cf_xarray/accessor.py

Lines changed: 138 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import itertools
66
import re
77
import warnings
8-
from collections import ChainMap
8+
from collections import ChainMap, namedtuple
99
from datetime import datetime
1010
from typing import (
1111
Any,
@@ -58,6 +58,8 @@
5858
parse_cf_standard_name_table,
5959
)
6060

61+
FlagParam = namedtuple("FlagParam", ["flag_mask", "flag_value"])
62+
6163
#: Classes wrapped by cf_xarray.
6264
_WRAPPED_CLASSES = (Resample, GroupBy, Rolling, Coarsen, Weighted)
6365

@@ -1057,18 +1059,39 @@ def __getattr__(self, attr):
10571059
)
10581060

10591061

1060-
def create_flag_dict(da):
1062+
def create_flag_dict(da) -> Mapping[Hashable, FlagParam]:
1063+
"""
1064+
Return possible flag meanings and associated bitmask/values.
1065+
1066+
The mapping values are a tuple containing a bitmask and a value. Either
1067+
can be None.
1068+
If only a bitmask: Independent flags.
1069+
If only a value: Mutually exclusive flags.
1070+
If both: Mix of independent and mutually exclusive flags.
1071+
"""
10611072
if not da.cf.is_flag_variable:
10621073
raise ValueError(
1063-
"Comparisons are only supported for DataArrays that represent CF flag variables."
1064-
".attrs must contain 'flag_values' and 'flag_meanings'"
1074+
"Comparisons are only supported for DataArrays that represent "
1075+
"CF flag variables. .attrs must contain 'flag_meanings' and "
1076+
"'flag_values' or 'flag_masks'."
10651077
)
10661078

10671079
flag_meanings = da.attrs["flag_meanings"].split(" ")
1068-
flag_values = da.attrs["flag_values"]
1069-
# TODO: assert flag_values is iterable
1070-
assert len(flag_values) == len(flag_meanings)
1071-
return dict(zip(flag_meanings, flag_values))
1080+
n_flag = len(flag_meanings)
1081+
1082+
flag_values = da.attrs.get("flag_values", [None] * n_flag)
1083+
flag_masks = da.attrs.get("flag_masks", [None] * n_flag)
1084+
1085+
if not (n_flag == len(flag_values) == len(flag_masks)):
1086+
raise ValueError(
1087+
"Not as many flag meanings as values or masks. "
1088+
"Please check the flag_meanings, flag_values, flag_masks attributes "
1089+
)
1090+
1091+
flag_params = tuple(
1092+
FlagParam(mask, value) for mask, value in zip(flag_masks, flag_values)
1093+
)
1094+
return dict(zip(flag_meanings, flag_params))
10721095

10731096

10741097
class CFAccessor:
@@ -1084,36 +1107,40 @@ def __setstate__(self, d):
10841107
self.__dict__ = d
10851108

10861109
def _assert_valid_other_comparison(self, other):
1110+
# TODO cache this property
10871111
flag_dict = create_flag_dict(self._obj)
10881112
if other not in flag_dict:
10891113
raise ValueError(
10901114
f"Did not find flag value meaning [{other}] in known flag meanings: [{flag_dict.keys()!r}]"
10911115
)
1116+
if flag_dict[other].flag_mask is not None:
1117+
raise NotImplementedError(
1118+
"Only equals and not-equals comparisons with flag masks are supported."
1119+
" Please open an issue."
1120+
)
10921121
return flag_dict
10931122

1094-
def __eq__(self, other):
1123+
def __eq__(self, other) -> DataArray: # type: ignore
10951124
"""
10961125
Compare flag values against `other`.
10971126
10981127
`other` must be in the 'flag_meanings' attribute.
10991128
`other` is mapped to the corresponding value in the 'flag_values' attribute, and then
11001129
compared.
11011130
"""
1102-
flag_dict = self._assert_valid_other_comparison(other)
1103-
return self._obj == flag_dict[other]
1131+
return self._extract_flags([other])[other].rename(self._obj.name)
11041132

1105-
def __ne__(self, other):
1133+
def __ne__(self, other) -> DataArray: # type: ignore
11061134
"""
11071135
Compare flag values against `other`.
11081136
11091137
`other` must be in the 'flag_meanings' attribute.
11101138
`other` is mapped to the corresponding value in the 'flag_values' attribute, and then
11111139
compared.
11121140
"""
1113-
flag_dict = self._assert_valid_other_comparison(other)
1114-
return self._obj != flag_dict[other]
1141+
return ~self._extract_flags([other])[other].rename(self._obj.name)
11151142

1116-
def __lt__(self, other):
1143+
def __lt__(self, other) -> DataArray:
11171144
"""
11181145
Compare flag values against `other`.
11191146
@@ -1122,9 +1149,9 @@ def __lt__(self, other):
11221149
compared.
11231150
"""
11241151
flag_dict = self._assert_valid_other_comparison(other)
1125-
return self._obj < flag_dict[other]
1152+
return self._obj < flag_dict[other].flag_value
11261153

1127-
def __le__(self, other):
1154+
def __le__(self, other) -> DataArray:
11281155
"""
11291156
Compare flag values against `other`.
11301157
@@ -1133,9 +1160,9 @@ def __le__(self, other):
11331160
compared.
11341161
"""
11351162
flag_dict = self._assert_valid_other_comparison(other)
1136-
return self._obj <= flag_dict[other]
1163+
return self._obj <= flag_dict[other].flag_value
11371164

1138-
def __gt__(self, other):
1165+
def __gt__(self, other) -> DataArray:
11391166
"""
11401167
Compare flag values against `other`.
11411168
@@ -1144,9 +1171,9 @@ def __gt__(self, other):
11441171
compared.
11451172
"""
11461173
flag_dict = self._assert_valid_other_comparison(other)
1147-
return self._obj > flag_dict[other]
1174+
return self._obj > flag_dict[other].flag_value
11481175

1149-
def __ge__(self, other):
1176+
def __ge__(self, other) -> DataArray:
11501177
"""
11511178
Compare flag values against `other`.
11521179
@@ -1155,9 +1182,9 @@ def __ge__(self, other):
11551182
compared.
11561183
"""
11571184
flag_dict = self._assert_valid_other_comparison(other)
1158-
return self._obj >= flag_dict[other]
1185+
return self._obj >= flag_dict[other].flag_value
11591186

1160-
def isin(self, test_elements):
1187+
def isin(self, test_elements) -> DataArray:
11611188
"""Test each value in the array for whether it is in test_elements.
11621189
11631190
Parameters
@@ -1177,14 +1204,15 @@ def isin(self, test_elements):
11771204
raise ValueError(
11781205
".cf.isin is only supported on DataArrays that contain CF flag attributes."
11791206
)
1207+
# TODO cache this property
11801208
flag_dict = create_flag_dict(self._obj)
11811209
mapped_test_elements = []
11821210
for elem in test_elements:
11831211
if elem not in flag_dict:
11841212
raise ValueError(
11851213
f"Did not find flag value meaning [{elem}] in known flag meanings: [{flag_dict.keys()!r}]"
11861214
)
1187-
mapped_test_elements.append(flag_dict[elem])
1215+
mapped_test_elements.append(flag_dict[elem].flag_value)
11881216
return self._obj.isin(mapped_test_elements)
11891217

11901218
def _drop_missing_variables(self, variables: list[Hashable]) -> list[Hashable]:
@@ -2753,22 +2781,104 @@ def __getitem__(self, key: Hashable | Iterable[Hashable]) -> DataArray:
27532781

27542782
return _getitem(self, key)
27552783

2784+
@property
2785+
def flags(self) -> Dataset:
2786+
"""
2787+
Dataset containing boolean masks of available flags.
2788+
"""
2789+
return self._extract_flags()
2790+
2791+
def _extract_flags(self, flags: Sequence[Hashable] | None = None) -> Dataset:
2792+
"""
2793+
Return dataset of boolean mask(s) corresponding to `flags`.
2794+
2795+
Parameters
2796+
----------
2797+
flags: Sequence[str]
2798+
Flags to extract. If empty (string or list), return all flags in
2799+
`flag_meanings`.
2800+
"""
2801+
# TODO cache this property
2802+
flag_dict = create_flag_dict(self._obj)
2803+
2804+
if flags is None:
2805+
flags = tuple(flag_dict.keys())
2806+
2807+
out = {} # Output arrays
2808+
2809+
masks = [] # Bitmasks and values for asked flags
2810+
values = []
2811+
flags_reduced = [] # Flags left after removing mutually excl. flags
2812+
for flag in flags:
2813+
if flag not in flag_dict:
2814+
raise ValueError(
2815+
f"Did not find flag value meaning [{flag}] in known flag meanings:"
2816+
f" [{flag_dict.keys()!r}]"
2817+
)
2818+
mask, value = flag_dict[flag]
2819+
if mask is None:
2820+
out[flag] = self._obj == value
2821+
else:
2822+
masks.append(mask)
2823+
values.append(value)
2824+
flags_reduced.append(flag)
2825+
2826+
if len(masks) > 0: # If independant masks are left
2827+
# We cast both masks and flag variable as integers to make the
2828+
# bitwise comparison. We could probably restrict the integer size
2829+
# but it's difficult to make it safely for mixed type flags.
2830+
bit_mask = DataArray(masks, dims=["_mask"]).astype("i")
2831+
x = self._obj.astype("i")
2832+
bit_comp = x & bit_mask
2833+
2834+
for i, (flag, value) in enumerate(zip(flags_reduced, values)):
2835+
bit = bit_comp.isel(_mask=i)
2836+
if value is not None:
2837+
out[flag] = bit == value
2838+
else:
2839+
out[flag] = bit.astype(bool)
2840+
2841+
return Dataset(out)
2842+
2843+
def isin(self, test_elements):
2844+
"""
2845+
Test each value in the array for whether it is in test_elements.
2846+
2847+
Parameters
2848+
----------
2849+
test_elements : array_like, 1D
2850+
The values against which to test each value of `element`.
2851+
2852+
Returns
2853+
-------
2854+
isin : DataArray
2855+
Has the same type and shape as this object, but with a bool dtype.
2856+
"""
2857+
flags_masks = self.flags.drop_vars(
2858+
[v for v in self.flags.data_vars if v not in test_elements]
2859+
)
2860+
if len(flags_masks) == 0:
2861+
out = self.copy().astype(bool)
2862+
out.attrs = {}
2863+
out[:] = False
2864+
return out
2865+
# Merge into a single DataArray
2866+
flags_masks = xr.concat(flags_masks.data_vars.values(), dim="_flags")
2867+
return flags_masks.any(dim="_flags").rename(self._obj.name)
2868+
27562869
@property
27572870
def is_flag_variable(self) -> bool:
27582871
"""
27592872
Returns True if the DataArray satisfies CF conventions for flag variables.
27602873
2761-
.. warning::
2762-
Flag masks are not supported yet.
2763-
27642874
Returns
27652875
-------
27662876
bool
27672877
"""
27682878
if (
27692879
isinstance(self._obj, DataArray)
27702880
and "flag_meanings" in self._obj.attrs
2771-
and "flag_values" in self._obj.attrs
2881+
and ("flag_values" in self._obj.attrs or "flag_masks" in self._obj.attrs)
27722882
):
27732883
return True
27742884
else:

cf_xarray/datasets.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -466,7 +466,7 @@ def _create_inexact_bounds():
466466
)
467467
)
468468

469-
469+
# Same as flags_excl but easier to read
470470
basin = xr.DataArray(
471471
[1, 2, 1, 1, 2, 2, 3, 3, 3, 3],
472472
dims=("time",),
@@ -479,6 +479,44 @@ def _create_inexact_bounds():
479479
)
480480

481481

482+
flag_excl = xr.DataArray(
483+
np.array([1, 1, 2, 1, 2, 3, 3, 2], np.uint8),
484+
dims=("time",),
485+
coords={"time": [0, 1, 2, 3, 4, 5, 6, 7]},
486+
attrs={
487+
"flag_values": [1, 2, 3],
488+
"flag_meanings": "flag_1 flag_2 flag_3",
489+
"standard_name": "flag_mutual_exclusive",
490+
},
491+
name="flag_var",
492+
)
493+
494+
495+
flag_indep = xr.DataArray(
496+
np.array([0, 1, 2, 3, 4, 5, 6, 7], dtype=np.uint8),
497+
dims=("time",),
498+
attrs={
499+
"flag_masks": [1, 2, 4],
500+
"flag_meanings": "flag_1 flag_2 flag_4",
501+
"standard_name": "flag_independent",
502+
},
503+
name="flag_var",
504+
)
505+
506+
507+
flag_mix = xr.DataArray(
508+
np.array([4, 8, 13, 5, 10, 14, 7, 3], np.uint8),
509+
dims=("time",),
510+
attrs={
511+
"flag_values": [1, 2, 4, 8, 12],
512+
"flag_masks": [1, 2, 12, 12, 12],
513+
"flag_meanings": "flag_1 flag_2 flag_3 flag_4 flag_5",
514+
"standard_name": "flag_mix",
515+
},
516+
name="flag_var",
517+
)
518+
519+
482520
ambig = xr.Dataset(
483521
data_vars={},
484522
coords={

0 commit comments

Comments
 (0)