Skip to content

Commit 7c338c3

Browse files
authored
Set xincrease/yincrease in more cases. (#287)
1 parent 7f95068 commit 7c338c3

File tree

2 files changed

+34
-25
lines changed

2 files changed

+34
-25
lines changed

cf_xarray/accessor.py

Lines changed: 29 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -822,48 +822,54 @@ def __init__(self, obj, accessor):
822822
self.accessor = accessor
823823
self._keys = ("x", "y", "hue", "col", "row")
824824

825+
def _process_x_or_y(self, kwargs, key, skip=None):
826+
"""Choose a default 'x' or 'y' variable name."""
827+
if key not in kwargs:
828+
kwargs[key] = _possible_x_y_plot(self._obj, key, skip)
829+
return kwargs
830+
831+
def _set_axis_props(self, kwargs, key):
832+
value = kwargs.get(key)
833+
if value:
834+
if value in self.accessor.keys():
835+
var = self.accessor[value]
836+
else:
837+
var = self._obj[value]
838+
if "positive" in var.attrs:
839+
if var.attrs["positive"] == "down":
840+
kwargs.setdefault(f"{key}increase", False)
841+
else:
842+
kwargs.setdefault(f"{key}increase", True)
843+
return kwargs
844+
825845
def _plot_decorator(self, func):
826846
"""
827847
This decorator is used to set default kwargs on plotting functions.
828848
For now, this can
829849
1. set ``xincrease`` and ``yincrease``.
830850
2. automatically set ``x`` or ``y``.
831851
"""
832-
valid_keys = self.accessor.keys()
833852

834853
@functools.wraps(func)
835854
def _plot_wrapper(*args, **kwargs):
836-
def _process_x_or_y(kwargs, key, skip=None):
837-
if key not in kwargs:
838-
kwargs[key] = _possible_x_y_plot(self._obj, key, skip)
839-
840-
value = kwargs.get(key)
841-
if value:
842-
if value in valid_keys:
843-
var = self.accessor[value]
844-
else:
845-
var = self._obj[value]
846-
if "positive" in var.attrs:
847-
if var.attrs["positive"] == "down":
848-
kwargs.setdefault(f"{key}increase", False)
849-
else:
850-
kwargs.setdefault(f"{key}increase", True)
851-
return kwargs
852-
855+
# First choose 'x' or 'y' if possible
853856
is_line_plot = (func.__name__ == "line") or (
854857
func.__name__ == "wrapper"
855858
and (kwargs.get("hue") or self._obj.ndim == 1)
856859
)
857860
if is_line_plot:
858861
hue = kwargs.get("hue")
859862
if "x" not in kwargs and "y" not in kwargs:
860-
kwargs = _process_x_or_y(kwargs, "x", skip=hue)
863+
kwargs = self._process_x_or_y(kwargs, "x", skip=hue)
861864
if not kwargs.get("x"):
862-
kwargs = _process_x_or_y(kwargs, "y", skip=hue)
863-
865+
kwargs = self._process_x_or_y(kwargs, "y", skip=hue)
864866
else:
865-
kwargs = _process_x_or_y(kwargs, "x", skip=kwargs.get("y"))
866-
kwargs = _process_x_or_y(kwargs, "y", skip=kwargs.get("x"))
867+
kwargs = self._process_x_or_y(kwargs, "x", skip=kwargs.get("y"))
868+
kwargs = self._process_x_or_y(kwargs, "y", skip=kwargs.get("x"))
869+
870+
# Now set some nice properties
871+
kwargs = self._set_axis_props(kwargs, "x")
872+
kwargs = self._set_axis_props(kwargs, "y")
867873

868874
return func(*args, **kwargs)
869875

cf_xarray/tests/test_accessor.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -509,7 +509,7 @@ def test_dataarray_getitem():
509509

510510
def test_dataarray_plot():
511511

512-
obj = airds.air
512+
obj = airds.air.copy(deep=True)
513513

514514
rv = obj.isel(time=1).transpose("lon", "lat").cf.plot()
515515
assert isinstance(rv, mpl.collections.QuadMesh)
@@ -554,15 +554,18 @@ def test_dataarray_plot():
554554
np.testing.assert_equal(rv[0].get_xdata(), obj.lon.data)
555555
plt.close()
556556

557+
obj.lon.attrs["positive"] = "down"
557558
rv = obj.cf.isel(T=1, Y=[0, 1, 2]).cf.plot(hue="Y")
558559
np.testing.assert_equal(rv[0].get_xdata(), obj.lon.data)
560+
xlim = rv[0].axes.get_xlim()
561+
assert xlim[0] > xlim[1]
559562
plt.close()
563+
del obj.lon.attrs["positive"]
560564

561565
rv = obj.cf.isel(T=1, Y=[0, 1, 2]).cf.plot.line()
562566
np.testing.assert_equal(rv[0].get_xdata(), obj.lon.data)
563567
plt.close()
564568

565-
obj = obj.copy(deep=True)
566569
obj.time.attrs.clear()
567570
rv = obj.cf.plot(x="X", y="Y", col="time")
568571
assert isinstance(rv, xr.plot.FacetGrid)

0 commit comments

Comments
 (0)