Skip to content

Commit ee2f904

Browse files
FBumannclaude
andcommitted
feat(imshow): support facet_row for subplot rows (plotly>=6.7)
px.imshow gained facet_row support in plotly 6.7.0. Add it to the imshow slot order (y, x, facet_col, facet_row, animation_frame), consistent with all other plot types. On older plotly versions, using facet_row raises an informative error; imshow without facet_row keeps working. Note: 4D DataArrays now auto-assign their fourth dimension to facet_row instead of animation_frame. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
1 parent c889e9a commit ee2f904

6 files changed

Lines changed: 167 additions & 9 deletions

File tree

docs/examples/dimensions.ipynb

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@
134134
"| line | x, color, line_dash, facet_col, facet_row, animation_frame |\n",
135135
"| scatter | x, color, symbol, facet_col, facet_row, animation_frame |\n",
136136
"| bar | x, color, facet_col, facet_row, animation_frame |\n",
137-
"| imshow | x, y, facet_col, facet_row, animation_frame |"
137+
"| imshow | y, x, facet_col, facet_row, animation_frame |"
138138
]
139139
},
140140
{
@@ -186,6 +186,16 @@
186186
"xpx(data_3d).line(x=\"year\", facet_col=\"metric\", facet_row=\"country\")"
187187
]
188188
},
189+
{
190+
"cell_type": "code",
191+
"execution_count": null,
192+
"metadata": {},
193+
"outputs": [],
194+
"source": [
195+
"# Heatmap grids: imshow supports facet_col and facet_row (requires plotly>=6.7)\n",
196+
"xpx(data_3d).imshow(facet_row=\"metric\")"
197+
]
198+
},
189199
{
190200
"cell_type": "markdown",
191201
"metadata": {},

tests/test_accessor.py

Lines changed: 96 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@
99
import xarray as xr
1010

1111
import xarray_plotly # noqa: F401 - registers accessor
12-
from xarray_plotly import xpx
12+
from xarray_plotly import plotting, xpx
13+
from xarray_plotly.plotting import _imshow_supports_facet_row
1314

1415

1516
class TestXpxFunction:
@@ -403,6 +404,100 @@ def test_imshow_animation_consistent_bounds(self) -> None:
403404
assert coloraxis.cmax == 70.0
404405

405406

407+
requires_imshow_facet_row = pytest.mark.skipif(
408+
not _imshow_supports_facet_row(),
409+
reason="facet_row in px.imshow requires plotly>=6.7.0",
410+
)
411+
412+
413+
class TestImshowFaceting:
414+
"""Tests for imshow facet_col and facet_row."""
415+
416+
@pytest.fixture(autouse=True)
417+
def setup(self) -> None:
418+
"""Set up test data."""
419+
self.da_3d = xr.DataArray(
420+
np.random.rand(4, 5, 3),
421+
dims=["lat", "lon", "scenario"],
422+
coords={"scenario": ["a", "b", "c"]},
423+
name="temperature",
424+
)
425+
self.da_4d = xr.DataArray(
426+
np.random.rand(4, 5, 2, 3),
427+
dims=["lat", "lon", "scenario", "year"],
428+
coords={"scenario": ["low", "high"], "year": [2020, 2021, 2022]},
429+
name="temperature",
430+
)
431+
432+
def test_imshow_facet_col(self) -> None:
433+
"""Test imshow with facet_col creates one subplot per value."""
434+
fig = self.da_3d.plotly.imshow()
435+
assert len(fig.data) == 3
436+
xaxes = [k for k in fig.layout if k.startswith("xaxis")]
437+
assert len(xaxes) == 3
438+
439+
@requires_imshow_facet_row
440+
def test_imshow_facet_row_explicit(self) -> None:
441+
"""Test imshow with explicit facet_row creates one subplot row per value."""
442+
fig = self.da_3d.plotly.imshow(facet_col=None, facet_row="scenario")
443+
assert len(fig.data) == 3
444+
yaxes = [k for k in fig.layout if k.startswith("yaxis")]
445+
assert len(yaxes) == 3
446+
annotations = {a.text for a in fig.layout.annotations}
447+
assert annotations == {"scenario=a", "scenario=b", "scenario=c"}
448+
449+
@requires_imshow_facet_row
450+
def test_imshow_facet_row_auto_4d(self) -> None:
451+
"""Test that a 4D array auto-assigns facet_col and facet_row."""
452+
fig = self.da_4d.plotly.imshow()
453+
# 2 facet columns (scenario) x 3 facet rows (year)
454+
assert len(fig.data) == 6
455+
annotations = {a.text for a in fig.layout.annotations}
456+
assert annotations == {
457+
"scenario=low",
458+
"scenario=high",
459+
"year=2020",
460+
"year=2021",
461+
"year=2022",
462+
}
463+
464+
@requires_imshow_facet_row
465+
def test_imshow_facet_grid_consistent_bounds(self) -> None:
466+
"""Test that facet grid subplots share global color bounds."""
467+
da = xr.DataArray(
468+
np.arange(24, dtype=float).reshape(2, 2, 2, 3),
469+
dims=["y", "x", "scenario", "year"],
470+
)
471+
fig = da.plotly.imshow()
472+
coloraxis = fig.layout.coloraxis
473+
assert coloraxis.cmin == 0.0
474+
assert coloraxis.cmax == 23.0
475+
476+
@requires_imshow_facet_row
477+
def test_imshow_facet_grid_with_animation(self) -> None:
478+
"""Test imshow with facet_col, facet_row, and animation_frame together."""
479+
da = xr.DataArray(
480+
np.random.rand(4, 5, 2, 3, 6),
481+
dims=["lat", "lon", "scenario", "year", "time"],
482+
name="temperature",
483+
)
484+
fig = da.plotly.imshow()
485+
assert len(fig.data) == 6
486+
assert len(fig.frames) == 6
487+
488+
def test_imshow_facet_row_unsupported_error(self, monkeypatch: pytest.MonkeyPatch) -> None:
489+
"""Test informative error when plotly is too old for facet_row."""
490+
monkeypatch.setattr(plotting, "_imshow_supports_facet_row", lambda: False)
491+
with pytest.raises(ValueError, match=r"facet_row for imshow requires plotly>=6\.7\.0"):
492+
self.da_3d.plotly.imshow(facet_col=None, facet_row="scenario")
493+
494+
def test_imshow_facet_row_none_on_old_plotly(self, monkeypatch: pytest.MonkeyPatch) -> None:
495+
"""Test that imshow still works on old plotly when facet_row is not used."""
496+
monkeypatch.setattr(plotting, "_imshow_supports_facet_row", lambda: False)
497+
fig = self.da_3d.plotly.imshow(facet_row=None)
498+
assert isinstance(fig, go.Figure)
499+
500+
406501
class TestColorsParameter:
407502
"""Tests for the unified colors parameter."""
408503

tests/test_common.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,22 @@ def test_auto_assignment_imshow(self) -> None:
2020
slots = assign_slots(["lat", "lon"], "imshow")
2121
assert slots == {"y": "lat", "x": "lon"}
2222

23+
def test_auto_assignment_imshow_4d(self) -> None:
24+
"""Test that the fourth dimension fills facet_row for imshow."""
25+
slots = assign_slots(["lat", "lon", "scenario", "year"], "imshow")
26+
assert slots == {"y": "lat", "x": "lon", "facet_col": "scenario", "facet_row": "year"}
27+
28+
def test_auto_assignment_imshow_5d(self) -> None:
29+
"""Test that the fifth dimension fills animation_frame for imshow."""
30+
slots = assign_slots(["lat", "lon", "scenario", "year", "time"], "imshow")
31+
assert slots == {
32+
"y": "lat",
33+
"x": "lon",
34+
"facet_col": "scenario",
35+
"facet_row": "year",
36+
"animation_frame": "time",
37+
}
38+
2339
def test_auto_assignment_scatter(self) -> None:
2440
"""Test automatic positional assignment for scatter plots."""
2541
slots = assign_slots(["x_dim", "color_dim"], "scatter")
@@ -124,5 +140,6 @@ def test_imshow_slot_order(self) -> None:
124140
"y",
125141
"x",
126142
"facet_col",
143+
"facet_row",
127144
"animation_frame",
128145
)

xarray_plotly/accessor.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -302,14 +302,15 @@ def imshow(
302302
x: SlotValue = auto,
303303
y: SlotValue = auto,
304304
facet_col: SlotValue = auto,
305+
facet_row: SlotValue = auto,
305306
animation_frame: SlotValue = auto,
306307
robust: bool = False,
307308
colors: Colors = None,
308309
**px_kwargs: Any,
309310
) -> go.Figure:
310311
"""Create an interactive heatmap image.
311312
312-
Slot order: y (rows) -> x (columns) -> facet_col -> animation_frame
313+
Slot order: y (rows) -> x (columns) -> facet_col -> facet_row -> animation_frame
313314
314315
Note:
315316
**Difference from px.imshow**: Color bounds are computed from the
@@ -320,7 +321,9 @@ def imshow(
320321
x: Dimension for x-axis (columns). Default: second dimension.
321322
y: Dimension for y-axis (rows). Default: first dimension.
322323
facet_col: Dimension for subplot columns. Default: third dimension.
323-
animation_frame: Dimension for animation. Default: fourth dimension.
324+
facet_row: Dimension for subplot rows. Default: fourth dimension.
325+
Requires plotly>=6.7.0.
326+
animation_frame: Dimension for animation. Default: fifth dimension.
324327
robust: If True, use 2nd/98th percentiles for color bounds (handles outliers).
325328
colors: Color scale name (e.g., "Viridis", "RdBu"). See module docs.
326329
**px_kwargs: Additional arguments passed to `plotly.express.imshow()`.
@@ -334,6 +337,7 @@ def imshow(
334337
x=x,
335338
y=y,
336339
facet_col=facet_col,
340+
facet_row=facet_row,
337341
animation_frame=animation_frame,
338342
robust=robust,
339343
colors=colors,

xarray_plotly/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
"facet_row",
4444
"animation_frame",
4545
),
46-
"imshow": ("y", "x", "facet_col", "animation_frame"),
46+
"imshow": ("y", "x", "facet_col", "facet_row", "animation_frame"),
4747
"box": ("x", "color", "facet_col", "facet_row", "animation_frame"),
4848
"pie": ("names", "facet_col", "facet_row"),
4949
}

xarray_plotly/plotting.py

Lines changed: 36 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
from __future__ import annotations
66

7+
import inspect
78
import warnings
89
from typing import TYPE_CHECKING, Any
910

@@ -614,12 +615,21 @@ def scatter(
614615
)
615616

616617

618+
def _imshow_supports_facet_row() -> bool:
619+
"""Check whether the installed plotly version supports facet_row in px.imshow.
620+
621+
Support was added in plotly 6.7.0.
622+
"""
623+
return "facet_row" in inspect.signature(px.imshow).parameters
624+
625+
617626
def imshow(
618627
darray: DataArray,
619628
*,
620629
x: SlotValue = auto,
621630
y: SlotValue = auto,
622631
facet_col: SlotValue = auto,
632+
facet_row: SlotValue = auto,
623633
animation_frame: SlotValue = auto,
624634
robust: bool = False,
625635
colors: Colors = None,
@@ -629,7 +639,7 @@ def imshow(
629639
Create an interactive heatmap from a DataArray.
630640
631641
Both x and y are dimensions. Dimensions fill slots in order:
632-
y (rows) -> x (columns) -> facet_col -> animation_frame
642+
y (rows) -> x (columns) -> facet_col -> facet_row -> animation_frame
633643
634644
.. note::
635645
**Difference from plotly.express.imshow**: By default, color bounds
@@ -649,8 +659,12 @@ def imshow(
649659
Dimension for y-axis (rows). Default: first dimension.
650660
facet_col
651661
Dimension for subplot columns. Default: third dimension.
662+
facet_row
663+
Dimension for subplot rows. Default: fourth dimension.
664+
Requires plotly>=6.7.0. Note: ``facet_col_wrap`` is ignored by
665+
plotly when ``facet_row`` is set.
652666
animation_frame
653-
Dimension for animation. Default: fourth dimension.
667+
Dimension for animation. Default: fifth dimension.
654668
robust
655669
If True, compute color bounds using 2nd and 98th percentiles
656670
for robustness against outliers. Default: False (uses min/max).
@@ -674,12 +688,29 @@ def imshow(
674688
y=y,
675689
x=x,
676690
facet_col=facet_col,
691+
facet_row=facet_row,
677692
animation_frame=animation_frame,
678693
)
679694

680-
# Transpose to: y (rows), x (cols), facet_col, animation_frame
695+
facet_row_kwargs: dict[str, Any] = {}
696+
if slots.get("facet_row") is not None:
697+
if not _imshow_supports_facet_row():
698+
import plotly
699+
700+
msg = (
701+
f"facet_row for imshow requires plotly>=6.7.0 "
702+
f"(installed: {plotly.__version__}). "
703+
"Upgrade plotly, or pass facet_row=None and use "
704+
"facet_col/animation_frame instead."
705+
)
706+
raise ValueError(msg)
707+
facet_row_kwargs["facet_row"] = slots["facet_row"]
708+
709+
# Transpose to: y (rows), x (cols), facet_col, facet_row, animation_frame
681710
transpose_order = [
682-
slots[k] for k in ("y", "x", "facet_col", "animation_frame") if slots.get(k) is not None
711+
slots[k]
712+
for k in ("y", "x", "facet_col", "facet_row", "animation_frame")
713+
if slots.get(k) is not None
683714
]
684715
plot_data = darray.transpose(*transpose_order) if transpose_order else darray
685716

@@ -701,6 +732,7 @@ def imshow(
701732
plot_data,
702733
facet_col=slots.get("facet_col"),
703734
animation_frame=slots.get("animation_frame"),
735+
**facet_row_kwargs,
704736
**px_kwargs,
705737
)
706738

0 commit comments

Comments
 (0)