Skip to content

Commit 55500d7

Browse files
authored
Merge pull request #8 from scipp/logscale-in-other-artists
Logscale fixes in scatter, pcolormesh and imshow
2 parents 6e472e8 + 37cead2 commit 55500d7

File tree

5 files changed

+65
-65
lines changed

5 files changed

+65
-65
lines changed

src/matplotgl/axes.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -345,7 +345,7 @@ def autoscale(self):
345345

346346
def add_artist(self, artist):
347347
self._artists.append(artist)
348-
self.scene.add(artist.get())
348+
self.scene.add(artist._as_object3d())
349349

350350
def get_figure(self):
351351
return self._fig
@@ -717,7 +717,9 @@ def loglog(self, *args, **kwargs):
717717
def scatter(self, *args, c=None, **kwargs):
718718
if c is None:
719719
c = f"C{len(self.collections)}"
720-
coll = Points(*args, c=c, **kwargs)
720+
coll = Points(
721+
*args, c=c, xscale=self.get_xscale(), yscale=self.get_yscale(), **kwargs
722+
)
721723
coll.axes = self
722724
self.collections.append(coll)
723725
self.add_artist(coll)
@@ -733,7 +735,7 @@ def imshow(self, *args, **kwargs):
733735
return image
734736

735737
def pcolormesh(self, *args, **kwargs):
736-
mesh = Mesh(*args, **kwargs)
738+
mesh = Mesh(*args, xscale=self.get_xscale(), yscale=self.get_yscale(), **kwargs)
737739
mesh.axes = self
738740
self.collections.append(mesh)
739741
self.add_artist(mesh)

src/matplotgl/image.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def get_bbox(self) -> dict[str, float]:
6363
def _update_colors(self) -> None:
6464
self._texture.data = self._make_colors()
6565

66-
def get(self) -> p3.Object3D:
66+
def _as_object3d(self) -> p3.Object3D:
6767
return self._image
6868

6969
def _set_xscale(self, scale: str) -> None:

src/matplotgl/line.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def get_bbox(self):
6161
bottom, top = fix_empty_range(find_limits(self._y, scale=self._yscale, pad=pad))
6262
return {"left": left, "right": right, "bottom": bottom, "top": top}
6363

64-
def get(self):
64+
def _as_object3d(self) -> p3.Object3D:
6565
out = []
6666
if self._line is not None:
6767
out.append(self._line)

src/matplotgl/mesh.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,14 @@
1212

1313

1414
class Mesh:
15-
def __init__(self, *args, cmap: str = "viridis", norm: str = "linear"):
15+
def __init__(
16+
self,
17+
*args,
18+
cmap: str = "viridis",
19+
norm: str = "linear",
20+
xscale="linear",
21+
yscale="linear",
22+
):
1623
if len(args) not in (1, 3):
1724
raise ValueError(
1825
f"Invalid number of arguments: expected 1 or 3. Got {len(args)}"
@@ -28,8 +35,8 @@ def __init__(self, *args, cmap: str = "viridis", norm: str = "linear"):
2835

2936
self.axes = None
3037
self._colorbar = None
31-
self._xscale = "linear"
32-
self._yscale = "linear"
38+
self._xscale = xscale
39+
self._yscale = yscale
3340

3441
self._x = np.asarray(x)
3542
self._y = np.asarray(y)
@@ -126,7 +133,7 @@ def get_bbox(self) -> dict[str, float]:
126133
bottom, top = fix_empty_range(find_limits(self._y, scale=self._yscale, pad=pad))
127134
return {"left": left, "right": right, "bottom": bottom, "top": top}
128135

129-
def get(self) -> p3.Object3D:
136+
def _as_object3d(self) -> p3.Object3D:
130137
return self._mesh
131138

132139
def get_xdata(self) -> np.ndarray:

src/matplotgl/points.py

Lines changed: 47 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,22 @@
1010
from .norm import Normalizer
1111
from .utils import find_limits, fix_empty_range
1212

13-
SHADER_LIBRARY = {
13+
# Custom vertex shader for variable size and color
14+
VERTEX_SHADER = """
15+
attribute float size;
16+
attribute vec3 customColor;
17+
varying vec3 vColor;
18+
19+
void main() {
20+
vColor = customColor;
21+
vec4 mvPosition = modelViewMatrix * vec4(position, 1.0);
22+
gl_PointSize = size;
23+
gl_Position = projectionMatrix * mvPosition;
24+
}
25+
"""
26+
27+
# Custom fragment shaders for different markers
28+
FRAGMENT_SHADERS = {
1429
"o": """
1530
varying vec3 vColor;
1631
@@ -62,29 +77,31 @@ def __init__(
6277
zorder=0,
6378
cmap="viridis",
6479
norm: str = "linear",
80+
xscale="linear",
81+
yscale="linear",
6582
) -> None:
6683
self.axes = None
6784
self._x = np.asarray(x)
6885
self._y = np.asarray(y)
69-
self._xscale = "linear"
70-
self._yscale = "linear"
86+
self._xscale = xscale
87+
self._yscale = yscale
7188
self._zorder = zorder
7289

90+
self._geometry = p3.BufferGeometry(
91+
attributes={"position": p3.BufferAttribute(array=self._make_positions())}
92+
)
93+
7394
if not isinstance(c, str) or not np.isscalar(s) or marker != "s":
7495
if isinstance(c, str):
7596
self._c = np.ones_like(self._x)
7697
self._norm = Normalizer(vmin=1, vmax=1)
7798
self._cmap = cm.LinearSegmentedColormap.from_list("tmp", [c, c])
78-
# (
79-
# np.ones_like(self._x)
80-
# )
8199
else:
82100
self._c = np.asarray(c)
83101
self._norm = Normalizer(
84102
vmin=np.min(self._c), vmax=np.max(self._c), norm=norm
85103
)
86104
self._cmap = mpl.colormaps[cmap].copy()
87-
# rgba = self.cmap(self.norm(self._c))
88105

89106
colors = self._make_colors()
90107

@@ -93,103 +110,77 @@ def __init__(
93110
else:
94111
sizes = np.asarray(s, dtype=np.float32)
95112

96-
# Custom vertex shader for variable size and color
97-
vertex_shader = """
98-
attribute float size;
99-
attribute vec3 customColor;
100-
varying vec3 vColor;
101-
102-
void main() {
103-
vColor = customColor;
104-
vec4 mvPosition = modelViewMatrix * vec4(position, 1.0);
105-
gl_PointSize = size;
106-
gl_Position = projectionMatrix * mvPosition;
107-
}
108-
"""
109-
110-
self._geometry = p3.BufferGeometry(
111-
attributes={
112-
"position": p3.BufferAttribute(
113-
array=np.array(
114-
[self._x, self._y, np.full_like(self._x, self._zorder)],
115-
dtype="float32",
116-
).T
117-
),
113+
self._geometry.attributes.update(
114+
{
118115
"customColor": p3.BufferAttribute(array=colors),
119116
"size": p3.BufferAttribute(array=sizes),
120117
}
121118
)
122119
# Create ShaderMaterial with custom shaders
123120
self._material = p3.ShaderMaterial(
124-
vertexShader=vertex_shader,
125-
fragmentShader=SHADER_LIBRARY[marker],
121+
vertexShader=VERTEX_SHADER,
122+
fragmentShader=FRAGMENT_SHADERS[marker],
126123
transparent=True,
127124
)
128125
else:
129-
self._geometry = p3.BufferGeometry(
130-
attributes={
131-
"position": p3.BufferAttribute(
132-
array=np.array(
133-
[self._x, self._y, np.full_like(self._x, self._zorder)],
134-
dtype="float32",
135-
).T
136-
),
137-
}
138-
)
139-
140126
self._material = p3.PointsMaterial(color=cm.to_hex(c), size=s)
141127

142128
self._points = p3.Points(geometry=self._geometry, material=self._material)
143129

130+
def _make_positions(self) -> np.ndarray:
131+
with warnings.catch_warnings(category=RuntimeWarning, action="ignore"):
132+
xx = self._x if self._xscale == "linear" else np.log10(self._x)
133+
yy = self._y if self._yscale == "linear" else np.log10(self._y)
134+
return np.array([xx, yy, np.full_like(xx, self._zorder)], dtype="float32").T
135+
144136
def _make_colors(self) -> np.ndarray:
145137
return self._cmap(self.norm(self._c))[..., :3].astype("float32")
146138

147139
def _update_colors(self) -> None:
148140
self._geometry.attributes["customColor"].array = self._make_colors()
149141

142+
def _update_positions(self):
143+
self._geometry.attributes["position"].array = self._make_positions()
144+
150145
def get_bbox(self):
151146
pad = 0.03
152147
left, right = fix_empty_range(find_limits(self._x, scale=self._xscale, pad=pad))
153148
bottom, top = fix_empty_range(find_limits(self._y, scale=self._yscale, pad=pad))
154149
return {"left": left, "right": right, "bottom": bottom, "top": top}
155150

156-
def _update(self):
157-
with warnings.catch_warnings(category=RuntimeWarning, action="ignore"):
158-
xx = self._x if self._xscale == "linear" else np.log10(self._x)
159-
yy = self._y if self._yscale == "linear" else np.log10(self._y)
160-
self._geometry.attributes["position"].array = np.array(
161-
[xx, yy, np.full_like(xx, self._zorder)], dtype="float32"
162-
).T
163-
164-
def get(self):
151+
def _as_object3d(self) -> p3.Object3D:
165152
return self._points
166153

167154
def get_xdata(self) -> np.ndarray:
168155
return self._x
169156

170157
def set_xdata(self, x):
171158
self._x = np.asarray(x)
172-
self._update()
159+
self._update_positions()
173160

174161
def get_ydata(self) -> np.ndarray:
175162
return self._y
176163

177164
def set_ydata(self, y):
178165
self._y = np.asarray(y)
179-
self._update()
166+
self._update_positions()
180167

181168
def set_data(self, xy):
182169
self._x = np.asarray(xy[:, 0])
183170
self._y = np.asarray(xy[:, 1])
184-
self._update()
171+
self._update_positions()
185172

186173
def _set_xscale(self, scale):
187174
self._xscale = scale
188-
self._update()
175+
self._update_positions()
189176

190177
def _set_yscale(self, scale):
191178
self._yscale = scale
192-
self._update()
179+
self._update_positions()
180+
181+
def set_array(self, c: np.ndarray):
182+
self._c = np.asarray(c)
183+
self._update_colors()
193184

194185
def set_cmap(self, cmap: str) -> None:
195186
self._cmap = mpl.colormaps[cmap].copy()

0 commit comments

Comments
 (0)