112
112
coordinate_criteria ["long_name" ] = coordinate_criteria ["standard_name" ]
113
113
114
114
# Type for Mapper functions
115
- Mapper = Callable [[Union [xr . DataArray , xr . Dataset ], str ], List [str ]]
115
+ Mapper = Callable [[Union [DataArray , Dataset ], str ], List [str ]]
116
116
117
117
118
118
def apply_mapper (
119
119
mapper : Mapper ,
120
- obj : Union [xr . DataArray , xr . Dataset ],
120
+ obj : Union [DataArray , Dataset ],
121
121
key : str ,
122
122
error : bool = True ,
123
123
default : Any = None ,
@@ -139,9 +139,7 @@ def apply_mapper(
139
139
return results
140
140
141
141
142
- def _get_axis_coord_single (
143
- var : Union [xr .DataArray , xr .Dataset ], key : str ,
144
- ) -> List [str ]:
142
+ def _get_axis_coord_single (var : Union [DataArray , Dataset ], key : str ,) -> List [str ]:
145
143
""" Helper method for when we really want only one result per key. """
146
144
results = _get_axis_coord (var , key )
147
145
if len (results ) > 1 :
@@ -153,7 +151,7 @@ def _get_axis_coord_single(
153
151
return results
154
152
155
153
156
- def _get_axis_coord (var : Union [xr . DataArray , xr . Dataset ], key : str ,) -> List [str ]:
154
+ def _get_axis_coord (var : Union [DataArray , Dataset ], key : str ,) -> List [str ]:
157
155
"""
158
156
Translate from axis or coord name to variable name
159
157
@@ -210,7 +208,7 @@ def _get_axis_coord(var: Union[xr.DataArray, xr.Dataset], key: str,) -> List[str
210
208
211
209
212
210
def _get_measure_variable (
213
- da : xr . DataArray , key : str , error : bool = True , default : str = None
211
+ da : DataArray , key : str , error : bool = True , default : str = None
214
212
) -> List [DataArray ]:
215
213
""" tiny wrapper since xarray does not support providing str for weights."""
216
214
varnames = apply_mapper (_get_measure , da , key , error , default )
@@ -219,7 +217,7 @@ def _get_measure_variable(
219
217
return [da [varnames [0 ]]]
220
218
221
219
222
- def _get_measure (da : Union [xr . DataArray , xr . Dataset ], key : str ) -> List [str ]:
220
+ def _get_measure (da : Union [DataArray , Dataset ], key : str ) -> List [str ]:
223
221
"""
224
222
Translate from cell measures ("area" or "volume") to appropriate variable name.
225
223
This function interprets the ``cell_measures`` attribute on DataArrays.
@@ -279,7 +277,7 @@ def _get_measure(da: Union[xr.DataArray, xr.Dataset], key: str) -> List[str]:
279
277
}
280
278
281
279
282
- def _filter_by_standard_names (ds : xr . Dataset , name : Union [str , List [str ]]) -> List [str ]:
280
+ def _filter_by_standard_names (ds : Dataset , name : Union [str , List [str ]]) -> List [str ]:
283
281
""" returns a list of variable names with standard names matching name. """
284
282
if isinstance (name , str ):
285
283
name = [name ]
@@ -295,7 +293,7 @@ def _filter_by_standard_names(ds: xr.Dataset, name: Union[str, List[str]]) -> Li
295
293
return varnames
296
294
297
295
298
- def _get_list_standard_names (obj : xr . Dataset ) -> List [str ]:
296
+ def _get_list_standard_names (obj : Dataset ) -> List [str ]:
299
297
"""
300
298
Returns a sorted list of standard names in Dataset.
301
299
@@ -365,7 +363,12 @@ def _getattr(
365
363
An extra decorator, if necessary. This is used by _CFPlotMethods to set default
366
364
kwargs based on CF attributes.
367
365
"""
368
- attribute : Union [Mapping , Callable ] = getattr (obj , attr )
366
+ try :
367
+ attribute : Union [Mapping , Callable ] = getattr (obj , attr )
368
+ except AttributeError :
369
+ raise AttributeError (
370
+ f"{ attr !r} is not a valid attribute on the underlying xarray object."
371
+ )
369
372
370
373
if isinstance (attribute , Mapping ):
371
374
if not attribute :
@@ -545,16 +548,14 @@ def _process_signature(self, func: Callable, args, kwargs, key_mappers):
545
548
arguments = self ._rewrite_values (
546
549
bound .arguments , key_mappers , tuple (var_kws )
547
550
)
548
- else :
549
- arguments = {}
550
-
551
- if arguments :
552
551
# now unwrap the **indexers_kwargs type arguments
553
552
# so that xarray can parse it :)
554
553
for kw in var_kws :
555
554
value = arguments .pop (kw , None )
556
555
if value :
557
556
arguments .update (** value )
557
+ else :
558
+ arguments = {}
558
559
559
560
return arguments
560
561
@@ -583,45 +584,41 @@ def _rewrite_values(self, kwargs, key_mappers: dict, var_kws):
583
584
# these are valid for .sel, .isel, .coarsen
584
585
key_mappers .update (dict .fromkeys (var_kws , _get_axis_coord ))
585
586
586
- for key , value in kwargs .items ():
587
- mapper = key_mappers .get (key , None )
588
-
589
- if mapper is not None :
590
- if isinstance (value , str ):
591
- value = [value ]
592
-
593
- if isinstance (value , dict ):
594
- # this for things like isel where **kwargs captures things like T=5
595
- # .sel, .isel, .rolling
596
- # Account for multiple names matching the key.
597
- # e.g. .isel(X=5) → .isel(xi_rho=5, xi_u=5, xi_v=5, xi_psi=5)
598
- # where xi_* have attrs["axis"] = "X"
599
- updates [key ] = ChainMap (
600
- * [
601
- dict .fromkeys (
602
- apply_mapper (mapper , self ._obj , k , False , k ), v
603
- )
604
- for k , v in value .items ()
605
- ]
606
- )
587
+ for key in set (key_mappers ) & set (kwargs ):
588
+ value = kwargs [key ]
589
+ mapper = key_mappers [key ]
590
+
591
+ if isinstance (value , str ):
592
+ value = [value ]
593
+
594
+ if isinstance (value , dict ):
595
+ # this for things like isel where **kwargs captures things like T=5
596
+ # .sel, .isel, .rolling
597
+ # Account for multiple names matching the key.
598
+ # e.g. .isel(X=5) → .isel(xi_rho=5, xi_u=5, xi_v=5, xi_psi=5)
599
+ # where xi_* have attrs["axis"] = "X"
600
+ updates [key ] = ChainMap (
601
+ * [
602
+ dict .fromkeys (apply_mapper (mapper , self ._obj , k , False , k ), v )
603
+ for k , v in value .items ()
604
+ ]
605
+ )
607
606
608
- elif value is Ellipsis :
609
- pass
607
+ elif value is Ellipsis :
608
+ pass
610
609
610
+ else :
611
+ # things like sum which have dim
612
+ newvalue = [apply_mapper (mapper , self ._obj , v , False , v ) for v in value ]
613
+ # Mappers return list by default
614
+ # for input dim=["lat", "X"], newvalue=[["lat"], ["lon"]],
615
+ # so we deal with that here.
616
+ unpacked = list (itertools .chain (* newvalue ))
617
+ if len (unpacked ) == 1 :
618
+ # handle 'group'
619
+ updates [key ] = unpacked [0 ]
611
620
else :
612
- # things like sum which have dim
613
- newvalue = [
614
- apply_mapper (mapper , self ._obj , v , False , v ) for v in value
615
- ]
616
- # Mappers return list by default
617
- # for input dim=["lat", "X"], newvalue=[["lat"], ["lon"]],
618
- # so we deal with that here.
619
- unpacked = list (itertools .chain (* newvalue ))
620
- if len (unpacked ) == 1 :
621
- # handle 'group'
622
- updates [key ] = unpacked [0 ]
623
- else :
624
- updates [key ] = unpacked
621
+ updates [key ] = unpacked
625
622
626
623
kwargs .update (updates )
627
624
@@ -670,13 +667,13 @@ def describe(self):
670
667
671
668
text += "\n Cell Measures:\n "
672
669
for measure in _CELL_MEASURES :
673
- if isinstance (self ._obj , xr . Dataset ):
670
+ if isinstance (self ._obj , Dataset ):
674
671
text += f"\t { measure } : unsupported\n "
675
672
else :
676
673
text += f"\t { measure } : { apply_mapper (_get_measure , self ._obj , measure , error = False )} \n "
677
674
678
675
text += "\n Standard Names:\n "
679
- if isinstance (self ._obj , xr . DataArray ):
676
+ if isinstance (self ._obj , DataArray ):
680
677
text += "\t unsupported\n "
681
678
else :
682
679
stdnames = _get_list_standard_names (self ._obj )
@@ -702,7 +699,7 @@ def get_valid_keys(self) -> Set[str]:
702
699
for key in _AXIS_NAMES + _COORD_NAMES
703
700
if apply_mapper (_get_axis_coord , self ._obj , key , error = False )
704
701
]
705
- if not isinstance (self ._obj , xr . Dataset ):
702
+ if not isinstance (self ._obj , Dataset ):
706
703
measures = [
707
704
key
708
705
for key in _CELL_MEASURES
@@ -711,16 +708,45 @@ def get_valid_keys(self) -> Set[str]:
711
708
if measures :
712
709
varnames .extend (measures )
713
710
714
- if not isinstance (self ._obj , xr . DataArray ):
711
+ if not isinstance (self ._obj , DataArray ):
715
712
varnames .extend (_get_list_standard_names (self ._obj ))
716
713
return set (varnames )
717
714
715
+ def get_associated_variable_names (self , name : Hashable ) -> List [Hashable ]:
716
+ """
717
+ Returns a list of variable names referred to in the following attributes
718
+ 1. "coordinates"
719
+ 2. "cell_measures"
720
+ 3. "ancillary_variables"
721
+ """
722
+ coords = []
723
+ attrs_or_encoding = ChainMap (self ._obj [name ].attrs , self ._obj [name ].encoding )
724
+
725
+ if "coordinates" in attrs_or_encoding :
726
+ coords .extend (attrs_or_encoding ["coordinates" ].split (" " ))
727
+
728
+ if "cell_measures" in attrs_or_encoding :
729
+ measures = [
730
+ _get_measure (self ._obj [name ], measure )
731
+ for measure in _CELL_MEASURES
732
+ if measure in attrs_or_encoding ["cell_measures" ]
733
+ ]
734
+ coords .extend (* measures )
735
+
736
+ if (
737
+ isinstance (self ._obj , Dataset )
738
+ and "ancillary_variables" in attrs_or_encoding
739
+ ):
740
+ anames = attrs_or_encoding ["ancillary_variables" ].split (" " )
741
+ coords .extend (anames )
742
+ return coords
743
+
718
744
def __getitem__ (self , key : Union [str , List [str ]]):
719
745
720
746
kind = str (type (self ._obj ).__name__ )
721
747
scalar_key = isinstance (key , str )
722
748
723
- if isinstance (self ._obj , xr . DataArray ) and not scalar_key :
749
+ if isinstance (self ._obj , DataArray ) and not scalar_key :
724
750
raise KeyError (
725
751
f"Cannot use a list of keys with DataArrays. Expected a single string. Received { key !r} instead."
726
752
)
@@ -741,7 +767,7 @@ def __getitem__(self, key: Union[str, List[str]]):
741
767
successful [k ] = bool (measure )
742
768
if measure :
743
769
varnames .extend (measure )
744
- elif not isinstance (self ._obj , xr . DataArray ):
770
+ elif not isinstance (self ._obj , DataArray ):
745
771
stdnames = _filter_by_standard_names (self ._obj , k )
746
772
successful [k ] = bool (stdnames )
747
773
varnames .extend (stdnames )
@@ -752,39 +778,16 @@ def __getitem__(self, key: Union[str, List[str]]):
752
778
varnames .extend ([k for k , v in successful .items () if not v ])
753
779
754
780
try :
755
- # TODO: make this a get_auxiliary_variables function
756
- # 1. set coordinate variables referred to in "coordinates" attribute
757
- # 2. set measures variables as coordinates
758
- # 3. set ancillary variables as coordinates
759
781
for name in varnames :
760
- attrs_or_encoding = ChainMap (
761
- self ._obj [name ].attrs , self ._obj [name ].encoding
762
- )
763
- if "coordinates" in attrs_or_encoding :
764
- coords .extend (attrs_or_encoding ["coordinates" ].split (" " ))
765
-
766
- if "cell_measures" in attrs_or_encoding :
767
- measures = [
768
- _get_measure (self ._obj [name ], measure )
769
- for measure in _CELL_MEASURES
770
- if measure in attrs_or_encoding ["cell_measures" ]
771
- ]
772
- coords .extend (* measures )
773
-
774
- if (
775
- isinstance (self ._obj , xr .Dataset )
776
- and "ancillary_variables" in attrs_or_encoding
777
- ):
778
- anames = attrs_or_encoding ["ancillary_variables" ].split (" " )
779
- coords .extend (anames )
782
+ coords .extend (self .get_associated_variable_names (name ))
780
783
781
- if isinstance (self ._obj , xr . DataArray ):
784
+ if isinstance (self ._obj , DataArray ):
782
785
ds = self ._obj ._to_temp_dataset ()
783
786
else :
784
787
ds = self ._obj
785
788
786
789
if scalar_key and len (varnames ) == 1 :
787
- da : xr . DataArray = ds [varnames [0 ]].reset_coords (drop = True ) # type: ignore
790
+ da : DataArray = ds [varnames [0 ]].reset_coords (drop = True ) # type: ignore
788
791
failed = []
789
792
for k1 in coords :
790
793
if k1 not in ds .variables :
@@ -821,18 +824,18 @@ def __getitem__(self, key: Union[str, List[str]]):
821
824
f"Use { kind } .cf.describe() to see a list of key names that can be interpreted."
822
825
)
823
826
824
- def _maybe_to_dataset (self , obj = None ) -> xr . Dataset :
827
+ def _maybe_to_dataset (self , obj = None ) -> Dataset :
825
828
if obj is None :
826
829
obj = self ._obj
827
- if isinstance (self ._obj , xr . DataArray ):
830
+ if isinstance (self ._obj , DataArray ):
828
831
return obj ._to_temp_dataset ()
829
832
else :
830
833
return obj
831
834
832
835
def _maybe_to_dataarray (self , obj = None ):
833
836
if obj is None :
834
837
obj = self ._obj
835
- if isinstance (self ._obj , xr . DataArray ):
838
+ if isinstance (self ._obj , DataArray ):
836
839
return self ._obj ._from_temp_dataset (obj )
837
840
else :
838
841
return obj
@@ -879,8 +882,8 @@ def add_bounds(self, dims: Union[Hashable, Iterable[Hashable]]):
879
882
return self ._maybe_to_dataarray (obj )
880
883
881
884
def rename_like (
882
- self , other : Union [xr . DataArray , xr . Dataset ]
883
- ) -> Union [xr . DataArray , xr . Dataset ]:
885
+ self , other : Union [DataArray , Dataset ]
886
+ ) -> Union [DataArray , Dataset ]:
884
887
"""
885
888
Renames variables in object to match names of like-variables in ``other``.
886
889
0 commit comments