Skip to content

Commit 8231bfc

Browse files
authored
Add sgrid axes parsing (#421)
* Add sgrid axes parsing * Update cf_xarray/sgrid.py * Fix repr * Add more sgrid examples * Fix ROMS detection * Propagate the grid attribute too. * Fix. * Fix repr * Cleanup formatting.py * Fix repr * Add docs. * Update docs. * Add whats-new * last edits
1 parent 1d8fe9d commit 8231bfc

File tree

9 files changed

+294
-48
lines changed

9 files changed

+294
-48
lines changed

cf_xarray/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from . import sgrid # noqa
12
from .accessor import CFAccessor # noqa
23
from .coding import ( # noqa
34
decode_compress_to_multi_index,

cf_xarray/accessor.py

Lines changed: 34 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,9 @@
3030
from xarray.core.rolling import Coarsen, Rolling
3131
from xarray.core.weighted import Weighted
3232

33+
from . import sgrid
3334
from .criteria import (
35+
_DSG_ROLES,
3436
cf_role_criteria,
3537
coordinate_criteria,
3638
grid_mapping_var_criteria,
@@ -40,8 +42,9 @@
4042
_format_coordinates,
4143
# _format_conventions,
4244
_format_data_vars,
45+
_format_dsg_roles,
4346
_format_flags,
44-
_format_roles,
47+
_format_sgrid,
4548
_maybe_panel,
4649
)
4750
from .helpers import _guess_bounds_1d, _guess_bounds_2d, bounds_to_vertices
@@ -313,6 +316,11 @@ def _get_axis_coord(obj: DataArray | Dataset, key: str) -> list[str]:
313316
units = getattr(var.data, "units", None)
314317
if units in expected:
315318
results.update((coord,))
319+
320+
if key in _AXIS_NAMES and "grid_topology" in obj.cf.cf_roles:
321+
sgrid_axes = sgrid.parse_axes(obj)
322+
results.update((search_in | set(obj.dims)) & sgrid_axes[key])
323+
316324
return list(results)
317325

318326

@@ -474,7 +482,7 @@ def _get_coords(obj: DataArray | Dataset, key: Hashable) -> list[Hashable]:
474482
One or more of ('X', 'Y', 'Z', 'T', 'longitude', 'latitude', 'vertical', 'time',
475483
'area', 'volume'), or arbitrary measures, or standard names present in .coords
476484
"""
477-
return [k for k in _get_all(obj, key) if k in obj.coords]
485+
return [k for k in _get_all(obj, key) if k in obj.coords or k in obj.dims]
478486

479487

480488
def _variables(func: F) -> F:
@@ -1435,12 +1443,22 @@ def _generate_repr(self, rich=False):
14351443
_format_flags(self, rich), title="Flag Variable", rich=rich
14361444
)
14371445

1438-
if self.cf_roles:
1439-
yield _maybe_panel(
1440-
_format_roles(self, dims, rich),
1441-
title="Discrete Sampling Geometry",
1442-
rich=rich,
1443-
)
1446+
roles = self.cf_roles
1447+
if roles:
1448+
if any(role in roles for role in _DSG_ROLES):
1449+
yield _maybe_panel(
1450+
_format_dsg_roles(self, dims, rich),
1451+
title="Discrete Sampling Geometry",
1452+
rich=rich,
1453+
)
1454+
1455+
if "grid_topology" in self.cf_roles:
1456+
axes = sgrid.parse_axes(self._obj)
1457+
yield _maybe_panel(
1458+
_format_sgrid(self, axes, rich),
1459+
title="SGRID",
1460+
rich=rich,
1461+
)
14441462

14451463
yield _maybe_panel(
14461464
_format_coordinates(self, dims, coords, rich),
@@ -1642,6 +1660,7 @@ def get_associated_variable_names(
16421660
3. "cell_measures"
16431661
4. "coordinates"
16441662
5. "grid_mapping"
1663+
6. "grid"
16451664
to a list of variable names referred to in the appropriate attribute
16461665
16471666
Parameters
@@ -1654,15 +1673,18 @@ def get_associated_variable_names(
16541673
Returns
16551674
-------
16561675
names : dict
1657-
Dictionary with keys "ancillary_variables", "cell_measures", "coordinates", "bounds".
1676+
Dictionary with keys "ancillary_variables", "cell_measures", "coordinates", "bounds",
1677+
"grid_mapping", "grid".
16581678
"""
16591679
keys = [
16601680
"ancillary_variables",
16611681
"cell_measures",
16621682
"coordinates",
16631683
"bounds",
16641684
"grid_mapping",
1685+
"grid",
16651686
]
1687+
16661688
coords: dict[str, list[Hashable]] = {k: [] for k in keys}
16671689
attrs_or_encoding = ChainMap(self._obj[name].attrs, self._obj[name].encoding)
16681690

@@ -1704,6 +1726,9 @@ def get_associated_variable_names(
17041726
if dbounds:
17051727
coords["bounds"].append(dbounds)
17061728

1729+
if "grid" in attrs_or_encoding:
1730+
coords["grid"] = [attrs_or_encoding["grid"]]
1731+
17071732
if "grid_mapping" in attrs_or_encoding:
17081733
coords["grid_mapping"] = [attrs_or_encoding["grid_mapping"]]
17091734

cf_xarray/criteria.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,13 @@
1111

1212
from typing import Any, Mapping, MutableMapping, Tuple
1313

14+
_DSG_ROLES = ["timeseries_id", "profile_id", "trajectory_id"]
15+
1416
cf_role_criteria: Mapping[str, Mapping[str, str]] = {
1517
k: {"cf_role": k}
1618
for k in (
1719
# CF Discrete sampling geometry
18-
"timeseries_id",
19-
"profile_id",
20-
"trajectory_id",
20+
*_DSG_ROLES,
2121
# SGRID
2222
"grid_topology",
2323
# UGRID

cf_xarray/datasets.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -655,3 +655,48 @@ def _create_inexact_bounds():
655655
"trajectory": ("trajectory", [0, 1], {"cf_role": "trajectory_id"}),
656656
},
657657
)
658+
659+
660+
sgrid_roms = xr.Dataset()
661+
sgrid_roms["grid"] = xr.DataArray(
662+
0,
663+
attrs=dict(
664+
cf_role="grid_topology",
665+
topology_dimension=2,
666+
node_dimensions="xi_psi eta_psi",
667+
face_dimensions="xi_rho: xi_psi (padding: both) eta_rho: eta_psi (padding: both)",
668+
edge1_dimensions="xi_u: xi_psi eta_u: eta_psi (padding: both)",
669+
edge2_dimensions="xi_v: xi_psi (padding: both) eta_v: eta_psi",
670+
node_coordinates="lon_psi lat_psi",
671+
face_coordinates="lon_rho lat_rho",
672+
edge1_coordinates="lon_u lat_u",
673+
edge2_coordinates="lon_v lat_v",
674+
vertical_dimensions="s_rho: s_w (padding: none)",
675+
),
676+
)
677+
sgrid_roms["u"] = (("xi_u", "eta_u"), np.ones((2, 2)), {"grid": "grid"})
678+
679+
sgrid_delft = xr.Dataset()
680+
sgrid_delft["grid"] = xr.DataArray(
681+
0,
682+
attrs=dict(
683+
cf_role="grid_topology",
684+
topology_dimension=2,
685+
node_dimensions="inode jnode",
686+
face_dimensions="icell: inode (padding: none) jcell: jnode (padding: none)",
687+
node_coordinates="node_lon node_lat",
688+
),
689+
)
690+
691+
692+
sgrid_delft3 = xr.Dataset()
693+
sgrid_delft3["grid"] = xr.DataArray(
694+
0,
695+
attrs=dict(
696+
cf_role="grid_topology",
697+
topology_dimension=3,
698+
node_dimensions="inode jnode knode",
699+
volume_dimensions="iface: inode (padding: none) jface: jnode (padding: none) kface: knode (padding: none)",
700+
node_coordinates="node_lon node_lat node_elevation",
701+
),
702+
)

cf_xarray/formatting.py

Lines changed: 62 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import warnings
2+
from functools import partial
23
from typing import Dict, Hashable, Iterable, List
34

45
STAR = " * "
@@ -35,6 +36,7 @@ def make_text_section(
3536
subtitle: str,
3637
attr: str,
3738
dims=None,
39+
valid_keys=None,
3840
valid_values=None,
3941
default_keys=None,
4042
rich: bool = False,
@@ -46,10 +48,16 @@ def make_text_section(
4648
dims = []
4749
with warnings.catch_warnings():
4850
warnings.simplefilter("ignore")
49-
try:
50-
vardict: Dict[str, Iterable[Hashable]] = getattr(accessor, attr, {})
51-
except ValueError:
52-
vardict = {}
51+
if isinstance(attr, str):
52+
try:
53+
vardict: Dict[str, Iterable[Hashable]] = getattr(accessor, attr, {})
54+
except ValueError:
55+
vardict = {}
56+
else:
57+
assert isinstance(attr, dict)
58+
vardict = attr
59+
if valid_keys:
60+
vardict = {k: v for k, v in vardict.items() if k in valid_keys}
5361

5462
# Sort keys if there aren't extra keys,
5563
# preserve default keys order otherwise.
@@ -142,53 +150,72 @@ def _format_flags(accessor, rich):
142150
return _print_rows("Flag Meanings", rows, rich)
143151

144152

145-
def _format_roles(accessor, dims, rich):
146-
yield make_text_section(accessor, "CF Roles", "cf_roles", dims=dims, rich=rich)
153+
def _format_dsg_roles(accessor, dims, rich):
154+
from .criteria import _DSG_ROLES
155+
156+
yield make_text_section(
157+
accessor,
158+
"CF Roles",
159+
"cf_roles",
160+
dims=dims,
161+
valid_keys=_DSG_ROLES,
162+
rich=rich,
163+
)
147164

148165

149166
def _format_coordinates(accessor, dims, coords, rich):
150167
from .accessor import _AXIS_NAMES, _CELL_MEASURES, _COORD_NAMES
151168

152-
yield make_text_section(
153-
accessor, "CF Axes", "axes", dims, coords, _AXIS_NAMES, rich=rich
169+
section = partial(
170+
make_text_section, accessor=accessor, dims=dims, valid_values=coords, rich=rich
154171
)
155-
yield make_text_section(
156-
accessor, "CF Coordinates", "coordinates", dims, coords, _COORD_NAMES, rich=rich
157-
)
158-
yield make_text_section(
159-
accessor,
160-
"Cell Measures",
161-
"cell_measures",
162-
dims,
163-
coords,
164-
_CELL_MEASURES,
165-
rich=rich,
166-
)
167-
yield make_text_section(
168-
accessor, "Standard Names", "standard_names", dims, coords, rich=rich
172+
173+
yield section(subtitle="CF Axes", attr="axes", default_keys=_AXIS_NAMES)
174+
yield section(
175+
subtitle="CF Coordinates", attr="coordinates", default_keys=_COORD_NAMES
169176
)
170-
yield make_text_section(accessor, "Bounds", "bounds", dims, coords, rich=rich)
171-
yield make_text_section(
172-
accessor, "Grid Mappings", "grid_mapping_names", dims, coords, rich=rich
177+
yield section(
178+
subtitle="Cell Measures", attr="cell_measures", default_keys=_CELL_MEASURES
173179
)
180+
yield section(subtitle="Standard Names", attr="standard_names")
181+
yield section(subtitle="Bounds", attr="bounds")
182+
yield section(subtitle="Grid Mappings", attr="grid_mapping_names")
174183

175184

176185
def _format_data_vars(accessor, data_vars, rich):
177186
from .accessor import _CELL_MEASURES
178187

179-
yield make_text_section(
180-
accessor,
181-
"Cell Measures",
182-
"cell_measures",
183-
None,
184-
data_vars,
185-
_CELL_MEASURES,
188+
section = partial(
189+
make_text_section,
190+
accessor=accessor,
191+
dims=None,
192+
valid_values=data_vars,
186193
rich=rich,
187194
)
195+
196+
yield section(
197+
subtitle="Cell Measures", attr="cell_measures", default_keys=_CELL_MEASURES
198+
)
199+
yield section(subtitle="Standard Names", attr="standard_names")
200+
yield section(subtitle="Bounds", attr="bounds")
201+
yield section(subtitle="Grid Mappings", attr="grid_mapping_names")
202+
203+
204+
def _format_sgrid(accessor, axes, rich):
188205
yield make_text_section(
189-
accessor, "Standard Names", "standard_names", None, data_vars, rich=rich
206+
accessor,
207+
"CF role",
208+
"cf_roles",
209+
valid_keys=["grid_topology"],
210+
rich=rich,
190211
)
191-
yield make_text_section(accessor, "Bounds", "bounds", None, data_vars, rich=rich)
212+
192213
yield make_text_section(
193-
accessor, "Grid Mappings", "grid_mapping_names", None, data_vars, rich=rich
214+
accessor,
215+
"Axes",
216+
axes,
217+
accessor._obj.dims,
218+
valid_values=accessor._obj.dims,
219+
default_keys=axes.keys(),
220+
rich=rich,
194221
)

cf_xarray/sgrid.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
SGRID_DIM_ATTRS = [
2+
"face_dimensions",
3+
"volume_dimensions",
4+
# the following are optional and should be redundant with the above
5+
# at least for dimension names
6+
# "face1_dimensions",
7+
# "face2_dimensions",
8+
# "face3_dimensions",
9+
"edge1_dimensions",
10+
"edge2_dimensions",
11+
# "edge3_dimensions",
12+
]
13+
14+
15+
def parse_axes(ds):
16+
import re
17+
18+
(gridvar,) = ds.cf.cf_roles["grid_topology"]
19+
grid = ds[gridvar]
20+
pattern = re.compile("\\s?(.*?):\\s*(.*?)\\s+(?:\\(padding:(.+?)\\))?")
21+
ndim = grid.attrs["topology_dimension"]
22+
axes_names = ["X", "Y", "Z"][:ndim]
23+
axes = dict(
24+
zip(
25+
axes_names,
26+
({k} for k in grid.attrs["node_dimensions"].split(" ")),
27+
)
28+
)
29+
for attr in SGRID_DIM_ATTRS:
30+
if attr in grid.attrs:
31+
matches = re.findall(pattern, grid.attrs[attr] + "\n")
32+
assert len(matches) == ndim, matches
33+
for ax, match in zip(axes_names, matches):
34+
axes[ax].update(set(match[:2]))
35+
36+
if ndim == 2 and "vertical_dimensions" in grid.attrs:
37+
matches = re.findall(pattern, grid.attrs["vertical_dimensions"] + "\n")
38+
assert len(matches) == 1
39+
axes["Z"] = set(matches[0][:2])
40+
41+
return axes

0 commit comments

Comments
 (0)