5
5
import itertools
6
6
import re
7
7
import warnings
8
- from collections import ChainMap
8
+ from collections import ChainMap , namedtuple
9
9
from datetime import datetime
10
10
from typing import (
11
11
Any ,
58
58
parse_cf_standard_name_table ,
59
59
)
60
60
61
+ FlagParam = namedtuple ("FlagParam" , ["flag_mask" , "flag_value" ])
62
+
61
63
#: Classes wrapped by cf_xarray.
62
64
_WRAPPED_CLASSES = (Resample , GroupBy , Rolling , Coarsen , Weighted )
63
65
@@ -1057,18 +1059,39 @@ def __getattr__(self, attr):
1057
1059
)
1058
1060
1059
1061
1060
- def create_flag_dict (da ):
1062
+ def create_flag_dict (da ) -> Mapping [Hashable , FlagParam ]:
1063
+ """
1064
+ Return possible flag meanings and associated bitmask/values.
1065
+
1066
+ The mapping values are a tuple containing a bitmask and a value. Either
1067
+ can be None.
1068
+ If only a bitmask: Independent flags.
1069
+ If only a value: Mutually exclusive flags.
1070
+ If both: Mix of independent and mutually exclusive flags.
1071
+ """
1061
1072
if not da .cf .is_flag_variable :
1062
1073
raise ValueError (
1063
- "Comparisons are only supported for DataArrays that represent CF flag variables."
1064
- ".attrs must contain 'flag_values' and 'flag_meanings'"
1074
+ "Comparisons are only supported for DataArrays that represent "
1075
+ "CF flag variables. .attrs must contain 'flag_meanings' and "
1076
+ "'flag_values' or 'flag_masks'."
1065
1077
)
1066
1078
1067
1079
flag_meanings = da .attrs ["flag_meanings" ].split (" " )
1068
- flag_values = da .attrs ["flag_values" ]
1069
- # TODO: assert flag_values is iterable
1070
- assert len (flag_values ) == len (flag_meanings )
1071
- return dict (zip (flag_meanings , flag_values ))
1080
+ n_flag = len (flag_meanings )
1081
+
1082
+ flag_values = da .attrs .get ("flag_values" , [None ] * n_flag )
1083
+ flag_masks = da .attrs .get ("flag_masks" , [None ] * n_flag )
1084
+
1085
+ if not (n_flag == len (flag_values ) == len (flag_masks )):
1086
+ raise ValueError (
1087
+ "Not as many flag meanings as values or masks. "
1088
+ "Please check the flag_meanings, flag_values, flag_masks attributes "
1089
+ )
1090
+
1091
+ flag_params = tuple (
1092
+ FlagParam (mask , value ) for mask , value in zip (flag_masks , flag_values )
1093
+ )
1094
+ return dict (zip (flag_meanings , flag_params ))
1072
1095
1073
1096
1074
1097
class CFAccessor :
@@ -1084,36 +1107,40 @@ def __setstate__(self, d):
1084
1107
self .__dict__ = d
1085
1108
1086
1109
def _assert_valid_other_comparison (self , other ):
1110
+ # TODO cache this property
1087
1111
flag_dict = create_flag_dict (self ._obj )
1088
1112
if other not in flag_dict :
1089
1113
raise ValueError (
1090
1114
f"Did not find flag value meaning [{ other } ] in known flag meanings: [{ flag_dict .keys ()!r} ]"
1091
1115
)
1116
+ if flag_dict [other ].flag_mask is not None :
1117
+ raise NotImplementedError (
1118
+ "Only equals and not-equals comparisons with flag masks are supported."
1119
+ " Please open an issue."
1120
+ )
1092
1121
return flag_dict
1093
1122
1094
- def __eq__ (self , other ):
1123
+ def __eq__ (self , other ) -> DataArray : # type: ignore
1095
1124
"""
1096
1125
Compare flag values against `other`.
1097
1126
1098
1127
`other` must be in the 'flag_meanings' attribute.
1099
1128
`other` is mapped to the corresponding value in the 'flag_values' attribute, and then
1100
1129
compared.
1101
1130
"""
1102
- flag_dict = self ._assert_valid_other_comparison (other )
1103
- return self ._obj == flag_dict [other ]
1131
+ return self ._extract_flags ([other ])[other ].rename (self ._obj .name )
1104
1132
1105
- def __ne__ (self , other ):
1133
+ def __ne__ (self , other ) -> DataArray : # type: ignore
1106
1134
"""
1107
1135
Compare flag values against `other`.
1108
1136
1109
1137
`other` must be in the 'flag_meanings' attribute.
1110
1138
`other` is mapped to the corresponding value in the 'flag_values' attribute, and then
1111
1139
compared.
1112
1140
"""
1113
- flag_dict = self ._assert_valid_other_comparison (other )
1114
- return self ._obj != flag_dict [other ]
1141
+ return ~ self ._extract_flags ([other ])[other ].rename (self ._obj .name )
1115
1142
1116
- def __lt__ (self , other ):
1143
+ def __lt__ (self , other ) -> DataArray :
1117
1144
"""
1118
1145
Compare flag values against `other`.
1119
1146
@@ -1122,9 +1149,9 @@ def __lt__(self, other):
1122
1149
compared.
1123
1150
"""
1124
1151
flag_dict = self ._assert_valid_other_comparison (other )
1125
- return self ._obj < flag_dict [other ]
1152
+ return self ._obj < flag_dict [other ]. flag_value
1126
1153
1127
- def __le__ (self , other ):
1154
+ def __le__ (self , other ) -> DataArray :
1128
1155
"""
1129
1156
Compare flag values against `other`.
1130
1157
@@ -1133,9 +1160,9 @@ def __le__(self, other):
1133
1160
compared.
1134
1161
"""
1135
1162
flag_dict = self ._assert_valid_other_comparison (other )
1136
- return self ._obj <= flag_dict [other ]
1163
+ return self ._obj <= flag_dict [other ]. flag_value
1137
1164
1138
- def __gt__ (self , other ):
1165
+ def __gt__ (self , other ) -> DataArray :
1139
1166
"""
1140
1167
Compare flag values against `other`.
1141
1168
@@ -1144,9 +1171,9 @@ def __gt__(self, other):
1144
1171
compared.
1145
1172
"""
1146
1173
flag_dict = self ._assert_valid_other_comparison (other )
1147
- return self ._obj > flag_dict [other ]
1174
+ return self ._obj > flag_dict [other ]. flag_value
1148
1175
1149
- def __ge__ (self , other ):
1176
+ def __ge__ (self , other ) -> DataArray :
1150
1177
"""
1151
1178
Compare flag values against `other`.
1152
1179
@@ -1155,9 +1182,9 @@ def __ge__(self, other):
1155
1182
compared.
1156
1183
"""
1157
1184
flag_dict = self ._assert_valid_other_comparison (other )
1158
- return self ._obj >= flag_dict [other ]
1185
+ return self ._obj >= flag_dict [other ]. flag_value
1159
1186
1160
- def isin (self , test_elements ):
1187
+ def isin (self , test_elements ) -> DataArray :
1161
1188
"""Test each value in the array for whether it is in test_elements.
1162
1189
1163
1190
Parameters
@@ -1177,14 +1204,15 @@ def isin(self, test_elements):
1177
1204
raise ValueError (
1178
1205
".cf.isin is only supported on DataArrays that contain CF flag attributes."
1179
1206
)
1207
+ # TODO cache this property
1180
1208
flag_dict = create_flag_dict (self ._obj )
1181
1209
mapped_test_elements = []
1182
1210
for elem in test_elements :
1183
1211
if elem not in flag_dict :
1184
1212
raise ValueError (
1185
1213
f"Did not find flag value meaning [{ elem } ] in known flag meanings: [{ flag_dict .keys ()!r} ]"
1186
1214
)
1187
- mapped_test_elements .append (flag_dict [elem ])
1215
+ mapped_test_elements .append (flag_dict [elem ]. flag_value )
1188
1216
return self ._obj .isin (mapped_test_elements )
1189
1217
1190
1218
def _drop_missing_variables (self , variables : list [Hashable ]) -> list [Hashable ]:
@@ -2753,22 +2781,104 @@ def __getitem__(self, key: Hashable | Iterable[Hashable]) -> DataArray:
2753
2781
2754
2782
return _getitem (self , key )
2755
2783
2784
+ @property
2785
+ def flags (self ) -> Dataset :
2786
+ """
2787
+ Dataset containing boolean masks of available flags.
2788
+ """
2789
+ return self ._extract_flags ()
2790
+
2791
+ def _extract_flags (self , flags : Sequence [Hashable ] | None = None ) -> Dataset :
2792
+ """
2793
+ Return dataset of boolean mask(s) corresponding to `flags`.
2794
+
2795
+ Parameters
2796
+ ----------
2797
+ flags: Sequence[str]
2798
+ Flags to extract. If empty (string or list), return all flags in
2799
+ `flag_meanings`.
2800
+ """
2801
+ # TODO cache this property
2802
+ flag_dict = create_flag_dict (self ._obj )
2803
+
2804
+ if flags is None :
2805
+ flags = tuple (flag_dict .keys ())
2806
+
2807
+ out = {} # Output arrays
2808
+
2809
+ masks = [] # Bitmasks and values for asked flags
2810
+ values = []
2811
+ flags_reduced = [] # Flags left after removing mutually excl. flags
2812
+ for flag in flags :
2813
+ if flag not in flag_dict :
2814
+ raise ValueError (
2815
+ f"Did not find flag value meaning [{ flag } ] in known flag meanings:"
2816
+ f" [{ flag_dict .keys ()!r} ]"
2817
+ )
2818
+ mask , value = flag_dict [flag ]
2819
+ if mask is None :
2820
+ out [flag ] = self ._obj == value
2821
+ else :
2822
+ masks .append (mask )
2823
+ values .append (value )
2824
+ flags_reduced .append (flag )
2825
+
2826
+ if len (masks ) > 0 : # If independant masks are left
2827
+ # We cast both masks and flag variable as integers to make the
2828
+ # bitwise comparison. We could probably restrict the integer size
2829
+ # but it's difficult to make it safely for mixed type flags.
2830
+ bit_mask = DataArray (masks , dims = ["_mask" ]).astype ("i" )
2831
+ x = self ._obj .astype ("i" )
2832
+ bit_comp = x & bit_mask
2833
+
2834
+ for i , (flag , value ) in enumerate (zip (flags_reduced , values )):
2835
+ bit = bit_comp .isel (_mask = i )
2836
+ if value is not None :
2837
+ out [flag ] = bit == value
2838
+ else :
2839
+ out [flag ] = bit .astype (bool )
2840
+
2841
+ return Dataset (out )
2842
+
2843
+ def isin (self , test_elements ):
2844
+ """
2845
+ Test each value in the array for whether it is in test_elements.
2846
+
2847
+ Parameters
2848
+ ----------
2849
+ test_elements : array_like, 1D
2850
+ The values against which to test each value of `element`.
2851
+
2852
+ Returns
2853
+ -------
2854
+ isin : DataArray
2855
+ Has the same type and shape as this object, but with a bool dtype.
2856
+ """
2857
+ flags_masks = self .flags .drop_vars (
2858
+ [v for v in self .flags .data_vars if v not in test_elements ]
2859
+ )
2860
+ if len (flags_masks ) == 0 :
2861
+ out = self .copy ().astype (bool )
2862
+ out .attrs = {}
2863
+ out [:] = False
2864
+ return out
2865
+ # Merge into a single DataArray
2866
+ flags_masks = xr .concat (flags_masks .data_vars .values (), dim = "_flags" )
2867
+ return flags_masks .any (dim = "_flags" ).rename (self ._obj .name )
2868
+
2756
2869
@property
2757
2870
def is_flag_variable (self ) -> bool :
2758
2871
"""
2759
2872
Returns True if the DataArray satisfies CF conventions for flag variables.
2760
2873
2761
- .. warning::
2762
- Flag masks are not supported yet.
2763
-
2764
2874
Returns
2765
2875
-------
2766
2876
bool
2767
2877
"""
2768
2878
if (
2769
2879
isinstance (self ._obj , DataArray )
2770
2880
and "flag_meanings" in self ._obj .attrs
2771
- and "flag_values" in self ._obj .attrs
2881
+ and ( "flag_values" in self ._obj .attrs or "flag_masks" in self . _obj . attrs )
2772
2882
):
2773
2883
return True
2774
2884
else :
0 commit comments