1
- import copy
2
- import inspect
3
1
import functools
2
+ import inspect
3
+ from typing import Union
4
4
5
5
import xarray as xr
6
-
6
+ from xarray import DataArray , Dataset
7
7
8
8
_WRAPPED_CLASSES = (
9
9
xr .core .resample .Resample ,
@@ -18,69 +18,89 @@ def _get_axis_name_mapping(da: xr.DataArray):
18
18
return {"X" : "lon" , "Y" : "lat" , "T" : "time" }
19
19
20
20
21
- class _CFWrapped :
22
- def __init__ (self , towrap , accessor ):
21
+ def _getattr (
22
+ obj : Union [DataArray , Dataset ],
23
+ attr : str ,
24
+ accessor : "CFAccessor" ,
25
+ wrap_classes = False ,
26
+ keys = ("dim" ,),
27
+ ):
28
+ """
29
+ Common getattr functionality.
30
+
31
+ Parameters
32
+ ----------
33
+
34
+ obj : DataArray, Dataset
35
+ attr : Name of attribute in obj that will be shadowed.
36
+ accessor : High level accessor object: CFAccessor
37
+ wrap_classes: bool
38
+ Should we wrap the return value with _CFWrappedClass?
39
+ Only True for the high level CFAccessor.
40
+ Facilitates code reuse for _CFWrappedClass and _CFWrapppedPlotMethods
41
+ For both of thos, wrap_classes is False.
42
+ """
43
+ func = getattr (obj , attr )
44
+
45
+ @functools .wraps (func )
46
+ def wrapper (* args , ** kwargs ):
47
+ arguments = accessor ._process_signature (func , args , kwargs , keys = keys )
48
+ rv = func (** arguments )
49
+ if wrap_classes and isinstance (rv , _WRAPPED_CLASSES ):
50
+ return _CFWrappedClass (obj , rv , accessor )
51
+ else :
52
+ return rv
53
+
54
+ return wrapper
55
+
56
+
57
+ class _CFWrappedClass :
58
+ def __init__ (self , obj : Union [DataArray , Dataset ], towrap , accessor : "CFAccessor" ):
59
+ """
60
+
61
+ Parameters
62
+ ----------
63
+
64
+ obj : DataArray, Dataset
65
+ towrap : Resample, GroupBy, Coarsen, Rolling, Weighted
66
+ Instance of xarray class that is being wrapped.
67
+ accessor : CFAccessor
68
+ """
69
+ self ._obj = obj
23
70
self .wrapped = towrap
24
71
self .accessor = accessor
25
- self ._can_wrap_classes = False
26
72
27
73
def __repr__ (self ):
28
74
return "--- CF-xarray wrapped \n " + repr (self .wrapped )
29
75
30
76
def __getattr__ (self , attr ):
31
- func = getattr (self .wrapped , attr )
32
-
33
- @functools .wraps (func )
34
- def wrapper (* args , ** kwargs ):
35
- arguments = self .accessor ._process_signature (func , args , kwargs )
36
- rv = func (** arguments )
37
- return rv
38
-
39
- return wrapper
77
+ return _getattr (obj = self ._obj , attr = attr , accessor = self .accessor )
40
78
41
79
42
80
class _CFWrappedPlotMethods :
43
81
def __init__ (self , obj , accessor ):
44
82
self ._obj = obj
45
83
self .accessor = accessor
46
- self ._can_wrap_classes = False
84
+ self ._keys = ( "x" , "y" , "hue" , "col" , "row" )
47
85
48
86
def __call__ (self , * args , ** kwargs ):
49
- func = self ._obj .plot # (*args, **kwargs)
50
-
51
- @functools .wraps (func )
52
- def wrapper (* args , ** kwargs ):
53
- arguments = self .accessor ._process_signature (
54
- func , args , kwargs , keys = ("x" , "y" , "hue" , "col" , "row" )
55
- )
56
- print (arguments )
57
- rv = func (** arguments )
58
- return rv
59
-
60
- return wrapper (* args , ** kwargs )
87
+ return _getattr (
88
+ obj = self ._obj , attr = "plot" , accessor = self .accessor , keys = self ._keys
89
+ )
61
90
62
91
def __getattr__ (self , attr ):
63
- func = getattr (self ._obj .plot , attr )
64
-
65
- @functools .wraps (func )
66
- def wrapper (* args , ** kwargs ):
67
- arguments = self .accessor ._process_signature (
68
- func , args , kwargs , keys = ("x" , "y" , "hue" , "col" , "row" )
69
- )
70
- rv = func (** arguments )
71
- return rv
72
-
73
- return wrapper
92
+ return _getattr (
93
+ obj = self ._obj .plot , attr = attr , accessor = self .accessor , keys = self ._keys
94
+ )
74
95
75
96
76
97
@xr .register_dataarray_accessor ("cf" )
77
98
class CFAccessor :
78
99
def __init__ (self , da ):
79
100
self ._obj = da
80
101
self ._coords = _get_axis_name_mapping (da )
81
- self ._can_wrap_classes = True
82
102
83
- def _process_signature (self , func , args , kwargs , keys = ( "dim" ,) ):
103
+ def _process_signature (self , func , args , kwargs , keys ):
84
104
sig = inspect .signature (func , follow_wrapped = False )
85
105
86
106
# Catch things like .isel(T=5).
@@ -101,6 +121,7 @@ def _process_signature(self, func, args, kwargs, keys=("dim",)):
101
121
102
122
if arguments :
103
123
# now unwrap the **indexers_kwargs type arguments
124
+ # so that xarray can parse it :)
104
125
for kw in var_kws :
105
126
value = arguments .pop (kw , None )
106
127
if value :
@@ -128,8 +149,10 @@ def _rewrite_values_with_axis_names(self, kwargs, keys, var_kws):
128
149
129
150
kwargs .update (updates )
130
151
131
- # maybe the keys I'm looking for are in kwargs.
132
- # This happens with DataArray.plot() for example, where the signature is obscured.
152
+ # maybe the keys we are looking for are in kwargs.
153
+ # For example, this happens with DataArray.plot(),
154
+ # where the signature is obscured and kwargs is
155
+ # kwargs = {"x": "X", "col": "T"}
133
156
for vkw in var_kws :
134
157
if vkw in kwargs :
135
158
maybe_update = {
@@ -141,19 +164,8 @@ def _rewrite_values_with_axis_names(self, kwargs, keys, var_kws):
141
164
142
165
return kwargs
143
166
144
- def __getattr__ (self , name ):
145
- func = getattr (self ._obj , name )
146
-
147
- @functools .wraps (func )
148
- def wrapper (* args , ** kwargs ):
149
- arguments = self ._process_signature (func , args , kwargs )
150
- rv = func (** arguments )
151
- if isinstance (rv , _WRAPPED_CLASSES ):
152
- return _CFWrapped (rv , self )
153
- else :
154
- return rv
155
-
156
- return wrapper
167
+ def __getattr__ (self , attr ):
168
+ return _getattr (obj = self ._obj , attr = attr , accessor = self , wrap_classes = True )
157
169
158
170
@property
159
171
def plot (self ):
0 commit comments