Skip to content

Commit 9d32c18

Browse files
authored
Automatically choose x,y for plots (#148)
1 parent 654a709 commit 9d32c18

File tree

3 files changed

+126
-23
lines changed

3 files changed

+126
-23
lines changed

cf_xarray/accessor.py

Lines changed: 62 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -655,6 +655,42 @@ def check_results(names, k):
655655
)
656656

657657

658+
def _possible_x_y_plot(obj, key):
659+
"""Guesses a name for an x/y variable if possible."""
660+
# in priority order
661+
x_criteria = [
662+
("coordinates", "longitude"),
663+
("axes", "X"),
664+
("coordinates", "time"),
665+
("axes", "T"),
666+
]
667+
y_criteria = [
668+
("coordinates", "vertical"),
669+
("axes", "Z"),
670+
("coordinates", "latitude"),
671+
("axes", "Y"),
672+
]
673+
674+
def _get_possible(accessor, criteria):
675+
# is_scalar depends on NON_NUMPY_SUPPORTED_TYPES
676+
# importing a private function seems better than
677+
# maintaining that variable!
678+
from xarray.core.utils import is_scalar
679+
680+
for attr, key in criteria:
681+
value = getattr(accessor, attr).get(key)
682+
if not value or len(value) > 1:
683+
continue
684+
if not is_scalar(accessor._obj[value[0]]):
685+
return value[0]
686+
return None
687+
688+
if key == "x":
689+
return _get_possible(obj.cf, x_criteria)
690+
elif key == "y":
691+
return _get_possible(obj.cf, y_criteria)
692+
693+
658694
class _CFWrappedClass:
659695
"""
660696
This class is used to wrap any class in _WRAPPED_CLASSES.
@@ -705,27 +741,34 @@ def _plot_decorator(self, func):
705741

706742
@functools.wraps(func)
707743
def _plot_wrapper(*args, **kwargs):
708-
if "x" in kwargs:
709-
if kwargs["x"] in valid_keys:
710-
xvar = self.accessor[kwargs["x"]]
711-
else:
712-
xvar = self._obj[kwargs["x"]]
713-
if "positive" in xvar.attrs:
714-
if xvar.attrs["positive"] == "down":
715-
kwargs.setdefault("xincrease", False)
716-
else:
717-
kwargs.setdefault("xincrease", True)
744+
def _process_x_or_y(kwargs, key):
745+
if key not in kwargs:
746+
kwargs[key] = _possible_x_y_plot(self._obj, key)
718747

719-
if "y" in kwargs:
720-
if kwargs["y"] in valid_keys:
721-
yvar = self.accessor[kwargs["y"]]
722-
else:
723-
yvar = self._obj[kwargs["y"]]
724-
if "positive" in yvar.attrs:
725-
if yvar.attrs["positive"] == "down":
726-
kwargs.setdefault("yincrease", False)
748+
value = kwargs.get(key)
749+
if value:
750+
if value in valid_keys:
751+
var = self.accessor[value]
727752
else:
728-
kwargs.setdefault("yincrease", True)
753+
var = self._obj[value]
754+
if "positive" in var.attrs:
755+
if var.attrs["positive"] == "down":
756+
kwargs.setdefault(f"{key}increase", False)
757+
else:
758+
kwargs.setdefault(f"{key}increase", True)
759+
return kwargs
760+
761+
is_line_plot = (func.__name__ == "line") or (
762+
func.__name__ == "wrapper" and kwargs.get("hue")
763+
)
764+
if is_line_plot:
765+
if not kwargs.get("hue"):
766+
kwargs = _process_x_or_y(kwargs, "x")
767+
if not kwargs.get("x"):
768+
kwargs = _process_x_or_y(kwargs, "y")
769+
else:
770+
kwargs = _process_x_or_y(kwargs, "x")
771+
kwargs = _process_x_or_y(kwargs, "y")
729772

730773
return func(*args, **kwargs)
731774

cf_xarray/tests/test_accessor.py

Lines changed: 63 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -293,15 +293,20 @@ def test_dataarray_getitem():
293293
assert_identical(air.cf["area_grid_cell"], air.cell_area.reset_coords(drop=True))
294294

295295

296-
@pytest.mark.parametrize("obj", dataarrays)
297-
def test_dataarray_plot(obj):
296+
def test_dataarray_plot():
297+
298+
obj = airds.air
298299

299-
rv = obj.isel(time=1).cf.plot(x="X", y="Y")
300+
rv = obj.isel(time=1).transpose("lon", "lat").cf.plot()
300301
assert isinstance(rv, mpl.collections.QuadMesh)
302+
assert all(v > 180 for v in rv.axes.get_xlim())
303+
assert all(v < 200 for v in rv.axes.get_ylim())
301304
plt.close()
302305

303-
rv = obj.isel(time=1).cf.plot.contourf(x="X", y="Y")
306+
rv = obj.isel(time=1).transpose("lon", "lat").cf.plot.contourf()
304307
assert isinstance(rv, mpl.contour.QuadContourSet)
308+
assert all(v > 180 for v in rv.axes.get_xlim())
309+
assert all(v < 200 for v in rv.axes.get_ylim())
305310
plt.close()
306311

307312
rv = obj.cf.plot(x="X", y="Y", col="T")
@@ -316,6 +321,29 @@ def test_dataarray_plot(obj):
316321
assert all([isinstance(line, mpl.lines.Line2D) for line in rv])
317322
plt.close()
318323

324+
# set y automatically
325+
rv = obj.isel(time=0, lon=1).cf.plot.line()
326+
np.testing.assert_equal(rv[0].get_ydata(), obj.lat.data)
327+
plt.close()
328+
329+
# don't set y automatically
330+
rv = obj.isel(time=0, lon=1).cf.plot.line(x="lat")
331+
np.testing.assert_equal(rv[0].get_xdata(), obj.lat.data)
332+
plt.close()
333+
334+
# various line plots and automatic guessing
335+
rv = obj.cf.isel(T=1, Y=[0, 1, 2]).cf.plot.line()
336+
np.testing.assert_equal(rv[0].get_xdata(), obj.lon.data)
337+
plt.close()
338+
339+
# rv = obj.cf.isel(T=1, Y=[0, 1, 2]).cf.plot(hue="Y")
340+
# np.testing.assert_equal(rv[0].get_xdata(), obj.lon.data)
341+
# plt.close()
342+
343+
rv = obj.cf.isel(T=1, Y=[0, 1, 2]).cf.plot.line()
344+
np.testing.assert_equal(rv[0].get_xdata(), obj.lon.data)
345+
plt.close()
346+
319347
obj = obj.copy(deep=True)
320348
obj.time.attrs.clear()
321349
rv = obj.cf.plot(x="X", y="Y", col="time")
@@ -714,3 +742,34 @@ def test_drop_dims(ds):
714742
# Axis and coordinate
715743
for cf_name in ["X", "longitude"]:
716744
assert_identical(ds.drop_dims("lon"), ds.cf.drop_dims(cf_name))
745+
746+
747+
def test_possible_x_y_plot():
748+
from ..accessor import _possible_x_y_plot
749+
750+
# choose axes
751+
assert _possible_x_y_plot(airds.air.isel(time=1), "x") == "lon"
752+
assert _possible_x_y_plot(airds.air.isel(time=1), "y") == "lat"
753+
assert _possible_x_y_plot(airds.air.isel(lon=1), "y") == "lat"
754+
assert _possible_x_y_plot(airds.air.isel(lon=1), "x") == "time"
755+
756+
# choose coordinates over axes
757+
assert _possible_x_y_plot(popds.UVEL, "x") == "ULONG"
758+
assert _possible_x_y_plot(popds.UVEL, "y") == "ULAT"
759+
assert _possible_x_y_plot(popds.TEMP, "x") == "TLONG"
760+
assert _possible_x_y_plot(popds.TEMP, "y") == "TLAT"
761+
762+
assert _possible_x_y_plot(popds.UVEL.drop_vars("ULONG"), "x") == "nlon"
763+
764+
# choose X over T, Y over Z
765+
def makeds(*dims):
766+
coords = {dim: (dim, np.arange(3), {"axis": dim}) for dim in dims}
767+
return xr.DataArray(np.zeros((3, 3)), dims=dims, coords=coords)
768+
769+
yzds = makeds("Y", "Z")
770+
assert _possible_x_y_plot(yzds, "y") == "Z"
771+
assert _possible_x_y_plot(yzds, "x") is None
772+
773+
xtds = makeds("X", "T")
774+
assert _possible_x_y_plot(xtds, "y") is None
775+
assert _possible_x_y_plot(xtds, "x") == "X"

doc/whats-new.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ What's New
66
v0.4.1 (unreleased)
77
===================
88

9+
- Automatically set ``x`` or ``y`` for :py:attr:`DataArray.cf.plot`. By `Deepak Cherian`_.
910
- Added scripts to document :ref:`criteria` with tables. By `Mattia Almansi`_.
1011
- Support for ``.drop()``, ``.drop_vars()``, ``.drop_sel()``, ``.drop_dims()``, ``.set_coords()``, ``.reset_coords()``. By `Mattia Almansi`_.
1112
- Support for using ``standard_name`` in more functions. (:pr:`128`) By `Deepak Cherian`_

0 commit comments

Comments
 (0)