Skip to content

Commit 8c68924

Browse files
committed
cleanup
Avoid [None] in lists returned by mappers. Instead use [[]]. This makes things a little more pythonic.
1 parent c8fe2d7 commit 8c68924

File tree

2 files changed

+33
-42
lines changed

2 files changed

+33
-42
lines changed

cf_xarray/accessor.py

Lines changed: 31 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,14 @@
44
import textwrap
55
import warnings
66
from collections import ChainMap
7-
from contextlib import suppress
87
from typing import (
8+
Any,
99
Callable,
1010
Hashable,
1111
Iterable,
1212
List,
1313
Mapping,
1414
MutableMapping,
15-
Optional,
1615
Set,
1716
Tuple,
1817
Union,
@@ -113,45 +112,36 @@
113112
coordinate_criteria["long_name"] = coordinate_criteria["standard_name"]
114113

115114
# Type for Mapper functions
116-
Mapper = Callable[[Union[xr.DataArray, xr.Dataset], str], List[Optional[str]]]
117-
118-
119-
def _strip_none_list(lst: List[Optional[str]]) -> List[str]:
120-
""" The mappers can return [None]. Strip that when necessary. Keeps mypy happy."""
121-
return [item for item in lst if item != [None]] # type: ignore
115+
Mapper = Callable[[Union[xr.DataArray, xr.Dataset], str], List[str]]
122116

123117

124118
def apply_mapper(
125119
mapper: Mapper,
126120
obj: Union[xr.DataArray, xr.Dataset],
127121
key: str,
128122
error: bool = True,
129-
default: str = None,
130-
) -> List[Optional[str]]:
123+
default: Any = None,
124+
) -> List[Any]:
131125
"""
132126
Applies a mapping function; does error handling / returning defaults.
133127
"""
134-
135128
try:
136129
results = mapper(obj, key)
137130
except Exception as e:
138131
if error:
139132
raise e
140133
else:
141-
results = None # type: ignore
134+
if default:
135+
results = [default] # type: ignore
136+
else:
137+
results = []
142138

143-
if not results:
144-
if error:
145-
raise KeyError(f"Attributes to select {key!r} not found!")
146-
else:
147-
return [default]
148-
else:
149-
return list(results)
139+
return results
150140

151141

152142
def _get_axis_coord_single(
153143
var: Union[xr.DataArray, xr.Dataset], key: str,
154-
) -> List[Optional[str]]:
144+
) -> List[str]:
155145
""" Helper method for when we really want only one result per key. """
156146
results = _get_axis_coord(var, key)
157147
if len(results) > 1:
@@ -163,9 +153,7 @@ def _get_axis_coord_single(
163153
return results
164154

165155

166-
def _get_axis_coord(
167-
var: Union[xr.DataArray, xr.Dataset], key: str,
168-
) -> List[Optional[str]]:
156+
def _get_axis_coord(var: Union[xr.DataArray, xr.Dataset], key: str,) -> List[str]:
169157
"""
170158
Translate from axis or coord name to variable name
171159
@@ -225,13 +213,13 @@ def _get_measure_variable(
225213
da: xr.DataArray, key: str, error: bool = True, default: str = None
226214
) -> List[DataArray]:
227215
""" tiny wrapper since xarray does not support providing str for weights."""
228-
varnames = _strip_none_list(apply_mapper(_get_measure, da, key, error, default))
216+
varnames = apply_mapper(_get_measure, da, key, error, default)
229217
if len(varnames) > 1:
230218
raise ValueError(f"Multiple measures found for key {key!r}: {varnames!r}.")
231219
return [da[varnames[0]]]
232220

233221

234-
def _get_measure(da: Union[xr.DataArray, xr.Dataset], key: str) -> List[Optional[str]]:
222+
def _get_measure(da: Union[xr.DataArray, xr.Dataset], key: str) -> List[str]:
235223
"""
236224
Translate from cell measures ("area" or "volume") to appropriate variable name.
237225
This function interprets the ``cell_measures`` attribute on DataArrays.
@@ -269,7 +257,10 @@ def _get_measure(da: Union[xr.DataArray, xr.Dataset], key: str) -> List[Optional
269257
if len(strings) % 2 != 0:
270258
raise ValueError(f"attrs['cell_measures'] = {attr!r} is malformed.")
271259
measures = dict(zip(strings[slice(0, None, 2)], strings[slice(1, None, 2)]))
272-
return [measures.get(key, None)]
260+
results = measures.get(key, [])
261+
if isinstance(results, str):
262+
return [results]
263+
return results
273264

274265

275266
#: Default mappers for common keys.
@@ -383,10 +374,10 @@ def _getattr(
383374
newmap = dict()
384375
unused_keys = set(attribute.keys())
385376
for key in _AXIS_NAMES + _COORD_NAMES:
386-
value = apply_mapper(_get_axis_coord, obj, key, error=False)
387-
unused_keys -= set(value)
388-
if value != [None]:
389-
good_values = set(value) & set(obj.dims)
377+
value = set(apply_mapper(_get_axis_coord, obj, key, error=False))
378+
unused_keys -= value
379+
if value:
380+
good_values = value & set(obj.dims)
390381
if not good_values:
391382
continue
392383
if len(good_values) > 1:
@@ -592,10 +583,10 @@ def _rewrite_values(self, kwargs, key_mappers: dict, var_kws):
592583
# these are valid for .sel, .isel, .coarsen
593584
key_mappers.update(dict.fromkeys(var_kws, _get_axis_coord))
594585

595-
for key, mapper in key_mappers.items():
596-
value = kwargs.get(key, None)
586+
for key, value in kwargs.items():
587+
mapper = key_mappers.get(key, None)
597588

598-
if value is not None:
589+
if mapper is not None:
599590
if isinstance(value, str):
600591
value = [value]
601592

@@ -709,13 +700,13 @@ def get_valid_keys(self) -> Set[str]:
709700
varnames = [
710701
key
711702
for key in _AXIS_NAMES + _COORD_NAMES
712-
if apply_mapper(_get_axis_coord, self._obj, key, error=False) != [None]
703+
if apply_mapper(_get_axis_coord, self._obj, key, error=False)
713704
]
714-
with suppress(NotImplementedError):
705+
if not isinstance(self._obj, xr.Dataset):
715706
measures = [
716707
key
717708
for key in _CELL_MEASURES
718-
if apply_mapper(_get_measure, self._obj, key, error=False) != [None]
709+
if apply_mapper(_get_measure, self._obj, key, error=False)
719710
]
720711
if measures:
721712
varnames.extend(measures)
@@ -742,11 +733,11 @@ def __getitem__(self, key: Union[str, List[str]]):
742733
successful = dict.fromkeys(key, False)
743734
for k in key:
744735
if k in _AXIS_NAMES + _COORD_NAMES:
745-
names = _strip_none_list(_get_axis_coord(self._obj, k))
736+
names = _get_axis_coord(self._obj, k)
746737
successful[k] = bool(names)
747738
coords.extend(names)
748739
elif k in _CELL_MEASURES:
749-
measure = _strip_none_list(_get_measure(self._obj, k))
740+
measure = _get_measure(self._obj, k)
750741
successful[k] = bool(measure)
751742
if measure:
752743
varnames.extend(measure)
@@ -778,7 +769,7 @@ def __getitem__(self, key: Union[str, List[str]]):
778769
for measure in _CELL_MEASURES
779770
if measure in attrs_or_encoding["cell_measures"]
780771
]
781-
coords.extend(_strip_none_list(*measures))
772+
coords.extend(*measures)
782773

783774
if (
784775
isinstance(self._obj, xr.Dataset)
@@ -793,7 +784,7 @@ def __getitem__(self, key: Union[str, List[str]]):
793784
ds = self._obj
794785

795786
if scalar_key and len(varnames) == 1:
796-
da = ds[varnames[0]].reset_coords(drop=True)
787+
da: xr.DataArray = ds[varnames[0]].reset_coords(drop=True) # type: ignore
797788
failed = []
798789
for k1 in coords:
799790
if k1 not in ds.variables:

cf_xarray/tests/test_accessor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,9 +57,9 @@ def test_describe(capsys):
5757
airds.cf.describe()
5858
actual = capsys.readouterr().out
5959
expected = (
60-
"Axes:\n\tX: ['lon']\n\tY: ['lat']\n\tZ: [None]\n\tT: ['time']\n"
60+
"Axes:\n\tX: ['lon']\n\tY: ['lat']\n\tZ: []\n\tT: ['time']\n"
6161
"\nCoordinates:\n\tlongitude: ['lon']\n\tlatitude: ['lat']"
62-
"\n\tvertical: [None]\n\ttime: ['time']\n"
62+
"\n\tvertical: []\n\ttime: ['time']\n"
6363
"\nCell Measures:\n\tarea: unsupported\n\tvolume: unsupported\n"
6464
"\nStandard Names:\n\t['air_temperature', 'latitude', 'longitude', 'time']\n"
6565
)

0 commit comments

Comments
 (0)