Skip to content

Commit eea8f3b

Browse files
committed
Add remaining plot tests
1 parent 4067e5e commit eea8f3b

File tree

2 files changed

+53
-0
lines changed

2 files changed

+53
-0
lines changed

tests/data/test_dataset.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
"""Test the dataset module."""
22

3+
import matplotlib.pyplot as plt
34
import pandas as pd
45
import pytest
6+
from matplotlib.axes import Axes
7+
from matplotlib.patches import Rectangle
58
from numpy.testing import assert_equal
69
from pandas.testing import assert_frame_equal
710

@@ -108,3 +111,41 @@ def test_repr(self, scale):
108111

109112
dataset = DataLoader.load_dataset('dino', scale=scale)
110113
assert repr(dataset) == (f'<Dataset name=dino scaled={scale is not None}>')
114+
115+
@pytest.mark.parametrize(
116+
('ax', 'show_bounds', 'title', 'alpha'),
117+
[
118+
(None, True, None, 1),
119+
(None, True, 'Custom title', 0.75),
120+
(plt.subplots()[1], False, 'default', 0.5),
121+
],
122+
)
123+
def test_plot(self, ax, show_bounds, title, alpha):
124+
"""Test the plot() method."""
125+
dataset = DataLoader.load_dataset('dino')
126+
ax = dataset.plot(ax=ax, show_bounds=show_bounds, title=title, alpha=alpha)
127+
128+
assert isinstance(ax, Axes)
129+
assert pytest.approx(ax.get_aspect()) == 1.0
130+
131+
if title is None:
132+
assert not ax.get_title()
133+
elif title == 'default':
134+
assert ax.get_title() == repr(dataset)
135+
else:
136+
assert ax.get_title() == title
137+
138+
assert ax.collections[0].get_alpha() == alpha
139+
140+
points_expected = dataset.data.shape[0]
141+
points_plotted = sum(
142+
collection.get_offsets().data.shape[0] for collection in ax.collections
143+
)
144+
assert points_expected == points_plotted
145+
146+
if show_bounds:
147+
assert sum(isinstance(patch, Rectangle) for patch in ax.patches) == 3
148+
assert ax.patches[0] != ax.patches[1] != ax.patches[2]
149+
150+
assert len(labels := ax.texts) == 3
151+
assert all(label.get_text().endswith('BOUNDS') for label in labels)

tests/shapes/bases/test_line_collection.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@
33
import itertools
44
import re
55

6+
import matplotlib.pyplot as plt
67
import pytest
8+
from matplotlib.axes import Axes
79

810
from data_morph.shapes.bases.line_collection import LineCollection
911

@@ -49,3 +51,13 @@ def test_repr(self, line_collection):
4951
)
5052
is not None
5153
)
54+
55+
@pytest.mark.parametrize('existing_ax', [True, False])
56+
def test_plot(self, line_collection, existing_ax):
57+
"""Test the plot() method is working."""
58+
input_ax = plt.subplots()[1] if existing_ax else None
59+
result_ax = line_collection.plot(input_ax)
60+
61+
assert isinstance(result_ax, Axes)
62+
assert len(result_ax.lines) == len(line_collection.lines)
63+
assert pytest.approx(result_ax.get_aspect()) == 1.0

0 commit comments

Comments
 (0)