Skip to content

Commit c54d9b3

Browse files
committed
More cleanups.
1 parent 8c68924 commit c54d9b3

File tree

1 file changed

+91
-88
lines changed

1 file changed

+91
-88
lines changed

cf_xarray/accessor.py

Lines changed: 91 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -112,12 +112,12 @@
112112
coordinate_criteria["long_name"] = coordinate_criteria["standard_name"]
113113

114114
# Type for Mapper functions
115-
Mapper = Callable[[Union[xr.DataArray, xr.Dataset], str], List[str]]
115+
Mapper = Callable[[Union[DataArray, Dataset], str], List[str]]
116116

117117

118118
def apply_mapper(
119119
mapper: Mapper,
120-
obj: Union[xr.DataArray, xr.Dataset],
120+
obj: Union[DataArray, Dataset],
121121
key: str,
122122
error: bool = True,
123123
default: Any = None,
@@ -139,9 +139,7 @@ def apply_mapper(
139139
return results
140140

141141

142-
def _get_axis_coord_single(
143-
var: Union[xr.DataArray, xr.Dataset], key: str,
144-
) -> List[str]:
142+
def _get_axis_coord_single(var: Union[DataArray, Dataset], key: str,) -> List[str]:
145143
""" Helper method for when we really want only one result per key. """
146144
results = _get_axis_coord(var, key)
147145
if len(results) > 1:
@@ -153,7 +151,7 @@ def _get_axis_coord_single(
153151
return results
154152

155153

156-
def _get_axis_coord(var: Union[xr.DataArray, xr.Dataset], key: str,) -> List[str]:
154+
def _get_axis_coord(var: Union[DataArray, Dataset], key: str,) -> List[str]:
157155
"""
158156
Translate from axis or coord name to variable name
159157
@@ -210,7 +208,7 @@ def _get_axis_coord(var: Union[xr.DataArray, xr.Dataset], key: str,) -> List[str
210208

211209

212210
def _get_measure_variable(
213-
da: xr.DataArray, key: str, error: bool = True, default: str = None
211+
da: DataArray, key: str, error: bool = True, default: str = None
214212
) -> List[DataArray]:
215213
""" tiny wrapper since xarray does not support providing str for weights."""
216214
varnames = apply_mapper(_get_measure, da, key, error, default)
@@ -219,7 +217,7 @@ def _get_measure_variable(
219217
return [da[varnames[0]]]
220218

221219

222-
def _get_measure(da: Union[xr.DataArray, xr.Dataset], key: str) -> List[str]:
220+
def _get_measure(da: Union[DataArray, Dataset], key: str) -> List[str]:
223221
"""
224222
Translate from cell measures ("area" or "volume") to appropriate variable name.
225223
This function interprets the ``cell_measures`` attribute on DataArrays.
@@ -279,7 +277,7 @@ def _get_measure(da: Union[xr.DataArray, xr.Dataset], key: str) -> List[str]:
279277
}
280278

281279

282-
def _filter_by_standard_names(ds: xr.Dataset, name: Union[str, List[str]]) -> List[str]:
280+
def _filter_by_standard_names(ds: Dataset, name: Union[str, List[str]]) -> List[str]:
283281
""" returns a list of variable names with standard names matching name. """
284282
if isinstance(name, str):
285283
name = [name]
@@ -295,7 +293,7 @@ def _filter_by_standard_names(ds: xr.Dataset, name: Union[str, List[str]]) -> Li
295293
return varnames
296294

297295

298-
def _get_list_standard_names(obj: xr.Dataset) -> List[str]:
296+
def _get_list_standard_names(obj: Dataset) -> List[str]:
299297
"""
300298
Returns a sorted list of standard names in Dataset.
301299
@@ -365,7 +363,12 @@ def _getattr(
365363
An extra decorator, if necessary. This is used by _CFPlotMethods to set default
366364
kwargs based on CF attributes.
367365
"""
368-
attribute: Union[Mapping, Callable] = getattr(obj, attr)
366+
try:
367+
attribute: Union[Mapping, Callable] = getattr(obj, attr)
368+
except AttributeError:
369+
raise AttributeError(
370+
f"{attr!r} is not a valid attribute on the underlying xarray object."
371+
)
369372

370373
if isinstance(attribute, Mapping):
371374
if not attribute:
@@ -545,16 +548,14 @@ def _process_signature(self, func: Callable, args, kwargs, key_mappers):
545548
arguments = self._rewrite_values(
546549
bound.arguments, key_mappers, tuple(var_kws)
547550
)
548-
else:
549-
arguments = {}
550-
551-
if arguments:
552551
# now unwrap the **indexers_kwargs type arguments
553552
# so that xarray can parse it :)
554553
for kw in var_kws:
555554
value = arguments.pop(kw, None)
556555
if value:
557556
arguments.update(**value)
557+
else:
558+
arguments = {}
558559

559560
return arguments
560561

@@ -583,45 +584,41 @@ def _rewrite_values(self, kwargs, key_mappers: dict, var_kws):
583584
# these are valid for .sel, .isel, .coarsen
584585
key_mappers.update(dict.fromkeys(var_kws, _get_axis_coord))
585586

586-
for key, value in kwargs.items():
587-
mapper = key_mappers.get(key, None)
588-
589-
if mapper is not None:
590-
if isinstance(value, str):
591-
value = [value]
592-
593-
if isinstance(value, dict):
594-
# this for things like isel where **kwargs captures things like T=5
595-
# .sel, .isel, .rolling
596-
# Account for multiple names matching the key.
597-
# e.g. .isel(X=5) → .isel(xi_rho=5, xi_u=5, xi_v=5, xi_psi=5)
598-
# where xi_* have attrs["axis"] = "X"
599-
updates[key] = ChainMap(
600-
*[
601-
dict.fromkeys(
602-
apply_mapper(mapper, self._obj, k, False, k), v
603-
)
604-
for k, v in value.items()
605-
]
606-
)
587+
for key in set(key_mappers) & set(kwargs):
588+
value = kwargs[key]
589+
mapper = key_mappers[key]
590+
591+
if isinstance(value, str):
592+
value = [value]
593+
594+
if isinstance(value, dict):
595+
# this for things like isel where **kwargs captures things like T=5
596+
# .sel, .isel, .rolling
597+
# Account for multiple names matching the key.
598+
# e.g. .isel(X=5) → .isel(xi_rho=5, xi_u=5, xi_v=5, xi_psi=5)
599+
# where xi_* have attrs["axis"] = "X"
600+
updates[key] = ChainMap(
601+
*[
602+
dict.fromkeys(apply_mapper(mapper, self._obj, k, False, k), v)
603+
for k, v in value.items()
604+
]
605+
)
607606

608-
elif value is Ellipsis:
609-
pass
607+
elif value is Ellipsis:
608+
pass
610609

610+
else:
611+
# things like sum which have dim
612+
newvalue = [apply_mapper(mapper, self._obj, v, False, v) for v in value]
613+
# Mappers return list by default
614+
# for input dim=["lat", "X"], newvalue=[["lat"], ["lon"]],
615+
# so we deal with that here.
616+
unpacked = list(itertools.chain(*newvalue))
617+
if len(unpacked) == 1:
618+
# handle 'group'
619+
updates[key] = unpacked[0]
611620
else:
612-
# things like sum which have dim
613-
newvalue = [
614-
apply_mapper(mapper, self._obj, v, False, v) for v in value
615-
]
616-
# Mappers return list by default
617-
# for input dim=["lat", "X"], newvalue=[["lat"], ["lon"]],
618-
# so we deal with that here.
619-
unpacked = list(itertools.chain(*newvalue))
620-
if len(unpacked) == 1:
621-
# handle 'group'
622-
updates[key] = unpacked[0]
623-
else:
624-
updates[key] = unpacked
621+
updates[key] = unpacked
625622

626623
kwargs.update(updates)
627624

@@ -670,13 +667,13 @@ def describe(self):
670667

671668
text += "\nCell Measures:\n"
672669
for measure in _CELL_MEASURES:
673-
if isinstance(self._obj, xr.Dataset):
670+
if isinstance(self._obj, Dataset):
674671
text += f"\t{measure}: unsupported\n"
675672
else:
676673
text += f"\t{measure}: {apply_mapper(_get_measure, self._obj, measure, error=False)}\n"
677674

678675
text += "\nStandard Names:\n"
679-
if isinstance(self._obj, xr.DataArray):
676+
if isinstance(self._obj, DataArray):
680677
text += "\tunsupported\n"
681678
else:
682679
stdnames = _get_list_standard_names(self._obj)
@@ -702,7 +699,7 @@ def get_valid_keys(self) -> Set[str]:
702699
for key in _AXIS_NAMES + _COORD_NAMES
703700
if apply_mapper(_get_axis_coord, self._obj, key, error=False)
704701
]
705-
if not isinstance(self._obj, xr.Dataset):
702+
if not isinstance(self._obj, Dataset):
706703
measures = [
707704
key
708705
for key in _CELL_MEASURES
@@ -711,16 +708,45 @@ def get_valid_keys(self) -> Set[str]:
711708
if measures:
712709
varnames.extend(measures)
713710

714-
if not isinstance(self._obj, xr.DataArray):
711+
if not isinstance(self._obj, DataArray):
715712
varnames.extend(_get_list_standard_names(self._obj))
716713
return set(varnames)
717714

715+
def get_associated_variable_names(self, name: Hashable) -> List[Hashable]:
716+
"""
717+
Returns a list of variable names referred to in the following attributes
718+
1. "coordinates"
719+
2. "cell_measures"
720+
3. "ancillary_variables"
721+
"""
722+
coords = []
723+
attrs_or_encoding = ChainMap(self._obj[name].attrs, self._obj[name].encoding)
724+
725+
if "coordinates" in attrs_or_encoding:
726+
coords.extend(attrs_or_encoding["coordinates"].split(" "))
727+
728+
if "cell_measures" in attrs_or_encoding:
729+
measures = [
730+
_get_measure(self._obj[name], measure)
731+
for measure in _CELL_MEASURES
732+
if measure in attrs_or_encoding["cell_measures"]
733+
]
734+
coords.extend(*measures)
735+
736+
if (
737+
isinstance(self._obj, Dataset)
738+
and "ancillary_variables" in attrs_or_encoding
739+
):
740+
anames = attrs_or_encoding["ancillary_variables"].split(" ")
741+
coords.extend(anames)
742+
return coords
743+
718744
def __getitem__(self, key: Union[str, List[str]]):
719745

720746
kind = str(type(self._obj).__name__)
721747
scalar_key = isinstance(key, str)
722748

723-
if isinstance(self._obj, xr.DataArray) and not scalar_key:
749+
if isinstance(self._obj, DataArray) and not scalar_key:
724750
raise KeyError(
725751
f"Cannot use a list of keys with DataArrays. Expected a single string. Received {key!r} instead."
726752
)
@@ -741,7 +767,7 @@ def __getitem__(self, key: Union[str, List[str]]):
741767
successful[k] = bool(measure)
742768
if measure:
743769
varnames.extend(measure)
744-
elif not isinstance(self._obj, xr.DataArray):
770+
elif not isinstance(self._obj, DataArray):
745771
stdnames = _filter_by_standard_names(self._obj, k)
746772
successful[k] = bool(stdnames)
747773
varnames.extend(stdnames)
@@ -752,39 +778,16 @@ def __getitem__(self, key: Union[str, List[str]]):
752778
varnames.extend([k for k, v in successful.items() if not v])
753779

754780
try:
755-
# TODO: make this a get_auxiliary_variables function
756-
# 1. set coordinate variables referred to in "coordinates" attribute
757-
# 2. set measures variables as coordinates
758-
# 3. set ancillary variables as coordinates
759781
for name in varnames:
760-
attrs_or_encoding = ChainMap(
761-
self._obj[name].attrs, self._obj[name].encoding
762-
)
763-
if "coordinates" in attrs_or_encoding:
764-
coords.extend(attrs_or_encoding["coordinates"].split(" "))
765-
766-
if "cell_measures" in attrs_or_encoding:
767-
measures = [
768-
_get_measure(self._obj[name], measure)
769-
for measure in _CELL_MEASURES
770-
if measure in attrs_or_encoding["cell_measures"]
771-
]
772-
coords.extend(*measures)
773-
774-
if (
775-
isinstance(self._obj, xr.Dataset)
776-
and "ancillary_variables" in attrs_or_encoding
777-
):
778-
anames = attrs_or_encoding["ancillary_variables"].split(" ")
779-
coords.extend(anames)
782+
coords.extend(self.get_associated_variable_names(name))
780783

781-
if isinstance(self._obj, xr.DataArray):
784+
if isinstance(self._obj, DataArray):
782785
ds = self._obj._to_temp_dataset()
783786
else:
784787
ds = self._obj
785788

786789
if scalar_key and len(varnames) == 1:
787-
da: xr.DataArray = ds[varnames[0]].reset_coords(drop=True) # type: ignore
790+
da: DataArray = ds[varnames[0]].reset_coords(drop=True) # type: ignore
788791
failed = []
789792
for k1 in coords:
790793
if k1 not in ds.variables:
@@ -821,18 +824,18 @@ def __getitem__(self, key: Union[str, List[str]]):
821824
f"Use {kind}.cf.describe() to see a list of key names that can be interpreted."
822825
)
823826

824-
def _maybe_to_dataset(self, obj=None) -> xr.Dataset:
827+
def _maybe_to_dataset(self, obj=None) -> Dataset:
825828
if obj is None:
826829
obj = self._obj
827-
if isinstance(self._obj, xr.DataArray):
830+
if isinstance(self._obj, DataArray):
828831
return obj._to_temp_dataset()
829832
else:
830833
return obj
831834

832835
def _maybe_to_dataarray(self, obj=None):
833836
if obj is None:
834837
obj = self._obj
835-
if isinstance(self._obj, xr.DataArray):
838+
if isinstance(self._obj, DataArray):
836839
return self._obj._from_temp_dataset(obj)
837840
else:
838841
return obj
@@ -879,8 +882,8 @@ def add_bounds(self, dims: Union[Hashable, Iterable[Hashable]]):
879882
return self._maybe_to_dataarray(obj)
880883

881884
def rename_like(
882-
self, other: Union[xr.DataArray, xr.Dataset]
883-
) -> Union[xr.DataArray, xr.Dataset]:
885+
self, other: Union[DataArray, Dataset]
886+
) -> Union[DataArray, Dataset]:
884887
"""
885888
Renames variables in object to match names of like-variables in ``other``.
886889

0 commit comments

Comments
 (0)