Skip to content

Commit 78f1683

Browse files
authored
Rework attribute rewriting, .cf.data_vars, .cf.coords (#130)
* Fix .cf.data_vars * Return "CF DataArrays" in .cf.data_vars,.cf.coords * Fix DataArray.cf["standard_name"] Closes #129 Closes #126
1 parent 3dc7f3c commit 78f1683

File tree

3 files changed

+191
-115
lines changed

3 files changed

+191
-115
lines changed

cf_xarray/accessor.py

Lines changed: 150 additions & 112 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import inspect
33
import itertools
44
import warnings
5-
from collections import ChainMap
5+
from collections import ChainMap, defaultdict
66
from typing import (
77
Any,
88
Callable,
@@ -159,6 +159,18 @@ def _is_datetime_like(da: DataArray) -> bool:
159159
return False
160160

161161

162+
def invert_mappings(*mappings):
163+
"""Takes a set of mappings and iterates through, inverting to make a
164+
new mapping of value: set(keys). Keys are deduplicated to avoid clashes between
165+
standard_name and coordinate names."""
166+
merged = defaultdict(set)
167+
for mapping in mappings:
168+
for k, v in mapping.items():
169+
for name in v:
170+
merged[name] |= set([k])
171+
return merged
172+
173+
162174
# Type for Mapper functions
163175
Mapper = Callable[[Union[DataArray, Dataset], str], List[str]]
164176

@@ -503,23 +515,29 @@ def _getattr(
503515
if isinstance(attribute, Mapping):
504516
if not attribute:
505517
return dict(attribute)
506-
# attributes like chunks / sizes
518+
507519
newmap = dict()
508-
unused_keys = set(attribute.keys())
509-
for key in _AXIS_NAMES + _COORD_NAMES:
510-
value = set(apply_mapper(_get_axis_coord, obj, key, error=False))
511-
unused_keys -= value
512-
if value:
513-
good_values = value & set(obj.dims)
514-
if not good_values:
515-
continue
516-
if len(good_values) > 1:
520+
inverted = invert_mappings(
521+
accessor.axes,
522+
accessor.coordinates,
523+
accessor.cell_measures,
524+
accessor.standard_names,
525+
)
526+
unused_keys = set(attribute.keys()) - set(inverted)
527+
for key, value in attribute.items():
528+
for name in inverted[key]:
529+
if name in newmap:
517530
raise AttributeError(
518-
f"cf_xarray can't wrap attribute {attr!r} because there are multiple values for {key!r} viz. {good_values!r}. "
519-
f"There is no unique mapping from {key!r} to a value in {attr!r}."
531+
f"cf_xarray can't wrap attribute {attr!r} because there are multiple values for {name!r}. "
532+
f"There is no unique mapping from {name!r} to a value in {attr!r}."
520533
)
521-
newmap.update({key: attribute[good_values.pop()]})
534+
newmap.update(dict.fromkeys(inverted[key], value))
522535
newmap.update({key: attribute[key] for key in unused_keys})
536+
537+
skip = {"data_vars": ["coords"], "coords": None}
538+
if attr in ["coords", "data_vars"]:
539+
for key in newmap:
540+
newmap[key] = _getitem(accessor, key, skip=skip[attr])
523541
return newmap
524542

525543
elif isinstance(attribute, Callable): # type: ignore
@@ -548,6 +566,123 @@ def wrapper(*args, **kwargs):
548566
return wrapper
549567

550568

569+
def _getitem(
570+
accessor: "CFAccessor", key: Union[str, List[str]], skip: List[str] = None
571+
) -> Union[DataArray, Dataset]:
572+
"""
573+
Index into obj using key. Attaches CF associated variables.
574+
575+
Parameters
576+
----------
577+
accessor: CFAccessor
578+
key: str, List[str]
579+
skip: str, optional
580+
One of ["coords", "measures"], avoid clashes with special coord names
581+
"""
582+
583+
obj = accessor._obj
584+
kind = str(type(obj).__name__)
585+
scalar_key = isinstance(key, str)
586+
587+
if isinstance(obj, DataArray) and not scalar_key:
588+
raise KeyError(
589+
f"Cannot use a list of keys with DataArrays. Expected a single string. Received {key!r} instead."
590+
)
591+
592+
if scalar_key:
593+
key = (key,) # type: ignore
594+
595+
if skip is None:
596+
skip = []
597+
598+
def check_results(names, k):
599+
if scalar_key and len(names) > 1:
600+
raise ValueError(
601+
f"Receive multiple variables for key {k!r}: {names}. "
602+
f"Expected only one. Please pass a list [{k!r}] "
603+
f"instead to get all variables matching {k!r}."
604+
)
605+
606+
varnames: List[Hashable] = []
607+
coords: List[Hashable] = []
608+
successful = dict.fromkeys(key, False)
609+
for k in key:
610+
if "coords" not in skip and k in _AXIS_NAMES + _COORD_NAMES:
611+
names = _get_axis_coord(obj, k)
612+
check_results(names, k)
613+
successful[k] = bool(names)
614+
coords.extend(names)
615+
elif "measures" not in skip and k in accessor._get_all_cell_measures():
616+
measure = _get_measure(obj, k)
617+
check_results(measure, k)
618+
successful[k] = bool(measure)
619+
if measure:
620+
varnames.extend(measure)
621+
else:
622+
stdnames = set(_get_with_standard_name(obj, k))
623+
check_results(stdnames, k)
624+
successful[k] = bool(stdnames)
625+
objcoords = set(obj.coords)
626+
varnames.extend(stdnames - objcoords)
627+
coords.extend(stdnames & objcoords)
628+
629+
# these are not special names but could be variable names in underlying object
630+
# we allow this so that we can return variables with appropriate CF auxiliary variables
631+
varnames.extend([k for k, v in successful.items() if not v])
632+
allnames = varnames + coords
633+
634+
try:
635+
for name in allnames:
636+
extravars = accessor.get_associated_variable_names(name)
637+
# we cannot return bounds variables with scalar keys
638+
if scalar_key:
639+
extravars.pop("bounds")
640+
coords.extend(itertools.chain(*extravars.values()))
641+
642+
if isinstance(obj, DataArray):
643+
ds = obj._to_temp_dataset()
644+
else:
645+
ds = obj
646+
647+
if scalar_key:
648+
if len(allnames) == 1:
649+
da: DataArray = ds.reset_coords()[allnames[0]] # type: ignore
650+
if allnames[0] in coords:
651+
coords.remove(allnames[0])
652+
for k1 in coords:
653+
da.coords[k1] = ds.variables[k1]
654+
return da
655+
else:
656+
raise ValueError(
657+
f"Received scalar key {key[0]!r} but multiple results: {allnames!r}. "
658+
f"Please pass a list instead (['{key[0]}']) to get back a Dataset "
659+
f"with {allnames!r}."
660+
)
661+
662+
ds = ds.reset_coords()[varnames + coords]
663+
if isinstance(obj, DataArray):
664+
if scalar_key and len(ds.variables) == 1:
665+
# single dimension coordinates
666+
assert coords
667+
assert not varnames
668+
669+
return ds[coords[0]]
670+
671+
elif scalar_key and len(ds.variables) > 1:
672+
raise NotImplementedError(
673+
"Not sure what to return when given scalar key for DataArray and it has multiple values. "
674+
"Please open an issue."
675+
)
676+
677+
return ds.set_coords(coords)
678+
679+
except KeyError:
680+
raise KeyError(
681+
f"{kind}.cf does not understand the key {k!r}. "
682+
f"Use {kind}.cf.describe() to see a list of key names that can be interpreted."
683+
)
684+
685+
551686
class _CFWrappedClass:
552687
"""
553688
This class is used to wrap any class in _WRAPPED_CLASSES.
@@ -1061,104 +1196,7 @@ def get_associated_variable_names(self, name: Hashable) -> Dict[str, List[str]]:
10611196
return coords
10621197

10631198
def __getitem__(self, key: Union[str, List[str]]):
1064-
1065-
kind = str(type(self._obj).__name__)
1066-
scalar_key = isinstance(key, str)
1067-
1068-
if isinstance(self._obj, DataArray) and not scalar_key:
1069-
raise KeyError(
1070-
f"Cannot use a list of keys with DataArrays. Expected a single string. Received {key!r} instead."
1071-
)
1072-
1073-
if scalar_key:
1074-
key = (key,) # type: ignore
1075-
1076-
def check_results(names, k):
1077-
if scalar_key and len(names) > 1:
1078-
raise ValueError(
1079-
f"Receive multiple variables for key {k!r}: {names}. "
1080-
f"Expected only one. Please pass a list [{k!r}] "
1081-
f"instead to get all variables matching {k!r}."
1082-
)
1083-
1084-
varnames: List[Hashable] = []
1085-
coords: List[Hashable] = []
1086-
successful = dict.fromkeys(key, False)
1087-
for k in key:
1088-
if k in _AXIS_NAMES + _COORD_NAMES:
1089-
names = _get_axis_coord(self._obj, k)
1090-
check_results(names, k)
1091-
successful[k] = bool(names)
1092-
coords.extend(names)
1093-
elif k in self._get_all_cell_measures():
1094-
measure = _get_measure(self._obj, k)
1095-
check_results(measure, k)
1096-
successful[k] = bool(measure)
1097-
if measure:
1098-
varnames.extend(measure)
1099-
elif not isinstance(self._obj, DataArray):
1100-
stdnames = set(_get_with_standard_name(self._obj, k))
1101-
check_results(stdnames, k)
1102-
successful[k] = bool(stdnames)
1103-
objcoords = set(self._obj.coords)
1104-
varnames.extend(stdnames - objcoords)
1105-
coords.extend(stdnames & objcoords)
1106-
1107-
# these are not special names but could be variable names in underlying object
1108-
# we allow this so that we can return variables with appropriate CF auxiliary variables
1109-
varnames.extend([k for k, v in successful.items() if not v])
1110-
allnames = varnames + coords
1111-
1112-
try:
1113-
for name in allnames:
1114-
extravars = self.get_associated_variable_names(name)
1115-
# we cannot return bounds variables with scalar keys
1116-
if scalar_key:
1117-
extravars.pop("bounds")
1118-
coords.extend(itertools.chain(*extravars.values()))
1119-
1120-
if isinstance(self._obj, DataArray):
1121-
ds = self._obj._to_temp_dataset()
1122-
else:
1123-
ds = self._obj
1124-
1125-
if scalar_key:
1126-
if len(allnames) == 1:
1127-
da: DataArray = ds.reset_coords()[allnames[0]] # type: ignore
1128-
if allnames[0] in coords:
1129-
coords.remove(allnames[0])
1130-
for k1 in coords:
1131-
da.coords[k1] = ds.variables[k1]
1132-
return da
1133-
else:
1134-
raise ValueError(
1135-
f"Received scalar key {key[0]!r} but multiple results: {allnames!r}. "
1136-
f"Please pass a list instead (['{key[0]}']) to get back a Dataset "
1137-
f"with {allnames!r}."
1138-
)
1139-
1140-
ds = ds.reset_coords()[varnames + coords]
1141-
if isinstance(self._obj, DataArray):
1142-
if scalar_key and len(ds.variables) == 1:
1143-
# single dimension coordinates
1144-
assert coords
1145-
assert not varnames
1146-
1147-
return ds[coords[0]]
1148-
1149-
elif scalar_key and len(ds.variables) > 1:
1150-
raise NotImplementedError(
1151-
"Not sure what to return when given scalar key for DataArray and it has multiple values. "
1152-
"Please open an issue."
1153-
)
1154-
1155-
return ds.set_coords(coords)
1156-
1157-
except KeyError:
1158-
raise KeyError(
1159-
f"{kind}.cf does not understand the key {k!r}. "
1160-
f"Use {kind}.cf.describe() to see a list of key names that can be interpreted."
1161-
)
1199+
return _getitem(self, key)
11621200

11631201
def _maybe_to_dataset(self, obj=None) -> Dataset:
11641202
if obj is None:

cf_xarray/tests/test_accessor.py

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,12 @@
1919
objects = datasets + dataarrays
2020

2121

22+
def assert_dicts_identical(dict1, dict2):
23+
assert dict1.keys() == dict2.keys()
24+
for k in dict1:
25+
assert_identical(dict1[k], dict2[k])
26+
27+
2228
def test_describe(capsys):
2329
airds.cf.describe()
2430
actual = capsys.readouterr().out
@@ -280,7 +286,10 @@ def test_dataarray_getitem():
280286
with pytest.raises(KeyError):
281287
air.cf[["longitude"]]
282288
with pytest.raises(KeyError):
283-
air.cf[["longitude", "latitude"]],
289+
air.cf[["longitude", "latitude"]]
290+
291+
air["cell_area"].attrs["standard_name"] = "area_grid_cell"
292+
assert_identical(air.cf["area_grid_cell"], air.cell_area.reset_coords(drop=True))
284293

285294

286295
@pytest.mark.parametrize("obj", dataarrays)
@@ -512,7 +521,7 @@ def test_guess_coord_axis():
512521
assert dsnew.y1.attrs == {"axis": "Y"}
513522

514523

515-
def test_dicts():
524+
def test_attributes():
516525
actual = airds.cf.sizes
517526
expected = {"X": 50, "Y": 25, "T": 4, "longitude": 50, "latitude": 25, "time": 4}
518527
assert actual == expected
@@ -543,6 +552,30 @@ def test_dicts():
543552
expected = {"lon": 50, "Y": 25, "T": 4, "latitude": 25, "time": 4}
544553
assert actual == expected
545554

555+
actual = popds.cf.data_vars
556+
expected = {
557+
"sea_water_x_velocity": popds.cf["UVEL"],
558+
"sea_water_potential_temperature": popds.cf["TEMP"],
559+
}
560+
assert_dicts_identical(actual, expected)
561+
562+
actual = multiple.cf.data_vars
563+
expected = dict(multiple.data_vars)
564+
assert_dicts_identical(actual, expected)
565+
566+
# check that data_vars contains ancillary variables
567+
assert_identical(anc.cf.data_vars["specific_humidity"], anc.cf["specific_humidity"])
568+
569+
# clash between var name and "special" CF name
570+
# Regression test for #126
571+
data = np.random.rand(4, 3)
572+
times = pd.date_range("2000-01-01", periods=4)
573+
locs = [30, 60, 90]
574+
coords = [("time", times, {"axis": "T"}), ("space", locs)]
575+
foo = xr.DataArray(data, coords, dims=["time", "space"])
576+
ds1 = xr.Dataset({"T": foo})
577+
assert_identical(ds1.cf.data_vars["T"], ds1["T"])
578+
546579

547580
def test_missing_variable_in_coordinates():
548581
airds.air.attrs["coordinates"] = "lat lon time"

doc/whats-new.rst

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,12 @@ What's New
66
v0.4.1 (unreleased)
77
===================
88

9-
- Support for using ``standard_name`` in more functions. By `Deepak Cherian`_
9+
- Support for using ``standard_name`` in more functions. (:pr:`128`) By `Deepak Cherian`_
10+
- Allow ``DataArray.cf[]`` with standard names. By `Deepak Cherian`_
11+
- Rewrite the ``values`` of ``.cf.coords`` and ``.cf.data_vars`` with objects returned
12+
by ``.cf.__getitem___``. This allows extraction of DataArrays when there are clashes
13+
between DataArray names and "special" CF names like ``T``.
14+
(:issue:`129`, :pr:`130`). By `Deepak Cherian`_
1015

1116
v0.4.0 (Jan 22, 2021)
1217
=====================

0 commit comments

Comments
 (0)