Skip to content

Commit 9a1f2f0

Browse files
committed
plotting methods work.
This was painful because inspect.signature(DataArray.plot) just returns **kwargs. To get the signature for plot.contour and friends, I needed to use inspect.signature(DataArray.plot.contour, follow_wrapped=False)
1 parent 11d9c9f commit 9a1f2f0

File tree

1 file changed

+62
-16
lines changed

1 file changed

+62
-16
lines changed

cf_xarray/accessor.py

Lines changed: 62 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ class _CFWrapped:
2222
def __init__(self, towrap, accessor):
2323
self.wrapped = towrap
2424
self.accessor = accessor
25+
self._can_wrap_classes = False
2526

2627
def __repr__(self):
2728
return "--- CF-xarray wrapped \n" + repr(self.wrapped)
@@ -38,14 +39,49 @@ def wrapper(*args, **kwargs):
3839
return wrapper
3940

4041

42+
class _CFWrappedPlotMethods:
43+
def __init__(self, obj, accessor):
44+
self._obj = obj
45+
self.accessor = accessor
46+
self._can_wrap_classes = False
47+
48+
def __call__(self, *args, **kwargs):
49+
func = self._obj.plot # (*args, **kwargs)
50+
51+
@functools.wraps(func)
52+
def wrapper(*args, **kwargs):
53+
arguments = self.accessor._process_signature(
54+
func, args, kwargs, keys=("x", "y", "hue", "col", "row")
55+
)
56+
print(arguments)
57+
rv = func(**arguments)
58+
return rv
59+
60+
return wrapper(*args, **kwargs)
61+
62+
def __getattr__(self, attr):
63+
func = getattr(self._obj.plot, attr)
64+
65+
@functools.wraps(func)
66+
def wrapper(*args, **kwargs):
67+
arguments = self.accessor._process_signature(
68+
func, args, kwargs, keys=("x", "y", "hue", "col", "row")
69+
)
70+
rv = func(**arguments)
71+
return rv
72+
73+
return wrapper
74+
75+
4176
@xr.register_dataarray_accessor("cf")
4277
class CFAccessor:
4378
def __init__(self, da):
4479
self._obj = da
4580
self._coords = _get_axis_name_mapping(da)
81+
self._can_wrap_classes = True
4682

47-
def _process_signature(self, func, args, kwargs):
48-
sig = inspect.signature(func)
83+
def _process_signature(self, func, args, kwargs, keys=("dim",)):
84+
sig = inspect.signature(func, follow_wrapped=False)
4985

5086
# Catch things like .isel(T=5).
5187
# This assigns indexers_kwargs=dict(T=5).
@@ -55,10 +91,13 @@ def _process_signature(self, func, args, kwargs):
5591
if sig.parameters[param].kind is inspect.Parameter.VAR_KEYWORD:
5692
var_kws.append(param)
5793

58-
bound = sig.bind(*args, **kwargs)
59-
arguments = self._rewrite_values_with_axis_names(
60-
bound.arguments, ["dim",] + var_kws
61-
)
94+
if args or kwargs:
95+
bound = sig.bind(*args, **kwargs)
96+
arguments = self._rewrite_values_with_axis_names(
97+
bound.arguments, keys, tuple(var_kws)
98+
)
99+
else:
100+
arguments = {}
62101

63102
if arguments:
64103
# now unwrap the **indexers_kwargs type arguments
@@ -69,10 +108,10 @@ def _process_signature(self, func, args, kwargs):
69108

70109
return arguments
71110

72-
def _rewrite_values_with_axis_names(self, kwargs, keys):
111+
def _rewrite_values_with_axis_names(self, kwargs, keys, var_kws):
73112
""" rewrites 'dim' for example. """
74113
updates = {}
75-
for key in keys:
114+
for key in tuple(keys) + tuple(var_kws):
76115
value = kwargs.get(key, None)
77116
if value:
78117
if isinstance(value, str):
@@ -88,6 +127,18 @@ def _rewrite_values_with_axis_names(self, kwargs, keys):
88127
updates[key] = updates[key][0]
89128

90129
kwargs.update(updates)
130+
131+
# maybe the keys I'm looking for are in kwargs.
132+
# This happens with DataArray.plot() for example, where the signature is obscured.
133+
for vkw in var_kws:
134+
if vkw in kwargs:
135+
maybe_update = {
136+
k: self._coords.get(v, v)
137+
for k, v in kwargs[vkw].items()
138+
if k in keys
139+
}
140+
kwargs[vkw].update(maybe_update)
141+
91142
return kwargs
92143

93144
def __getattr__(self, name):
@@ -104,11 +155,6 @@ def wrapper(*args, **kwargs):
104155

105156
return wrapper
106157

107-
def plot(self, *args, **kwargs):
108-
if args:
109-
raise ValueError("cf.plot can only be called with keyword arguments.")
110-
111-
kwargs = self._rewrite_values_with_axis_names(
112-
kwargs, ("x", "y", "hue", "col", "row")
113-
)
114-
return self._obj.plot(*args, **kwargs)
158+
@property
159+
def plot(self):
160+
return _CFWrappedPlotMethods(self._obj, self)

0 commit comments

Comments
 (0)