Skip to content

Commit 13847b2

Browse files
committed
cleanup
1 parent 9a1f2f0 commit 13847b2

File tree

1 file changed

+68
-56
lines changed

1 file changed

+68
-56
lines changed

cf_xarray/accessor.py

Lines changed: 68 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
1-
import copy
2-
import inspect
31
import functools
2+
import inspect
3+
from typing import Union
44

55
import xarray as xr
6-
6+
from xarray import DataArray, Dataset
77

88
_WRAPPED_CLASSES = (
99
xr.core.resample.Resample,
@@ -18,69 +18,89 @@ def _get_axis_name_mapping(da: xr.DataArray):
1818
return {"X": "lon", "Y": "lat", "T": "time"}
1919

2020

21-
class _CFWrapped:
22-
def __init__(self, towrap, accessor):
21+
def _getattr(
22+
obj: Union[DataArray, Dataset],
23+
attr: str,
24+
accessor: "CFAccessor",
25+
wrap_classes=False,
26+
keys=("dim",),
27+
):
28+
"""
29+
Common getattr functionality.
30+
31+
Parameters
32+
----------
33+
34+
obj : DataArray, Dataset
35+
attr : Name of attribute in obj that will be shadowed.
36+
accessor : High level accessor object: CFAccessor
37+
wrap_classes: bool
38+
Should we wrap the return value with _CFWrappedClass?
39+
Only True for the high level CFAccessor.
40+
Facilitates code reuse for _CFWrappedClass and _CFWrapppedPlotMethods
41+
For both of thos, wrap_classes is False.
42+
"""
43+
func = getattr(obj, attr)
44+
45+
@functools.wraps(func)
46+
def wrapper(*args, **kwargs):
47+
arguments = accessor._process_signature(func, args, kwargs, keys=keys)
48+
rv = func(**arguments)
49+
if wrap_classes and isinstance(rv, _WRAPPED_CLASSES):
50+
return _CFWrappedClass(obj, rv, accessor)
51+
else:
52+
return rv
53+
54+
return wrapper
55+
56+
57+
class _CFWrappedClass:
58+
def __init__(self, obj: Union[DataArray, Dataset], towrap, accessor: "CFAccessor"):
59+
"""
60+
61+
Parameters
62+
----------
63+
64+
obj : DataArray, Dataset
65+
towrap : Resample, GroupBy, Coarsen, Rolling, Weighted
66+
Instance of xarray class that is being wrapped.
67+
accessor : CFAccessor
68+
"""
69+
self._obj = obj
2370
self.wrapped = towrap
2471
self.accessor = accessor
25-
self._can_wrap_classes = False
2672

2773
def __repr__(self):
2874
return "--- CF-xarray wrapped \n" + repr(self.wrapped)
2975

3076
def __getattr__(self, attr):
31-
func = getattr(self.wrapped, attr)
32-
33-
@functools.wraps(func)
34-
def wrapper(*args, **kwargs):
35-
arguments = self.accessor._process_signature(func, args, kwargs)
36-
rv = func(**arguments)
37-
return rv
38-
39-
return wrapper
77+
return _getattr(obj=self._obj, attr=attr, accessor=self.accessor)
4078

4179

4280
class _CFWrappedPlotMethods:
4381
def __init__(self, obj, accessor):
4482
self._obj = obj
4583
self.accessor = accessor
46-
self._can_wrap_classes = False
84+
self._keys = ("x", "y", "hue", "col", "row")
4785

4886
def __call__(self, *args, **kwargs):
49-
func = self._obj.plot # (*args, **kwargs)
50-
51-
@functools.wraps(func)
52-
def wrapper(*args, **kwargs):
53-
arguments = self.accessor._process_signature(
54-
func, args, kwargs, keys=("x", "y", "hue", "col", "row")
55-
)
56-
print(arguments)
57-
rv = func(**arguments)
58-
return rv
59-
60-
return wrapper(*args, **kwargs)
87+
return _getattr(
88+
obj=self._obj, attr="plot", accessor=self.accessor, keys=self._keys
89+
)
6190

6291
def __getattr__(self, attr):
63-
func = getattr(self._obj.plot, attr)
64-
65-
@functools.wraps(func)
66-
def wrapper(*args, **kwargs):
67-
arguments = self.accessor._process_signature(
68-
func, args, kwargs, keys=("x", "y", "hue", "col", "row")
69-
)
70-
rv = func(**arguments)
71-
return rv
72-
73-
return wrapper
92+
return _getattr(
93+
obj=self._obj.plot, attr=attr, accessor=self.accessor, keys=self._keys
94+
)
7495

7596

7697
@xr.register_dataarray_accessor("cf")
7798
class CFAccessor:
7899
def __init__(self, da):
79100
self._obj = da
80101
self._coords = _get_axis_name_mapping(da)
81-
self._can_wrap_classes = True
82102

83-
def _process_signature(self, func, args, kwargs, keys=("dim",)):
103+
def _process_signature(self, func, args, kwargs, keys):
84104
sig = inspect.signature(func, follow_wrapped=False)
85105

86106
# Catch things like .isel(T=5).
@@ -101,6 +121,7 @@ def _process_signature(self, func, args, kwargs, keys=("dim",)):
101121

102122
if arguments:
103123
# now unwrap the **indexers_kwargs type arguments
124+
# so that xarray can parse it :)
104125
for kw in var_kws:
105126
value = arguments.pop(kw, None)
106127
if value:
@@ -128,8 +149,10 @@ def _rewrite_values_with_axis_names(self, kwargs, keys, var_kws):
128149

129150
kwargs.update(updates)
130151

131-
# maybe the keys I'm looking for are in kwargs.
132-
# This happens with DataArray.plot() for example, where the signature is obscured.
152+
# maybe the keys we are looking for are in kwargs.
153+
# For example, this happens with DataArray.plot(),
154+
# where the signature is obscured and kwargs is
155+
# kwargs = {"x": "X", "col": "T"}
133156
for vkw in var_kws:
134157
if vkw in kwargs:
135158
maybe_update = {
@@ -141,19 +164,8 @@ def _rewrite_values_with_axis_names(self, kwargs, keys, var_kws):
141164

142165
return kwargs
143166

144-
def __getattr__(self, name):
145-
func = getattr(self._obj, name)
146-
147-
@functools.wraps(func)
148-
def wrapper(*args, **kwargs):
149-
arguments = self._process_signature(func, args, kwargs)
150-
rv = func(**arguments)
151-
if isinstance(rv, _WRAPPED_CLASSES):
152-
return _CFWrapped(rv, self)
153-
else:
154-
return rv
155-
156-
return wrapper
167+
def __getattr__(self, attr):
168+
return _getattr(obj=self._obj, attr=attr, accessor=self, wrap_classes=True)
157169

158170
@property
159171
def plot(self):

0 commit comments

Comments
 (0)