|
1 | | -from matplotlib.colorbar import ColorbarBase |
2 | | -from matplotlib.pyplot import Figure |
3 | | - |
4 | | - |
5 | | -def fig_to_bytes(fig) -> bytes: |
6 | | - """ |
7 | | - Convert a Matplotlib figure to svg bytes. |
| 1 | +# SPDX-License-Identifier: BSD-3-Clause |
8 | 2 |
|
9 | | - Parameters |
10 | | - ---------- |
11 | | - fig: |
12 | | - The figure to be converted. |
13 | | - """ |
14 | | - from io import BytesIO |
15 | | - |
16 | | - buf = BytesIO() |
17 | | - fig.savefig(buf, format="svg", bbox_inches="tight") |
18 | | - buf.seek(0) |
19 | | - return buf.getvalue() |
| 3 | +import matplotlib as mpl |
| 4 | +import numpy as np |
| 5 | +from matplotlib import colors as cm |
| 6 | +from matplotlib.pyplot import Figure |
20 | 7 |
|
| 8 | +from .utils import html_to_svg, latex_to_html |
21 | 9 |
|
22 | | -# def make_colorbar(mappable, height_inches: float) -> str: |
23 | | -# fig = Figure(figsize=(height_inches * 0.2, height_inches)) |
24 | | -# cax = fig.add_axes([0.05, 0.02, 0.2, 0.98]) |
25 | | -# ColorbarBase(cax, cmap=mappable.cmap) # , norm=self.normalizer) |
26 | | -# return fig_to_bytes(fig).decode() |
| 10 | +mpl.use("Agg") |
27 | 11 |
|
28 | 12 |
|
29 | 13 | class Colorbar: |
30 | | - def __init__(self, widget, mappable, height_inches: float): |
| 14 | + def __init__(self, widget, mappable, height: int | float): |
31 | 15 | self._mappable = mappable |
32 | | - self._height_inches = height_inches |
| 16 | + self._height = height |
33 | 17 | self._widget = widget |
34 | 18 |
|
35 | 19 | def update(self): |
36 | 20 | self._mappable._update_colors() |
37 | | - fig = Figure(figsize=(self._height_inches * 0.2, self._height_inches)) |
38 | | - cax = fig.add_axes([0.05, 0.02, 0.2, 0.98]) |
39 | | - ColorbarBase(cax, cmap=self._mappable.cmap, norm=self._mappable.norm._norm) |
40 | | - self._widget.value = fig_to_bytes(fig).decode() |
| 21 | + fig = Figure() |
| 22 | + cax = fig.add_subplot(111) |
| 23 | + norm = self._mappable._norm._norm |
| 24 | + cax.set_ylim(norm.vmin, norm.vmax) |
| 25 | + cax.set_yscale('log' if isinstance(norm, cm.LogNorm) else 'linear') |
| 26 | + |
| 27 | + # Generate colors |
| 28 | + n_colors = 128 |
| 29 | + segment_height = self._height / n_colors |
| 30 | + bar_width = 18 |
| 31 | + bar_left = 20 |
| 32 | + bar_top = 3 |
| 33 | + |
| 34 | + # Build SVG |
| 35 | + svg_parts = [ |
| 36 | + f'<svg width="{100}" height="{self._height}">', |
| 37 | + ] |
| 38 | + |
| 39 | + # Draw color segments |
| 40 | + for i in range(n_colors): |
| 41 | + y = (n_colors - i - 1) * segment_height |
| 42 | + color = cm.to_hex(self._mappable.cmap(i / (n_colors - 1))) |
| 43 | + svg_parts.append( |
| 44 | + f' <rect x="{bar_left}" y="{y + bar_top}" width="{bar_width}" ' |
| 45 | + f'height="{segment_height}" fill="{color}" />' |
| 46 | + ) |
| 47 | + |
| 48 | + # Draw colorbar border |
| 49 | + svg_parts.append( |
| 50 | + f' <rect x="{bar_left}" y="{bar_top}" width="{bar_width}" ' |
| 51 | + f'height="{self._height}" fill="none" stroke="black" stroke-width="1"/>' |
| 52 | + ) |
| 53 | + |
| 54 | + yticks = cax.get_yticks() |
| 55 | + ylabels = cax.get_yticklabels() |
| 56 | + ytexts = [lab.get_text() for lab in ylabels] |
| 57 | + tick_length = 6 |
| 58 | + label_offset = 3 |
| 59 | + |
| 60 | + xy = np.vstack((np.zeros_like(yticks), yticks)).T |
| 61 | + |
| 62 | + inv_trans_axes = cax.transAxes.inverted() |
| 63 | + trans_data = cax.transData |
| 64 | + yticks_axes = inv_trans_axes.transform(trans_data.transform(xy))[:, 1] |
| 65 | + |
| 66 | + for tick, label in zip(yticks_axes, ytexts, strict=True): |
| 67 | + if tick < 0 or tick > 1.0: |
| 68 | + continue |
| 69 | + y = self._height - (tick * self._height) + bar_top |
| 70 | + svg_parts.append( |
| 71 | + f'<line x1="{bar_left+bar_width}" y1="{y}" ' |
| 72 | + f'x2="{bar_left+bar_width+tick_length}" y2="{y}" ' |
| 73 | + 'style="stroke:black;stroke-width:1" />' |
| 74 | + ) |
| 75 | + |
| 76 | + svg_parts.append( |
| 77 | + f'<text x="{bar_left+bar_width+tick_length+label_offset}" ' |
| 78 | + f'y="{y}" text-anchor="start" dominant-baseline="middle">' |
| 79 | + f"{html_to_svg(latex_to_html(label), baseline='middle')}</text>" |
| 80 | + ) |
| 81 | + |
| 82 | + minor_ticks = cax.yaxis.get_minorticklocs() |
| 83 | + if len(minor_ticks) > 0: |
| 84 | + xy = np.vstack((np.zeros_like(minor_ticks), minor_ticks)).T |
| 85 | + yticks_axes = inv_trans_axes.transform(trans_data.transform(xy))[:, 1] |
| 86 | + |
| 87 | + for tick in yticks_axes: |
| 88 | + if tick < 0 or tick > 1.0: |
| 89 | + continue |
| 90 | + y = self._height - (tick * self._height) + bar_top |
| 91 | + svg_parts.append( |
| 92 | + f'<line x1="{bar_left+bar_width}" y1="{y}" ' |
| 93 | + f'x2="{bar_left+bar_width+tick_length * 0.6}" y2="{y}" ' |
| 94 | + 'style="stroke:black;stroke-width:0.5" />' |
| 95 | + ) |
| 96 | + |
| 97 | + svg_parts.append('</svg>') |
| 98 | + |
| 99 | + self._widget.value = '\n'.join(svg_parts) |
0 commit comments