Skip to content

Commit f7a1061

Browse files
committed
use custom built svg instead of rendering colorbar using matplotlib
1 parent 136a8cc commit f7a1061

File tree

3 files changed

+116
-54
lines changed

3 files changed

+116
-54
lines changed

src/matplotgl/axes.py

Lines changed: 27 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -352,13 +352,15 @@ def _make_xticks(self):
352352

353353
# width = f"calc({self.width}px + 0.5em)"
354354

355-
bottom_string = (
356-
f'<svg height="calc(1.2em + {tick_length}px + {label_offset}px)" '
357-
f'width="{self.width}"><line x1="0" y1="0" '
358-
# f'width="calc(0.5em + {self.width}px)"><line x1="0" y1="0" '
359-
f'x2="{self.width}" y2="0" '
360-
f'style="stroke:black;stroke-width:{self._spine_linewidth}" />'
361-
)
355+
bottom_string = [
356+
(
357+
f'<svg height="calc(1.2em + {tick_length}px + {label_offset}px)" '
358+
f'width="{self.width}"><line x1="0" y1="0" '
359+
# f'width="calc(0.5em + {self.width}px)"><line x1="0" y1="0" '
360+
f'x2="{self.width}" y2="0" '
361+
f'style="stroke:black;stroke-width:{self._spine_linewidth}" />'
362+
)
363+
]
362364

363365
self._margins["topspine"].value = (
364366
f'<svg height="{self._thin_margin}px" width="{self.width}">'
@@ -371,11 +373,11 @@ def _make_xticks(self):
371373
if tick < 0 or tick > 1.0:
372374
continue
373375
x = tick * self.width
374-
bottom_string += (
376+
bottom_string.append(
375377
f'<line x1="{x}" y1="0" x2="{x}" y2="{tick_length}" '
376378
'style="stroke:black;stroke-width:1" />'
377379
)
378-
bottom_string += (
380+
bottom_string.append(
379381
f'<text x="{x}" y="{tick_length + label_offset}" '
380382
'text-anchor="middle" dominant-baseline="hanging">'
381383
f"{html_to_svg(latex_to_html(label), baseline='hanging')}</text>"
@@ -390,13 +392,13 @@ def _make_xticks(self):
390392
if tick < 0 or tick > 1.0:
391393
continue
392394
x = tick * self.width
393-
bottom_string += (
395+
bottom_string.append(
394396
f'<line x1="{x}" y1="0" x2="{x}" y2="{tick_length * 0.7}" '
395397
'style="stroke:black;stroke-width:0.5" />'
396398
)
397399

398-
bottom_string += "</svg></div>"
399-
self._margins["bottomspine"].value = bottom_string
400+
bottom_string.append("</svg></div>")
401+
self._margins["bottomspine"].value = "".join(bottom_string)
400402

401403
def _make_yticks(self):
402404
"""
@@ -424,12 +426,14 @@ def _make_yticks(self):
424426
width2 = f"calc({max_length}px)"
425427
width3 = f"calc({max_length}px + {tick_length * 0.3}px + {label_offset}px)"
426428

427-
left_string = (
428-
f'<svg height="{self.height}" width="{width}">'
429-
f'<line x1="{width}" y1="0" '
430-
f'x2="{width}" y2="{self.height}" '
431-
f'style="stroke:black;stroke-width:{self._spine_linewidth}" />'
432-
)
429+
left_string = [
430+
(
431+
f'<svg height="{self.height}" width="{width}">'
432+
f'<line x1="{width}" y1="0" '
433+
f'x2="{width}" y2="{self.height}" '
434+
f'style="stroke:black;stroke-width:{self._spine_linewidth}" />'
435+
)
436+
]
433437

434438
self._margins["rightspine"].value = (
435439
f'<svg height="{self.height}" width="{self._thin_margin}">'
@@ -441,13 +445,13 @@ def _make_yticks(self):
441445
if tick < 0 or tick > 1.0:
442446
continue
443447
y = self.height - (tick * self.height)
444-
left_string += (
448+
left_string.append(
445449
f'<line x1="{width}" y1="{y}" '
446450
f'x2="{width1}" y2="{y}" '
447451
'style="stroke:black;stroke-width:1" />'
448452
)
449453

450-
left_string += (
454+
left_string.append(
451455
f'<text x="{width2}" '
452456
f'y="{y}" text-anchor="end" dominant-baseline="middle">'
453457
f"{html_to_svg(latex_to_html(label), baseline='middle')}</text>"
@@ -462,14 +466,14 @@ def _make_yticks(self):
462466
if tick < 0 or tick > 1.0:
463467
continue
464468
y = self.height - (tick * self.height)
465-
left_string += (
469+
left_string.append(
466470
f'<line x1="{width}" y1="{y}" '
467471
f'x2="{width3}" y2="{y}" '
468472
'style="stroke:black;stroke-width:0.5" />'
469473
)
470474

471-
left_string += "</svg></div>"
472-
self._margins["leftspine"].value = left_string
475+
left_string.append("</svg></div>")
476+
self._margins["leftspine"].value = "".join(left_string)
473477

474478
def get_xlim(self):
475479
return self._xmin, self._xmax

src/matplotgl/colorbar.py

Lines changed: 88 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,40 +1,99 @@
1-
from matplotlib.colorbar import ColorbarBase
2-
from matplotlib.pyplot import Figure
3-
4-
5-
def fig_to_bytes(fig) -> bytes:
6-
"""
7-
Convert a Matplotlib figure to svg bytes.
1+
# SPDX-License-Identifier: BSD-3-Clause
82

9-
Parameters
10-
----------
11-
fig:
12-
The figure to be converted.
13-
"""
14-
from io import BytesIO
15-
16-
buf = BytesIO()
17-
fig.savefig(buf, format="svg", bbox_inches="tight")
18-
buf.seek(0)
19-
return buf.getvalue()
3+
import matplotlib as mpl
4+
import numpy as np
5+
from matplotlib import colors as cm
6+
from matplotlib.pyplot import Figure
207

8+
from .utils import html_to_svg, latex_to_html
219

22-
# def make_colorbar(mappable, height_inches: float) -> str:
23-
# fig = Figure(figsize=(height_inches * 0.2, height_inches))
24-
# cax = fig.add_axes([0.05, 0.02, 0.2, 0.98])
25-
# ColorbarBase(cax, cmap=mappable.cmap) # , norm=self.normalizer)
26-
# return fig_to_bytes(fig).decode()
10+
mpl.use("Agg")
2711

2812

2913
class Colorbar:
30-
def __init__(self, widget, mappable, height_inches: float):
14+
def __init__(self, widget, mappable, height: int | float):
3115
self._mappable = mappable
32-
self._height_inches = height_inches
16+
self._height = height
3317
self._widget = widget
3418

3519
def update(self):
3620
self._mappable._update_colors()
37-
fig = Figure(figsize=(self._height_inches * 0.2, self._height_inches))
38-
cax = fig.add_axes([0.05, 0.02, 0.2, 0.98])
39-
ColorbarBase(cax, cmap=self._mappable.cmap, norm=self._mappable.norm._norm)
40-
self._widget.value = fig_to_bytes(fig).decode()
21+
fig = Figure()
22+
cax = fig.add_subplot(111)
23+
norm = self._mappable._norm._norm
24+
cax.set_ylim(norm.vmin, norm.vmax)
25+
cax.set_yscale('log' if isinstance(norm, cm.LogNorm) else 'linear')
26+
27+
# Generate colors
28+
n_colors = 128
29+
segment_height = self._height / n_colors
30+
bar_width = 18
31+
bar_left = 20
32+
bar_top = 3
33+
34+
# Build SVG
35+
svg_parts = [
36+
f'<svg width="{100}" height="{self._height}">',
37+
]
38+
39+
# Draw color segments
40+
for i in range(n_colors):
41+
y = (n_colors - i - 1) * segment_height
42+
color = cm.to_hex(self._mappable.cmap(i / (n_colors - 1)))
43+
svg_parts.append(
44+
f' <rect x="{bar_left}" y="{y + bar_top}" width="{bar_width}" '
45+
f'height="{segment_height}" fill="{color}" />'
46+
)
47+
48+
# Draw colorbar border
49+
svg_parts.append(
50+
f' <rect x="{bar_left}" y="{bar_top}" width="{bar_width}" '
51+
f'height="{self._height}" fill="none" stroke="black" stroke-width="1"/>'
52+
)
53+
54+
yticks = cax.get_yticks()
55+
ylabels = cax.get_yticklabels()
56+
ytexts = [lab.get_text() for lab in ylabels]
57+
tick_length = 6
58+
label_offset = 3
59+
60+
xy = np.vstack((np.zeros_like(yticks), yticks)).T
61+
62+
inv_trans_axes = cax.transAxes.inverted()
63+
trans_data = cax.transData
64+
yticks_axes = inv_trans_axes.transform(trans_data.transform(xy))[:, 1]
65+
66+
for tick, label in zip(yticks_axes, ytexts, strict=True):
67+
if tick < 0 or tick > 1.0:
68+
continue
69+
y = self._height - (tick * self._height) + bar_top
70+
svg_parts.append(
71+
f'<line x1="{bar_left+bar_width}" y1="{y}" '
72+
f'x2="{bar_left+bar_width+tick_length}" y2="{y}" '
73+
'style="stroke:black;stroke-width:1" />'
74+
)
75+
76+
svg_parts.append(
77+
f'<text x="{bar_left+bar_width+tick_length+label_offset}" '
78+
f'y="{y}" text-anchor="start" dominant-baseline="middle">'
79+
f"{html_to_svg(latex_to_html(label), baseline='middle')}</text>"
80+
)
81+
82+
minor_ticks = cax.yaxis.get_minorticklocs()
83+
if len(minor_ticks) > 0:
84+
xy = np.vstack((np.zeros_like(minor_ticks), minor_ticks)).T
85+
yticks_axes = inv_trans_axes.transform(trans_data.transform(xy))[:, 1]
86+
87+
for tick in yticks_axes:
88+
if tick < 0 or tick > 1.0:
89+
continue
90+
y = self._height - (tick * self._height) + bar_top
91+
svg_parts.append(
92+
f'<line x1="{bar_left+bar_width}" y1="{y}" '
93+
f'x2="{bar_left+bar_width+tick_length * 0.6}" y2="{y}" '
94+
'style="stroke:black;stroke-width:0.5" />'
95+
)
96+
97+
svg_parts.append('</svg>')
98+
99+
self._widget.value = '\n'.join(svg_parts)

src/matplotgl/figure.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -106,10 +106,9 @@ def colorbar(self, mappable, ax=None):
106106
cb = Colorbar(
107107
widget=ax._margins["colorbar"],
108108
mappable=mappable,
109-
height_inches=ax.height / self._dpi,
109+
height=ax.height,
110110
)
111111
cb.update()
112112
mappable._colorbar = cb
113113
mappable._norm._colorbar = cb
114114
return cb
115-
# ax._margins["colorbar"].value = cb_svg

0 commit comments

Comments
 (0)