Skip to content

Commit 1d8fe9d

Browse files
authored
Add rich repr (#409)
* Add rich repr Closes #393 * Fixes. * Cleanup + add tests * Try again * Fix tests * revert unnecessary change * Right justify * de-emphasize empty rows; highlight CF names more * Refactor * Small cleanup * Small tweaks + solarized colors * fix tests. * fix typing * Fix test * Add back type ignores. Revert "fix typing" This reverts commit 0f9c57e. * Update docs * Add as optional dep * Add image. * WIP conventions * Merge upstream/main * Add whats-new note
1 parent 5356069 commit 1d8fe9d

File tree

11 files changed

+409
-190
lines changed

11 files changed

+409
-190
lines changed

cf_xarray/accessor.py

Lines changed: 44 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,14 @@
3636
grid_mapping_var_criteria,
3737
regex,
3838
)
39+
from .formatting import (
40+
_format_coordinates,
41+
# _format_conventions,
42+
_format_data_vars,
43+
_format_flags,
44+
_format_roles,
45+
_maybe_panel,
46+
)
3947
from .helpers import _guess_bounds_1d, _guess_bounds_2d, bounds_to_vertices
4048
from .options import OPTIONS
4149
from .utils import (
@@ -602,6 +610,11 @@ def _getattr(
602610
An extra decorator, if necessary. This is used by _CFPlotMethods to set default
603611
kwargs based on CF attributes.
604612
"""
613+
614+
# UGH. this seems unavoidable because I'm overriding getattr
615+
if attr in ["_repr_html_", "__rich__", "__rich_console__"]:
616+
raise AttributeError
617+
605618
try:
606619
attribute: Mapping | Callable = getattr(obj, attr)
607620
except AttributeError:
@@ -669,7 +682,9 @@ def wrapper(*args, **kwargs):
669682

670683
return result
671684

672-
wrapper.__doc__ = _build_docstring(func) + wrapper.__doc__
685+
# handle rich
686+
if wrapper.__doc__:
687+
wrapper.__doc__ = _build_docstring(func) + wrapper.__doc__
673688

674689
return wrapper
675690

@@ -1399,90 +1414,45 @@ def describe(self):
13991414
print(repr(self))
14001415

14011416
def __repr__(self):
1417+
return ("".join(self._generate_repr(rich=False))).rstrip()
14021418

1403-
coords = self._obj.coords
1404-
dims = self._obj.dims
1405-
1406-
def make_text_section(subtitle, attr, valid_values=None, default_keys=None):
1407-
1408-
with warnings.catch_warnings():
1409-
warnings.simplefilter("ignore")
1410-
try:
1411-
vardict = getattr(self, attr, {})
1412-
except ValueError:
1413-
vardict = {}
1414-
1415-
star = " * "
1416-
tab = len(star) * " "
1417-
subtitle = f"- {subtitle}:"
1418-
1419-
# Sort keys if there aren't extra keys,
1420-
# preserve default keys order otherwise.
1421-
default_keys = [] if not default_keys else list(default_keys)
1422-
extra_keys = list(set(vardict) - set(default_keys))
1423-
ordered_keys = sorted(vardict) if extra_keys else default_keys
1424-
vardict = {key: vardict[key] for key in ordered_keys if key in vardict}
1425-
1426-
# Keep only valid values (e.g., coords or data_vars)
1427-
if valid_values is not None:
1428-
vardict = {
1429-
key: set(value).intersection(valid_values)
1430-
for key, value in vardict.items()
1431-
if set(value).intersection(valid_values)
1432-
}
1419+
def __rich__(self):
1420+
from rich.console import Group
14331421

1434-
# Star for keys with dims only, tab otherwise
1435-
rows = [
1436-
f"{star if set(value) <= set(dims) else tab}{key}: {sorted(value)}"
1437-
for key, value in vardict.items()
1438-
]
1422+
return Group(*self._generate_repr(rich=True))
14391423

1440-
# Append missing default keys followed by n/a
1441-
if default_keys:
1442-
missing_keys = [key for key in default_keys if key not in vardict]
1443-
if missing_keys:
1444-
rows += [tab + ", ".join(missing_keys) + ": n/a"]
1445-
elif not rows:
1446-
rows = [tab + "n/a"]
1447-
1448-
# Add subtitle to the first row, align other rows
1449-
rows = [
1450-
"\n" + subtitle + row if i == 0 else len(subtitle) * " " + row
1451-
for i, row in enumerate(rows)
1452-
]
1424+
def _generate_repr(self, rich=False):
1425+
dims = self._obj.dims
1426+
coords = self._obj.coords
14531427

1454-
return "\n".join(rows) + "\n"
1428+
# if self._obj._attrs:
1429+
# conventions = self._obj.attrs.pop("Conventions", None)
1430+
# if conventions:
1431+
# yield _format_conventions(conventions, rich)
14551432

14561433
if isinstance(self._obj, DataArray) and self._obj.cf.is_flag_variable:
1457-
flag_dict = create_flag_dict(self._obj)
1458-
text = f"CF Flag variable with mapping:\n\t{flag_dict!r}\n\n"
1459-
else:
1460-
text = ""
1434+
yield _maybe_panel(
1435+
_format_flags(self, rich), title="Flag Variable", rich=rich
1436+
)
14611437

14621438
if self.cf_roles:
1463-
text += make_text_section("CF Roles", "cf_roles")
1464-
text += "\n"
1465-
1466-
text += "Coordinates:"
1467-
text += make_text_section("CF Axes", "axes", coords, _AXIS_NAMES)
1468-
text += make_text_section("CF Coordinates", "coordinates", coords, _COORD_NAMES)
1469-
text += make_text_section(
1470-
"Cell Measures", "cell_measures", coords, _CELL_MEASURES
1439+
yield _maybe_panel(
1440+
_format_roles(self, dims, rich),
1441+
title="Discrete Sampling Geometry",
1442+
rich=rich,
1443+
)
1444+
1445+
yield _maybe_panel(
1446+
_format_coordinates(self, dims, coords, rich),
1447+
title="Coordinates",
1448+
rich=rich,
14711449
)
1472-
text += make_text_section("Standard Names", "standard_names", coords)
1473-
text += make_text_section("Bounds", "bounds", coords)
14741450
if isinstance(self._obj, Dataset):
1475-
text += make_text_section("Grid Mappings", "grid_mapping_names", coords)
1476-
data_vars = self._obj.data_vars
1477-
text += "\nData Variables:"
1478-
text += make_text_section(
1479-
"Cell Measures", "cell_measures", data_vars, _CELL_MEASURES
1451+
yield _maybe_panel(
1452+
_format_data_vars(self, self._obj.data_vars, rich),
1453+
title="Data Variables",
1454+
rich=rich,
14801455
)
1481-
text += make_text_section("Standard Names", "standard_names", data_vars)
1482-
text += make_text_section("Bounds", "bounds", data_vars)
1483-
text += make_text_section("Grid Mappings", "grid_mapping_names", data_vars)
1484-
1485-
return text
14861456

14871457
def keys(self) -> set[Hashable]:
14881458
"""

cf_xarray/formatting.py

Lines changed: 194 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,194 @@
1+
import warnings
2+
from typing import Dict, Hashable, Iterable, List
3+
4+
STAR = " * "
5+
TAB = len(STAR) * " "
6+
7+
8+
def _format_missing_row(row: str, rich: bool) -> str:
9+
if rich:
10+
return f"[grey62]{row}[/grey62]"
11+
else:
12+
return row
13+
14+
15+
def _format_varname(name, rich: bool):
16+
return name
17+
18+
19+
def _format_subtitle(name: str, rich: bool) -> str:
20+
if rich:
21+
return f"[bold]{name}[/bold]"
22+
else:
23+
return name
24+
25+
26+
def _format_cf_name(name: str, rich: bool) -> str:
27+
if rich:
28+
return f"[color(33)]{name}[/color(33)]"
29+
else:
30+
return name
31+
32+
33+
def make_text_section(
34+
accessor,
35+
subtitle: str,
36+
attr: str,
37+
dims=None,
38+
valid_values=None,
39+
default_keys=None,
40+
rich: bool = False,
41+
):
42+
43+
from .accessor import sort_maybe_hashable
44+
45+
if dims is None:
46+
dims = []
47+
with warnings.catch_warnings():
48+
warnings.simplefilter("ignore")
49+
try:
50+
vardict: Dict[str, Iterable[Hashable]] = getattr(accessor, attr, {})
51+
except ValueError:
52+
vardict = {}
53+
54+
# Sort keys if there aren't extra keys,
55+
# preserve default keys order otherwise.
56+
default_keys = [] if not default_keys else list(default_keys)
57+
extra_keys = list(set(vardict) - set(default_keys))
58+
ordered_keys = sorted(vardict) if extra_keys else default_keys
59+
vardict = {key: vardict[key] for key in ordered_keys if key in vardict}
60+
61+
# Keep only valid values (e.g., coords or data_vars)
62+
if valid_values is not None:
63+
vardict = {
64+
key: set(value).intersection(valid_values)
65+
for key, value in vardict.items()
66+
if set(value).intersection(valid_values)
67+
}
68+
69+
# Star for keys with dims only, tab otherwise
70+
rows = [
71+
(
72+
f"{STAR if dims and set(value) <= set(dims) else TAB}"
73+
f"{_format_cf_name(key, rich)}: "
74+
f"{_format_varname(sort_maybe_hashable(value), rich)}"
75+
)
76+
for key, value in vardict.items()
77+
]
78+
79+
# Append missing default keys followed by n/a
80+
if default_keys:
81+
missing_keys = [key for key in default_keys if key not in vardict]
82+
if missing_keys:
83+
rows.append(
84+
_format_missing_row(TAB + ", ".join(missing_keys) + ": n/a", rich)
85+
)
86+
elif not rows:
87+
rows.append(_format_missing_row(TAB + "n/a", rich))
88+
89+
return _print_rows(subtitle, rows, rich)
90+
91+
92+
def _print_rows(subtitle: str, rows: List[str], rich: bool):
93+
subtitle = f"{subtitle.rjust(20)}:"
94+
95+
# Add subtitle to the first row, align other rows
96+
rows = [
97+
_format_subtitle(subtitle, rich=rich) + row
98+
if i == 0
99+
else len(subtitle) * " " + row
100+
for i, row in enumerate(rows)
101+
]
102+
103+
return "\n".join(rows) + "\n\n"
104+
105+
106+
def _format_conventions(string: str, rich: bool):
107+
row = _print_rows(
108+
subtitle="Conventions",
109+
rows=[_format_cf_name(TAB + string, rich=rich)],
110+
rich=rich,
111+
)
112+
if rich:
113+
row = row.rstrip()
114+
return row
115+
116+
117+
def _maybe_panel(textgen, title: str, rich: bool):
118+
text = "".join(textgen)
119+
if rich:
120+
from rich.panel import Panel
121+
122+
return Panel(
123+
f"[color(241)]{text.rstrip()}[/color(241)]",
124+
expand=True,
125+
title_align="left",
126+
title=f"[bold][color(244)]{title}[/bold][/color(244)]",
127+
highlight=True,
128+
width=100,
129+
)
130+
else:
131+
return title + ":\n" + text
132+
133+
134+
def _format_flags(accessor, rich):
135+
from .accessor import create_flag_dict
136+
137+
flag_dict = create_flag_dict(accessor._obj)
138+
rows = [
139+
f"{TAB}{_format_varname(v, rich)}: {_format_cf_name(k, rich)}"
140+
for k, v in flag_dict.items()
141+
]
142+
return _print_rows("Flag Meanings", rows, rich)
143+
144+
145+
def _format_roles(accessor, dims, rich):
146+
yield make_text_section(accessor, "CF Roles", "cf_roles", dims=dims, rich=rich)
147+
148+
149+
def _format_coordinates(accessor, dims, coords, rich):
150+
from .accessor import _AXIS_NAMES, _CELL_MEASURES, _COORD_NAMES
151+
152+
yield make_text_section(
153+
accessor, "CF Axes", "axes", dims, coords, _AXIS_NAMES, rich=rich
154+
)
155+
yield make_text_section(
156+
accessor, "CF Coordinates", "coordinates", dims, coords, _COORD_NAMES, rich=rich
157+
)
158+
yield make_text_section(
159+
accessor,
160+
"Cell Measures",
161+
"cell_measures",
162+
dims,
163+
coords,
164+
_CELL_MEASURES,
165+
rich=rich,
166+
)
167+
yield make_text_section(
168+
accessor, "Standard Names", "standard_names", dims, coords, rich=rich
169+
)
170+
yield make_text_section(accessor, "Bounds", "bounds", dims, coords, rich=rich)
171+
yield make_text_section(
172+
accessor, "Grid Mappings", "grid_mapping_names", dims, coords, rich=rich
173+
)
174+
175+
176+
def _format_data_vars(accessor, data_vars, rich):
177+
from .accessor import _CELL_MEASURES
178+
179+
yield make_text_section(
180+
accessor,
181+
"Cell Measures",
182+
"cell_measures",
183+
None,
184+
data_vars,
185+
_CELL_MEASURES,
186+
rich=rich,
187+
)
188+
yield make_text_section(
189+
accessor, "Standard Names", "standard_names", None, data_vars, rich=rich
190+
)
191+
yield make_text_section(accessor, "Bounds", "bounds", None, data_vars, rich=rich)
192+
yield make_text_section(
193+
accessor, "Grid Mappings", "grid_mapping_names", None, data_vars, rich=rich
194+
)

cf_xarray/tests/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,4 +67,5 @@ def LooseVersion(vstring):
6767
has_scipy, requires_scipy = _importorskip("scipy")
6868
has_shapely, requires_shapely = _importorskip("shapely")
6969
has_pint, requires_pint = _importorskip("pint")
70+
_, requires_rich = _importorskip("rich")
7071
has_regex, requires_regex = _importorskip("regex")

0 commit comments

Comments
 (0)