Skip to content

Commit 5410abf

Browse files
committed
Add GridMapping dataclass
1 parent d6ae4b1 commit 5410abf

File tree

7 files changed

+411
-22
lines changed

7 files changed

+411
-22
lines changed

cf_xarray/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
from . import geometry as geometry
55
from . import sgrid # noqa
6-
from .accessor import CFAccessor # noqa
6+
from .accessor import CFAccessor, GridMapping # noqa
77
from .coding import ( # noqa
88
decode_compress_to_multi_index,
99
encode_multi_index_as_compress,

cf_xarray/accessor.py

Lines changed: 211 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
MutableMapping,
1414
Sequence,
1515
)
16+
from dataclasses import dataclass
1617
from datetime import datetime
1718
from typing import (
1819
Any,
@@ -82,6 +83,61 @@
8283

8384
FlagParam = namedtuple("FlagParam", ["flag_mask", "flag_value"])
8485

86+
87+
@dataclass(frozen=True, kw_only=True)
88+
class GridMapping:
89+
"""
90+
Represents a CF grid mapping with its properties and associated coordinate variables.
91+
92+
Attributes
93+
----------
94+
name : str
95+
The CF grid mapping name (e.g., ``'latitude_longitude'``, ``'transverse_mercator'``)
96+
crs : pyproj.CRS
97+
The coordinate reference system object
98+
array : xarray.DataArray
99+
The grid mapping variable as a DataArray containing the CRS parameters
100+
coordinates : tuple[str, ...]
101+
Names of coordinate variables associated with this grid mapping. For grid mappings
102+
that are explicitly listed with coordinates in the grid_mapping attribute
103+
(e.g., ``'spatial_ref: crs_4326: latitude longitude'``), this contains those coordinates.
104+
For grid mappings (e.g. ``spatial_ref``) that don't explicitly specify coordinates,
105+
this falls back to the dimension names of the data variable that references
106+
this grid mapping.
107+
"""
108+
109+
name: str
110+
crs: Any # really pyproj.CRS
111+
array: xr.DataArray
112+
coordinates: tuple[str, ...]
113+
114+
def __repr__(self) -> str:
115+
# Try to get EPSG code first, fallback to shorter description
116+
try:
117+
if hasattr(self.crs, "to_epsg") and self.crs.to_epsg():
118+
crs_repr = f"<CRS: EPSG:{self.crs.to_epsg()}>"
119+
else:
120+
# Use the name if available, otherwise authority:code
121+
crs_name = getattr(self.crs, "name", str(self.crs)[:50] + "...")
122+
crs_repr = f"<CRS: {crs_name}>"
123+
except Exception:
124+
# Fallback to generic representation
125+
crs_repr = "<CRS>"
126+
127+
# Short array representation - name and shape
128+
array_repr = f"<DataArray '{self.array.name}' {self.array.shape}>"
129+
130+
# Format coordinates nicely
131+
coords_repr = f"({', '.join(repr(c) for c in self.coordinates)})"
132+
133+
return (
134+
f"GridMapping(name={self.name!r}, "
135+
f"crs={crs_repr}, "
136+
f"array={array_repr}, "
137+
f"coordinates={coords_repr})"
138+
)
139+
140+
85141
#: Classes wrapped by cf_xarray.
86142
_WRAPPED_CLASSES = (Resample, GroupBy, Rolling, Coarsen, Weighted)
87143

@@ -2406,6 +2462,160 @@ def add_canonical_attributes(
24062462

24072463
return obj
24082464

2465+
def _create_grid_mapping(
2466+
self,
2467+
var_name: str,
2468+
obj_dataset: Dataset,
2469+
grid_mapping_dict: dict[str, list[str]],
2470+
) -> GridMapping:
2471+
"""
2472+
Create a GridMapping dataclass instance from a grid mapping variable.
2473+
2474+
Parameters
2475+
----------
2476+
var_name : str
2477+
Name of the grid mapping variable
2478+
obj_dataset : Dataset
2479+
Dataset containing the grid mapping variable
2480+
grid_mapping_dict : dict[str, list[str]]
2481+
Dictionary mapping grid mapping variable names to their coordinate variables
2482+
2483+
Returns
2484+
-------
2485+
GridMapping
2486+
GridMapping dataclass instance
2487+
2488+
Notes
2489+
-----
2490+
Assumes pyproj is available (should be checked by caller).
2491+
"""
2492+
from pyproj import (
2493+
CRS, # Safe to import since grid_mappings property checks availability
2494+
)
2495+
2496+
var = obj_dataset._variables[var_name]
2497+
2498+
# Create DataArray from Variable, preserving the name
2499+
# Use reset_coords(drop=True) to avoid coordinate conflicts
2500+
if var_name in obj_dataset.coords:
2501+
da = obj_dataset.coords[var_name].reset_coords(drop=True)
2502+
else:
2503+
da = obj_dataset[var_name].reset_coords(drop=True)
2504+
2505+
# Get the CF grid mapping name from the variable's attributes
2506+
cf_name = var.attrs.get("grid_mapping_name", var_name)
2507+
2508+
# Create CRS from the grid mapping variable
2509+
try:
2510+
crs = CRS.from_cf(var.attrs)
2511+
except Exception:
2512+
# If CRS creation fails, use None
2513+
crs = None
2514+
2515+
# Get associated coordinate variables, fallback to dimension names
2516+
coordinates = grid_mapping_dict.get(var_name, [])
2517+
if not coordinates:
2518+
# For DataArrays, find the data variable that references this grid mapping
2519+
for _data_var_name, data_var in obj_dataset.data_vars.items():
2520+
if "grid_mapping" in data_var.attrs:
2521+
gm_attr = data_var.attrs["grid_mapping"]
2522+
if var_name in gm_attr:
2523+
coordinates = list(data_var.dims)
2524+
break
2525+
2526+
return GridMapping(
2527+
name=cf_name, crs=crs, array=da, coordinates=tuple(coordinates)
2528+
)
2529+
2530+
@property
2531+
def grid_mappings(self) -> tuple[GridMapping, ...]:
2532+
"""
2533+
Return a tuple of GridMapping objects for all grid mappings in this object.
2534+
2535+
For DataArrays, the order in the tuple matches the order that grid mappings appear
2536+
in the grid_mapping attribute string.
2537+
2538+
Parameters
2539+
----------
2540+
None
2541+
2542+
Returns
2543+
-------
2544+
tuple[GridMapping, ...]
2545+
Tuple of GridMapping dataclass instances, each containing:
2546+
- name: CF grid mapping name
2547+
- crs: pyproj.CRS object
2548+
- array: xarray.DataArray containing the grid mapping variable
2549+
- coordinates: tuple of coordinate variable names
2550+
2551+
Raises
2552+
------
2553+
ImportError
2554+
If pyproj is not available. This property requires pyproj for CRS creation.
2555+
2556+
Examples
2557+
--------
2558+
>>> ds.cf.grid_mappings
2559+
(GridMapping(name='latitude_longitude', crs=<CRS: EPSG:4326>, ...),)
2560+
2561+
Notes
2562+
-----
2563+
This property requires pyproj to be installed for creating CRS objects from
2564+
CF grid mapping parameters. Install with: ``conda install pyproj`` or
2565+
``pip install pyproj``.
2566+
"""
2567+
# Check pyproj availability upfront
2568+
try:
2569+
import pyproj # noqa: F401
2570+
except ImportError:
2571+
raise ImportError(
2572+
"pyproj is required for .cf.grid_mappings property. "
2573+
"Install with: conda install pyproj or pip install pyproj"
2574+
) from None
2575+
# For DataArrays, preserve order from grid_mapping attribute
2576+
if isinstance(self._obj, DataArray) and "grid_mapping" in self._obj.attrs:
2577+
grid_mapping_dict = _parse_grid_mapping_attribute(
2578+
self._obj.attrs["grid_mapping"]
2579+
)
2580+
# Get grid mappings in the order they appear in the string
2581+
ordered_var_names = list(grid_mapping_dict.keys())
2582+
else:
2583+
# For Datasets, look for grid_mapping attributes in data variables
2584+
grid_mapping_dict = {}
2585+
ordered_var_names = []
2586+
2587+
# Search all data variables for grid_mapping attributes
2588+
for _var_name, var in self._obj.data_vars.items():
2589+
if "grid_mapping" in var.attrs:
2590+
parsed = _parse_grid_mapping_attribute(var.attrs["grid_mapping"])
2591+
grid_mapping_dict.update(parsed)
2592+
# Add variables in order they appear in this grid_mapping string
2593+
for gm_var in parsed.keys():
2594+
if gm_var not in ordered_var_names:
2595+
ordered_var_names.append(gm_var)
2596+
2597+
# If no grid_mapping attributes found in data vars, try grid_mapping_names property
2598+
if not ordered_var_names and hasattr(self, "grid_mapping_names"):
2599+
grid_mapping_names = self.grid_mapping_names
2600+
for var_names in grid_mapping_names.values():
2601+
ordered_var_names.extend(var_names)
2602+
2603+
if not ordered_var_names:
2604+
return ()
2605+
2606+
grid_mappings = []
2607+
obj_dataset = self._maybe_to_dataset()
2608+
2609+
for var_name in ordered_var_names:
2610+
if var_name not in obj_dataset._variables:
2611+
continue
2612+
2613+
grid_mappings.append(
2614+
self._create_grid_mapping(var_name, obj_dataset, grid_mapping_dict)
2615+
)
2616+
2617+
return tuple(grid_mappings)
2618+
24092619

24102620
@xr.register_dataset_accessor("cf")
24112621
class CFDatasetAccessor(CFAccessor):
@@ -3009,8 +3219,7 @@ def grid_mapping_names(self) -> dict[str, list[str]]:
30093219
grid_mapping_var = da.coords[grid_mapping_var_name]
30103220
if gmn := grid_mapping_var.attrs.get("grid_mapping_name"):
30113221
results[gmn].append(grid_mapping_var_name)
3012-
3013-
return results
3222+
return dict(results)
30143223

30153224
@property
30163225
def grid_mapping_name(self) -> str:

cf_xarray/tests/test_accessor.py

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1130,6 +1130,116 @@ def test_multiple_grid_mapping_attribute():
11301130
assert "foo" in result.data_vars
11311131

11321132

1133+
@requires_pyproj
1134+
def test_grid_mappings_property():
1135+
"""Test the .cf.grid_mappings property on both Dataset and DataArray."""
1136+
from ..datasets import hrrrds
1137+
1138+
ds = hrrrds
1139+
1140+
# Test Dataset
1141+
grid_mappings = ds.cf.grid_mappings
1142+
assert len(grid_mappings) == 3
1143+
1144+
# Check that all expected grid mapping names are present
1145+
gm_names = {gm.name for gm in grid_mappings}
1146+
expected_names = {
1147+
"latitude_longitude",
1148+
"lambert_azimuthal_equal_area",
1149+
"transverse_mercator",
1150+
}
1151+
assert gm_names == expected_names
1152+
1153+
# Test specific properties of each grid mapping
1154+
for gm in grid_mappings:
1155+
assert gm.crs is not None # Should have pyproj CRS
1156+
assert isinstance(gm.array, xr.DataArray) # Should be DataArray, not Variable
1157+
assert isinstance(gm.coordinates, tuple)
1158+
assert gm.array.name is not None # DataArray should preserve name
1159+
1160+
# Check specific coordinate associations
1161+
if gm.name == "latitude_longitude":
1162+
assert gm.coordinates == ("latitude", "longitude")
1163+
elif gm.name == "transverse_mercator":
1164+
assert gm.coordinates == ("x27700", "y27700")
1165+
elif gm.name == "lambert_azimuthal_equal_area":
1166+
assert gm.coordinates == (
1167+
"x",
1168+
"y",
1169+
) # Falls back to data variable dimensions
1170+
1171+
# Test DataArray
1172+
da = ds.foo
1173+
da_grid_mappings = da.cf.grid_mappings
1174+
assert len(da_grid_mappings) == 3
1175+
1176+
# DataArray should have the same grid mappings as Dataset
1177+
da_names = {gm.name for gm in da_grid_mappings}
1178+
assert da_names == expected_names
1179+
1180+
# Check that coordinates are populated for DataArray too
1181+
for gm in da_grid_mappings:
1182+
assert len(gm.coordinates) > 0 # Should never be empty now
1183+
if gm.name == "lambert_azimuthal_equal_area":
1184+
assert gm.coordinates == ("x", "y")
1185+
1186+
1187+
@requires_pyproj
1188+
def test_grid_mappings_coordinates_attribute():
1189+
"""Test that coordinates attribute is always populated correctly for DataArray grid mappings."""
1190+
from ..datasets import hrrrds
1191+
1192+
ds = hrrrds
1193+
1194+
# Focus on DataArray access
1195+
da = ds.foo
1196+
grid_mappings = da.cf.grid_mappings
1197+
assert len(grid_mappings) == 3
1198+
1199+
# Verify order preservation for DataArray (should match grid_mapping attribute order)
1200+
expected_order = [
1201+
"lambert_azimuthal_equal_area",
1202+
"latitude_longitude",
1203+
"transverse_mercator",
1204+
]
1205+
actual_order = [gm.name for gm in grid_mappings]
1206+
assert actual_order == expected_order, (
1207+
f"DataArray order {actual_order} doesn't match expected {expected_order}"
1208+
)
1209+
1210+
for gm in grid_mappings:
1211+
# Coordinates should never be empty
1212+
assert len(gm.coordinates) > 0, (
1213+
f"Grid mapping '{gm.name}' has empty coordinates"
1214+
)
1215+
1216+
# All coordinates should be strings
1217+
assert all(isinstance(coord, str) for coord in gm.coordinates), (
1218+
f"Grid mapping '{gm.name}' has non-string coordinates: {gm.coordinates}"
1219+
)
1220+
1221+
# Test specific expected coordinates for each grid mapping
1222+
if gm.name == "latitude_longitude":
1223+
# Explicitly listed in grid_mapping attribute: "crs_4326: latitude longitude"
1224+
assert gm.coordinates == ("latitude", "longitude"), (
1225+
f"Expected ('latitude', 'longitude'), got {gm.coordinates}"
1226+
)
1227+
elif gm.name == "transverse_mercator":
1228+
# Explicitly listed in grid_mapping attribute: "crs_27700: x27700 y27700"
1229+
assert gm.coordinates == ("x27700", "y27700"), (
1230+
f"Expected ('x27700', 'y27700'), got {gm.coordinates}"
1231+
)
1232+
elif gm.name == "lambert_azimuthal_equal_area":
1233+
# Not explicitly listed, should fallback to DataArray dimensions
1234+
assert gm.coordinates == ("x", "y"), (
1235+
f"Expected ('x', 'y') from DataArray dimensions, got {gm.coordinates}"
1236+
)
1237+
# Verify these are actually the DataArray's dimensions
1238+
assert gm.coordinates == da.dims, (
1239+
f"Fallback coordinates {gm.coordinates} don't match DataArray dims {da.dims}"
1240+
)
1241+
1242+
11331243
def test_bad_grid_mapping_attribute():
11341244
ds = rotds.copy(deep=False)
11351245
ds.temp.attrs["grid_mapping"] = "foo"

ci/doc.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ dependencies:
1010
- xarray
1111
- sphinx
1212
- sphinx-copybutton
13+
- sphinx-autobuild
1314
- numpydoc
1415
- sphinx-autosummary-accessors
1516
- ipython
@@ -22,5 +23,6 @@ dependencies:
2223
- shapely
2324
- furo>=2024
2425
- myst-nb
26+
- pyproj
2527
- pip:
2628
- -e ../

doc/Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# You can set these variables from the command line, and also
55
# from the environment for the first two.
66
SPHINXOPTS ?=
7-
SPHINXBUILD ?= sphinx-build
7+
SPHINXBUILD ?= sphinx-autobuild
88
SOURCEDIR = .
99
BUILDDIR = _build
1010

0 commit comments

Comments
 (0)