@@ -178,7 +178,7 @@ def _get_axis_coord(
178
178
179
179
if key not in _COORD_NAMES and key not in _AXIS_NAMES :
180
180
if error :
181
- raise KeyError (f"Did not understand { key } " )
181
+ raise KeyError (f"Did not understand key { key !r } " )
182
182
else :
183
183
return [default ]
184
184
@@ -589,6 +589,12 @@ def __getitem__(self, key: Union[str, List[str]]):
589
589
590
590
kind = str (type (self ._obj ).__name__ )
591
591
scalar_key = isinstance (key , str )
592
+
593
+ if isinstance (self ._obj , xr .DataArray ) and not scalar_key :
594
+ raise KeyError (
595
+ f"Cannot use a list of keys with DataArrays. Expected a single string. Received { key !r} instead."
596
+ )
597
+
592
598
if scalar_key :
593
599
key = (key ,) # type: ignore
594
600
@@ -599,7 +605,6 @@ def __getitem__(self, key: Union[str, List[str]]):
599
605
if k in _AXIS_NAMES + _COORD_NAMES :
600
606
names = _get_axis_coord (self ._obj , k )
601
607
successful [k ] = bool (names )
602
- varnames .extend (_strip_none_list (names ))
603
608
coords .extend (_strip_none_list (names ))
604
609
elif k in _CELL_MEASURES :
605
610
if isinstance (self ._obj , xr .Dataset ):
@@ -615,12 +620,11 @@ def __getitem__(self, key: Union[str, List[str]]):
615
620
stdnames = _filter_by_standard_names (self ._obj , k )
616
621
successful [k ] = bool (stdnames )
617
622
varnames .extend (stdnames )
618
- coords .extend (list (set (stdnames ). intersection ( set (self ._obj .coords ) )))
623
+ coords .extend (list (set (stdnames ) & set (self ._obj .coords )))
619
624
620
625
# these are not special names but could be variable names in underlying object
621
626
# we allow this so that we can return variables with appropriate CF auxiliary variables
622
627
varnames .extend ([k for k , v in successful .items () if not v ])
623
- assert len (varnames ) > 0
624
628
625
629
try :
626
630
# TODO: make this a get_auxiliary_variables function
@@ -643,20 +647,29 @@ def __getitem__(self, key: Union[str, List[str]]):
643
647
ds = self ._obj ._to_temp_dataset ()
644
648
else :
645
649
ds = self ._obj
646
- ds = ds .reset_coords ()[varnames ]
650
+
651
+ if scalar_key and len (varnames ) == 1 :
652
+ da = ds [varnames [0 ]]
653
+ for k1 in coords :
654
+ da .coords [k1 ] = ds .variables [k1 ]
655
+ return da
656
+
657
+ ds = ds .reset_coords ()[varnames + coords ]
647
658
if isinstance (self ._obj , DataArray ):
648
659
if scalar_key and len (ds .variables ) == 1 :
649
660
# single dimension coordinates
650
- return ds [list (ds .variables .keys ())[0 ]].squeeze (drop = True )
651
- elif scalar_key and len (ds .coords ) > 1 :
661
+ assert coords
662
+ assert not varnames
663
+
664
+ return ds [coords [0 ]]
665
+
666
+ elif scalar_key and len (ds .variables ) > 1 :
652
667
raise NotImplementedError (
653
668
"Not sure what to return when given scalar key for DataArray and it has multiple values. "
654
669
"Please open an issue."
655
670
)
656
- elif not scalar_key :
657
- return ds .set_coords (coords )
658
- else :
659
- return ds .set_coords (coords )
671
+
672
+ return ds .set_coords (coords )
660
673
661
674
except KeyError :
662
675
raise KeyError (
0 commit comments