Skip to content

Commit d27d633

Browse files
committed
Add GridMapping dataclass
1 parent d6ae4b1 commit d27d633

File tree

7 files changed

+415
-22
lines changed

7 files changed

+415
-22
lines changed

cf_xarray/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
from . import geometry as geometry
55
from . import sgrid # noqa
6-
from .accessor import CFAccessor # noqa
6+
from .accessor import CFAccessor, GridMapping # noqa
77
from .coding import ( # noqa
88
decode_compress_to_multi_index,
99
encode_multi_index_as_compress,

cf_xarray/accessor.py

Lines changed: 215 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
MutableMapping,
1414
Sequence,
1515
)
16+
from dataclasses import dataclass
1617
from datetime import datetime
1718
from typing import (
1819
Any,
@@ -82,6 +83,65 @@
8283

8384
FlagParam = namedtuple("FlagParam", ["flag_mask", "flag_value"])
8485

86+
87+
@dataclass(frozen=True, kw_only=True)
88+
class GridMapping:
89+
"""
90+
Represents a CF grid mapping with its properties and associated coordinate variables.
91+
92+
Attributes
93+
----------
94+
name : str
95+
The CF grid mapping name (e.g., ``'latitude_longitude'``, ``'transverse_mercator'``)
96+
crs : pyproj.CRS
97+
The coordinate reference system object
98+
array : xarray.DataArray
99+
The grid mapping variable as a DataArray containing the CRS parameters
100+
coordinates : tuple[str, ...]
101+
Names of coordinate variables associated with this grid mapping. For grid mappings
102+
that are explicitly listed with coordinates in the grid_mapping attribute
103+
(e.g., ``'spatial_ref: crs_4326: latitude longitude'``), this contains those coordinates.
104+
For grid mappings (e.g. ``spatial_ref``) that don't explicitly specify coordinates,
105+
this falls back to the dimension names of the data variable that references
106+
this grid mapping.
107+
"""
108+
109+
name: str
110+
crs: Any # pyproj.CRS when available, None otherwise
111+
array: xr.DataArray
112+
coordinates: tuple[str, ...]
113+
114+
def __repr__(self) -> str:
115+
# Short CRS representation
116+
if self.crs is not None:
117+
# Try to get EPSG code first, fallback to shorter description
118+
try:
119+
if hasattr(self.crs, "to_epsg") and self.crs.to_epsg():
120+
crs_repr = f"<CRS: EPSG:{self.crs.to_epsg()}>"
121+
else:
122+
# Use the name if available, otherwise authority:code
123+
crs_name = getattr(self.crs, "name", str(self.crs)[:50] + "...")
124+
crs_repr = f"<CRS: {crs_name}>"
125+
except Exception:
126+
# Fallback to generic representation
127+
crs_repr = "<CRS>"
128+
else:
129+
crs_repr = "None"
130+
131+
# Short array representation - name and shape
132+
array_repr = f"<DataArray '{self.array.name}' {self.array.shape}>"
133+
134+
# Format coordinates nicely
135+
coords_repr = f"({', '.join(repr(c) for c in self.coordinates)})"
136+
137+
return (
138+
f"GridMapping(name={self.name!r}, "
139+
f"crs={crs_repr}, "
140+
f"array={array_repr}, "
141+
f"coordinates={coords_repr})"
142+
)
143+
144+
85145
#: Classes wrapped by cf_xarray.
86146
_WRAPPED_CLASSES = (Resample, GroupBy, Rolling, Coarsen, Weighted)
87147

@@ -2406,6 +2466,160 @@ def add_canonical_attributes(
24062466

24072467
return obj
24082468

2469+
def _create_grid_mapping(
2470+
self,
2471+
var_name: str,
2472+
obj_dataset: Dataset,
2473+
grid_mapping_dict: dict[str, list[str]],
2474+
) -> GridMapping:
2475+
"""
2476+
Create a GridMapping dataclass instance from a grid mapping variable.
2477+
2478+
Parameters
2479+
----------
2480+
var_name : str
2481+
Name of the grid mapping variable
2482+
obj_dataset : Dataset
2483+
Dataset containing the grid mapping variable
2484+
grid_mapping_dict : dict[str, list[str]]
2485+
Dictionary mapping grid mapping variable names to their coordinate variables
2486+
2487+
Returns
2488+
-------
2489+
GridMapping
2490+
GridMapping dataclass instance
2491+
2492+
Notes
2493+
-----
2494+
Assumes pyproj is available (should be checked by caller).
2495+
"""
2496+
from pyproj import (
2497+
CRS, # Safe to import since grid_mappings property checks availability
2498+
)
2499+
2500+
var = obj_dataset._variables[var_name]
2501+
2502+
# Create DataArray from Variable, preserving the name
2503+
# Use reset_coords(drop=True) to avoid coordinate conflicts
2504+
if var_name in obj_dataset.coords:
2505+
da = obj_dataset.coords[var_name].reset_coords(drop=True)
2506+
else:
2507+
da = obj_dataset[var_name].reset_coords(drop=True)
2508+
2509+
# Get the CF grid mapping name from the variable's attributes
2510+
cf_name = var.attrs.get("grid_mapping_name", var_name)
2511+
2512+
# Create CRS from the grid mapping variable
2513+
try:
2514+
crs = CRS.from_cf(var.attrs)
2515+
except Exception:
2516+
# If CRS creation fails, use None
2517+
crs = None
2518+
2519+
# Get associated coordinate variables, fallback to dimension names
2520+
coordinates = grid_mapping_dict.get(var_name, [])
2521+
if not coordinates:
2522+
# For DataArrays, find the data variable that references this grid mapping
2523+
for _data_var_name, data_var in obj_dataset.data_vars.items():
2524+
if "grid_mapping" in data_var.attrs:
2525+
gm_attr = data_var.attrs["grid_mapping"]
2526+
if var_name in gm_attr:
2527+
coordinates = list(data_var.dims)
2528+
break
2529+
2530+
return GridMapping(
2531+
name=cf_name, crs=crs, array=da, coordinates=tuple(coordinates)
2532+
)
2533+
2534+
@property
2535+
def grid_mappings(self) -> tuple[GridMapping, ...]:
2536+
"""
2537+
Return a tuple of GridMapping objects for all grid mappings in this object.
2538+
2539+
For DataArrays, the order in the tuple matches the order that grid mappings appear
2540+
in the grid_mapping attribute string.
2541+
2542+
Parameters
2543+
----------
2544+
None
2545+
2546+
Returns
2547+
-------
2548+
tuple[GridMapping, ...]
2549+
Tuple of GridMapping dataclass instances, each containing:
2550+
- name: CF grid mapping name
2551+
- crs: pyproj.CRS object
2552+
- array: xarray.DataArray containing the grid mapping variable
2553+
- coordinates: tuple of coordinate variable names
2554+
2555+
Raises
2556+
------
2557+
ImportError
2558+
If pyproj is not available. This property requires pyproj for CRS creation.
2559+
2560+
Examples
2561+
--------
2562+
>>> ds.cf.grid_mappings
2563+
(GridMapping(name='latitude_longitude', crs=<CRS: EPSG:4326>, ...),)
2564+
2565+
Notes
2566+
-----
2567+
This property requires pyproj to be installed for creating CRS objects from
2568+
CF grid mapping parameters. Install with: ``conda install pyproj`` or
2569+
``pip install pyproj``.
2570+
"""
2571+
# Check pyproj availability upfront
2572+
try:
2573+
import pyproj # noqa: F401
2574+
except ImportError:
2575+
raise ImportError(
2576+
"pyproj is required for .cf.grid_mappings property. "
2577+
"Install with: conda install pyproj or pip install pyproj"
2578+
) from None
2579+
# For DataArrays, preserve order from grid_mapping attribute
2580+
if isinstance(self._obj, DataArray) and "grid_mapping" in self._obj.attrs:
2581+
grid_mapping_dict = _parse_grid_mapping_attribute(
2582+
self._obj.attrs["grid_mapping"]
2583+
)
2584+
# Get grid mappings in the order they appear in the string
2585+
ordered_var_names = list(grid_mapping_dict.keys())
2586+
else:
2587+
# For Datasets, look for grid_mapping attributes in data variables
2588+
grid_mapping_dict = {}
2589+
ordered_var_names = []
2590+
2591+
# Search all data variables for grid_mapping attributes
2592+
for _var_name, var in self._obj.data_vars.items():
2593+
if "grid_mapping" in var.attrs:
2594+
parsed = _parse_grid_mapping_attribute(var.attrs["grid_mapping"])
2595+
grid_mapping_dict.update(parsed)
2596+
# Add variables in order they appear in this grid_mapping string
2597+
for gm_var in parsed.keys():
2598+
if gm_var not in ordered_var_names:
2599+
ordered_var_names.append(gm_var)
2600+
2601+
# If no grid_mapping attributes found in data vars, try grid_mapping_names property
2602+
if not ordered_var_names and hasattr(self, "grid_mapping_names"):
2603+
grid_mapping_names = self.grid_mapping_names
2604+
for var_names in grid_mapping_names.values():
2605+
ordered_var_names.extend(var_names)
2606+
2607+
if not ordered_var_names:
2608+
return ()
2609+
2610+
grid_mappings = []
2611+
obj_dataset = self._maybe_to_dataset()
2612+
2613+
for var_name in ordered_var_names:
2614+
if var_name not in obj_dataset._variables:
2615+
continue
2616+
2617+
grid_mappings.append(
2618+
self._create_grid_mapping(var_name, obj_dataset, grid_mapping_dict)
2619+
)
2620+
2621+
return tuple(grid_mappings)
2622+
24092623

24102624
@xr.register_dataset_accessor("cf")
24112625
class CFDatasetAccessor(CFAccessor):
@@ -3009,8 +3223,7 @@ def grid_mapping_names(self) -> dict[str, list[str]]:
30093223
grid_mapping_var = da.coords[grid_mapping_var_name]
30103224
if gmn := grid_mapping_var.attrs.get("grid_mapping_name"):
30113225
results[gmn].append(grid_mapping_var_name)
3012-
3013-
return results
3226+
return dict(results)
30143227

30153228
@property
30163229
def grid_mapping_name(self) -> str:

cf_xarray/tests/test_accessor.py

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1130,6 +1130,116 @@ def test_multiple_grid_mapping_attribute():
11301130
assert "foo" in result.data_vars
11311131

11321132

1133+
@requires_pyproj
1134+
def test_grid_mappings_property():
1135+
"""Test the .cf.grid_mappings property on both Dataset and DataArray."""
1136+
from ..datasets import hrrrds
1137+
1138+
ds = hrrrds
1139+
1140+
# Test Dataset
1141+
grid_mappings = ds.cf.grid_mappings
1142+
assert len(grid_mappings) == 3
1143+
1144+
# Check that all expected grid mapping names are present
1145+
gm_names = {gm.name for gm in grid_mappings}
1146+
expected_names = {
1147+
"latitude_longitude",
1148+
"lambert_azimuthal_equal_area",
1149+
"transverse_mercator",
1150+
}
1151+
assert gm_names == expected_names
1152+
1153+
# Test specific properties of each grid mapping
1154+
for gm in grid_mappings:
1155+
assert gm.crs is not None # Should have pyproj CRS
1156+
assert isinstance(gm.array, xr.DataArray) # Should be DataArray, not Variable
1157+
assert isinstance(gm.coordinates, tuple)
1158+
assert gm.array.name is not None # DataArray should preserve name
1159+
1160+
# Check specific coordinate associations
1161+
if gm.name == "latitude_longitude":
1162+
assert gm.coordinates == ("latitude", "longitude")
1163+
elif gm.name == "transverse_mercator":
1164+
assert gm.coordinates == ("x27700", "y27700")
1165+
elif gm.name == "lambert_azimuthal_equal_area":
1166+
assert gm.coordinates == (
1167+
"x",
1168+
"y",
1169+
) # Falls back to data variable dimensions
1170+
1171+
# Test DataArray
1172+
da = ds.foo
1173+
da_grid_mappings = da.cf.grid_mappings
1174+
assert len(da_grid_mappings) == 3
1175+
1176+
# DataArray should have the same grid mappings as Dataset
1177+
da_names = {gm.name for gm in da_grid_mappings}
1178+
assert da_names == expected_names
1179+
1180+
# Check that coordinates are populated for DataArray too
1181+
for gm in da_grid_mappings:
1182+
assert len(gm.coordinates) > 0 # Should never be empty now
1183+
if gm.name == "lambert_azimuthal_equal_area":
1184+
assert gm.coordinates == ("x", "y")
1185+
1186+
1187+
@requires_pyproj
1188+
def test_grid_mappings_coordinates_attribute():
1189+
"""Test that coordinates attribute is always populated correctly for DataArray grid mappings."""
1190+
from ..datasets import hrrrds
1191+
1192+
ds = hrrrds
1193+
1194+
# Focus on DataArray access
1195+
da = ds.foo
1196+
grid_mappings = da.cf.grid_mappings
1197+
assert len(grid_mappings) == 3
1198+
1199+
# Verify order preservation for DataArray (should match grid_mapping attribute order)
1200+
expected_order = [
1201+
"lambert_azimuthal_equal_area",
1202+
"latitude_longitude",
1203+
"transverse_mercator",
1204+
]
1205+
actual_order = [gm.name for gm in grid_mappings]
1206+
assert actual_order == expected_order, (
1207+
f"DataArray order {actual_order} doesn't match expected {expected_order}"
1208+
)
1209+
1210+
for gm in grid_mappings:
1211+
# Coordinates should never be empty
1212+
assert len(gm.coordinates) > 0, (
1213+
f"Grid mapping '{gm.name}' has empty coordinates"
1214+
)
1215+
1216+
# All coordinates should be strings
1217+
assert all(isinstance(coord, str) for coord in gm.coordinates), (
1218+
f"Grid mapping '{gm.name}' has non-string coordinates: {gm.coordinates}"
1219+
)
1220+
1221+
# Test specific expected coordinates for each grid mapping
1222+
if gm.name == "latitude_longitude":
1223+
# Explicitly listed in grid_mapping attribute: "crs_4326: latitude longitude"
1224+
assert gm.coordinates == ("latitude", "longitude"), (
1225+
f"Expected ('latitude', 'longitude'), got {gm.coordinates}"
1226+
)
1227+
elif gm.name == "transverse_mercator":
1228+
# Explicitly listed in grid_mapping attribute: "crs_27700: x27700 y27700"
1229+
assert gm.coordinates == ("x27700", "y27700"), (
1230+
f"Expected ('x27700', 'y27700'), got {gm.coordinates}"
1231+
)
1232+
elif gm.name == "lambert_azimuthal_equal_area":
1233+
# Not explicitly listed, should fallback to DataArray dimensions
1234+
assert gm.coordinates == ("x", "y"), (
1235+
f"Expected ('x', 'y') from DataArray dimensions, got {gm.coordinates}"
1236+
)
1237+
# Verify these are actually the DataArray's dimensions
1238+
assert gm.coordinates == da.dims, (
1239+
f"Fallback coordinates {gm.coordinates} don't match DataArray dims {da.dims}"
1240+
)
1241+
1242+
11331243
def test_bad_grid_mapping_attribute():
11341244
ds = rotds.copy(deep=False)
11351245
ds.temp.attrs["grid_mapping"] = "foo"

ci/doc.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ dependencies:
1010
- xarray
1111
- sphinx
1212
- sphinx-copybutton
13+
- sphinx-autobuild
1314
- numpydoc
1415
- sphinx-autosummary-accessors
1516
- ipython
@@ -22,5 +23,6 @@ dependencies:
2223
- shapely
2324
- furo>=2024
2425
- myst-nb
26+
- pyproj
2527
- pip:
2628
- -e ../

doc/Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# You can set these variables from the command line, and also
55
# from the environment for the first two.
66
SPHINXOPTS ?=
7-
SPHINXBUILD ?= sphinx-build
7+
SPHINXBUILD ?= sphinx-autobuild
88
SOURCEDIR = .
99
BUILDDIR = _build
1010

0 commit comments

Comments
 (0)