4
4
import textwrap
5
5
import warnings
6
6
from collections import ChainMap
7
- from contextlib import suppress
8
7
from typing import (
8
+ Any ,
9
9
Callable ,
10
10
Hashable ,
11
11
Iterable ,
12
12
List ,
13
13
Mapping ,
14
14
MutableMapping ,
15
- Optional ,
16
15
Set ,
17
16
Tuple ,
18
17
Union ,
113
112
coordinate_criteria ["long_name" ] = coordinate_criteria ["standard_name" ]
114
113
115
114
# Type for Mapper functions
116
- Mapper = Callable [[Union [xr .DataArray , xr .Dataset ], str ], List [Optional [str ]]]
117
-
118
-
119
- def _strip_none_list (lst : List [Optional [str ]]) -> List [str ]:
120
- """ The mappers can return [None]. Strip that when necessary. Keeps mypy happy."""
121
- return [item for item in lst if item != [None ]] # type: ignore
115
+ Mapper = Callable [[Union [xr .DataArray , xr .Dataset ], str ], List [str ]]
122
116
123
117
124
118
def apply_mapper (
125
119
mapper : Mapper ,
126
120
obj : Union [xr .DataArray , xr .Dataset ],
127
121
key : str ,
128
122
error : bool = True ,
129
- default : str = None ,
130
- ) -> List [Optional [ str ] ]:
123
+ default : Any = None ,
124
+ ) -> List [Any ]:
131
125
"""
132
126
Applies a mapping function; does error handling / returning defaults.
133
127
"""
134
-
135
128
try :
136
129
results = mapper (obj , key )
137
130
except Exception as e :
138
131
if error :
139
132
raise e
140
133
else :
141
- results = None # type: ignore
134
+ if default :
135
+ results = [default ] # type: ignore
136
+ else :
137
+ results = []
142
138
143
- if not results :
144
- if error :
145
- raise KeyError (f"Attributes to select { key !r} not found!" )
146
- else :
147
- return [default ]
148
- else :
149
- return list (results )
139
+ return results
150
140
151
141
152
142
def _get_axis_coord_single (
153
143
var : Union [xr .DataArray , xr .Dataset ], key : str ,
154
- ) -> List [Optional [ str ] ]:
144
+ ) -> List [str ]:
155
145
""" Helper method for when we really want only one result per key. """
156
146
results = _get_axis_coord (var , key )
157
147
if len (results ) > 1 :
@@ -163,9 +153,7 @@ def _get_axis_coord_single(
163
153
return results
164
154
165
155
166
- def _get_axis_coord (
167
- var : Union [xr .DataArray , xr .Dataset ], key : str ,
168
- ) -> List [Optional [str ]]:
156
+ def _get_axis_coord (var : Union [xr .DataArray , xr .Dataset ], key : str ,) -> List [str ]:
169
157
"""
170
158
Translate from axis or coord name to variable name
171
159
@@ -225,13 +213,13 @@ def _get_measure_variable(
225
213
da : xr .DataArray , key : str , error : bool = True , default : str = None
226
214
) -> List [DataArray ]:
227
215
""" tiny wrapper since xarray does not support providing str for weights."""
228
- varnames = _strip_none_list ( apply_mapper (_get_measure , da , key , error , default ) )
216
+ varnames = apply_mapper (_get_measure , da , key , error , default )
229
217
if len (varnames ) > 1 :
230
218
raise ValueError (f"Multiple measures found for key { key !r} : { varnames !r} ." )
231
219
return [da [varnames [0 ]]]
232
220
233
221
234
- def _get_measure (da : Union [xr .DataArray , xr .Dataset ], key : str ) -> List [Optional [ str ] ]:
222
+ def _get_measure (da : Union [xr .DataArray , xr .Dataset ], key : str ) -> List [str ]:
235
223
"""
236
224
Translate from cell measures ("area" or "volume") to appropriate variable name.
237
225
This function interprets the ``cell_measures`` attribute on DataArrays.
@@ -269,7 +257,10 @@ def _get_measure(da: Union[xr.DataArray, xr.Dataset], key: str) -> List[Optional
269
257
if len (strings ) % 2 != 0 :
270
258
raise ValueError (f"attrs['cell_measures'] = { attr !r} is malformed." )
271
259
measures = dict (zip (strings [slice (0 , None , 2 )], strings [slice (1 , None , 2 )]))
272
- return [measures .get (key , None )]
260
+ results = measures .get (key , [])
261
+ if isinstance (results , str ):
262
+ return [results ]
263
+ return results
273
264
274
265
275
266
#: Default mappers for common keys.
@@ -383,10 +374,10 @@ def _getattr(
383
374
newmap = dict ()
384
375
unused_keys = set (attribute .keys ())
385
376
for key in _AXIS_NAMES + _COORD_NAMES :
386
- value = apply_mapper (_get_axis_coord , obj , key , error = False )
387
- unused_keys -= set ( value )
388
- if value != [ None ] :
389
- good_values = set ( value ) & set (obj .dims )
377
+ value = set ( apply_mapper (_get_axis_coord , obj , key , error = False ) )
378
+ unused_keys -= value
379
+ if value :
380
+ good_values = value & set (obj .dims )
390
381
if not good_values :
391
382
continue
392
383
if len (good_values ) > 1 :
@@ -592,10 +583,10 @@ def _rewrite_values(self, kwargs, key_mappers: dict, var_kws):
592
583
# these are valid for .sel, .isel, .coarsen
593
584
key_mappers .update (dict .fromkeys (var_kws , _get_axis_coord ))
594
585
595
- for key , mapper in key_mappers .items ():
596
- value = kwargs .get (key , None )
586
+ for key , value in kwargs .items ():
587
+ mapper = key_mappers .get (key , None )
597
588
598
- if value is not None :
589
+ if mapper is not None :
599
590
if isinstance (value , str ):
600
591
value = [value ]
601
592
@@ -709,13 +700,13 @@ def get_valid_keys(self) -> Set[str]:
709
700
varnames = [
710
701
key
711
702
for key in _AXIS_NAMES + _COORD_NAMES
712
- if apply_mapper (_get_axis_coord , self ._obj , key , error = False ) != [ None ]
703
+ if apply_mapper (_get_axis_coord , self ._obj , key , error = False )
713
704
]
714
- with suppress ( NotImplementedError ):
705
+ if not isinstance ( self . _obj , xr . Dataset ):
715
706
measures = [
716
707
key
717
708
for key in _CELL_MEASURES
718
- if apply_mapper (_get_measure , self ._obj , key , error = False ) != [ None ]
709
+ if apply_mapper (_get_measure , self ._obj , key , error = False )
719
710
]
720
711
if measures :
721
712
varnames .extend (measures )
@@ -742,11 +733,11 @@ def __getitem__(self, key: Union[str, List[str]]):
742
733
successful = dict .fromkeys (key , False )
743
734
for k in key :
744
735
if k in _AXIS_NAMES + _COORD_NAMES :
745
- names = _strip_none_list ( _get_axis_coord (self ._obj , k ) )
736
+ names = _get_axis_coord (self ._obj , k )
746
737
successful [k ] = bool (names )
747
738
coords .extend (names )
748
739
elif k in _CELL_MEASURES :
749
- measure = _strip_none_list ( _get_measure (self ._obj , k ) )
740
+ measure = _get_measure (self ._obj , k )
750
741
successful [k ] = bool (measure )
751
742
if measure :
752
743
varnames .extend (measure )
@@ -778,7 +769,7 @@ def __getitem__(self, key: Union[str, List[str]]):
778
769
for measure in _CELL_MEASURES
779
770
if measure in attrs_or_encoding ["cell_measures" ]
780
771
]
781
- coords .extend (_strip_none_list ( * measures ) )
772
+ coords .extend (* measures )
782
773
783
774
if (
784
775
isinstance (self ._obj , xr .Dataset )
@@ -793,7 +784,7 @@ def __getitem__(self, key: Union[str, List[str]]):
793
784
ds = self ._obj
794
785
795
786
if scalar_key and len (varnames ) == 1 :
796
- da = ds [varnames [0 ]].reset_coords (drop = True )
787
+ da : xr . DataArray = ds [varnames [0 ]].reset_coords (drop = True ) # type: ignore
797
788
failed = []
798
789
for k1 in coords :
799
790
if k1 not in ds .variables :
0 commit comments