|
7 | 7 | from typing import (
|
8 | 8 | Any,
|
9 | 9 | Callable,
|
| 10 | + Dict, |
10 | 11 | Hashable,
|
11 | 12 | Iterable,
|
12 | 13 | List,
|
@@ -896,42 +897,72 @@ def get_standard_names(self) -> List[str]:
|
896 | 897 | ]
|
897 | 898 | )
|
898 | 899 |
|
899 |
| - def get_associated_variable_names(self, name: Hashable) -> List[Hashable]: |
| 900 | + def get_associated_variable_names(self, name: Hashable) -> Dict[str, List[str]]: |
900 | 901 | """
|
901 |
| - Returns a list of variable names referred to in the following attributes |
902 |
| - 1. "coordinates" |
903 |
| - 2. "cell_measures" |
904 |
| - 3. "ancillary_variables" |
| 902 | + Returns a dict mapping |
| 903 | + 1. "ancillary_variables" |
| 904 | + 2. "bounds" |
| 905 | + 3. "cell_measures" |
| 906 | + 4. "coordinates" |
| 907 | + to a list of variable names referred to in the appropriate attribute |
| 908 | +
|
| 909 | + Parameters |
| 910 | + ---------- |
| 911 | +
|
| 912 | + name: Hashable |
| 913 | +
|
| 914 | + Returns |
| 915 | + ------ |
| 916 | +
|
| 917 | + Dict with keys "ancillary_variables", "cell_measures", "coordinates", "bounds" |
905 | 918 | """
|
906 |
| - coords = [] |
| 919 | + keys = ["ancillary_variables", "cell_measures", "coordinates", "bounds"] |
| 920 | + coords: Dict[str, List[str]] = {k: [] for k in keys} |
907 | 921 | attrs_or_encoding = ChainMap(self._obj[name].attrs, self._obj[name].encoding)
|
908 | 922 |
|
909 | 923 | if "coordinates" in attrs_or_encoding:
|
910 |
| - coords.extend(attrs_or_encoding["coordinates"].split(" ")) |
| 924 | + coords["coordinates"] = attrs_or_encoding["coordinates"].split(" ") |
911 | 925 |
|
912 | 926 | if "cell_measures" in attrs_or_encoding:
|
913 |
| - measures = [ |
914 |
| - _get_measure(self._obj[name], measure) |
915 |
| - for measure in _CELL_MEASURES |
916 |
| - if measure in attrs_or_encoding["cell_measures"] |
917 |
| - ] |
918 |
| - coords.extend(*measures) |
| 927 | + coords["cell_measures"] = list( |
| 928 | + itertools.chain( |
| 929 | + *[ |
| 930 | + _get_measure(self._obj[name], measure) |
| 931 | + for measure in _CELL_MEASURES |
| 932 | + if measure in attrs_or_encoding["cell_measures"] |
| 933 | + ] |
| 934 | + ) |
| 935 | + ) |
919 | 936 |
|
920 | 937 | if (
|
921 | 938 | isinstance(self._obj, Dataset)
|
922 | 939 | and "ancillary_variables" in attrs_or_encoding
|
923 | 940 | ):
|
924 |
| - anames = attrs_or_encoding["ancillary_variables"].split(" ") |
925 |
| - coords.extend(anames) |
| 941 | + coords["ancillary_variables"] = attrs_or_encoding[ |
| 942 | + "ancillary_variables" |
| 943 | + ].split(" ") |
| 944 | + |
| 945 | + if "bounds" in attrs_or_encoding: |
| 946 | + coords["bounds"] = [attrs_or_encoding["bounds"]] |
926 | 947 |
|
927 |
| - missing = set(coords) - set(self._maybe_to_dataset().variables) |
| 948 | + for dim in self._obj[name].dims: |
| 949 | + dbounds = self._obj[dim].attrs.get("bounds", None) |
| 950 | + if dbounds: |
| 951 | + coords["bounds"].append(dbounds) |
| 952 | + |
| 953 | + allvars = itertools.chain(*coords.values()) |
| 954 | + missing = set(allvars) - set(self._maybe_to_dataset().variables) |
928 | 955 | if missing:
|
929 | 956 | warnings.warn(
|
930 | 957 | f"Variables {missing!r} not found in object but are referred to in the CF attributes.",
|
931 | 958 | UserWarning,
|
932 | 959 | )
|
933 |
| - for m in missing: |
934 |
| - coords.remove(m) |
| 960 | + for k, v in coords.items(): |
| 961 | + for m in missing: |
| 962 | + if m in v: |
| 963 | + v.remove(m) |
| 964 | + coords[k] = v |
| 965 | + |
935 | 966 | return coords
|
936 | 967 |
|
937 | 968 | def __getitem__(self, key: Union[str, List[str]]):
|
@@ -981,8 +1012,12 @@ def __getitem__(self, key: Union[str, List[str]]):
|
981 | 1012 | allnames = varnames + coords
|
982 | 1013 |
|
983 | 1014 | try:
|
984 |
| - for name in varnames: |
985 |
| - coords.extend(self.get_associated_variable_names(name)) |
| 1015 | + for name in allnames: |
| 1016 | + extravars = self.get_associated_variable_names(name) |
| 1017 | + # we cannot return bounds variables with scalar keys |
| 1018 | + if scalar_key: |
| 1019 | + extravars.pop("bounds") |
| 1020 | + coords.extend(itertools.chain(*extravars.values())) |
986 | 1021 |
|
987 | 1022 | if isinstance(self._obj, DataArray):
|
988 | 1023 | ds = self._obj._to_temp_dataset()
|
@@ -1036,47 +1071,6 @@ def _maybe_to_dataarray(self, obj=None):
|
1036 | 1071 | else:
|
1037 | 1072 | return obj
|
1038 | 1073 |
|
1039 |
| - def add_bounds(self, dims: Union[Hashable, Iterable[Hashable]]): |
1040 |
| - """ |
1041 |
| - Returns a new object with bounds variables. The bounds values are guessed assuming |
1042 |
| - equal spacing on either side of a coordinate label. |
1043 |
| -
|
1044 |
| - Parameters |
1045 |
| - ---------- |
1046 |
| - dims: Hashable or Iterable[Hashable] |
1047 |
| - Either a single dimension name or a list of dimension names. |
1048 |
| -
|
1049 |
| - Returns |
1050 |
| - ------- |
1051 |
| - DataArray or Dataset with bounds variables added and appropriate "bounds" attribute set. |
1052 |
| -
|
1053 |
| - Notes |
1054 |
| - ----- |
1055 |
| -
|
1056 |
| - The bounds variables are automatically named f"{dim}_bounds" where ``dim`` |
1057 |
| - is a dimension name. |
1058 |
| - """ |
1059 |
| - if isinstance(dims, Hashable): |
1060 |
| - dimensions = (dims,) |
1061 |
| - else: |
1062 |
| - dimensions = dims |
1063 |
| - |
1064 |
| - bad_dims: Set[Hashable] = set(dimensions) - set(self._obj.dims) |
1065 |
| - if bad_dims: |
1066 |
| - raise ValueError( |
1067 |
| - f"{bad_dims!r} are not dimensions in the underlying object." |
1068 |
| - ) |
1069 |
| - |
1070 |
| - obj = self._maybe_to_dataset(self._obj.copy(deep=True)) |
1071 |
| - for dim in dimensions: |
1072 |
| - bname = f"{dim}_bounds" |
1073 |
| - if bname in obj.variables: |
1074 |
| - raise ValueError(f"Bounds variable name {bname!r} will conflict!") |
1075 |
| - obj.coords[bname] = _guess_bounds_dim(obj[dim].reset_coords(drop=True)) |
1076 |
| - obj[dim].attrs["bounds"] = bname |
1077 |
| - |
1078 |
| - return self._maybe_to_dataarray(obj) |
1079 |
| - |
1080 | 1074 | def rename_like(
|
1081 | 1075 | self, other: Union[DataArray, Dataset]
|
1082 | 1076 | ) -> Union[DataArray, Dataset]:
|
@@ -1169,7 +1163,66 @@ def guess_coord_axis(self, verbose: bool = False) -> Union[DataArray, Dataset]:
|
1169 | 1163 |
|
1170 | 1164 | @xr.register_dataset_accessor("cf")
|
1171 | 1165 | class CFDatasetAccessor(CFAccessor):
|
1172 |
| - pass |
| 1166 | + def get_bounds(self, key: str) -> DataArray: |
| 1167 | + """ |
| 1168 | + Get bounds variable corresponding to key. |
| 1169 | +
|
| 1170 | + Parameters |
| 1171 | + ---------- |
| 1172 | + key: str |
| 1173 | + Name of variable whose bounds are desired |
| 1174 | +
|
| 1175 | + Returns |
| 1176 | + ------- |
| 1177 | + DataArray |
| 1178 | + """ |
| 1179 | + name = apply_mapper( |
| 1180 | + _get_axis_coord_single, self._obj, key, error=False, default=[key] |
| 1181 | + )[0] |
| 1182 | + bounds = self._obj[name].attrs["bounds"] |
| 1183 | + obj = self._maybe_to_dataset() |
| 1184 | + return obj[bounds] |
| 1185 | + |
| 1186 | + def add_bounds(self, dims: Union[Hashable, Iterable[Hashable]]): |
| 1187 | + """ |
| 1188 | + Returns a new object with bounds variables. The bounds values are guessed assuming |
| 1189 | + equal spacing on either side of a coordinate label. |
| 1190 | +
|
| 1191 | + Parameters |
| 1192 | + ---------- |
| 1193 | + dims: Hashable or Iterable[Hashable] |
| 1194 | + Either a single dimension name or a list of dimension names. |
| 1195 | +
|
| 1196 | + Returns |
| 1197 | + ------- |
| 1198 | + DataArray or Dataset with bounds variables added and appropriate "bounds" attribute set. |
| 1199 | +
|
| 1200 | + Notes |
| 1201 | + ----- |
| 1202 | +
|
| 1203 | + The bounds variables are automatically named f"{dim}_bounds" where ``dim`` |
| 1204 | + is a dimension name. |
| 1205 | + """ |
| 1206 | + if isinstance(dims, Hashable): |
| 1207 | + dimensions = (dims,) |
| 1208 | + else: |
| 1209 | + dimensions = dims |
| 1210 | + |
| 1211 | + bad_dims: Set[Hashable] = set(dimensions) - set(self._obj.dims) |
| 1212 | + if bad_dims: |
| 1213 | + raise ValueError( |
| 1214 | + f"{bad_dims!r} are not dimensions in the underlying object." |
| 1215 | + ) |
| 1216 | + |
| 1217 | + obj = self._maybe_to_dataset(self._obj.copy(deep=True)) |
| 1218 | + for dim in dimensions: |
| 1219 | + bname = f"{dim}_bounds" |
| 1220 | + if bname in obj.variables: |
| 1221 | + raise ValueError(f"Bounds variable name {bname!r} will conflict!") |
| 1222 | + obj.coords[bname] = _guess_bounds_dim(obj[dim].reset_coords(drop=True)) |
| 1223 | + obj[dim].attrs["bounds"] = bname |
| 1224 | + |
| 1225 | + return self._maybe_to_dataarray(obj) |
1173 | 1226 |
|
1174 | 1227 |
|
1175 | 1228 | @xr.register_dataarray_accessor("cf")
|
|
0 commit comments