Skip to content

Commit c626d74

Browse files
committed
Error if multiple results for scalar "standard_name" key
1 parent c1d3197 commit c626d74

File tree

2 files changed

+15
-11
lines changed

2 files changed

+15
-11
lines changed

cf_xarray/accessor.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -984,32 +984,34 @@ def __getitem__(self, key: Union[str, List[str]]):
984984
)
985985

986986
if scalar_key:
987-
axis_coord_mapper = _get_axis_coord_single
988987
key = (key,) # type: ignore
989-
else:
990-
axis_coord_mapper = _get_axis_coord
988+
989+
def check_results(names, k):
990+
if scalar_key and len(names) > 1:
991+
raise ValueError(
992+
f"Receive multiple variables for key {k!r}: {names}. "
993+
f"Expected only one. Please pass a list [{k!r}] "
994+
f"instead to get all variables matching {k!r}."
995+
)
991996

992997
varnames: List[Hashable] = []
993998
coords: List[Hashable] = []
994999
successful = dict.fromkeys(key, False)
9951000
for k in key:
9961001
if k in _AXIS_NAMES + _COORD_NAMES:
997-
try:
998-
names = axis_coord_mapper(self._obj, k)
999-
except KeyError as e:
1000-
raise KeyError(
1001-
f"Receive multiple variables for key {k!r}. Expected only one. Please pass a list [{k!r}] instead to get all variables matching {k!r}."
1002-
)
1003-
raise e
1002+
names = _get_axis_coord(self._obj, k)
1003+
check_results(names, k)
10041004
successful[k] = bool(names)
10051005
coords.extend(names)
10061006
elif k in _CELL_MEASURES:
10071007
measure = _get_measure(self._obj, k)
1008+
check_results(measure, k)
10081009
successful[k] = bool(measure)
10091010
if measure:
10101011
varnames.extend(measure)
10111012
elif not isinstance(self._obj, DataArray):
10121013
stdnames = _get_with_standard_name(self._obj, k)
1014+
check_results(stdnames, k)
10131015
successful[k] = bool(stdnames)
10141016
varnames.extend(stdnames)
10151017
coords.extend(list(set(stdnames) & set(self._obj.coords)))

cf_xarray/tests/test_accessor.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,9 @@ def test_getitem_standard_name():
5050

5151
ds = airds.copy(deep=True)
5252
ds["air2"] = ds.air
53-
actual = ds.cf["air_temperature"]
53+
with pytest.raises(ValueError):
54+
ds.cf["air_temperature"]
55+
actual = ds.cf[["air_temperature"]]
5456
expected = ds[["air", "air2"]]
5557
assert_identical(actual, expected)
5658

0 commit comments

Comments
 (0)