Skip to content

Commit 859e76b

Browse files
kthyngdcherian
andauthored
Support custom vocabularies/criteria (#234)
Co-authored-by: Deepak Cherian <[email protected]> Co-authored-by: dcherian <[email protected]>
1 parent 2fa2939 commit 859e76b

File tree

4 files changed

+148
-3
lines changed

4 files changed

+148
-3
lines changed

cf_xarray/accessor.py

Lines changed: 69 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,13 @@
6565
ATTRS["time"] = ATTRS["T"]
6666
ATTRS["vertical"] = ATTRS["Z"]
6767

68+
OPTIONS: MutableMapping[str, Any] = {"custom_criteria": []}
69+
70+
71+
def set_options(custom_criteria):
72+
OPTIONS["custom_criteria"] = always_iterable(custom_criteria, allowed=(tuple, list))
73+
74+
6875
# Type for Mapper functions
6976
Mapper = Callable[[Union[DataArray, Dataset], str], List[str]]
7077

@@ -170,6 +177,55 @@ def _get_groupby_time_accessor(var: Union[DataArray, Dataset], key: str) -> List
170177
return []
171178

172179

180+
def _get_custom_criteria(
181+
obj: Union[DataArray, Dataset], key: str, criteria=None
182+
) -> List[str]:
183+
"""
184+
Translate from axis, coord, or custom name to variable name optionally
185+
using ``custom_criteria``
186+
187+
Parameters
188+
----------
189+
obj : DataArray, Dataset
190+
key : str
191+
key to check for.
192+
criteria : dict, optional
193+
Criteria to use to map from variable to attributes describing the
194+
variable. An example is coordinate_criteria which maps coordinates to
195+
their attributes and attribute values. If user has defined
196+
custom_criteria, this will be used by default.
197+
198+
Returns
199+
-------
200+
List[str], Variable name(s) in parent xarray object that matches axis, coordinate, or custom `key`
201+
202+
"""
203+
204+
if isinstance(obj, DataArray):
205+
obj = obj._to_temp_dataset()
206+
207+
if criteria is None:
208+
if not OPTIONS["custom_criteria"]:
209+
return []
210+
criteria = OPTIONS["custom_criteria"]
211+
212+
if criteria is not None:
213+
criteria = always_iterable(criteria, allowed=(tuple, list, set))
214+
215+
criteria = ChainMap(*criteria)
216+
217+
results: Set = set()
218+
if key in criteria:
219+
for criterion, patterns in criteria[key].items():
220+
for var in obj.variables:
221+
if re.match(patterns, obj[var].attrs.get(criterion, "")):
222+
results.update((var,))
223+
# also check name specifically since not in attributes
224+
elif criterion == "name" and re.match(patterns, var):
225+
results.update((var,))
226+
return list(results)
227+
228+
173229
def _get_axis_coord(var: Union[DataArray, Dataset], key: str) -> List[str]:
174230
"""
175231
Translate from axis or coord name to variable name
@@ -314,7 +370,12 @@ def _get_all(obj: Union[DataArray, Dataset], key: str) -> List[str]:
314370
One or more of ('X', 'Y', 'Z', 'T', 'longitude', 'latitude', 'vertical', 'time',
315371
'area', 'volume'), or arbitrary measures, or standard names
316372
"""
317-
all_mappers = (_get_axis_coord, _get_measure, _get_with_standard_name)
373+
all_mappers = (
374+
_get_custom_criteria,
375+
_get_axis_coord,
376+
_get_measure,
377+
_get_with_standard_name,
378+
)
318379
results = apply_mapper(all_mappers, obj, key, error=False, default=None)
319380
return results
320381

@@ -586,6 +647,8 @@ def check_results(names, key):
586647
measures = []
587648
warnings.warn("Ignoring bad cell_measures attribute.", UserWarning)
588649

650+
custom_criteria = ChainMap(*OPTIONS["custom_criteria"])
651+
589652
varnames: List[Hashable] = []
590653
coords: List[Hashable] = []
591654
successful = dict.fromkeys(key, False)
@@ -602,6 +665,11 @@ def check_results(names, key):
602665
successful[k] = bool(measure)
603666
if measure:
604667
varnames.extend(measure)
668+
elif k in custom_criteria:
669+
names = _get_all(obj, k)
670+
check_results(names, k)
671+
successful[k] = bool(names)
672+
varnames.extend(names)
605673
else:
606674
stdnames = set(_get_with_standard_name(obj, k))
607675
objcoords = set(obj.coords)

cf_xarray/tests/test_accessor.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1237,6 +1237,82 @@ def test_cmip6_attrs():
12371237
assert da.cf.axes["Y"] == ["nlat"]
12381238

12391239

1240+
def test_custom_criteria():
1241+
my_custom_criteria = {
1242+
"ssh": {
1243+
"standard_name": "sea_surface_elev*|sea_surface_height",
1244+
"name": "sea_surface_elevation$", # variable name
1245+
},
1246+
"salt": {
1247+
"standard_name": "salinity",
1248+
"name": "sal*",
1249+
},
1250+
"wind_speed": {
1251+
"standard_name": "wind_speed$",
1252+
},
1253+
}
1254+
my_custom_criteria2 = {"temp": {"name": "temperature"}}
1255+
cf_xarray.accessor.set_options(my_custom_criteria)
1256+
my_custom_criteria_list = [my_custom_criteria, my_custom_criteria2]
1257+
my_custom_criteria_tuple = (my_custom_criteria, my_custom_criteria2)
1258+
1259+
# Match by name regex match
1260+
ds = xr.Dataset()
1261+
ds["salinity"] = ("dim", np.arange(10))
1262+
assert_identical(ds.cf["salt"], ds["salinity"])
1263+
1264+
# Match by standard_name regex match
1265+
ds = xr.Dataset()
1266+
ds["elev"] = ("dim", np.arange(10), {"standard_name": "sea_surface_elevBLAH"})
1267+
assert_identical(ds.cf["ssh"], ds["elev"])
1268+
1269+
# Match by standard_name exact match
1270+
ds = xr.Dataset()
1271+
ds["salinity"] = ("dim", np.arange(10), {"standard_name": "salinity"})
1272+
assert_identical(ds.cf["salt"], ds["salinity"])
1273+
1274+
# If not exact name, won't match
1275+
ds = xr.Dataset()
1276+
ds["sea_surface_elevation123"] = ("dim", np.arange(10))
1277+
# Since this will not match, this should error
1278+
with pytest.raises(KeyError):
1279+
ds.cf["ssh"]
1280+
1281+
# will select only one variable here since exact match
1282+
ds = xr.Dataset()
1283+
ds["winds"] = ("dim", np.arange(10), {"standard_name": "wind_speed"})
1284+
ds["gusts"] = ("dim", np.arange(10), {"standard_name": "wind_speed_of_gust"})
1285+
assert_identical(ds.cf["wind_speed"], ds["winds"])
1286+
1287+
# Match by exact name
1288+
ds = xr.Dataset()
1289+
ds["sea_surface_elevation"] = ("dim", np.arange(10))
1290+
ds["sea_surface_height"] = (
1291+
"dim",
1292+
np.arange(10),
1293+
{"standard_name": "sea_surface_elevBLAH"},
1294+
)
1295+
# Since there are two variables, this should error
1296+
with pytest.raises(KeyError):
1297+
ds.cf["ssh"]
1298+
# But the following should work instead given the two ssh variables
1299+
assert_identical(
1300+
ds.cf[["ssh"]], ds[["sea_surface_elevation", "sea_surface_height"]]
1301+
)
1302+
1303+
# test criteria list of dicts
1304+
cf_xarray.accessor.set_options(my_custom_criteria_list)
1305+
ds = xr.Dataset()
1306+
ds["temperature"] = ("dim", np.arange(10))
1307+
assert_identical(ds.cf["temp"], ds["temperature"])
1308+
1309+
# test criteria tuple of dicts
1310+
cf_xarray.accessor.set_options(my_custom_criteria_tuple)
1311+
ds = xr.Dataset()
1312+
ds["temperature"] = ("dim", np.arange(10))
1313+
assert_identical(ds.cf["temp"], ds["temperature"])
1314+
1315+
12401316
def test_cf_standard_name_table_version():
12411317

12421318
url = (

cf_xarray/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,8 @@ def invert_mappings(*mappings):
5858
return merged
5959

6060

61-
def always_iterable(obj: Any) -> Iterable:
62-
return [obj] if not isinstance(obj, (tuple, list, set, dict)) else obj
61+
def always_iterable(obj: Any, allowed=(tuple, list, set, dict)) -> Iterable:
62+
return [obj] if not isinstance(obj, allowed) else obj
6363

6464

6565
def parse_cf_standard_name_table(source=None):

doc/whats-new.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ v0.5.3 (unreleased)
1010
- Begin adding support for units with a unit registry for pint arrays. :pr:`197`.
1111
By `Jon Thielen`_ and `Justus Magin`_.
1212
- :py:meth:`Dataset.cf.rename_like` also updates the ``bounds`` and ``cell_measures`` attributes. By `Mattia Almansi`_.
13+
- Support of custom vocabularies/criteria: user can input criteria for identifying variables by their name and attributes to be able to refer to them by custom names like `ds.cf["ssh"]`. :pr:`234`. By `Kristen Thyng`_ and `Deepak Cherian`_.
1314

1415
v0.5.2 (May 11, 2021)
1516
=====================

0 commit comments

Comments
 (0)