|
2 | 2 | import inspect
|
3 | 3 | import itertools
|
4 | 4 | import warnings
|
5 |
| -from collections import ChainMap |
| 5 | +from collections import ChainMap, defaultdict |
6 | 6 | from typing import (
|
7 | 7 | Any,
|
8 | 8 | Callable,
|
@@ -159,6 +159,18 @@ def _is_datetime_like(da: DataArray) -> bool:
|
159 | 159 | return False
|
160 | 160 |
|
161 | 161 |
|
| 162 | +def invert_mappings(*mappings): |
| 163 | + """Takes a set of mappings and iterates through, inverting to make a |
| 164 | + new mapping of value: set(keys). Keys are deduplicated to avoid clashes between |
| 165 | + standard_name and coordinate names.""" |
| 166 | + merged = defaultdict(set) |
| 167 | + for mapping in mappings: |
| 168 | + for k, v in mapping.items(): |
| 169 | + for name in v: |
| 170 | + merged[name] |= set([k]) |
| 171 | + return merged |
| 172 | + |
| 173 | + |
162 | 174 | # Type for Mapper functions
|
163 | 175 | Mapper = Callable[[Union[DataArray, Dataset], str], List[str]]
|
164 | 176 |
|
@@ -503,23 +515,29 @@ def _getattr(
|
503 | 515 | if isinstance(attribute, Mapping):
|
504 | 516 | if not attribute:
|
505 | 517 | return dict(attribute)
|
506 |
| - # attributes like chunks / sizes |
| 518 | + |
507 | 519 | newmap = dict()
|
508 |
| - unused_keys = set(attribute.keys()) |
509 |
| - for key in _AXIS_NAMES + _COORD_NAMES: |
510 |
| - value = set(apply_mapper(_get_axis_coord, obj, key, error=False)) |
511 |
| - unused_keys -= value |
512 |
| - if value: |
513 |
| - good_values = value & set(obj.dims) |
514 |
| - if not good_values: |
515 |
| - continue |
516 |
| - if len(good_values) > 1: |
| 520 | + inverted = invert_mappings( |
| 521 | + accessor.axes, |
| 522 | + accessor.coordinates, |
| 523 | + accessor.cell_measures, |
| 524 | + accessor.standard_names, |
| 525 | + ) |
| 526 | + unused_keys = set(attribute.keys()) - set(inverted) |
| 527 | + for key, value in attribute.items(): |
| 528 | + for name in inverted[key]: |
| 529 | + if name in newmap: |
517 | 530 | raise AttributeError(
|
518 |
| - f"cf_xarray can't wrap attribute {attr!r} because there are multiple values for {key!r} viz. {good_values!r}. " |
519 |
| - f"There is no unique mapping from {key!r} to a value in {attr!r}." |
| 531 | + f"cf_xarray can't wrap attribute {attr!r} because there are multiple values for {name!r}. " |
| 532 | + f"There is no unique mapping from {name!r} to a value in {attr!r}." |
520 | 533 | )
|
521 |
| - newmap.update({key: attribute[good_values.pop()]}) |
| 534 | + newmap.update(dict.fromkeys(inverted[key], value)) |
522 | 535 | newmap.update({key: attribute[key] for key in unused_keys})
|
| 536 | + |
| 537 | + skip = {"data_vars": ["coords"], "coords": None} |
| 538 | + if attr in ["coords", "data_vars"]: |
| 539 | + for key in newmap: |
| 540 | + newmap[key] = _getitem(accessor, key, skip=skip[attr]) |
523 | 541 | return newmap
|
524 | 542 |
|
525 | 543 | elif isinstance(attribute, Callable): # type: ignore
|
@@ -548,6 +566,123 @@ def wrapper(*args, **kwargs):
|
548 | 566 | return wrapper
|
549 | 567 |
|
550 | 568 |
|
| 569 | +def _getitem( |
| 570 | + accessor: "CFAccessor", key: Union[str, List[str]], skip: List[str] = None |
| 571 | +) -> Union[DataArray, Dataset]: |
| 572 | + """ |
| 573 | + Index into obj using key. Attaches CF associated variables. |
| 574 | +
|
| 575 | + Parameters |
| 576 | + ---------- |
| 577 | + accessor: CFAccessor |
| 578 | + key: str, List[str] |
| 579 | + skip: str, optional |
| 580 | + One of ["coords", "measures"], avoid clashes with special coord names |
| 581 | + """ |
| 582 | + |
| 583 | + obj = accessor._obj |
| 584 | + kind = str(type(obj).__name__) |
| 585 | + scalar_key = isinstance(key, str) |
| 586 | + |
| 587 | + if isinstance(obj, DataArray) and not scalar_key: |
| 588 | + raise KeyError( |
| 589 | + f"Cannot use a list of keys with DataArrays. Expected a single string. Received {key!r} instead." |
| 590 | + ) |
| 591 | + |
| 592 | + if scalar_key: |
| 593 | + key = (key,) # type: ignore |
| 594 | + |
| 595 | + if skip is None: |
| 596 | + skip = [] |
| 597 | + |
| 598 | + def check_results(names, k): |
| 599 | + if scalar_key and len(names) > 1: |
| 600 | + raise ValueError( |
| 601 | + f"Receive multiple variables for key {k!r}: {names}. " |
| 602 | + f"Expected only one. Please pass a list [{k!r}] " |
| 603 | + f"instead to get all variables matching {k!r}." |
| 604 | + ) |
| 605 | + |
| 606 | + varnames: List[Hashable] = [] |
| 607 | + coords: List[Hashable] = [] |
| 608 | + successful = dict.fromkeys(key, False) |
| 609 | + for k in key: |
| 610 | + if "coords" not in skip and k in _AXIS_NAMES + _COORD_NAMES: |
| 611 | + names = _get_axis_coord(obj, k) |
| 612 | + check_results(names, k) |
| 613 | + successful[k] = bool(names) |
| 614 | + coords.extend(names) |
| 615 | + elif "measures" not in skip and k in accessor._get_all_cell_measures(): |
| 616 | + measure = _get_measure(obj, k) |
| 617 | + check_results(measure, k) |
| 618 | + successful[k] = bool(measure) |
| 619 | + if measure: |
| 620 | + varnames.extend(measure) |
| 621 | + else: |
| 622 | + stdnames = set(_get_with_standard_name(obj, k)) |
| 623 | + check_results(stdnames, k) |
| 624 | + successful[k] = bool(stdnames) |
| 625 | + objcoords = set(obj.coords) |
| 626 | + varnames.extend(stdnames - objcoords) |
| 627 | + coords.extend(stdnames & objcoords) |
| 628 | + |
| 629 | + # these are not special names but could be variable names in underlying object |
| 630 | + # we allow this so that we can return variables with appropriate CF auxiliary variables |
| 631 | + varnames.extend([k for k, v in successful.items() if not v]) |
| 632 | + allnames = varnames + coords |
| 633 | + |
| 634 | + try: |
| 635 | + for name in allnames: |
| 636 | + extravars = accessor.get_associated_variable_names(name) |
| 637 | + # we cannot return bounds variables with scalar keys |
| 638 | + if scalar_key: |
| 639 | + extravars.pop("bounds") |
| 640 | + coords.extend(itertools.chain(*extravars.values())) |
| 641 | + |
| 642 | + if isinstance(obj, DataArray): |
| 643 | + ds = obj._to_temp_dataset() |
| 644 | + else: |
| 645 | + ds = obj |
| 646 | + |
| 647 | + if scalar_key: |
| 648 | + if len(allnames) == 1: |
| 649 | + da: DataArray = ds.reset_coords()[allnames[0]] # type: ignore |
| 650 | + if allnames[0] in coords: |
| 651 | + coords.remove(allnames[0]) |
| 652 | + for k1 in coords: |
| 653 | + da.coords[k1] = ds.variables[k1] |
| 654 | + return da |
| 655 | + else: |
| 656 | + raise ValueError( |
| 657 | + f"Received scalar key {key[0]!r} but multiple results: {allnames!r}. " |
| 658 | + f"Please pass a list instead (['{key[0]}']) to get back a Dataset " |
| 659 | + f"with {allnames!r}." |
| 660 | + ) |
| 661 | + |
| 662 | + ds = ds.reset_coords()[varnames + coords] |
| 663 | + if isinstance(obj, DataArray): |
| 664 | + if scalar_key and len(ds.variables) == 1: |
| 665 | + # single dimension coordinates |
| 666 | + assert coords |
| 667 | + assert not varnames |
| 668 | + |
| 669 | + return ds[coords[0]] |
| 670 | + |
| 671 | + elif scalar_key and len(ds.variables) > 1: |
| 672 | + raise NotImplementedError( |
| 673 | + "Not sure what to return when given scalar key for DataArray and it has multiple values. " |
| 674 | + "Please open an issue." |
| 675 | + ) |
| 676 | + |
| 677 | + return ds.set_coords(coords) |
| 678 | + |
| 679 | + except KeyError: |
| 680 | + raise KeyError( |
| 681 | + f"{kind}.cf does not understand the key {k!r}. " |
| 682 | + f"Use {kind}.cf.describe() to see a list of key names that can be interpreted." |
| 683 | + ) |
| 684 | + |
| 685 | + |
551 | 686 | class _CFWrappedClass:
|
552 | 687 | """
|
553 | 688 | This class is used to wrap any class in _WRAPPED_CLASSES.
|
@@ -1061,104 +1196,7 @@ def get_associated_variable_names(self, name: Hashable) -> Dict[str, List[str]]:
|
1061 | 1196 | return coords
|
1062 | 1197 |
|
1063 | 1198 | def __getitem__(self, key: Union[str, List[str]]):
|
1064 |
| - |
1065 |
| - kind = str(type(self._obj).__name__) |
1066 |
| - scalar_key = isinstance(key, str) |
1067 |
| - |
1068 |
| - if isinstance(self._obj, DataArray) and not scalar_key: |
1069 |
| - raise KeyError( |
1070 |
| - f"Cannot use a list of keys with DataArrays. Expected a single string. Received {key!r} instead." |
1071 |
| - ) |
1072 |
| - |
1073 |
| - if scalar_key: |
1074 |
| - key = (key,) # type: ignore |
1075 |
| - |
1076 |
| - def check_results(names, k): |
1077 |
| - if scalar_key and len(names) > 1: |
1078 |
| - raise ValueError( |
1079 |
| - f"Receive multiple variables for key {k!r}: {names}. " |
1080 |
| - f"Expected only one. Please pass a list [{k!r}] " |
1081 |
| - f"instead to get all variables matching {k!r}." |
1082 |
| - ) |
1083 |
| - |
1084 |
| - varnames: List[Hashable] = [] |
1085 |
| - coords: List[Hashable] = [] |
1086 |
| - successful = dict.fromkeys(key, False) |
1087 |
| - for k in key: |
1088 |
| - if k in _AXIS_NAMES + _COORD_NAMES: |
1089 |
| - names = _get_axis_coord(self._obj, k) |
1090 |
| - check_results(names, k) |
1091 |
| - successful[k] = bool(names) |
1092 |
| - coords.extend(names) |
1093 |
| - elif k in self._get_all_cell_measures(): |
1094 |
| - measure = _get_measure(self._obj, k) |
1095 |
| - check_results(measure, k) |
1096 |
| - successful[k] = bool(measure) |
1097 |
| - if measure: |
1098 |
| - varnames.extend(measure) |
1099 |
| - elif not isinstance(self._obj, DataArray): |
1100 |
| - stdnames = set(_get_with_standard_name(self._obj, k)) |
1101 |
| - check_results(stdnames, k) |
1102 |
| - successful[k] = bool(stdnames) |
1103 |
| - objcoords = set(self._obj.coords) |
1104 |
| - varnames.extend(stdnames - objcoords) |
1105 |
| - coords.extend(stdnames & objcoords) |
1106 |
| - |
1107 |
| - # these are not special names but could be variable names in underlying object |
1108 |
| - # we allow this so that we can return variables with appropriate CF auxiliary variables |
1109 |
| - varnames.extend([k for k, v in successful.items() if not v]) |
1110 |
| - allnames = varnames + coords |
1111 |
| - |
1112 |
| - try: |
1113 |
| - for name in allnames: |
1114 |
| - extravars = self.get_associated_variable_names(name) |
1115 |
| - # we cannot return bounds variables with scalar keys |
1116 |
| - if scalar_key: |
1117 |
| - extravars.pop("bounds") |
1118 |
| - coords.extend(itertools.chain(*extravars.values())) |
1119 |
| - |
1120 |
| - if isinstance(self._obj, DataArray): |
1121 |
| - ds = self._obj._to_temp_dataset() |
1122 |
| - else: |
1123 |
| - ds = self._obj |
1124 |
| - |
1125 |
| - if scalar_key: |
1126 |
| - if len(allnames) == 1: |
1127 |
| - da: DataArray = ds.reset_coords()[allnames[0]] # type: ignore |
1128 |
| - if allnames[0] in coords: |
1129 |
| - coords.remove(allnames[0]) |
1130 |
| - for k1 in coords: |
1131 |
| - da.coords[k1] = ds.variables[k1] |
1132 |
| - return da |
1133 |
| - else: |
1134 |
| - raise ValueError( |
1135 |
| - f"Received scalar key {key[0]!r} but multiple results: {allnames!r}. " |
1136 |
| - f"Please pass a list instead (['{key[0]}']) to get back a Dataset " |
1137 |
| - f"with {allnames!r}." |
1138 |
| - ) |
1139 |
| - |
1140 |
| - ds = ds.reset_coords()[varnames + coords] |
1141 |
| - if isinstance(self._obj, DataArray): |
1142 |
| - if scalar_key and len(ds.variables) == 1: |
1143 |
| - # single dimension coordinates |
1144 |
| - assert coords |
1145 |
| - assert not varnames |
1146 |
| - |
1147 |
| - return ds[coords[0]] |
1148 |
| - |
1149 |
| - elif scalar_key and len(ds.variables) > 1: |
1150 |
| - raise NotImplementedError( |
1151 |
| - "Not sure what to return when given scalar key for DataArray and it has multiple values. " |
1152 |
| - "Please open an issue." |
1153 |
| - ) |
1154 |
| - |
1155 |
| - return ds.set_coords(coords) |
1156 |
| - |
1157 |
| - except KeyError: |
1158 |
| - raise KeyError( |
1159 |
| - f"{kind}.cf does not understand the key {k!r}. " |
1160 |
| - f"Use {kind}.cf.describe() to see a list of key names that can be interpreted." |
1161 |
| - ) |
| 1199 | + return _getitem(self, key) |
1162 | 1200 |
|
1163 | 1201 | def _maybe_to_dataset(self, obj=None) -> Dataset:
|
1164 | 1202 | if obj is None:
|
|
0 commit comments