Skip to content

Commit 50d96d8

Browse files
committed
Don't accept list of keys for DataArray.cf[...].
DataArray[List] fails with pure xarray, so we copy that behaviour.
1 parent d454cb4 commit 50d96d8

File tree

2 files changed

+28
-16
lines changed

2 files changed

+28
-16
lines changed

cf_xarray/accessor.py

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ def _get_axis_coord(
178178

179179
if key not in _COORD_NAMES and key not in _AXIS_NAMES:
180180
if error:
181-
raise KeyError(f"Did not understand {key}")
181+
raise KeyError(f"Did not understand key {key!r}")
182182
else:
183183
return [default]
184184

@@ -589,6 +589,12 @@ def __getitem__(self, key: Union[str, List[str]]):
589589

590590
kind = str(type(self._obj).__name__)
591591
scalar_key = isinstance(key, str)
592+
593+
if isinstance(self._obj, xr.DataArray) and not scalar_key:
594+
raise KeyError(
595+
f"Cannot use a list of keys with DataArrays. Expected a single string. Received {key!r} instead."
596+
)
597+
592598
if scalar_key:
593599
key = (key,) # type: ignore
594600

@@ -599,7 +605,6 @@ def __getitem__(self, key: Union[str, List[str]]):
599605
if k in _AXIS_NAMES + _COORD_NAMES:
600606
names = _get_axis_coord(self._obj, k)
601607
successful[k] = bool(names)
602-
varnames.extend(_strip_none_list(names))
603608
coords.extend(_strip_none_list(names))
604609
elif k in _CELL_MEASURES:
605610
if isinstance(self._obj, xr.Dataset):
@@ -615,12 +620,11 @@ def __getitem__(self, key: Union[str, List[str]]):
615620
stdnames = _filter_by_standard_names(self._obj, k)
616621
successful[k] = bool(stdnames)
617622
varnames.extend(stdnames)
618-
coords.extend(list(set(stdnames).intersection(set(self._obj.coords))))
623+
coords.extend(list(set(stdnames) & set(self._obj.coords)))
619624

620625
# these are not special names but could be variable names in underlying object
621626
# we allow this so that we can return variables with appropriate CF auxiliary variables
622627
varnames.extend([k for k, v in successful.items() if not v])
623-
assert len(varnames) > 0
624628

625629
try:
626630
# TODO: make this a get_auxiliary_variables function
@@ -643,20 +647,29 @@ def __getitem__(self, key: Union[str, List[str]]):
643647
ds = self._obj._to_temp_dataset()
644648
else:
645649
ds = self._obj
646-
ds = ds.reset_coords()[varnames]
650+
651+
if scalar_key and len(varnames) == 1:
652+
da = ds[varnames[0]]
653+
for k1 in coords:
654+
da.coords[k1] = ds.variables[k1]
655+
return da
656+
657+
ds = ds.reset_coords()[varnames + coords]
647658
if isinstance(self._obj, DataArray):
648659
if scalar_key and len(ds.variables) == 1:
649660
# single dimension coordinates
650-
return ds[list(ds.variables.keys())[0]].squeeze(drop=True)
651-
elif scalar_key and len(ds.coords) > 1:
661+
assert coords
662+
assert not varnames
663+
664+
return ds[coords[0]]
665+
666+
elif scalar_key and len(ds.variables) > 1:
652667
raise NotImplementedError(
653668
"Not sure what to return when given scalar key for DataArray and it has multiple values. "
654669
"Please open an issue."
655670
)
656-
elif not scalar_key:
657-
return ds.set_coords(coords)
658-
else:
659-
return ds.set_coords(coords)
671+
672+
return ds.set_coords(coords)
660673

661674
except KeyError:
662675
raise KeyError(

cf_xarray/tests/test_accessor.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def test_describe():
3232

3333
def test_getitem_standard_name():
3434
actual = airds.cf["air_temperature"]
35-
expected = airds[["air"]]
35+
expected = airds["air"]
3636
assert_identical(actual, expected)
3737

3838
ds = airds.copy(deep=True)
@@ -162,11 +162,10 @@ def test_dataarray_getitem():
162162
air.name = None
163163

164164
assert_identical(air.cf["longitude"], air["lon"])
165-
assert_identical(air.cf[["longitude"]], air["lon"].reset_coords())
166-
assert_identical(
165+
with pytest.raises(KeyError):
166+
air.cf[["longitude"]]
167+
with pytest.raises(KeyError):
167168
air.cf[["longitude", "latitude"]],
168-
air.to_dataset(name="air").drop_vars("cell_area")[["lon", "lat"]],
169-
)
170169

171170

172171
@pytest.mark.parametrize("obj", dataarrays)

0 commit comments

Comments
 (0)