Skip to content

Commit e628079

Browse files
authored
Rework plotting heuristics. (#251)
1 parent 30a577d commit e628079

File tree

2 files changed

+26
-15
lines changed

2 files changed

+26
-15
lines changed

cf_xarray/accessor.py

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -736,7 +736,7 @@ def check_results(names, key):
736736
)
737737

738738

739-
def _possible_x_y_plot(obj, key):
739+
def _possible_x_y_plot(obj, key, skip=None):
740740
"""Guesses a name for an x/y variable if possible."""
741741
# in priority order
742742
x_criteria = [
@@ -759,11 +759,20 @@ def _get_possible(accessor, criteria):
759759
from xarray.core.utils import is_scalar
760760

761761
for attr, key in criteria:
762-
value = getattr(accessor, attr).get(key)
763-
if not value or len(value) > 1:
762+
values = getattr(accessor, attr).get(key)
763+
ax_coord_name = getattr(accessor, attr).get(key)
764+
if not values:
764765
continue
765-
if not is_scalar(accessor._obj[value[0]]):
766-
return value[0]
766+
elif ax_coord_name:
767+
values = [v for v in values if v in ax_coord_name]
768+
769+
values = [v for v in values if v != skip]
770+
if len(values) == 1 and not is_scalar(accessor._obj[values[0]]):
771+
return values[0]
772+
else:
773+
for v in values:
774+
if not is_scalar(accessor._obj[v]):
775+
return v
767776
return None
768777

769778
if key == "x":
@@ -825,9 +834,9 @@ def _plot_decorator(self, func):
825834

826835
@functools.wraps(func)
827836
def _plot_wrapper(*args, **kwargs):
828-
def _process_x_or_y(kwargs, key):
837+
def _process_x_or_y(kwargs, key, skip=None):
829838
if key not in kwargs:
830-
kwargs[key] = _possible_x_y_plot(self._obj, key)
839+
kwargs[key] = _possible_x_y_plot(self._obj, key, skip)
831840

832841
value = kwargs.get(key)
833842
if value:
@@ -847,13 +856,15 @@ def _process_x_or_y(kwargs, key):
847856
and (kwargs.get("hue") or self._obj.ndim == 1)
848857
)
849858
if is_line_plot:
850-
if not kwargs.get("hue"):
851-
kwargs = _process_x_or_y(kwargs, "x")
859+
hue = kwargs.get("hue")
860+
if "x" not in kwargs and "y" not in kwargs:
861+
kwargs = _process_x_or_y(kwargs, "x", skip=hue)
852862
if not kwargs.get("x"):
853-
kwargs = _process_x_or_y(kwargs, "y")
863+
kwargs = _process_x_or_y(kwargs, "y", skip=hue)
864+
854865
else:
855-
kwargs = _process_x_or_y(kwargs, "x")
856-
kwargs = _process_x_or_y(kwargs, "y")
866+
kwargs = _process_x_or_y(kwargs, "x", skip=kwargs.get("y"))
867+
kwargs = _process_x_or_y(kwargs, "y", skip=kwargs.get("x"))
857868

858869
return func(*args, **kwargs)
859870

cf_xarray/tests/test_accessor.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -496,9 +496,9 @@ def test_dataarray_plot():
496496
np.testing.assert_equal(rv[0].get_xdata(), obj.lon.data)
497497
plt.close()
498498

499-
# rv = obj.cf.isel(T=1, Y=[0, 1, 2]).cf.plot(hue="Y")
500-
# np.testing.assert_equal(rv[0].get_xdata(), obj.lon.data)
501-
# plt.close()
499+
rv = obj.cf.isel(T=1, Y=[0, 1, 2]).cf.plot(hue="Y")
500+
np.testing.assert_equal(rv[0].get_xdata(), obj.lon.data)
501+
plt.close()
502502

503503
rv = obj.cf.isel(T=1, Y=[0, 1, 2]).cf.plot.line()
504504
np.testing.assert_equal(rv[0].get_xdata(), obj.lon.data)

0 commit comments

Comments
 (0)