Skip to content

Commit 6e472e8

Browse files
authored
Merge pull request #7 from scipp/loglog
Add semilogx, semilogy, and loglog, and fix some log scale issues
2 parents 667d0e5 + 92d9aa0 commit 6e472e8

File tree

3 files changed

+180
-43
lines changed

3 files changed

+180
-43
lines changed

src/matplotgl/axes.py

Lines changed: 68 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,14 @@
1515
from .widgets import ClickableHTML
1616

1717

18+
def min_with_none(a, b):
19+
return a if b is None else min(a, b)
20+
21+
22+
def max_with_none(a, b):
23+
return a if b is None else max(a, b)
24+
25+
1826
class Axes(ipw.GridBox):
1927
def __init__(self, *, ax: MplAxes, figure=None) -> None:
2028
self.background_color = "#ffffff"
@@ -290,35 +298,43 @@ def height(self, h):
290298
# self._margins["rightspine"].height = h
291299

292300
def autoscale(self):
293-
xmin = np.inf
294-
xmax = -np.inf
295-
ymin = np.inf
296-
ymax = -np.inf
301+
xmin = None
302+
xmax = None
303+
ymin = None
304+
ymax = None
297305
for artist in self._artists:
298306
lims = artist.get_bbox()
299-
xmin = min(lims["left"], xmin)
300-
xmax = max(lims["right"], xmax)
301-
ymin = min(lims["bottom"], ymin)
302-
ymax = max(lims["top"], ymax)
303-
self._xmin = xmin
304-
self._xmax = xmax
305-
self._ymin = ymin
306-
self._ymax = ymax
307-
308-
# self._background_mesh.geometry = p3.BoxGeometry(
309-
# width=2 * (self._xmax - self._xmin),
310-
# height=2 * (self._ymax - self._ymin),
311-
# widthSegments=1,
312-
# heightSegments=1,
313-
# )
307+
xmin = min_with_none(lims["left"], xmin)
308+
xmax = max_with_none(lims["right"], xmax)
309+
ymin = min_with_none(lims["bottom"], ymin)
310+
ymax = max_with_none(lims["top"], ymax)
311+
self._xmin = (
312+
xmin
313+
if xmin is not None
314+
else (0.0 if self.get_xscale() == "linear" else 1.0)
315+
)
316+
self._xmax = (
317+
xmax
318+
if xmax is not None
319+
else (1.0 if self.get_xscale() == "linear" else 10.0)
320+
)
321+
self._ymin = (
322+
ymin
323+
if ymin is not None
324+
else (0.0 if self.get_yscale() == "linear" else 1.0)
325+
)
326+
self._ymax = (
327+
ymax
328+
if ymax is not None
329+
else (1.0 if self.get_yscale() == "linear" else 10.0)
330+
)
331+
314332
self._background_mesh.geometry = p3.PlaneGeometry(
315333
width=2 * (self._xmax - self._xmin),
316334
height=2 * (self._ymax - self._ymin),
317335
widthSegments=1,
318336
heightSegments=1,
319337
)
320-
# self._background_mesh.geometry.width = 2 * (self._xmax - self._xmin)
321-
# self._background_mesh.geometry.height = 2 * (self._ymax - self._ymin)
322338

323339
self._background_mesh.position = [
324340
0.5 * (self._xmin + self._xmax),
@@ -523,6 +539,10 @@ def get_xscale(self):
523539
return self._ax.get_xscale()
524540

525541
def set_xscale(self, scale):
542+
if scale not in ("linear", "log"):
543+
raise ValueError("Scale must be 'linear' or 'log'")
544+
if scale == self.get_xscale():
545+
return
526546
self._ax.set_xscale(scale)
527547
for artist in self._artists:
528548
artist._set_xscale(scale)
@@ -533,6 +553,10 @@ def get_yscale(self):
533553
return self._ax.get_yscale()
534554

535555
def set_yscale(self, scale):
556+
if scale not in ("linear", "log"):
557+
raise ValueError("Scale must be 'linear' or 'log'")
558+
if scale == self.get_yscale():
559+
return
536560
self._ax.set_yscale(scale)
537561
for artist in self._artists:
538562
artist._set_yscale(scale)
@@ -661,13 +685,35 @@ def get_title(self):
661685
def plot(self, *args, color=None, **kwargs):
662686
if color is None:
663687
color = f"C{len(self.lines)}"
664-
line = Line(*args, color=color, **kwargs)
688+
line = Line(
689+
*args,
690+
color=color,
691+
xscale=self.get_xscale(),
692+
yscale=self.get_yscale(),
693+
**kwargs,
694+
)
665695
line.axes = self
666696
self.lines.append(line)
667697
self.add_artist(line)
668698
self.autoscale()
669699
return line
670700

701+
def semilogx(self, *args, **kwargs):
702+
out = self.plot(*args, **kwargs)
703+
self.set_xscale("log")
704+
return out
705+
706+
def semilogy(self, *args, **kwargs):
707+
out = self.plot(*args, **kwargs)
708+
self.set_yscale("log")
709+
return out
710+
711+
def loglog(self, *args, **kwargs):
712+
out = self.plot(*args, **kwargs)
713+
self.set_xscale("log")
714+
self.set_yscale("log")
715+
return out
716+
671717
def scatter(self, *args, c=None, **kwargs):
672718
if c is None:
673719
c = f"C{len(self.collections)}"

src/matplotgl/line.py

Lines changed: 24 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -10,19 +10,27 @@
1010

1111

1212
class Line:
13-
def __init__(self, x, y, fmt="-", color="C0", ls="solid", lw=1, ms=5, zorder=0):
13+
def __init__(
14+
self,
15+
x,
16+
y,
17+
fmt="-",
18+
color="C0",
19+
ls="solid",
20+
lw=1,
21+
ms=5,
22+
zorder=0,
23+
xscale="linear",
24+
yscale="linear",
25+
):
1426
self.axes = None
15-
self._xscale = "linear"
16-
self._yscale = "linear"
27+
self._xscale = xscale
28+
self._yscale = yscale
1729
self._x = np.asarray(x)
1830
self._y = np.asarray(y)
1931
self._zorder = zorder
20-
self._line_geometry = p3.LineGeometry(
21-
positions=np.array(
22-
[self._x, self._y, np.full_like(self._x, self._zorder - 50)],
23-
dtype="float32",
24-
).T
25-
)
32+
pos = self._make_positions()
33+
self._line_geometry = p3.LineGeometry(positions=pos)
2634

2735
self._color = mplc.to_hex(color)
2836
self._line = None
@@ -39,16 +47,7 @@ def __init__(self, x, y, fmt="-", color="C0", ls="solid", lw=1, ms=5, zorder=0):
3947
if "o" in fmt:
4048
self._vertices_geometry = p3.BufferGeometry(
4149
attributes={
42-
"position": p3.BufferAttribute(
43-
array=np.array(
44-
[
45-
self._x,
46-
self._y,
47-
np.full_like(self._x, self._zorder - 50),
48-
],
49-
dtype="float32",
50-
).T
51-
),
50+
"position": p3.BufferAttribute(array=pos),
5251
}
5352
)
5453
self._vertices_material = p3.PointsMaterial(color=self._color, size=ms)
@@ -70,14 +69,18 @@ def get(self):
7069
out.append(self._vertices)
7170
return p3.Group(children=out) if len(out) > 1 else out[0]
7271

73-
def _update(self):
72+
def _make_positions(self):
7473
with warnings.catch_warnings(category=RuntimeWarning, action="ignore"):
7574
xx = self._x if self._xscale == "linear" else np.log10(self._x)
7675
yy = self._y if self._yscale == "linear" else np.log10(self._y)
7776
pos = np.array(
78-
[xx, yy, np.full_like(xx, self._zorder - 50)],
77+
[xx, yy, np.full_like(xx, self._zorder)],
7978
dtype="float32",
8079
).T
80+
return pos
81+
82+
def _update(self):
83+
pos = self._make_positions()
8184
if self._line is not None:
8285
self._line_geometry.positions = pos
8386
if self._vertices is not None:

tests/plot_test.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# SPDX-License-Identifier: BSD-3-Clause
22

33
import numpy as np
4+
import pytest
45

56
import matplotgl.pyplot as plt
67

@@ -60,3 +61,90 @@ def test_imshow():
6061
im = ax.images[0]
6162
assert np.allclose(im._array, data)
6263
assert im.get_extent() == [0, 10, 0, 5]
64+
65+
66+
def test_set_xscale_log():
67+
_, ax = plt.subplots()
68+
x = np.arange(50.0)
69+
y = np.sin(0.2 * x)
70+
71+
ax.plot(x, y, lw=2)
72+
ax.set_xscale('log')
73+
74+
assert ax.get_xscale() == 'log'
75+
76+
77+
def test_set_yscale_log():
78+
_, ax = plt.subplots()
79+
x = np.arange(50.0)
80+
y = np.sin(0.2 * x)
81+
82+
ax.plot(x, y, lw=2)
83+
ax.set_yscale('log')
84+
85+
assert ax.get_yscale() == 'log'
86+
87+
88+
def test_set_xscale_invalid():
89+
_, ax = plt.subplots()
90+
with pytest.raises(ValueError, match="Scale must be 'linear' or 'log'"):
91+
ax.set_xscale('invalid_scale')
92+
93+
94+
def test_set_yscale_invalid():
95+
_, ax = plt.subplots()
96+
with pytest.raises(ValueError, match="Scale must be 'linear' or 'log'"):
97+
ax.set_yscale('invalid_scale')
98+
99+
100+
def test_set_xscale_log_before_plot():
101+
_, ax = plt.subplots()
102+
x = np.arange(50.0)
103+
y = np.sin(0.2 * x)
104+
105+
ax.set_xscale('log')
106+
ax.plot(x, y, lw=2)
107+
108+
assert ax.get_xscale() == 'log'
109+
110+
111+
def test_set_yscale_log_before_plot():
112+
_, ax = plt.subplots()
113+
x = np.arange(50.0)
114+
y = np.sin(0.2 * x)
115+
116+
ax.set_yscale('log')
117+
ax.plot(x, y, lw=2)
118+
119+
assert ax.get_yscale() == 'log'
120+
121+
122+
def test_semilogx():
123+
_, ax = plt.subplots()
124+
x = np.arange(1.0, 50.0)
125+
y = np.sin(0.2 * x)
126+
127+
ax.semilogx(x, y, lw=2)
128+
129+
assert ax.get_xscale() == 'log'
130+
131+
132+
def test_semilogy():
133+
_, ax = plt.subplots()
134+
x = np.arange(50.0)
135+
y = np.exp(0.1 * x)
136+
137+
ax.semilogy(x, y, lw=2)
138+
139+
assert ax.get_yscale() == 'log'
140+
141+
142+
def test_loglog():
143+
_, ax = plt.subplots()
144+
x = np.arange(1.0, 50.0)
145+
y = np.exp(0.1 * x)
146+
147+
ax.loglog(x, y, lw=2)
148+
149+
assert ax.get_xscale() == 'log'
150+
assert ax.get_yscale() == 'log'

0 commit comments

Comments
 (0)