Skip to content

Commit c0e26b8

Browse files
committed
add basic tests
1 parent 85be38a commit c0e26b8

File tree

2 files changed

+61
-1
lines changed

2 files changed

+61
-1
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ addopts = """
6060
testpaths = "tests"
6161
filterwarnings = [
6262
"error",
63-
]
63+
'ignore:\n Sentinel is not a public part of the traitlets API:DeprecationWarning',]
6464

6565
[tool.ruff]
6666
line-length = 88

tests/plot_test.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
import numpy as np
2+
3+
import matplotgl.pyplot as plt
4+
5+
6+
def test_plot_one_line():
7+
fig, ax = plt.subplots()
8+
points_per_line = 100
9+
x = np.linspace(0, 10, points_per_line)
10+
y = np.sin(x) * np.exp(-x / 5) + np.random.uniform(-0.1, 0.1, size=points_per_line)
11+
ax.plot(x, y)
12+
13+
assert len(fig.axes) == 1
14+
assert len(ax.lines) == 1
15+
lines = ax.lines[0]
16+
assert np.allclose(lines.get_xdata(), x)
17+
assert np.allclose(lines.get_ydata(), y)
18+
19+
20+
def test_plot_multiple_lines():
21+
fig, ax = plt.subplots()
22+
points_per_line = 100
23+
x = np.linspace(0, 10, points_per_line)
24+
y = []
25+
for i in range(4):
26+
y.append(
27+
np.sin(x + i) * np.exp(-x / 5)
28+
+ np.random.uniform(-0.1, 0.1, size=points_per_line)
29+
)
30+
ax.plot(x, y[i])
31+
32+
assert len(fig.axes) == 1
33+
assert len(ax.lines) == 4
34+
for i, line in enumerate(ax.lines):
35+
assert np.allclose(line.get_xdata(), x)
36+
assert np.allclose(line.get_ydata(), y[i])
37+
38+
39+
def test_scatter():
40+
fig, ax = plt.subplots()
41+
x, y = np.random.normal(size=(2, 1000))
42+
ax.scatter(x, y)
43+
44+
assert len(fig.axes) == 1
45+
assert len(ax.collections) == 1
46+
scatter = ax.collections[0]
47+
assert np.allclose(scatter.get_xdata(), x)
48+
assert np.allclose(scatter.get_ydata(), y)
49+
50+
51+
def test_imshow():
52+
fig, ax = plt.subplots()
53+
data = np.random.rand(200, 300)
54+
ax.imshow(data, cmap='viridis', extent=[0, 10, 0, 5])
55+
56+
assert len(fig.axes) == 1
57+
assert len(ax.images) == 1
58+
im = ax.images[0]
59+
assert np.allclose(im._array, data)
60+
assert im.get_extent() == [0, 10, 0, 5]

0 commit comments

Comments
 (0)