Skip to content

Commit ca10015

Browse files
committed
fix colorbar for scatter plots
1 parent 8fcfe1a commit ca10015

File tree

2 files changed

+71
-15
lines changed

2 files changed

+71
-15
lines changed

src/matplotgl/image.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# SPDX-License-Identifier: BSD-3-Clause
22

33
import matplotlib as mpl
4+
import matplotlib.colors as cm
45
import numpy as np
56
import pythreejs as p3
67

@@ -13,6 +14,7 @@ def __init__(
1314
array: np.ndarray,
1415
extent: list[float] | None = None,
1516
cmap: str = "viridis",
17+
norm: str = "linear",
1618
zorder: float = 0,
1719
):
1820
self.axes = None
@@ -22,7 +24,9 @@ def __init__(
2224
extent if extent is not None else [0, array.shape[1], 0, array.shape[0]]
2325
)
2426
self._zorder = zorder
25-
self._norm = Normalizer(vmin=np.min(self._array), vmax=np.max(self._array))
27+
self._norm = Normalizer(
28+
vmin=np.min(self._array), vmax=np.max(self._array), norm=norm
29+
)
2630
self._cmap = mpl.colormaps[cmap].copy()
2731
self._texture = p3.DataTexture(
2832
data=self._make_colors(), format="RGBFormat", type="FloatType"
@@ -97,12 +101,13 @@ def set_extent(self, extent: list[float]) -> None:
97101

98102
def set_cmap(self, cmap: str) -> None:
99103
self._cmap = mpl.colormaps[cmap].copy()
100-
self._texture.data = self._make_colors()
104+
# self._texture.data = self._make_colors()
105+
self._update_colors()
101106
if self._colorbar is not None:
102107
self._colorbar.update()
103108

104109
@property
105-
def cmap(self) -> mpl.colormaps:
110+
def cmap(self) -> cm.Colormap:
106111
return self._cmap
107112

108113
@cmap.setter
@@ -121,6 +126,7 @@ def norm(self, norm: Normalizer | str) -> None:
121126
)
122127
else:
123128
self._norm = norm
124-
self._texture.data = self._make_colors()
129+
# self._texture.data = self._make_colors()
130+
self._update_colors()
125131
if self._colorbar is not None:
126132
self._colorbar.update()

src/matplotgl/points.py

Lines changed: 61 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,11 @@
33
import warnings
44

55
import matplotlib as mpl
6-
import matplotlib.colors as mplc
6+
import matplotlib.colors as cm
77
import numpy as np
88
import pythreejs as p3
99

10+
from .norm import Normalizer
1011
from .utils import find_limits, fix_empty_range
1112

1213
SHADER_LIBRARY = {
@@ -51,7 +52,17 @@
5152

5253

5354
class Points:
54-
def __init__(self, x, y, c="C0", s=3, marker="s", zorder=0, cmap="viridis") -> None:
55+
def __init__(
56+
self,
57+
x,
58+
y,
59+
c="C0",
60+
s=3,
61+
marker="s",
62+
zorder=0,
63+
cmap="viridis",
64+
norm: str = "linear",
65+
) -> None:
5566
self.axes = None
5667
self._x = np.asarray(x)
5768
self._y = np.asarray(y)
@@ -61,18 +72,21 @@ def __init__(self, x, y, c="C0", s=3, marker="s", zorder=0, cmap="viridis") -> N
6172

6273
if not isinstance(c, str) or not np.isscalar(s) or marker != "s":
6374
if isinstance(c, str):
64-
rgba = mplc.LinearSegmentedColormap.from_list("tmp", [c, c])(
65-
np.ones_like(self._x)
66-
)
75+
self._c = np.ones_like(self._x)
76+
self._norm = Normalizer(vmin=1, vmax=1)
77+
self._cmap = cm.LinearSegmentedColormap.from_list("tmp", [c, c])
78+
# (
79+
# np.ones_like(self._x)
80+
# )
6781
else:
6882
self._c = np.asarray(c)
69-
self.norm = mpl.colors.Normalize(
70-
vmin=np.min(self._c), vmax=np.max(self._c)
83+
self._norm = Normalizer(
84+
vmin=np.min(self._c), vmax=np.max(self._c), norm=norm
7185
)
72-
self.cmap = mpl.colormaps[cmap].copy()
73-
rgba = self.cmap(self.norm(self._c))
86+
self._cmap = mpl.colormaps[cmap].copy()
87+
# rgba = self.cmap(self.norm(self._c))
7488

75-
colors = rgba[:, :3].astype(np.float32) # Take only RGB, drop alpha
89+
colors = self._make_colors()
7690

7791
if np.isscalar(s):
7892
sizes = np.full_like(self._x, s, dtype=np.float32)
@@ -123,10 +137,16 @@ def __init__(self, x, y, c="C0", s=3, marker="s", zorder=0, cmap="viridis") -> N
123137
}
124138
)
125139

126-
self._material = p3.PointsMaterial(color=mplc.to_hex(c), size=s)
140+
self._material = p3.PointsMaterial(color=cm.to_hex(c), size=s)
127141

128142
self._points = p3.Points(geometry=self._geometry, material=self._material)
129143

144+
def _make_colors(self) -> np.ndarray:
145+
return self._cmap(self.norm(self._c))[..., :3].astype("float32")
146+
147+
def _update_colors(self) -> None:
148+
self._geometry.attributes["customColor"].array = self._make_colors()
149+
130150
def get_bbox(self):
131151
pad = 0.03
132152
left, right = fix_empty_range(find_limits(self._x, scale=self._xscale, pad=pad))
@@ -170,3 +190,33 @@ def _set_xscale(self, scale):
170190
def _set_yscale(self, scale):
171191
self._yscale = scale
172192
self._update()
193+
194+
def set_cmap(self, cmap: str) -> None:
195+
self._cmap = mpl.colormaps[cmap].copy()
196+
self._update_colors()
197+
if self._colorbar is not None:
198+
self._colorbar.update()
199+
200+
@property
201+
def cmap(self) -> cm.Colormap:
202+
return self._cmap
203+
204+
@cmap.setter
205+
def cmap(self, cmap: str) -> None:
206+
self.set_cmap(cmap)
207+
208+
@property
209+
def norm(self) -> Normalizer:
210+
return self._norm
211+
212+
@norm.setter
213+
def norm(self, norm: Normalizer | str) -> None:
214+
if isinstance(norm, str):
215+
self._norm = Normalizer(
216+
vmin=np.min(self._c), vmax=np.max(self._c), norm=norm
217+
)
218+
else:
219+
self._norm = norm
220+
self._update_colors()
221+
if self._colorbar is not None:
222+
self._colorbar.update()

0 commit comments

Comments
 (0)