Skip to content

Commit 667d0e5

Browse files
authored
Merge pull request #3 from scipp/colorbar-refactor
Make custom svg for colorbar instead of rendering using matplotlib
2 parents 136a8cc + ccb935d commit 667d0e5

File tree

9 files changed

+424
-75
lines changed

9 files changed

+424
-75
lines changed

docs/user-guide/imshow.ipynb

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,18 @@
1919
"import numpy as np"
2020
]
2121
},
22+
{
23+
"cell_type": "markdown",
24+
"id": "2",
25+
"metadata": {},
26+
"source": [
27+
"## Basic imshow"
28+
]
29+
},
2230
{
2331
"cell_type": "code",
2432
"execution_count": null,
25-
"id": "2",
33+
"id": "3",
2634
"metadata": {},
2735
"outputs": [],
2836
"source": [
@@ -34,6 +42,29 @@
3442
"\n",
3543
"fig"
3644
]
45+
},
46+
{
47+
"cell_type": "markdown",
48+
"id": "4",
49+
"metadata": {},
50+
"source": [
51+
"## With a colorbar"
52+
]
53+
},
54+
{
55+
"cell_type": "code",
56+
"execution_count": null,
57+
"id": "5",
58+
"metadata": {},
59+
"outputs": [],
60+
"source": [
61+
"fig, ax = plt.subplots()\n",
62+
"\n",
63+
"im = ax.imshow(np.random.random((20, 20)))\n",
64+
"fig.colorbar(im)\n",
65+
"\n",
66+
"fig"
67+
]
3768
}
3869
],
3970
"metadata": {

src/matplotgl/axes.py

Lines changed: 27 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -352,13 +352,15 @@ def _make_xticks(self):
352352

353353
# width = f"calc({self.width}px + 0.5em)"
354354

355-
bottom_string = (
356-
f'<svg height="calc(1.2em + {tick_length}px + {label_offset}px)" '
357-
f'width="{self.width}"><line x1="0" y1="0" '
358-
# f'width="calc(0.5em + {self.width}px)"><line x1="0" y1="0" '
359-
f'x2="{self.width}" y2="0" '
360-
f'style="stroke:black;stroke-width:{self._spine_linewidth}" />'
361-
)
355+
bottom_string = [
356+
(
357+
f'<svg height="calc(1.2em + {tick_length}px + {label_offset}px)" '
358+
f'width="{self.width}"><line x1="0" y1="0" '
359+
# f'width="calc(0.5em + {self.width}px)"><line x1="0" y1="0" '
360+
f'x2="{self.width}" y2="0" '
361+
f'style="stroke:black;stroke-width:{self._spine_linewidth}" />'
362+
)
363+
]
362364

363365
self._margins["topspine"].value = (
364366
f'<svg height="{self._thin_margin}px" width="{self.width}">'
@@ -371,11 +373,11 @@ def _make_xticks(self):
371373
if tick < 0 or tick > 1.0:
372374
continue
373375
x = tick * self.width
374-
bottom_string += (
376+
bottom_string.append(
375377
f'<line x1="{x}" y1="0" x2="{x}" y2="{tick_length}" '
376378
'style="stroke:black;stroke-width:1" />'
377379
)
378-
bottom_string += (
380+
bottom_string.append(
379381
f'<text x="{x}" y="{tick_length + label_offset}" '
380382
'text-anchor="middle" dominant-baseline="hanging">'
381383
f"{html_to_svg(latex_to_html(label), baseline='hanging')}</text>"
@@ -390,13 +392,13 @@ def _make_xticks(self):
390392
if tick < 0 or tick > 1.0:
391393
continue
392394
x = tick * self.width
393-
bottom_string += (
395+
bottom_string.append(
394396
f'<line x1="{x}" y1="0" x2="{x}" y2="{tick_length * 0.7}" '
395397
'style="stroke:black;stroke-width:0.5" />'
396398
)
397399

398-
bottom_string += "</svg></div>"
399-
self._margins["bottomspine"].value = bottom_string
400+
bottom_string.append("</svg></div>")
401+
self._margins["bottomspine"].value = "".join(bottom_string)
400402

401403
def _make_yticks(self):
402404
"""
@@ -424,12 +426,14 @@ def _make_yticks(self):
424426
width2 = f"calc({max_length}px)"
425427
width3 = f"calc({max_length}px + {tick_length * 0.3}px + {label_offset}px)"
426428

427-
left_string = (
428-
f'<svg height="{self.height}" width="{width}">'
429-
f'<line x1="{width}" y1="0" '
430-
f'x2="{width}" y2="{self.height}" '
431-
f'style="stroke:black;stroke-width:{self._spine_linewidth}" />'
432-
)
429+
left_string = [
430+
(
431+
f'<svg height="{self.height}" width="{width}">'
432+
f'<line x1="{width}" y1="0" '
433+
f'x2="{width}" y2="{self.height}" '
434+
f'style="stroke:black;stroke-width:{self._spine_linewidth}" />'
435+
)
436+
]
433437

434438
self._margins["rightspine"].value = (
435439
f'<svg height="{self.height}" width="{self._thin_margin}">'
@@ -441,13 +445,13 @@ def _make_yticks(self):
441445
if tick < 0 or tick > 1.0:
442446
continue
443447
y = self.height - (tick * self.height)
444-
left_string += (
448+
left_string.append(
445449
f'<line x1="{width}" y1="{y}" '
446450
f'x2="{width1}" y2="{y}" '
447451
'style="stroke:black;stroke-width:1" />'
448452
)
449453

450-
left_string += (
454+
left_string.append(
451455
f'<text x="{width2}" '
452456
f'y="{y}" text-anchor="end" dominant-baseline="middle">'
453457
f"{html_to_svg(latex_to_html(label), baseline='middle')}</text>"
@@ -462,14 +466,14 @@ def _make_yticks(self):
462466
if tick < 0 or tick > 1.0:
463467
continue
464468
y = self.height - (tick * self.height)
465-
left_string += (
469+
left_string.append(
466470
f'<line x1="{width}" y1="{y}" '
467471
f'x2="{width3}" y2="{y}" '
468472
'style="stroke:black;stroke-width:0.5" />'
469473
)
470474

471-
left_string += "</svg></div>"
472-
self._margins["leftspine"].value = left_string
475+
left_string.append("</svg></div>")
476+
self._margins["leftspine"].value = "".join(left_string)
473477

474478
def get_xlim(self):
475479
return self._xmin, self._xmax

src/matplotgl/colorbar.py

Lines changed: 88 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,40 +1,99 @@
1-
from matplotlib.colorbar import ColorbarBase
2-
from matplotlib.pyplot import Figure
3-
4-
5-
def fig_to_bytes(fig) -> bytes:
6-
"""
7-
Convert a Matplotlib figure to svg bytes.
1+
# SPDX-License-Identifier: BSD-3-Clause
82

9-
Parameters
10-
----------
11-
fig:
12-
The figure to be converted.
13-
"""
14-
from io import BytesIO
15-
16-
buf = BytesIO()
17-
fig.savefig(buf, format="svg", bbox_inches="tight")
18-
buf.seek(0)
19-
return buf.getvalue()
3+
import matplotlib as mpl
4+
import numpy as np
5+
from matplotlib import colors as cm
6+
from matplotlib.pyplot import Figure
207

8+
from .utils import html_to_svg, latex_to_html
219

22-
# def make_colorbar(mappable, height_inches: float) -> str:
23-
# fig = Figure(figsize=(height_inches * 0.2, height_inches))
24-
# cax = fig.add_axes([0.05, 0.02, 0.2, 0.98])
25-
# ColorbarBase(cax, cmap=mappable.cmap) # , norm=self.normalizer)
26-
# return fig_to_bytes(fig).decode()
10+
mpl.use("Agg")
2711

2812

2913
class Colorbar:
30-
def __init__(self, widget, mappable, height_inches: float):
14+
def __init__(self, widget, mappable, height: float):
3115
self._mappable = mappable
32-
self._height_inches = height_inches
16+
self._height = height
3317
self._widget = widget
3418

3519
def update(self):
3620
self._mappable._update_colors()
37-
fig = Figure(figsize=(self._height_inches * 0.2, self._height_inches))
38-
cax = fig.add_axes([0.05, 0.02, 0.2, 0.98])
39-
ColorbarBase(cax, cmap=self._mappable.cmap, norm=self._mappable.norm._norm)
40-
self._widget.value = fig_to_bytes(fig).decode()
21+
fig = Figure()
22+
cax = fig.add_subplot(111)
23+
norm = self._mappable._norm._norm
24+
cax.set_ylim(norm.vmin, norm.vmax)
25+
cax.set_yscale('log' if isinstance(norm, cm.LogNorm) else 'linear')
26+
27+
# Generate colors
28+
n_colors = 128
29+
segment_height = self._height / n_colors
30+
bar_width = 18
31+
bar_left = 20
32+
bar_top = 3
33+
34+
# Build SVG
35+
svg_parts = [
36+
f'<svg width="{100}" height="{self._height}">',
37+
]
38+
39+
# Draw color segments
40+
for i in range(n_colors):
41+
y = (n_colors - i - 1) * segment_height
42+
color = cm.to_hex(self._mappable.cmap(i / (n_colors - 1)))
43+
svg_parts.append(
44+
f' <rect x="{bar_left}" y="{y + bar_top}" width="{bar_width}" '
45+
f'height="{segment_height}" fill="{color}" />'
46+
)
47+
48+
# Draw colorbar border
49+
svg_parts.append(
50+
f' <rect x="{bar_left}" y="{bar_top}" width="{bar_width}" '
51+
f'height="{self._height}" fill="none" stroke="black" stroke-width="1"/>'
52+
)
53+
54+
yticks = cax.get_yticks()
55+
ylabels = cax.get_yticklabels()
56+
ytexts = [lab.get_text() for lab in ylabels]
57+
tick_length = 6
58+
label_offset = 3
59+
60+
xy = np.vstack((np.zeros_like(yticks), yticks)).T
61+
62+
inv_trans_axes = cax.transAxes.inverted()
63+
trans_data = cax.transData
64+
yticks_axes = inv_trans_axes.transform(trans_data.transform(xy))[:, 1]
65+
66+
for tick, label in zip(yticks_axes, ytexts, strict=True):
67+
if tick < 0 or tick > 1.0:
68+
continue
69+
y = self._height - (tick * self._height) + bar_top
70+
svg_parts.append(
71+
f'<line x1="{bar_left+bar_width}" y1="{y}" '
72+
f'x2="{bar_left+bar_width+tick_length}" y2="{y}" '
73+
'style="stroke:black;stroke-width:1" />'
74+
)
75+
76+
svg_parts.append(
77+
f'<text x="{bar_left+bar_width+tick_length+label_offset}" '
78+
f'y="{y}" text-anchor="start" dominant-baseline="middle">'
79+
f"{html_to_svg(latex_to_html(label), baseline='middle')}</text>"
80+
)
81+
82+
minor_ticks = cax.yaxis.get_minorticklocs()
83+
if len(minor_ticks) > 0:
84+
xy = np.vstack((np.zeros_like(minor_ticks), minor_ticks)).T
85+
yticks_axes = inv_trans_axes.transform(trans_data.transform(xy))[:, 1]
86+
87+
for tick in yticks_axes:
88+
if tick < 0 or tick > 1.0:
89+
continue
90+
y = self._height - (tick * self._height) + bar_top
91+
svg_parts.append(
92+
f'<line x1="{bar_left+bar_width}" y1="{y}" '
93+
f'x2="{bar_left+bar_width+tick_length * 0.6}" y2="{y}" '
94+
'style="stroke:black;stroke-width:0.5" />'
95+
)
96+
97+
svg_parts.append('</svg>')
98+
99+
self._widget.value = '\n'.join(svg_parts)

src/matplotgl/figure.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -106,10 +106,9 @@ def colorbar(self, mappable, ax=None):
106106
cb = Colorbar(
107107
widget=ax._margins["colorbar"],
108108
mappable=mappable,
109-
height_inches=ax.height / self._dpi,
109+
height=ax.height,
110110
)
111111
cb.update()
112112
mappable._colorbar = cb
113113
mappable._norm._colorbar = cb
114114
return cb
115-
# ax._margins["colorbar"].value = cb_svg

src/matplotgl/image.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# SPDX-License-Identifier: BSD-3-Clause
22

33
import matplotlib as mpl
4+
import matplotlib.colors as cm
45
import numpy as np
56
import pythreejs as p3
67

@@ -13,6 +14,7 @@ def __init__(
1314
array: np.ndarray,
1415
extent: list[float] | None = None,
1516
cmap: str = "viridis",
17+
norm: str = "linear",
1618
zorder: float = 0,
1719
):
1820
self.axes = None
@@ -22,7 +24,9 @@ def __init__(
2224
extent if extent is not None else [0, array.shape[1], 0, array.shape[0]]
2325
)
2426
self._zorder = zorder
25-
self._norm = Normalizer(vmin=np.min(self._array), vmax=np.max(self._array))
27+
self._norm = Normalizer(
28+
vmin=np.min(self._array), vmax=np.max(self._array), norm=norm
29+
)
2630
self._cmap = mpl.colormaps[cmap].copy()
2731
self._texture = p3.DataTexture(
2832
data=self._make_colors(), format="RGBFormat", type="FloatType"
@@ -97,12 +101,12 @@ def set_extent(self, extent: list[float]) -> None:
97101

98102
def set_cmap(self, cmap: str) -> None:
99103
self._cmap = mpl.colormaps[cmap].copy()
100-
self._texture.data = self._make_colors()
104+
self._update_colors()
101105
if self._colorbar is not None:
102106
self._colorbar.update()
103107

104108
@property
105-
def cmap(self) -> mpl.colormaps:
109+
def cmap(self) -> cm.Colormap:
106110
return self._cmap
107111

108112
@cmap.setter
@@ -121,6 +125,6 @@ def norm(self, norm: Normalizer | str) -> None:
121125
)
122126
else:
123127
self._norm = norm
124-
self._texture.data = self._make_colors()
128+
self._update_colors()
125129
if self._colorbar is not None:
126130
self._colorbar.update()

0 commit comments

Comments
 (0)