diff --git a/cf_xarray/accessor.py b/cf_xarray/accessor.py index 52d525b5..8611c322 100644 --- a/cf_xarray/accessor.py +++ b/cf_xarray/accessor.py @@ -1,6 +1,7 @@ from __future__ import annotations import functools +import importlib import inspect import itertools import re @@ -48,6 +49,14 @@ ) +if importlib.util.find_spec("cartopy"): + # pyproj is a dep of cartopy + import cartopy.crs + import pyproj +else: + pyproj = None + + from . import parametric, sgrid from .criteria import ( _DSG_ROLES, @@ -76,6 +85,7 @@ always_iterable, emit_user_level_warning, invert_mappings, + is_latitude_longitude, parse_cell_methods_attr, parse_cf_standard_name_table, ) @@ -1092,6 +1102,10 @@ def _plot_wrapper(*args, **kwargs): func.__name__ == "wrapper" and (kwargs.get("hue") or self._obj.ndim == 1) ) + is_grid_plot = (func.__name__ in ["contour", "countourf", "pcolormsh"]) or ( + func.__name__ == "wrapper" + and (self._obj.ndim - sum(["col" in kwargs, "row" in kwargs])) == 2 + ) if is_line_plot: hue = kwargs.get("hue") if "x" not in kwargs and "y" not in kwargs: @@ -1101,6 +1115,20 @@ def _plot_wrapper(*args, **kwargs): else: kwargs = self._process_x_or_y(kwargs, "x", skip=kwargs.get("y")) kwargs = self._process_x_or_y(kwargs, "y", skip=kwargs.get("x")) + if is_grid_plot and pyproj is not None: + from cartopy.mpl.geoaxes import GeoAxes + + ax = kwargs.get("ax") + if ax is None or isinstance(ax, GeoAxes): + try: + kwargs["transform"] = self._obj.cf.cartopy_crs + except ValueError: + pass + else: + if ax is None: + kwargs.setdefault("subplot_kws", {}).setdefault( + "projection", kwargs["transform"] + ) # Now set some nice properties kwargs = self._set_axis_props(kwargs, "x") @@ -2745,6 +2773,24 @@ def grid_mapping_names(self) -> dict[str, list[str]]: results[v].append(k) return results + @property + def cartopy_crs(self): + """Cartopy CRS of the dataset's grid mapping.""" + if pyproj is None: + raise ImportError( + "`crs` accessor requires optional packages `pyproj` and `cartopy`." + ) + gmaps = list(itertools.chain(*self.grid_mapping_names.values())) + if len(gmaps) > 1: + raise ValueError("Multiple grid mappings found.") + if len(gmaps) == 0: + if is_latitude_longitude(self._obj): + return cartopy.crs.PlateCarree() + raise ValueError( + "No grid mapping found and dataset guessed as not latitude_longitude." + ) + return cartopy.crs.Projection(pyproj.CRS.from_cf(self._obj[gmaps[0]].attrs)) + def decode_vertical_coords( self, *, outnames: dict[str, str] | None = None, prefix: str | None = None ) -> None: @@ -2899,6 +2945,21 @@ def formula_terms(self) -> dict[str, str]: # numpydoc ignore=SS06 terms[key] = value return terms + def _get_grid_mapping(self, ignore_missing=False) -> DataArray | None: + da = self._obj + + attrs_or_encoding = ChainMap(da.attrs, da.encoding) + grid_mapping = attrs_or_encoding.get("grid_mapping", None) + if not grid_mapping: + if ignore_missing: + return None + raise ValueError("No 'grid_mapping' attribute present.") + + if grid_mapping not in da._coords: + raise ValueError(f"Grid Mapping variable {grid_mapping} not present.") + + return da[grid_mapping] + @property def grid_mapping_name(self) -> str: """ @@ -2919,20 +2980,25 @@ def grid_mapping_name(self) -> str: >>> rotds.cf["temp"].cf.grid_mapping_name 'rotated_latitude_longitude' """ + grid_mapping_var = self._get_grid_mapping() + return grid_mapping_var.attrs["grid_mapping_name"] - da = self._obj - - attrs_or_encoding = ChainMap(da.attrs, da.encoding) - grid_mapping = attrs_or_encoding.get("grid_mapping", None) - if not grid_mapping: - raise ValueError("No 'grid_mapping' attribute present.") - - if grid_mapping not in da._coords: - raise ValueError(f"Grid Mapping variable {grid_mapping} not present.") - - grid_mapping_var = da[grid_mapping] + @property + def cartopy_crs(self): + """Cartopy CRS of the dataset's grid mapping.""" + if pyproj is None: + raise ImportError( + "`crs` accessor requires optional packages `pyproj` and `cartopy`." + ) - return grid_mapping_var.attrs["grid_mapping_name"] + grid_mapping_var = self._get_grid_mapping(ignore_missing=True) + if grid_mapping_var is None: + if is_latitude_longitude(self._obj): + return cartopy.crs.PlateCarree() + raise ValueError( + "No grid mapping found and dataset guesses as not latitude_longitude." + ) + return cartopy.crs.Projection(pyproj.CRS.from_cf(grid_mapping_var.attrs)) def __getitem__(self, key: Hashable | Iterable[Hashable]) -> DataArray: """ diff --git a/cf_xarray/tests/__init__.py b/cf_xarray/tests/__init__.py index 8c83df3a..8665233e 100644 --- a/cf_xarray/tests/__init__.py +++ b/cf_xarray/tests/__init__.py @@ -69,3 +69,4 @@ def LooseVersion(vstring): has_pooch, requires_pooch = _importorskip("pooch") _, requires_rich = _importorskip("rich") has_regex, requires_regex = _importorskip("regex") +has_cartopy, requires_cartopy = _importorskip("cartopy") diff --git a/cf_xarray/tests/test_accessor.py b/cf_xarray/tests/test_accessor.py index c3f7005a..3cf02cb0 100644 --- a/cf_xarray/tests/test_accessor.py +++ b/cf_xarray/tests/test_accessor.py @@ -41,6 +41,7 @@ ) from . import ( raise_if_dask_computes, + requires_cartopy, requires_cftime, requires_pint, requires_pooch, @@ -1084,6 +1085,34 @@ def test_bad_grid_mapping_attribute(): ds.cf.get_associated_variable_names("temp", error=False) +@requires_cartopy +def test_crs() -> None: + import cartopy.crs as ccrs + from pyproj import CRS + + # Dataset with explicit grid mapping + # ccrs.RotatedPole is not the same as CRS.from_cf(rotated_pole)... + # They are equivalent though, but specified differently + exp = ccrs.Projection(CRS.from_cf(rotds.rotated_pole.attrs)) + assert rotds.cf.crs == exp + with pytest.raises( + ValueError, match="Grid Mapping variable rotated_pole not present" + ): + rotds.temp.cf.crs + assert rotds.set_coords("rotated_pole").temp.cf.crs == exp + + # Dataset with regular latlon (no grid mapping ) + exp = ccrs.PlateCarree() + assert forecast.cf.crs == exp + assert forecast.sst.cf.crs == exp + + # Dataset with no grid mapping specified but not on latlon (error) + with pytest.raises(ValueError, match="No grid mapping found"): + mollwds.cf.crs + with pytest.raises(ValueError, match="No grid mapping found"): + mollwds.lon_bounds.cf.crs + + def test_docstring() -> None: assert "One of ('X'" in airds.cf.groupby.__doc__ assert "Time variable accessor e.g. 'T.month'" in airds.cf.groupby.__doc__ diff --git a/cf_xarray/utils.py b/cf_xarray/utils.py index bdc2605e..e4292b24 100644 --- a/cf_xarray/utils.py +++ b/cf_xarray/utils.py @@ -193,3 +193,18 @@ def emit_user_level_warning(message, category=None): """Emit a warning at the user level by inspecting the stack trace.""" stacklevel = find_stack_level() warnings.warn(message, category=category, stacklevel=stacklevel) + + +def is_latitude_longitude(ds): + """ + A dataset is probably using the latitude_longitude grid mapping implicitly if + - it has both longitude and latitude coordinates + - they are 1D (so either a list of points or a regular grid) + """ + coords = ds.cf.coordinates + return ( + "longitude" in coords + and "latitude" in coords + and ds[coords["longitude"][0]].ndim == 1 + and ds[coords["latitude"][0]].ndim == 1 + ) diff --git a/ci/environment.yml b/ci/environment.yml index bcfb0780..83f13c93 100644 --- a/ci/environment.yml +++ b/ci/environment.yml @@ -6,6 +6,7 @@ dependencies: - pytest - pytest-xdist - dask + - cartopy - flox - lxml - matplotlib-base @@ -13,6 +14,7 @@ dependencies: - pandas - pint - pooch + - pyproj - regex - rich - pooch diff --git a/pyproject.toml b/pyproject.toml index 2269b133..93a470d1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -137,6 +137,7 @@ enable_error_code = ["ignore-without-code", "redundant-expr", "truthy-bool"] [[tool.mypy.overrides]] module=[ + "cartopy", "cftime", "pandas", "pooch",